Julia 教程 / 分布式计算
分布式计算
Julia 内置了强大的分布式计算支持。从简单的多进程并行到集群级别的分布式计算,Julia 提供了完整的工具链。本文介绍如何利用分布式 Julia 解决大规模计算问题。
1. 分布式 Julia 基础
1.1 添加工作进程
using Distributed
# 查看当前进程数
nprocs() # 默认 1(主进程)
# 添加工作进程(本地)
addprocs(4) # 添加 4 个工作进程
nprocs() # 5(1 主 + 4 工作)
nworkers() # 4(仅工作进程)
# 查看进程信息
workers() # [2, 3, 4, 5]
myid() # 当前进程 ID(主进程为 1)
1.2 启动时指定进程
# 启动 Julia 时添加 4 个工作进程
julia -p 4
# 使用指定机器(需 SSH 配置)
julia -p 4 --machinefile machines.txt
1.3 工作进程上加载包
# 在所有工作进程上加载包
@everywhere using LinearAlgebra, Statistics
# 在所有进程上定义函数
@everywhere function heavy_compute(x)
s = 0.0
for i in 1:10_000
s += sin(x * i) / i
end
return s
end
2. Distributed.jl 模块
2.1 核心 API
| 函数 | 说明 |
|---|---|
addprocs(n) | 添加 n 个工作进程 |
rmprocs(pids...) | 移除工作进程 |
workers() | 获取所有工作进程 ID |
nworkers() | 工作进程数 |
nprocs() | 总进程数 |
myid() | 当前进程 ID |
@everywhere expr | 在所有进程上执行 expr |
remotecall(f, pid, args...) | 在指定进程上调用函数 |
fetch(ref) | 获取远程调用的结果 |
@spawnat pid expr | 在指定进程上执行表达式 |
@spawn expr | 在任意进程上执行表达式 |
pmap(f, collection) | 并行 map |
remotecall_eval(m, pid, expr) | 在远程进程上 eval |
2.2 远程调用基础
using Distributed
addprocs(4)
# @spawnat: 在指定进程上执行
r = @spawnat 2 begin
println("我在进程 $(myid()) 上")
42
end
result = fetch(r) # 42
# @spawn: 自动选择进程
r = @spawn myid()
fetch(r) # 返回某个工作进程的 ID
# remotecall: 显式远程调用
ref = remotecall(sqrt, 3, 16.0)
fetch(ref) # 4.0
3. @spawnat / remotecall / fetch
3.1 异步远程调用
# remotecall 是异步的,立即返回 Future
refs = [remotecall(heavy_compute, w, Float64(i)) for (w, i) in zip(workers(), 1:4)]
# 在等待结果的同时,主进程可以做其他事
println("计算中...")
# 收集结果
results = [fetch(r) for r in refs]
3.2 错误处理
# fetch 可能抛出远程进程上的异常
try
result = fetch(@spawnat 2 error("测试错误"))
catch e
println("捕获错误: ", e)
end
# 使用 remotecall_fetch 同步获取(阻塞)
result = remotecall_fetch(sqrt, 2, 25.0) # 5.0
3.3 低级通信
# RemoteChannel: 进程间通信通道
const channel = RemoteChannel(() -> Channel{Tuple}(32))
# 生产者(工作进程)
@spawnat 2 begin
for i in 1:10
put!(channel, (myid(), i * i))
end
end
# 消费者(主进程)
for _ in 1:10
pid, val = take!(channel)
println("来自进程 $pid: $val")
end
4. 分布式数组(DArrays)
4.1 创建 DArray
using Distributed
addprocs(4)
# 创建分布式数组
d = drand(1000, 1000) # 1000x1000 随机矩阵,分布在 4 个进程上
# 查看分布情况
d.chunks # 各块的位置信息
# 转为本地数组
local_arr = Array(d)
# 从本地数组创建 DArray
local_data = rand(100, 100)
d = distribute(local_data)
4.2 DArray 操作
d = drand(1000, 1000)
# 基本运算与普通数组一致
result = d .* 2 .+ 1
sum_val = sum(d)
mean_val = mean(d)
# 矩阵乘法
d1 = drand(100, 100)
d2 = drand(100, 100)
d3 = d1 * d2 # 分布式矩阵乘法
# mapreduce
total = mapreduce(x -> x^2, +, d)
4.3 自定义 DArray 分布
# 自定义分块方式
d = DArray(I -> rand(length.(I)...), (1000, 1000), workers(), (2, 2))
# 将 1000x1000 矩阵分为 2x2 = 4 块
# 查看本地部分
localpart(d) # 当前进程负责的数据块
5. 分布式 MapReduce
5.1 pmap — 并行映射
@everywhere function process_item(x)
# 模拟耗时计算
sleep(0.1)
return x^2 + sqrt(x)
end
# 串行
data = collect(1:100)
@time results_serial = map(process_item, data) # ~10 秒
# 并行
@time results_parallel = pmap(process_item, data) # ~1 秒(4 进程)
5.2 批量 pmap
# 对于大量小任务,批量处理减少通信开销
function batch_pmap(f, data; batch_size=100)
n = length(data)
results = Vector{Any}(undef, n)
for start in 1:batch_size:n
stop = min(start + batch_size - 1, n)
batch = data[start:stop]
batch_results = pmap(f, batch)
results[start:stop] = batch_results
end
return results
end
5.3 分布式归约
# 分布式求和
function distributed_sum(f, data)
refs = []
chunk_size = div(length(data), nworkers())
for (i, w) in enumerate(workers())
start_idx = (i-1) * chunk_size + 1
end_idx = min(i * chunk_size, length(data))
chunk = data[start_idx:end_idx]
push!(refs, remotecall(chunk -> sum(f, chunk), w, chunk))
end
return sum(fetch(r) for r in refs)
end
# 使用
data = rand(10_000_000)
total = distributed_sum(sin, data)
6. 集群管理
6.1 ClusterManagers.jl
using Distributed
using ClusterManagers
# SLURM 集群
addprocs(SlurmManager(100; partition="compute", t="01:00:00"))
# PBS/Torque 集群
addprocs(QsubManager(50; queue="batch"))
# SGE 集群
addprocs(SGEManager(30; queue="all.q"))
6.2 自定义机器列表
# machines.txt 内容:
# worker1.example.com
# worker2.example.com
# worker3.example.com
# 使用 machinefile
addprocs(MachineFile("machines.txt"))
# 或手动指定
addprocs([
("worker1.example.com", 4), # 4 个进程
("worker2.example.com", 2), # 2 个进程
]; exename="julia", dir="/home/user/project")
6.3 进程生命周期管理
# 动态添加/移除进程
pid = addprocs(1)[1]
rmprocs(pid)
# 监控进程状态
for w in workers()
try
@spawnat w println("进程 $w 运行正常")
catch
println("进程 $w 无响应")
end
end
7. MPI.jl
7.1 安装与基本使用
using Pkg
Pkg.add("MPI")
using MPI
MPI.Init()
comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)
size = MPI.Comm_size(comm)
println("进程 $rank / $size")
# 广播
if rank == 0
data = [1.0, 2.0, 3.0]
else
data = zeros(3)
end
MPI.Bcast!(data, 0, comm)
# 归约
local_val = Float64(rank)
global_sum = MPI.Allreduce(local_val, MPI.SUM, comm)
println("进程 $rank: 全局求和 = $global_sum")
MPI.Finalize()
7.2 MPI 通信模式
| 操作 | 说明 | 函数 |
|---|---|---|
| 点对点 | 两个进程间通信 | Send, Recv |
| 广播 | 一对多 | Bcast! |
| 散发 | 分发数据 | Scatter! |
| 收集 | 收集数据 | Gather! |
| 全收集 | 每个进程获得所有数据 | Allgather! |
| 归约 | 聚合操作 | Reduce, Allreduce |
8. 分布式机器学习示例
using Distributed
addprocs(4)
@everywhere begin
using Statistics
# 每个工作进程训练一个子模型
function train_local_model(X_local, y_local; epochs=10)
n_features = size(X_local, 2)
w = zeros(n_features)
lr = 0.01
for _ in 1:epochs
predictions = X_local * w
gradient = X_local' * (predictions - y_local) / size(X_local, 1)
w -= lr * gradient
end
return w
end
end
# 分布式训练
function distributed_train(X, y; n_epochs=10)
n = size(X, 1)
chunk_size = div(n, nworkers())
# 分发数据
refs = []
for (i, w) in enumerate(workers())
start_idx = (i-1) * chunk_size + 1
end_idx = min(i * chunk_size, n)
push!(refs, remotecall(train_local_model, w,
X[start_idx:end_idx, :], y[start_idx:end_idx];
epochs=n_epochs))
end
# 聚合参数(简单平均)
weights = [fetch(r) for r in refs]
return mean(weights)
end
# 使用
X = rand(10000, 5)
y = X * [1.0, 2.0, 3.0, 4.0, 5.0] + randn(10000) * 0.1
global_weights = distributed_train(X, y)
println("学到的权重: ", global_weights)
9. 容错与任务重试
9.1 重试机制
function resilient_remotecall(f, pid, args...; max_retries=3)
for attempt in 1:max_retries
try
return fetch(remotecall(f, pid, args...))
catch e
if attempt == max_retries
rethrow(e)
end
println("重试 $attempt/$max_retries: 进程 $pid 失败")
sleep(1)
end
end
end
9.2 任务超时
function with_timeout(f, timeout_sec)
task = @async f()
timer = Timer(timeout_sec)
result = timedwait(() -> istaskdone(task), timeout_sec)
if result == :ok
return fetch(task)
else
schedule(task, ErrorException("超时"), error=true)
error("任务超时")
end
end
# 使用
try
result = with_timeout(() -> fetch(@spawnat 2 heavy_compute(100)), 30.0)
catch e
println("超时或失败: ", e)
end
10. 实际案例:分布式参数扫描
using Distributed
addprocs(4)
@everywhere begin
using Statistics
function simulate(params)
# 模拟一个物理过程
n_steps = params.steps
dt = params.dt
damping = params.damping
x, v = 1.0, 0.0
trajectory = zeros(n_steps)
for i in 1:n_steps
force = -x - damping * v
v += force * dt
x += v * dt
trajectory[i] = x
end
return (
max_amplitude=maximum(abs.(trajectory)),
final_position=x,
energy=0.5 * v^2 + 0.5 * x^2,
)
end
end
# 生成参数组合
function generate_params(; steps_range, dt_range, damping_range)
params = []
for steps in steps_range
for dt in dt_range
for damping in damping_range
push!(params, (steps=steps, dt=dt, damping=damping))
end
end
end
return params
end
# 分布式扫描
function parameter_sweep()
params = generate_params(
steps_range=[1000, 5000, 10000],
dt_range=[0.001, 0.01, 0.1],
damping_range=[0.0, 0.1, 0.5, 1.0],
)
println("参数组合数: $(length(params))")
# 使用 pmap 并行执行
results = pmap(simulate, params)
# 分析结果
for (p, r) in zip(params, results)
if r.max_amplitude < 0.5
println("dt=$(p.dt), damping=$(p.damping) → 稳定")
end
end
return params, results
end
params, results = parameter_sweep()
性能对比
| 方法 | 适用场景 | 通信开销 | 编程复杂度 |
|---|---|---|---|
pmap | 独立任务 | 低 | 低 |
| DArray | 数组运算 | 中 | 低 |
@spawnat | 细粒度控制 | 高 | 中 |
| MPI | 大规模集群 | 最低 | 高 |
| SharedArray | 共享内存 | 无 | 低 |
业务场景
场景一:蒙特卡洛模拟
使用 pmap 在集群上并行运行百万次蒙特卡洛模拟。每次模拟独立运行,结果通过 fetch 收集。相比单机提速 50 倍。
场景二:参数优化
在超参数搜索中,每个参数组合在不同进程上独立评估。使用 DArray 存储大型数据集,避免内存不足。
场景三:大规模数据处理
使用 MPI.jl 在 HPC 集群上处理 TB 级数据。分布式矩阵分解、大规模特征提取等。
总结
| 主题 | 关键要点 |
|---|---|
| 添加进程 | addprocs(n) 或 julia -p n |
| 远程调用 | remotecall / fetch / @spawnat |
| 并行 map | pmap(f, collection) |
| 分布式数组 | drand() / distribute() |
| 集群管理 | ClusterManagers.jl (SLURM/PBS) |
| MPI | MPI.jl 用于大规模 HPC |
| 容错 | 重试机制和超时控制 |