强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

Julia 教程 / Julia 性能优化指南

19. 性能优化指南

Julia 的性能优势来自类型推断和 JIT 编译。掌握性能优化技巧,可以让 Julia 代码接近 C/Fortran 的速度。

19.1 基准测试工具

@time

# 基本计时
@time sum(1:1_000_000)
# 0.001234 seconds (2 allocations: 16 bytes)

# 第一次运行包含编译时间,需多次运行取稳定值
@time sum(1:1_000_000)  # 可能较慢(编译)
@time sum(1:1_000_000)  # 稳定值
@time sum(1:1_000_000)  # 稳定值

@elapsed 与 @allocated

# 只返回时间
t = @elapsed sin(1.0)

# 只返回内存分配
bytes = @allocated sin(1.0)
println("分配了 $bytes 字节")

BenchmarkTools.jl

using BenchmarkTools

# @btime 自动多次运行,报告最小值
@btime sum(1:1_000_000)
# 245.833 μs (0 allocations: 0 bytes)

# @benchmark 详细统计
@benchmark sin(1.0) evals=1000

# 参数传递:用 $ 避免全局变量污染
data = rand(1000)
@btime sum($data)   # ✅ 正确
@btime sum(data)    # ❌ 可能不准确

# 比较不同实现
@btime [sin(x) for x in 1:1000]
@btime sin.(1:1000)

💡 基准测试原则:使用 BenchmarkTools.jl 而非 @time@btime 的最小值最能反映真实性能。

19.2 类型稳定性

@code_warntype 诊断

# ❌ 类型不稳定
function unstable(x)
    if x > 0
        return x
    else
        return "negative"
    end
end

@code_warntype unstable(1)
# 返回类型显示为 Union{Int64, String} — 红色警告

# ✅ 类型稳定
function stable(x::Float64)
    if x > 0
        return x
    else
        return -x
    end
end

@code_warntype stable(1.0)
# 返回类型显示为 Float64 — 全绿

常见类型不稳定模式

# 1. 条件返回不同类型
function bad1(x)
    if x > 0
        return 1      # Int
    else
        return 1.0    # Float64
    end
end

# 修复:统一类型
function good1(x)
    if x > 0
        return Float64(1)
    else
        return Float64(1)
    end
end

# 2. 容器元素类型不确定
function bad2()
    x = []       # Vector{Any}
    push!(x, 1)
    push!(x, 2.0)
    return x
end

# 修复:指定类型
function good2()
    x = Float64[]
    push!(x, 1)
    push!(x, 2.0)
    return x
end

# 3. 函数参数类型不确定
function bad3(items)
    s = 0  # 可能推断不出 items 元素类型
    for item in items
        s += item
    end
    return s
end

# 修复:参数类型限定
function good3(items::Vector{T}) where T <: Number
    s = zero(T)
    for item in items
        s += item
    end
    return s
end

⚠️ 类型不稳定是性能杀手:一个类型不稳定的表达式会导致整个函数退化为动态派发。

19.3 避免全局变量

# ❌ 慢:全局变量
const GLOBAL_DATA = rand(1000)  # const 有助于编译器

function sum_global()
    s = 0.0
    for i in 1:length(GLOBAL_DATA)
        s += GLOBAL_DATA[i]
    end
    return s
end

# ✅ 快:参数传递
function sum_local(data::Vector{Float64})
    s = 0.0
    for i in 1:length(data)
        s += data[i]
    end
    return s
end

# ✅ 使用 const 全局变量
const DATA = rand(1000)
function sum_const()
    s = 0.0
    for i in 1:length(DATA)
        s += DATA[i]
    end
    return s
end

@btime sum_local($DATA)   # 最快
@btime sum_const()        # 接近最快(const)
# sum_global()            # 较慢(非 const 全局变量)

💡 全局变量规则:避免使用非 const 全局变量。必须用时加上 const 声明。

19.4 内存分配优化

预分配输出数组

# ❌ 每次循环分配内存
function bad_multiply(A, B)
    result = A * B  # 分配新矩阵
    return result
end

# ✅ 预分配
function good_multiply!(C, A, B)
    mul!(C, A, B)  # 原地写入
    return C
end

A = rand(100, 100)
B = rand(100, 100)
C = similar(A)  # 预分配

@btime bad_multiply($A, $B)       # 有分配
@btime good_multiply!($C, $A, $B) # 无额外分配

避免不必要的数组创建

# ❌ 创建临时数组
function bad_sum_squares(x)
    return sum(x .^ 2)  # x.^2 创建新数组
end

# ✅ 使用生成器
function good_sum_squares(x)
    return sum(xi^2 for xi in x)
end

# ✅ 使用 mapreduce
function best_sum_squares(x)
    return mapreduce(xi -> xi^2, +, x)
end

x = rand(10000)
@btime bad_sum_squares($x)    # ~15μs(分配)
@btime good_sum_squares($x)   # ~5μs
@btime best_sum_squares($x)   # ~5μs

使用 bang 函数(!)

# Julia 约定:! 结尾的函数会修改参数

# 排序
data = rand(1000)
sorted = sort(data)      # 返回新数组(分配)
sort!(data)              # 原地排序(无分配)

# 去重
unique!(data)            # 原地去重

# 拼接
a = [1, 2, 3]
b = [4, 5, 6]
append!(a, b)            # 原地拼接

# 常见 bang 函数:push!, pop!, insert!, deleteat!, resize!, fill!, copyto!

19.5 视图 vs 拷贝

A = rand(1000, 1000)

# ❌ 切片创建拷贝
col = A[:, 1]            # 分配新向量
sub = A[1:100, 1:100]    # 分配新矩阵

# ✅ 视图不分配内存
col_view = @view A[:, 1]
sub_view = @view A[1:100, 1:100]

# @views 宏:批量替换
@views function process_matrix(A)
    row1 = A[1, :]       # 自动变为视图
    col1 = A[:, 1]       # 自动变为视图
    return row1' * col1
end

@btime $A[:, 1]             # ~7.9μs(分配)
@btime @view $A[:, 1]       # ~0.05μs(无分配)

⚠️ 视图注意事项:视图与原数组共享内存。修改视图会影响原数组。视图可能比拷贝慢(非连续内存)。

19.6 @inbounds / @simd / @fastmath

@inbounds:关闭边界检查

function sum_with_bounds(a)
    s = zero(eltype(a))
    for i in eachindex(a)
        s += a[i]  # 每次访问都检查边界
    end
    return s
end

function sum_no_bounds(a)
    s = zero(eltype(a))
    @inbounds for i in eachindex(a)
        s += a[i]  # 跳过边界检查
    end
    return s
end

a = rand(10000)
@btime sum_with_bounds($a)
@btime sum_no_bounds($a)   # 稍快

@simd:向量化提示

function simd_sum(a)
    s = zero(eltype(a))
    @simd for i in eachindex(a)
        s += a[i]
    end
    return s
end

# @simd 允许 SIMD 指令(如 AVX),提升吞吐量
# 注意:浮点运算可能产生微小误差(不同运算顺序)

@fastmath:激进数学优化

function precise_sum(a)
    s = 0.0
    for x in a
        s += x
    end
    return s
end

function fast_sum(a)
    s = 0.0
    @fastmath for x in a
        s += x
    end
    return s
end

# @fastmath 允许:
# - 重新排序运算
# - 使用近似函数
# - 假设 NaN/Inf 不会出现
# 代价:可能损失精度
作用风险
@inbounds跳过边界检查越界导致崩溃
@simd向量化优化浮点结果微变
@fastmath激进数学优化精度降低

19.7 @inline 与 @noinline

@inline:内联函数

# 小函数建议内联
@inline function fast_add(a, b)
    return a + b
end

# 内联消除了函数调用开销
# Julia 编译器通常会自动内联小函数
# 但在复杂场景中手动提示有帮助

@noinline:禁止内联

# 大函数不应内联(增加代码体积)
@noinline function large_setup()
    # 大量初始化代码
    # ...
end

# 冷路径代码
@noinline function handle_error(msg)
    error(msg)
end

💡 内联原则:让编译器自动决定。只在 @code_warntype 或性能分析表明函数调用是瓶颈时手动使用 @inline

19.8 函数屏障(Function Barrier)

函数屏障通过将类型不稳定的代码隔离到单独的函数中,确保核心计算类型稳定。

# ❌ 无屏障:类型不稳定影响整个循环
function process_bad(dict::Dict{String, Any})
    for (key, val) in dict
        result = val * 2  # val 是 Any,类型不稳定
        println("$key: $result")
    end
end

# ✅ 函数屏障:核心计算类型稳定
@inline compute(x::Int) = x * 2
@inline compute(x::Float64) = x * 2.0
@inline compute(x::String) = x * 2

function process_good(dict::Dict{String, Any})
    for (key, val) in dict
        result = compute(val)  # 编译器知道 val 的具体类型
        println("$key: $result")
    end
end

实际应用:绘图库

# 类型不稳定的输入
draw(shape::Circle) = _draw_impl(shape)
draw(shape::Rectangle) = _draw_impl(shape)
draw(shape::Triangle) = _draw_impl(shape)

# 每种形状有类型稳定的实现
function _draw_impl(c::Circle)
    # 快速:所有类型已知
    for angle in 0:360
        x = c.center[1] + c.radius * cosd(angle)
        y = c.center[2] + c.radius * sind(angle)
        plot_pixel(x, y)
    end
end

19.9 并行性能

线程设置

# 启动时指定线程数
# JULIA_NUM_THREADS=8 julia

println("线程数: ", Threads.nthreads())

Threads.@threads

function parallel_sum(data)
    n = length(data)
    results = zeros(Threads.nthreads())
    
    Threads.@threads for i in 1:n
        tid = Threads.threadid()
        results[tid] += data[i]
    end
    
    return sum(results)
end

data = rand(10_000_000)
@btime parallel_sum($data)
@btime sum($data)  # 内置函数通常更快

@spawn 细粒度并行

function parallel_map(f, data; ntasks=Threads.nthreads())
    n = length(data)
    results = Vector{eltype(data)}(undef, n)
    chunk_size = cld(n, ntasks)
    
    tasks = map(1:ntasks) do t
        Threads.@spawn begin
            start = (t-1) * chunk_size + 1
            stop = min(t * chunk_size, n)
            for i in start:stop
                results[i] = f(data[i])
            end
        end
    end
    
    fetch.(tasks)
    return results
end

19.10 性能分析 Profile.jl

基本分析

using Profile

# 要分析的代码
function workload()
    data = rand(1000, 1000)
    result = data * data'
    eigvals(result)
end

# 预热
workload()

# 采样分析
Profile.clear()
@profile workload()

# 查看结果
Profile.print()

# 保存为火焰图
# using ProfileSVG
# ProfileSVG.save("profile.svg")

分析结果解读

# Profile.print() 输出解读:
# - 第一列:该函数及其子函数占用的时间比例
# - 第二列:该函数自身占用的时间
# - 第三列:代码位置

# 关注自身占用高的函数(优化目标)

19.11 实战优化案例

案例:矩阵行标准化

# 版本1:朴素实现(慢)
function normalize_rows_bad(A)
    n, m = size(A)
    result = zeros(n, m)
    for i in 1:n
        row_sum = 0.0
        for j in 1:m
            row_sum += A[i, j]
        end
        for j in 1:m
            result[i, j] = A[i, j] / row_sum
        end
    end
    return result
end

# 版本2:预分配 + @inbounds
function normalize_rows_better(A)
    n, m = size(A)
    result = similar(A)
    @inbounds for i in 1:n
        row_sum = 0.0
        for j in 1:m
            row_sum += A[i, j]
        end
        inv_sum = 1.0 / row_sum
        for j in 1:m
            result[i, j] = A[i, j] * inv_sum
        end
    end
    return result
end

# 版本3:向量化
function normalize_rows_best(A)
    sums = sum(A; dims=2)
    return A ./ sums
end

A = rand(1000, 1000)
@btime normalize_rows_bad($A)
@btime normalize_rows_better($A)
@btime normalize_rows_best($A)

案例:字符串解析

# ❌ 慢:正则 + 字符串操作
function parse_data_slow(lines)
    results = []
    for line in lines
        m = match(r"(\w+):(\d+)", line)
        if m !== nothing
            push!(results, (m.captures[1], parse(Int, m.captures[2])))
        end
    end
    return results
end

# ✅ 快:手动解析 + 类型标注
function parse_data_fast(lines::Vector{String})
    results = Tuple{String, Int}[]
    for line in lines
        colon_idx = findfirst(':', line)
        colon_idx === nothing && continue
        name = line[1:colon_idx-1]
        value = tryparse(Int, line[colon_idx+1:end])
        value === nothing && continue
        push!(results, (name, value))
    end
    return results
end

lines = ["alice:100", "bob:200", "charlie:300"]
@btime parse_data_slow($lines)
@btime parse_data_fast($lines)

优化检查清单

步骤检查项工具
1基准测试@btime
2类型稳定性@code_warntype
3内存分配@allocated
4热点函数Profile.@profile
5全局变量检查是否 const
6预分配similar / zeros
7视图@view / @views
8边界检查@inbounds
9向量化@simd
10并行化Threads.@threads

19.12 性能反模式

反模式问题修复
非 const 全局变量类型不确定const 或传参数
容器类型 Any无法特化指定具体类型
条件返回不同类型Union 类型统一返回类型
频繁小分配GC 压力预分配 / bang 函数
字符串拼接不可变,每次分配IOBuffer / join
过度抽象失去特化机会函数屏障

19.13 扩展阅读

资源链接
Julia 官方 - Performance Tipshttps://docs.julialang.org/en/v1/manual/performance-tips/
BenchmarkTools.jlhttps://github.com/JuliaCI/BenchmarkTools.jl
Profile.jlhttps://docs.julialang.org/en/v1/manual/profile/
JET.jl (静态分析)https://github.com/aviatesk/JET.jl
PProf.jl (火焰图)https://github.com/JuliaPerf/PProf.jl

19.14 本章小结

主题要点
基准测试@btime 而非 @time
类型稳定@code_warntype 检查,避免 Union
全局变量避免或用 const
内存分配预分配、视图、bang 函数
边界优化@inbounds / @simd / @fastmath
函数屏障隔离类型不稳定代码
分析工具Profile.jl 定位热点