Julia 教程 / Julia 性能优化指南
19. 性能优化指南
Julia 的性能优势来自类型推断和 JIT 编译。掌握性能优化技巧,可以让 Julia 代码接近 C/Fortran 的速度。
19.1 基准测试工具
@time
# 基本计时
@time sum(1:1_000_000)
# 0.001234 seconds (2 allocations: 16 bytes)
# 第一次运行包含编译时间,需多次运行取稳定值
@time sum(1:1_000_000) # 可能较慢(编译)
@time sum(1:1_000_000) # 稳定值
@time sum(1:1_000_000) # 稳定值
@elapsed 与 @allocated
# 只返回时间
t = @elapsed sin(1.0)
# 只返回内存分配
bytes = @allocated sin(1.0)
println("分配了 $bytes 字节")
BenchmarkTools.jl
using BenchmarkTools
# @btime 自动多次运行,报告最小值
@btime sum(1:1_000_000)
# 245.833 μs (0 allocations: 0 bytes)
# @benchmark 详细统计
@benchmark sin(1.0) evals=1000
# 参数传递:用 $ 避免全局变量污染
data = rand(1000)
@btime sum($data) # ✅ 正确
@btime sum(data) # ❌ 可能不准确
# 比较不同实现
@btime [sin(x) for x in 1:1000]
@btime sin.(1:1000)
💡 基准测试原则:使用
BenchmarkTools.jl而非@time。@btime的最小值最能反映真实性能。
19.2 类型稳定性
@code_warntype 诊断
# ❌ 类型不稳定
function unstable(x)
if x > 0
return x
else
return "negative"
end
end
@code_warntype unstable(1)
# 返回类型显示为 Union{Int64, String} — 红色警告
# ✅ 类型稳定
function stable(x::Float64)
if x > 0
return x
else
return -x
end
end
@code_warntype stable(1.0)
# 返回类型显示为 Float64 — 全绿
常见类型不稳定模式
# 1. 条件返回不同类型
function bad1(x)
if x > 0
return 1 # Int
else
return 1.0 # Float64
end
end
# 修复:统一类型
function good1(x)
if x > 0
return Float64(1)
else
return Float64(1)
end
end
# 2. 容器元素类型不确定
function bad2()
x = [] # Vector{Any}
push!(x, 1)
push!(x, 2.0)
return x
end
# 修复:指定类型
function good2()
x = Float64[]
push!(x, 1)
push!(x, 2.0)
return x
end
# 3. 函数参数类型不确定
function bad3(items)
s = 0 # 可能推断不出 items 元素类型
for item in items
s += item
end
return s
end
# 修复:参数类型限定
function good3(items::Vector{T}) where T <: Number
s = zero(T)
for item in items
s += item
end
return s
end
⚠️ 类型不稳定是性能杀手:一个类型不稳定的表达式会导致整个函数退化为动态派发。
19.3 避免全局变量
# ❌ 慢:全局变量
const GLOBAL_DATA = rand(1000) # const 有助于编译器
function sum_global()
s = 0.0
for i in 1:length(GLOBAL_DATA)
s += GLOBAL_DATA[i]
end
return s
end
# ✅ 快:参数传递
function sum_local(data::Vector{Float64})
s = 0.0
for i in 1:length(data)
s += data[i]
end
return s
end
# ✅ 使用 const 全局变量
const DATA = rand(1000)
function sum_const()
s = 0.0
for i in 1:length(DATA)
s += DATA[i]
end
return s
end
@btime sum_local($DATA) # 最快
@btime sum_const() # 接近最快(const)
# sum_global() # 较慢(非 const 全局变量)
💡 全局变量规则:避免使用非
const全局变量。必须用时加上const声明。
19.4 内存分配优化
预分配输出数组
# ❌ 每次循环分配内存
function bad_multiply(A, B)
result = A * B # 分配新矩阵
return result
end
# ✅ 预分配
function good_multiply!(C, A, B)
mul!(C, A, B) # 原地写入
return C
end
A = rand(100, 100)
B = rand(100, 100)
C = similar(A) # 预分配
@btime bad_multiply($A, $B) # 有分配
@btime good_multiply!($C, $A, $B) # 无额外分配
避免不必要的数组创建
# ❌ 创建临时数组
function bad_sum_squares(x)
return sum(x .^ 2) # x.^2 创建新数组
end
# ✅ 使用生成器
function good_sum_squares(x)
return sum(xi^2 for xi in x)
end
# ✅ 使用 mapreduce
function best_sum_squares(x)
return mapreduce(xi -> xi^2, +, x)
end
x = rand(10000)
@btime bad_sum_squares($x) # ~15μs(分配)
@btime good_sum_squares($x) # ~5μs
@btime best_sum_squares($x) # ~5μs
使用 bang 函数(!)
# Julia 约定:! 结尾的函数会修改参数
# 排序
data = rand(1000)
sorted = sort(data) # 返回新数组(分配)
sort!(data) # 原地排序(无分配)
# 去重
unique!(data) # 原地去重
# 拼接
a = [1, 2, 3]
b = [4, 5, 6]
append!(a, b) # 原地拼接
# 常见 bang 函数:push!, pop!, insert!, deleteat!, resize!, fill!, copyto!
19.5 视图 vs 拷贝
A = rand(1000, 1000)
# ❌ 切片创建拷贝
col = A[:, 1] # 分配新向量
sub = A[1:100, 1:100] # 分配新矩阵
# ✅ 视图不分配内存
col_view = @view A[:, 1]
sub_view = @view A[1:100, 1:100]
# @views 宏:批量替换
@views function process_matrix(A)
row1 = A[1, :] # 自动变为视图
col1 = A[:, 1] # 自动变为视图
return row1' * col1
end
@btime $A[:, 1] # ~7.9μs(分配)
@btime @view $A[:, 1] # ~0.05μs(无分配)
⚠️ 视图注意事项:视图与原数组共享内存。修改视图会影响原数组。视图可能比拷贝慢(非连续内存)。
19.6 @inbounds / @simd / @fastmath
@inbounds:关闭边界检查
function sum_with_bounds(a)
s = zero(eltype(a))
for i in eachindex(a)
s += a[i] # 每次访问都检查边界
end
return s
end
function sum_no_bounds(a)
s = zero(eltype(a))
@inbounds for i in eachindex(a)
s += a[i] # 跳过边界检查
end
return s
end
a = rand(10000)
@btime sum_with_bounds($a)
@btime sum_no_bounds($a) # 稍快
@simd:向量化提示
function simd_sum(a)
s = zero(eltype(a))
@simd for i in eachindex(a)
s += a[i]
end
return s
end
# @simd 允许 SIMD 指令(如 AVX),提升吞吐量
# 注意:浮点运算可能产生微小误差(不同运算顺序)
@fastmath:激进数学优化
function precise_sum(a)
s = 0.0
for x in a
s += x
end
return s
end
function fast_sum(a)
s = 0.0
@fastmath for x in a
s += x
end
return s
end
# @fastmath 允许:
# - 重新排序运算
# - 使用近似函数
# - 假设 NaN/Inf 不会出现
# 代价:可能损失精度
| 宏 | 作用 | 风险 |
|---|---|---|
@inbounds | 跳过边界检查 | 越界导致崩溃 |
@simd | 向量化优化 | 浮点结果微变 |
@fastmath | 激进数学优化 | 精度降低 |
19.7 @inline 与 @noinline
@inline:内联函数
# 小函数建议内联
@inline function fast_add(a, b)
return a + b
end
# 内联消除了函数调用开销
# Julia 编译器通常会自动内联小函数
# 但在复杂场景中手动提示有帮助
@noinline:禁止内联
# 大函数不应内联(增加代码体积)
@noinline function large_setup()
# 大量初始化代码
# ...
end
# 冷路径代码
@noinline function handle_error(msg)
error(msg)
end
💡 内联原则:让编译器自动决定。只在
@code_warntype或性能分析表明函数调用是瓶颈时手动使用@inline。
19.8 函数屏障(Function Barrier)
函数屏障通过将类型不稳定的代码隔离到单独的函数中,确保核心计算类型稳定。
# ❌ 无屏障:类型不稳定影响整个循环
function process_bad(dict::Dict{String, Any})
for (key, val) in dict
result = val * 2 # val 是 Any,类型不稳定
println("$key: $result")
end
end
# ✅ 函数屏障:核心计算类型稳定
@inline compute(x::Int) = x * 2
@inline compute(x::Float64) = x * 2.0
@inline compute(x::String) = x * 2
function process_good(dict::Dict{String, Any})
for (key, val) in dict
result = compute(val) # 编译器知道 val 的具体类型
println("$key: $result")
end
end
实际应用:绘图库
# 类型不稳定的输入
draw(shape::Circle) = _draw_impl(shape)
draw(shape::Rectangle) = _draw_impl(shape)
draw(shape::Triangle) = _draw_impl(shape)
# 每种形状有类型稳定的实现
function _draw_impl(c::Circle)
# 快速:所有类型已知
for angle in 0:360
x = c.center[1] + c.radius * cosd(angle)
y = c.center[2] + c.radius * sind(angle)
plot_pixel(x, y)
end
end
19.9 并行性能
线程设置
# 启动时指定线程数
# JULIA_NUM_THREADS=8 julia
println("线程数: ", Threads.nthreads())
Threads.@threads
function parallel_sum(data)
n = length(data)
results = zeros(Threads.nthreads())
Threads.@threads for i in 1:n
tid = Threads.threadid()
results[tid] += data[i]
end
return sum(results)
end
data = rand(10_000_000)
@btime parallel_sum($data)
@btime sum($data) # 内置函数通常更快
@spawn 细粒度并行
function parallel_map(f, data; ntasks=Threads.nthreads())
n = length(data)
results = Vector{eltype(data)}(undef, n)
chunk_size = cld(n, ntasks)
tasks = map(1:ntasks) do t
Threads.@spawn begin
start = (t-1) * chunk_size + 1
stop = min(t * chunk_size, n)
for i in start:stop
results[i] = f(data[i])
end
end
end
fetch.(tasks)
return results
end
19.10 性能分析 Profile.jl
基本分析
using Profile
# 要分析的代码
function workload()
data = rand(1000, 1000)
result = data * data'
eigvals(result)
end
# 预热
workload()
# 采样分析
Profile.clear()
@profile workload()
# 查看结果
Profile.print()
# 保存为火焰图
# using ProfileSVG
# ProfileSVG.save("profile.svg")
分析结果解读
# Profile.print() 输出解读:
# - 第一列:该函数及其子函数占用的时间比例
# - 第二列:该函数自身占用的时间
# - 第三列:代码位置
# 关注自身占用高的函数(优化目标)
19.11 实战优化案例
案例:矩阵行标准化
# 版本1:朴素实现(慢)
function normalize_rows_bad(A)
n, m = size(A)
result = zeros(n, m)
for i in 1:n
row_sum = 0.0
for j in 1:m
row_sum += A[i, j]
end
for j in 1:m
result[i, j] = A[i, j] / row_sum
end
end
return result
end
# 版本2:预分配 + @inbounds
function normalize_rows_better(A)
n, m = size(A)
result = similar(A)
@inbounds for i in 1:n
row_sum = 0.0
for j in 1:m
row_sum += A[i, j]
end
inv_sum = 1.0 / row_sum
for j in 1:m
result[i, j] = A[i, j] * inv_sum
end
end
return result
end
# 版本3:向量化
function normalize_rows_best(A)
sums = sum(A; dims=2)
return A ./ sums
end
A = rand(1000, 1000)
@btime normalize_rows_bad($A)
@btime normalize_rows_better($A)
@btime normalize_rows_best($A)
案例:字符串解析
# ❌ 慢:正则 + 字符串操作
function parse_data_slow(lines)
results = []
for line in lines
m = match(r"(\w+):(\d+)", line)
if m !== nothing
push!(results, (m.captures[1], parse(Int, m.captures[2])))
end
end
return results
end
# ✅ 快:手动解析 + 类型标注
function parse_data_fast(lines::Vector{String})
results = Tuple{String, Int}[]
for line in lines
colon_idx = findfirst(':', line)
colon_idx === nothing && continue
name = line[1:colon_idx-1]
value = tryparse(Int, line[colon_idx+1:end])
value === nothing && continue
push!(results, (name, value))
end
return results
end
lines = ["alice:100", "bob:200", "charlie:300"]
@btime parse_data_slow($lines)
@btime parse_data_fast($lines)
优化检查清单
| 步骤 | 检查项 | 工具 |
|---|---|---|
| 1 | 基准测试 | @btime |
| 2 | 类型稳定性 | @code_warntype |
| 3 | 内存分配 | @allocated |
| 4 | 热点函数 | Profile.@profile |
| 5 | 全局变量 | 检查是否 const |
| 6 | 预分配 | similar / zeros |
| 7 | 视图 | @view / @views |
| 8 | 边界检查 | @inbounds |
| 9 | 向量化 | @simd |
| 10 | 并行化 | Threads.@threads |
19.12 性能反模式
| 反模式 | 问题 | 修复 |
|---|---|---|
| 非 const 全局变量 | 类型不确定 | 加 const 或传参数 |
容器类型 Any | 无法特化 | 指定具体类型 |
| 条件返回不同类型 | Union 类型 | 统一返回类型 |
| 频繁小分配 | GC 压力 | 预分配 / bang 函数 |
| 字符串拼接 | 不可变,每次分配 | IOBuffer / join |
| 过度抽象 | 失去特化机会 | 函数屏障 |
19.13 扩展阅读
| 资源 | 链接 |
|---|---|
| Julia 官方 - Performance Tips | https://docs.julialang.org/en/v1/manual/performance-tips/ |
| BenchmarkTools.jl | https://github.com/JuliaCI/BenchmarkTools.jl |
| Profile.jl | https://docs.julialang.org/en/v1/manual/profile/ |
| JET.jl (静态分析) | https://github.com/aviatesk/JET.jl |
| PProf.jl (火焰图) | https://github.com/JuliaPerf/PProf.jl |
19.14 本章小结
| 主题 | 要点 |
|---|---|
| 基准测试 | 用 @btime 而非 @time |
| 类型稳定 | @code_warntype 检查,避免 Union |
| 全局变量 | 避免或用 const |
| 内存分配 | 预分配、视图、bang 函数 |
| 边界优化 | @inbounds / @simd / @fastmath |
| 函数屏障 | 隔离类型不稳定代码 |
| 分析工具 | Profile.jl 定位热点 |