强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

Julia 教程 / Julia 多重派发深入

11. 多重派发深入

多重派发(Multiple Dispatch)是 Julia 最核心的特性之一,它使得 Julia 能够以接近 C 的速度运行,同时保持高级语言的灵活性。

11.1 什么是多重派发

在传统面向对象语言中,方法调用根据一个对象的类型来决定(单派发)。Julia 则根据所有参数的类型来选择最匹配的方法。

# 单派发 (Python/Java 风格)
# obj.method(args)  → 仅根据 obj 的类型选择方法

# 多重派发 (Julia 风格)
# func(a, b)  → 根据 a 和 b 的类型选择方法

基本示例

# 定义不同类型之间的交互
struct Rock end
struct Paper end
struct Scissors end

# 所有可能的组合都定义方法
play(::Rock, ::Rock) = "平局"
play(::Rock, ::Paper) = "纸胜"
play(::Rock, ::Scissors) = "石胜"
play(::Paper, ::Paper) = "平局"
play(::Paper, ::Rock) = "纸胜"
play(::Paper, ::Scissors) = "剪刀胜"
play(::Scissors, ::Scissors) = "平局"
play(::Scissors, ::Rock) = "石胜"
play(::Scissors, ::Paper) = "剪刀胜"

println(play(Rock(), Paper()))    # 纸胜
println(play(Scissors(), Rock())) # 石胜

11.2 方法优先级与歧义解决

Julia 选择方法时遵循最具体优先原则。当两个方法都匹配时,更具体的胜出。

f(x::Int, y::Int) = "两个都是 Int"
f(x::Number, y::Number) = "两个都是 Number"
f(x::Any, y::Any) = "任意类型"

println(f(1, 2))       # 两个都是 Int
println(f(1.0, 2.0))   # 两个都是 Number
println(f("a", "b"))   # 任意类型

歧义问题

g(::Int, ::Float64) = "Int, Float64"
g(::Float64, ::Int) = "Float64, Int"

# 如果传入 (Int, Int),两个方法都不比另一个更具体
# g(1, 2)  # ❌ MethodError: ambiguous

解决歧义

# 方法1:添加更具体的方法
g(::Int, ::Int) = "两个都是 Int"

println(g(1, 2))  # 两个都是 Int

# 方法2:使用 Union 类型
h(::Union{Int, Float64}, ::Union{Int, Float64}) = "数值组合"
h(::Int, ::Float64) = "Int, Float64"
h(::Float64, ::Int) = "Float64, Int"

⚠️ 歧义警告:当 Julia 遇到歧义时会抛出 MethodError,必须手动解决。避免设计容易产生歧义的类型层次。

11.3 调试工具:@which 与 @code_warntype

@which:查看方法选择

println(@which 1 + 2)        # +(x::Int64, y::Int64) in Base at int.jl:87
println(@which 1.0 + 2.0)    # +(x::Float64, y::Float64) in Base at float.jl:406
println(@which [1,2] + [3,4]) # +(x::AbstractArray, y::AbstractArray)

@code_warntype:类型稳定性检查

function unstable_example(x)
    if x > 0
        return x
    else
        return "negative"
    end
end

# 红色标注表示类型不稳定
@code_warntype unstable_example(1)

function stable_example(x::Float64)
    if x > 0
        return x
    else
        return -x
    end
end

@code_warntype stable_example(1.0)  # 全绿,类型稳定

methods 函数

# 查看某个函数的所有方法
println(methods(+))   # 列出 + 的所有定义
println(methods(sin)) # 列出 sin 的所有定义

💡 调试技巧:写新方法后,用 @which 验证是否选择了预期的方法。用 @code_warntype 检查类型稳定性。

11.4 抽象类型派发

abstract type Shape end
struct Circle <: Shape
    radius::Float64
end
struct Rectangle <: Shape
    width::Float64
    height::Float64
end

# 对抽象类型定义方法
area(s::Shape) = error("请为 $(typeof(s)) 实现 area 方法")
area(c::Circle) = π * c.radius^2
area(r::Rectangle) = r.width * r.height

# 更高抽象层次的函数
perimeter(s::Circle) = 2π * s.radius
perimeter(r::Rectangle) = 2(r.width + r.height)

# 通用函数,接受任意 Shape
describe(s::Shape) = "$(typeof(s)) 的面积为 $(area(s))"

println(describe(Circle(5.0)))      # Circle 的面积为 78.53981633974483
println(describe(Rectangle(3.0, 4.0))) # Rectangle 的面积为 12.0

11.5 参数化类型派发

# 对容器类型进行派发
function sum_first(container::Vector{T}) where T <: Number
    return container[1]
end

function sum_first(container::Vector{Vector{T}}) where T
    return sum(v[1] for v in container)
end

println(sum_first([1, 2, 3]))           # 1
println(sum_first([[1, 2], [3, 4]]))    # 4

元组类型派发

# 对元组进行模式匹配式派发
process(t::Tuple{Int, String}) = "Int-String: $(t[1]) $(t[2])"
process(t::Tuple{String, Int}) = "String-Int: $(t[1]) $(t[2])"
process(t::Tuple{Vararg{Int}}) = "全 Int 元组: $t"

println(process((1, "hello")))   # Int-String: 1 hello
println(process(("hello", 1)))   # String-Int: hello 1
println(process((1, 2, 3)))      # 全 Int 元组: (1, 2, 3)

11.6 Val 分发(编译期计算)

Val 是一种将值提升为类型的技术,实现编译期分发。

# 运行时根据值选择不同实现
# ❌ 低效:运行时分支
function compute_slow(x, mode)
    if mode == :sin
        return sin(x)
    elseif mode == :cos
        return cos(x)
    end
end

# ✅ 高效:编译期分发
compute_fast(x, ::Val{:sin}) = sin(x)
compute_fast(x, ::Val{:cos}) = cos(x)
compute_fast(x, ::Val{:tan}) = tan(x)

# 调用时使用 Val 包装
println(compute_fast(π/4, Val(:sin)))  # 0.7071067811865475
println(compute_fast(π/4, Val(:cos)))  # 0.7071067811865476

实际应用:维度感知的数组操作

# 根据维度参数选择不同的实现
function reduce_along(A::AbstractArray, ::Val{dims}) where dims
    return sum(A; dims=dims)
end

A = rand(3, 4, 5)
println(size(reduce_along(A, Val(1))))  # (1, 4, 5)
println(size(reduce_along(A, Val(2))))  # (3, 1, 5)
println(size(reduce_along(A, Val(3))))  # (3, 4, 1)

💡 Val 使用原则:仅在编译期已知的小值(如符号、整数常量)上使用 Val。不要对运行时变量使用 Val,否则会丧失灵活性。

11.7 多重派发设计模式

模式一:类型分层

abstract type Transport end
struct Car <: Transport
    fuel_efficiency::Float64  # km/L
end
struct Train <: Transport
    passenger_count::Int
end
struct Airplane <: Transport
    range_km::Float64
end

# 每种交通工具有自己的碳排放计算
carbon_footprint(t::Car, distance_km) = distance_km / t.fuel_efficiency * 2.3
carbon_footprint(t::Train, distance_km) = distance_km * 0.041 / t.passenger_count
carbon_footprint(t::Airplane, distance_km) = distance_km * 0.255

# 通用比较函数
function greener(t1::Transport, t2::Transport, distance_km)
    cf1 = carbon_footprint(t1, distance_km)
    cf2 = carbon_footprint(t2, distance_km)
    return cf1 < cf2 ? t1 : t2
end

car = Car(15.0)
train = Train(200)
plane = Airplane(5000.0)

println("更环保: ", typeof(greener(car, train, 100.0)))

模式二:双分发实现二元操作

# 矩阵-向量乘法的多重派发
struct DenseMatrix
    data::Matrix{Float64}
end
struct SparseMatrix
    values::Vector{Float64}
    rowind::Vector{Int}
    colptr::Vector{Int}
    size::Tuple{Int,Int}
end

# 密集矩阵 × 向量
multiply(A::DenseMatrix, x::Vector{Float64}) = A.data * x

# 稀疏矩阵 × 向量
function multiply(A::SparseMatrix, x::Vector{Float64})
    m, n = A.size
    y = zeros(m)
    for j in 1:n
        for idx in A.colptr[j]:A.colptr[j+1]-1
            i = A.rowind[idx]
            y[i] += A.values[idx] * x[j]
        end
    end
    return y
end

11.8 Visitor 模式 vs 多重派发

特性Visitor 模式 (OOP)多重派发 (Julia)
添加新类型需修改所有 Visitor直接定义新方法
添加新操作定义新 Visitor直接定义新函数
分发依据单个对象类型所有参数类型
代码组织类内部 + Visitor独立函数
类型检查运行时编译时

Visitor 模式的 Julia 等价

# 不需要 Visitor 类,多重派发天然支持
abstract type ExprNode end
struct Num <: ExprNode
    value::Float64
end
struct Add <: ExprNode
    left::ExprNode
    right::ExprNode
end
struct Mul <: ExprNode
    left::ExprNode
    right::ExprNode
end

# 求值
evaluate(n::Num) = n.value
evaluate(a::Add) = evaluate(a.left) + evaluate(a.right)
evaluate(m::Mul) = evaluate(m.left) * evaluate(m.right)

# 打印
to_string(n::Num) = string(n.value)
to_string(a::Add) = "($(to_string(a.left)) + $(to_string(a.right)))"
to_string(m::Mul) = "($(to_string(m.left)) * $(to_string(m.right)))"

expr = Add(Num(1), Mul(Num(2), Num(3)))
println(to_string(expr))  # (1.0 + (2.0 * 3.0))
println(evaluate(expr))    # 7.0

11.9 实际案例:图形渲染系统

abstract type Renderable end
abstract type Color end
struct Red <: Color end
struct Blue <: Color end

struct Pixel <: Renderable
    x::Int
    y::Int
    color::Color
end
struct Line <: Renderable
    start::Pixel
    stop::Pixel
end
struct Circle <: Renderable
    center::Pixel
    radius::Float64
end

# 渲染到不同后端
abstract type RenderBackend end
struct OpenGL <: RenderBackend end
struct Vulkan <: RenderBackend end

# 双重分发:形状 × 后端
render(::OpenGL, p::Pixel) = "OpenGL: 绘制像素 ($(p.x), $(p.y))"
render(::Vulkan, p::Pixel) = "Vulkan: 绘制像素 ($(p.x), $(p.y))"
render(::OpenGL, l::Line) = "OpenGL: 绘制线段"
render(::Vulkan, l::Line) = "Vulkan: 绘制线段(优化)"
render(::OpenGL, c::Circle) = "OpenGL: 绘制圆形 r=$(c.radius)"
render(::Vulkan, c::Circle) = "Vulkan: 绘制圆形(GPU加速)"

# 批量渲染
function render_all(backend::RenderBackend, objects::Vector{<:Renderable})
    for obj in objects
        render(backend, obj)
    end
end

objects = [
    Pixel(10, 20, Red()),
    Line(Pixel(0,0,Red()), Pixel(100,100,Blue())),
    Circle(Pixel(50,50,Red()), 25.0)
]

render_all(OpenGL(), objects)

11.10 实际案例:物理模拟碰撞检测

abstract type Shape2D end
struct Disk <: Shape2D
    center::Tuple{Float64, Float64}
    radius::Float64
end
struct AABB <: Shape2D  # 轴对齐包围盒
    min::Tuple{Float64, Float64}
    max::Tuple{Float64, Float64}
end

distance(a::Tuple, b::Tuple) = sqrt((a[1]-b[1])^2 + (a[2]-b[2])^2)

# Disk-Disk 碰撞
collides(a::Disk, b::Disk) = distance(a.center, b.center) < a.radius + b.radius

# AABB-AABB 碰撞
function collides(a::AABB, b::AABB)
    a.min[1] < b.max[1] && a.max[1] > b.min[1] &&
    a.min[2] < b.max[2] && a.max[2] > b.min[2]
end

# Disk-AABB 碰撞(交叉类型)
function collides(d::Disk, b::AABB)
    cx = clamp(d.center[1], b.min[1], b.max[1])
    cy = clamp(d.center[2], b.min[2], b.max[2])
    distance(d.center, (cx, cy)) < d.radius
end
collides(b::AABB, d::Disk) = collides(d, b)  # 利用对称性

d1 = Disk((0.0, 0.0), 5.0)
d2 = Disk((3.0, 0.0), 3.0)
box = AABB((2.0, 2.0), (8.0, 8.0))

println("Disk-Disk: ", collides(d1, d2))   # true
println("Disk-Box:  ", collides(d1, box))   # true

11.11 扩展阅读

资源链接
Julia 官方文档 - Methodshttps://docs.julialang.org/en/v1/manual/methods/
Julia 官方文档 - Typeshttps://docs.julialang.org/en/v1/manual/types/
《Julia 高性能编程》 - 派发章节O’Reilly 出版
多重派发设计模式https://www.youtube.com/watch?v=kc9HwsxE1pY
Stefan Karpinski 的派发演示https://www.youtube.com/watch?v=8Pgj9cCBq1s

11.12 本章小结

主题要点
方法选择最具体优先,有歧义需手动解决
调试工具@which 看选择,@code_warntype 看类型
Val 分发编译期值→类型转换,消除运行时分支
设计模式类型分层、双分发、表达式求值
适用场景二元操作、渲染系统、物理引擎