强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

Julia 教程 / 分布式计算

分布式计算

Julia 内置了强大的分布式计算支持。从简单的多进程并行到集群级别的分布式计算,Julia 提供了完整的工具链。本文介绍如何利用分布式 Julia 解决大规模计算问题。


1. 分布式 Julia 基础

1.1 添加工作进程

using Distributed

# 查看当前进程数
nprocs()  # 默认 1(主进程)

# 添加工作进程(本地)
addprocs(4)  # 添加 4 个工作进程
nprocs()  # 5(1 主 + 4 工作)
nworkers()  # 4(仅工作进程)

# 查看进程信息
workers()  # [2, 3, 4, 5]
myid()     # 当前进程 ID(主进程为 1)

1.2 启动时指定进程

# 启动 Julia 时添加 4 个工作进程
julia -p 4

# 使用指定机器(需 SSH 配置)
julia -p 4 --machinefile machines.txt

1.3 工作进程上加载包

# 在所有工作进程上加载包
@everywhere using LinearAlgebra, Statistics

# 在所有进程上定义函数
@everywhere function heavy_compute(x)
    s = 0.0
    for i in 1:10_000
        s += sin(x * i) / i
    end
    return s
end

2. Distributed.jl 模块

2.1 核心 API

函数说明
addprocs(n)添加 n 个工作进程
rmprocs(pids...)移除工作进程
workers()获取所有工作进程 ID
nworkers()工作进程数
nprocs()总进程数
myid()当前进程 ID
@everywhere expr在所有进程上执行 expr
remotecall(f, pid, args...)在指定进程上调用函数
fetch(ref)获取远程调用的结果
@spawnat pid expr在指定进程上执行表达式
@spawn expr在任意进程上执行表达式
pmap(f, collection)并行 map
remotecall_eval(m, pid, expr)在远程进程上 eval

2.2 远程调用基础

using Distributed
addprocs(4)

# @spawnat: 在指定进程上执行
r = @spawnat 2 begin
    println("我在进程 $(myid()) 上")
    42
end
result = fetch(r)  # 42

# @spawn: 自动选择进程
r = @spawn myid()
fetch(r)  # 返回某个工作进程的 ID

# remotecall: 显式远程调用
ref = remotecall(sqrt, 3, 16.0)
fetch(ref)  # 4.0

3. @spawnat / remotecall / fetch

3.1 异步远程调用

# remotecall 是异步的,立即返回 Future
refs = [remotecall(heavy_compute, w, Float64(i)) for (w, i) in zip(workers(), 1:4)]

# 在等待结果的同时,主进程可以做其他事
println("计算中...")

# 收集结果
results = [fetch(r) for r in refs]

3.2 错误处理

# fetch 可能抛出远程进程上的异常
try
    result = fetch(@spawnat 2 error("测试错误"))
catch e
    println("捕获错误: ", e)
end

# 使用 remotecall_fetch 同步获取(阻塞)
result = remotecall_fetch(sqrt, 2, 25.0)  # 5.0

3.3 低级通信

# RemoteChannel: 进程间通信通道
const channel = RemoteChannel(() -> Channel{Tuple}(32))

# 生产者(工作进程)
@spawnat 2 begin
    for i in 1:10
        put!(channel, (myid(), i * i))
    end
end

# 消费者(主进程)
for _ in 1:10
    pid, val = take!(channel)
    println("来自进程 $pid: $val")
end

4. 分布式数组(DArrays)

4.1 创建 DArray

using Distributed
addprocs(4)

# 创建分布式数组
d = drand(1000, 1000)  # 1000x1000 随机矩阵,分布在 4 个进程上

# 查看分布情况
d.chunks  # 各块的位置信息

# 转为本地数组
local_arr = Array(d)

# 从本地数组创建 DArray
local_data = rand(100, 100)
d = distribute(local_data)

4.2 DArray 操作

d = drand(1000, 1000)

# 基本运算与普通数组一致
result = d .* 2 .+ 1
sum_val = sum(d)
mean_val = mean(d)

# 矩阵乘法
d1 = drand(100, 100)
d2 = drand(100, 100)
d3 = d1 * d2  # 分布式矩阵乘法

# mapreduce
total = mapreduce(x -> x^2, +, d)

4.3 自定义 DArray 分布

# 自定义分块方式
d = DArray(I -> rand(length.(I)...), (1000, 1000), workers(), (2, 2))
# 将 1000x1000 矩阵分为 2x2 = 4 块

# 查看本地部分
localpart(d)  # 当前进程负责的数据块

5. 分布式 MapReduce

5.1 pmap — 并行映射

@everywhere function process_item(x)
    # 模拟耗时计算
    sleep(0.1)
    return x^2 + sqrt(x)
end

# 串行
data = collect(1:100)
@time results_serial = map(process_item, data)  # ~10 秒

# 并行
@time results_parallel = pmap(process_item, data)  # ~1 秒(4 进程)

5.2 批量 pmap

# 对于大量小任务,批量处理减少通信开销
function batch_pmap(f, data; batch_size=100)
    n = length(data)
    results = Vector{Any}(undef, n)
    
    for start in 1:batch_size:n
        stop = min(start + batch_size - 1, n)
        batch = data[start:stop]
        batch_results = pmap(f, batch)
        results[start:stop] = batch_results
    end
    
    return results
end

5.3 分布式归约

# 分布式求和
function distributed_sum(f, data)
    refs = []
    chunk_size = div(length(data), nworkers())
    
    for (i, w) in enumerate(workers())
        start_idx = (i-1) * chunk_size + 1
        end_idx = min(i * chunk_size, length(data))
        chunk = data[start_idx:end_idx]
        push!(refs, remotecall(chunk -> sum(f, chunk), w, chunk))
    end
    
    return sum(fetch(r) for r in refs)
end

# 使用
data = rand(10_000_000)
total = distributed_sum(sin, data)

6. 集群管理

6.1 ClusterManagers.jl

using Distributed
using ClusterManagers

# SLURM 集群
addprocs(SlurmManager(100; partition="compute", t="01:00:00"))

# PBS/Torque 集群
addprocs(QsubManager(50; queue="batch"))

# SGE 集群
addprocs(SGEManager(30; queue="all.q"))

6.2 自定义机器列表

# machines.txt 内容:
# worker1.example.com
# worker2.example.com
# worker3.example.com

# 使用 machinefile
addprocs(MachineFile("machines.txt"))

# 或手动指定
addprocs([
    ("worker1.example.com", 4),  # 4 个进程
    ("worker2.example.com", 2),  # 2 个进程
]; exename="julia", dir="/home/user/project")

6.3 进程生命周期管理

# 动态添加/移除进程
pid = addprocs(1)[1]
rmprocs(pid)

# 监控进程状态
for w in workers()
    try
        @spawnat w println("进程 $w 运行正常")
    catch
        println("进程 $w 无响应")
    end
end

7. MPI.jl

7.1 安装与基本使用

using Pkg
Pkg.add("MPI")

using MPI

MPI.Init()

comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)
size = MPI.Comm_size(comm)

println("进程 $rank / $size")

# 广播
if rank == 0
    data = [1.0, 2.0, 3.0]
else
    data = zeros(3)
end
MPI.Bcast!(data, 0, comm)

# 归约
local_val = Float64(rank)
global_sum = MPI.Allreduce(local_val, MPI.SUM, comm)

println("进程 $rank: 全局求和 = $global_sum")

MPI.Finalize()

7.2 MPI 通信模式

操作说明函数
点对点两个进程间通信Send, Recv
广播一对多Bcast!
散发分发数据Scatter!
收集收集数据Gather!
全收集每个进程获得所有数据Allgather!
归约聚合操作Reduce, Allreduce

8. 分布式机器学习示例

using Distributed
addprocs(4)

@everywhere begin
    using Statistics

    # 每个工作进程训练一个子模型
    function train_local_model(X_local, y_local; epochs=10)
        n_features = size(X_local, 2)
        w = zeros(n_features)
        lr = 0.01

        for _ in 1:epochs
            predictions = X_local * w
            gradient = X_local' * (predictions - y_local) / size(X_local, 1)
            w -= lr * gradient
        end
        return w
    end
end

# 分布式训练
function distributed_train(X, y; n_epochs=10)
    n = size(X, 1)
    chunk_size = div(n, nworkers())

    # 分发数据
    refs = []
    for (i, w) in enumerate(workers())
        start_idx = (i-1) * chunk_size + 1
        end_idx = min(i * chunk_size, n)
        push!(refs, remotecall(train_local_model, w,
              X[start_idx:end_idx, :], y[start_idx:end_idx];
              epochs=n_epochs))
    end

    # 聚合参数(简单平均)
    weights = [fetch(r) for r in refs]
    return mean(weights)
end

# 使用
X = rand(10000, 5)
y = X * [1.0, 2.0, 3.0, 4.0, 5.0] + randn(10000) * 0.1

global_weights = distributed_train(X, y)
println("学到的权重: ", global_weights)

9. 容错与任务重试

9.1 重试机制

function resilient_remotecall(f, pid, args...; max_retries=3)
    for attempt in 1:max_retries
        try
            return fetch(remotecall(f, pid, args...))
        catch e
            if attempt == max_retries
                rethrow(e)
            end
            println("重试 $attempt/$max_retries: 进程 $pid 失败")
            sleep(1)
        end
    end
end

9.2 任务超时

function with_timeout(f, timeout_sec)
    task = @async f()
    timer = Timer(timeout_sec)
    
    result = timedwait(() -> istaskdone(task), timeout_sec)
    
    if result == :ok
        return fetch(task)
    else
        schedule(task, ErrorException("超时"), error=true)
        error("任务超时")
    end
end

# 使用
try
    result = with_timeout(() -> fetch(@spawnat 2 heavy_compute(100)), 30.0)
catch e
    println("超时或失败: ", e)
end

10. 实际案例:分布式参数扫描

using Distributed
addprocs(4)

@everywhere begin
    using Statistics

    function simulate(params)
        # 模拟一个物理过程
        n_steps = params.steps
        dt = params.dt
        damping = params.damping

        x, v = 1.0, 0.0
        trajectory = zeros(n_steps)

        for i in 1:n_steps
            force = -x - damping * v
            v += force * dt
            x += v * dt
            trajectory[i] = x
        end

        return (
            max_amplitude=maximum(abs.(trajectory)),
            final_position=x,
            energy=0.5 * v^2 + 0.5 * x^2,
        )
    end
end

# 生成参数组合
function generate_params(; steps_range, dt_range, damping_range)
    params = []
    for steps in steps_range
        for dt in dt_range
            for damping in damping_range
                push!(params, (steps=steps, dt=dt, damping=damping))
            end
        end
    end
    return params
end

# 分布式扫描
function parameter_sweep()
    params = generate_params(
        steps_range=[1000, 5000, 10000],
        dt_range=[0.001, 0.01, 0.1],
        damping_range=[0.0, 0.1, 0.5, 1.0],
    )

    println("参数组合数: $(length(params))")

    # 使用 pmap 并行执行
    results = pmap(simulate, params)

    # 分析结果
    for (p, r) in zip(params, results)
        if r.max_amplitude < 0.5
            println("dt=$(p.dt), damping=$(p.damping) → 稳定")
        end
    end

    return params, results
end

params, results = parameter_sweep()

性能对比

方法适用场景通信开销编程复杂度
pmap独立任务
DArray数组运算
@spawnat细粒度控制
MPI大规模集群最低
SharedArray共享内存

业务场景

场景一:蒙特卡洛模拟

使用 pmap 在集群上并行运行百万次蒙特卡洛模拟。每次模拟独立运行,结果通过 fetch 收集。相比单机提速 50 倍。

场景二:参数优化

在超参数搜索中,每个参数组合在不同进程上独立评估。使用 DArray 存储大型数据集,避免内存不足。

场景三:大规模数据处理

使用 MPI.jl 在 HPC 集群上处理 TB 级数据。分布式矩阵分解、大规模特征提取等。


总结

主题关键要点
添加进程addprocs(n)julia -p n
远程调用remotecall / fetch / @spawnat
并行 mappmap(f, collection)
分布式数组drand() / distribute()
集群管理ClusterManagers.jl (SLURM/PBS)
MPIMPI.jl 用于大规模 HPC
容错重试机制和超时控制

扩展阅读