强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

Julia 教程 / 宏实战:DSL 构建

宏实战:DSL 构建

宏(Macro)是 Julia 元编程的核心工具。本文深入探讨宏的高级用法,展示如何利用宏构建领域特定语言(DSL),实现编译期代码生成和优化。


1. 宏展开与 @macroexpand

1.1 理解宏展开

# 宏在编译期将代码转换为另一种代码
macro sayhello(name)
    return :(println("Hello, ", $name, "!"))
end

@sayhello "World"
# 展开后等价于: println("Hello, ", "World", "!")

1.2 使用 @macroexpand 查看展开结果

# 查看宏展开后的代码
ex = @macroexpand @sayhello "Julia"
println(ex)
# :((println)("Hello, ", "Julia", "!"))

# 更复杂的宏
macro timed(expr)
    return quote
        local t0 = time()
        local val = $expr
        local t1 = time()
        println("耗时: ", t1 - t0, " 秒")
        val
    end
end

ex = @macroexpand @timed sum(1:1000)
# 展开后会看到完整的 time() 调用和计时逻辑

1.3 宏展开步骤

源代码 → 解析为 AST → 宏展开 → Lowering → 类型推断 → 代码生成
                  ↑
            宏在此阶段工作

2. 表达式构造

2.1 quote 与 :( )

# quote 块创建带行号的表达式
ex = quote
    x = 1
    y = x + 2
    println(y)
end

# :( ) 创建单行表达式(无行号)
ex = :(x + y * z)

# 等价于
ex = Expr(:call, :+, :x, Expr(:call, :*, :y, :z))

2.2 Expr 对象详解

# 手动构造 Expr
ex = Expr(:call, :sin, :x)
# 等价于 :(sin(x))

# 函数定义
ex = Expr(:function,
    Expr(:call, :foo, :x),           # 函数签名
    Expr(:block,                      # 函数体
        Expr(:(=), :y, Expr(:call, :*, :x, :x)),
        :y
    )
)
# 等价于 :(function foo(x); y = x * x; y; end)

# 条件表达式
ex = Expr(:if, :(x > 0), :(x), :(-x))
# 等价于 :(if x > 0; x; else; -x; end)

2.3 在表达式中插值

# 在 quote 中使用 $ 插值
function make_assignment(name, value)
    :($name = $value)
end

ex = make_assignment(:x, 42)
# :(x = 42)

# 插值表达式
body = :(x^2 + 1)
ex = :(function f(x); return $body; end)

2.4 splatting 展开

# 在宏中使用 ... 展开参数
macro call_many(f, args...)
    return :($f($(args...)))
end

@call_many sin 1.0 2.0 3.0
# :(sin(1.0, 2.0, 3.0))

3. 符号操作与遍历

3.1 遍历表达式树

# 递归遍历 AST
function walk_expr(f, ex::Expr)
    f(ex)  # 对当前节点调用 f
    for arg in ex.args
        if arg isa Expr
            walk_expr(f, arg)
        end
    end
end

# 收集所有符号
function collect_symbols(ex::Expr)
    symbols = Symbol[]
    walk_expr(ex) do node
        if node isa Symbol
            push!(symbols, node)
        elseif node isa Expr && node.head == :(::)
            # 类型标注
            push!(symbols, node.args[1])
        end
    end
    return unique(symbols)
end

ex = :(function foo(x::Int, y::Float64)
    z = x + y
    return z * 2
end)

collect_symbols(ex)  # [:foo, :x, :y, :z]

3.2 MacroTools.jl

using MacroTools

# 模式匹配
ex = :(f(x) = x + 1)

@capture(ex, f_(args__) = body_)
# f = :f, args = [:x], body = :(x + 1)

# 表达式替换
postwalk(ex) do x
    if x === :x
        return :y  # 把所有 x 替换为 y
    end
    return x
end
# :(f(y) = y + 1)

# rmlines 移除行号节点
clean_ex = rmlines(:(begin
    x = 1
    y = 2
end))

3.3 自定义遍历函数

# prewalk: 先处理父节点,再处理子节点
# postwalk: 先处理子节点,再处理父节点

function replace_op(ex::Expr, old_op::Symbol, new_op::Symbol)
    MacroTools.postwalk(ex) do x
        if x isa Expr && x.head == :call && x.args[1] == old_op
            x.args[1] = new_op
        end
        return x
    end
end

# 把 + 替换为 *
replace_op(:(a + b * c), :+, :*)
# :(a * b * c)

4. 自定义 DSL

4.1 查询 DSL

# SQL-like 查询 DSL
macro query(source, condition)
    esc(quote
        filter(row -> $condition, $source)
    end)
end

# 使用
data = [
    (name="Alice", age=25, city="Beijing"),
    (name="Bob", age=30, city="Shanghai"),
    (name="Charlie", age=35, city="Beijing"),
    (name="Diana", age=28, city="Guangzhou"),
]

result = @query data begin
    row.age > 25 && row.city == "Beijing"
end
# [(name="Charlie", age=35, city="Beijing")]

4.2 配置 DSL

# 配置文件 DSL
macro config(ex)
    # 将配置块转换为 Dict
    items = Pair{String,Any}[]
    
    for arg in ex.args
        if arg isa Expr && arg.head == :(=)
            key = string(arg.args[1])
            val = arg.args[2]
            push!(items, key => val)
        end
    end
    
    esc(:(Dict($([:($k => $v) for (k, v) in items]...))))
end

cfg = @config begin
    host = "localhost"
    port = 8080
    debug = true
    max_connections = 100
end
# Dict("host" => "localhost", "port" => 8080, ...)

4.3 数值计算 DSL

# 向量化运算 DSL
macro vec(expr)
    # 自动将标量运算转换为向量化
    esc(MacroTools.postwalk(expr) do x
        if x isa Expr && x.head == :call
            func = x.args[1]
            # 用广播替换普通调用
            return :($func.($(x.args[2:end]...)))
        end
        return x
    end)
end

x = [1.0, 2.0, 3.0]
y = [4.0, 5.0, 6.0]

result = @vec sin(x) + cos(y) * 2
# :(sin.(x) .+ cos.(y) .* 2)

5. @generated 函数

5.1 基本用法

# @generated 在编译期根据参数类型生成代码
@generated function type_size(::T) where T
    if isbitstype(T)
        return :(sizeof(T))
    else
        return :(sizeof(Ptr{Nothing}))
    end
end

type_size(Int64)      # 8
type_size(Float32)    # 4
type_size(String)     # 8(指针大小)

5.2 生成特化代码

@generated function dot_product(a::NTuple{N,T}, b::NTuple{N,T}) where {N,T}
    # 编译期展开循环
    expr = :(+)
    for i in 1:N
        expr = :($expr + a[$i] * b[$i])
    end
    return expr
end

dot_product((1, 2, 3), (4, 5, 6))
# 编译期展开为: 1*4 + 2*5 + 3*6 = 32

5.3 @generated 与普通宏的区别

特性@generated
执行时机解析时编译时(类型确定后)
输入AST类型信息
输出ASTAST
特化是(每种类型组合一次)
用途代码变换类型驱动的代码生成

6. 编译期计算

6.1 Val{} 模式

# 使用 Val 将值提升到类型层面
function process(::Val{:fast}, data)
    # 快速路径
    return sum(data)
end

function process(::Val{:accurate}, data)
    # 精确路径
    return sum(big.(data))
end

process(Val(:fast), [1, 2, 3])      # 使用快速路径
process(Val(:accurate), [1, 2, 3])  # 使用精确路径

6.2 编译期消除分支

# 条件编译
@generated function compute(::Val{mode}, x) where mode
    if mode == :add
        return :(x + 1)
    elseif mode == :mul
        return :(x * 2)
    else
        return :(x)
    end
end

# 运行时无分支开销
compute(Val(:add), 5)  # 6
compute(Val(:mul), 5)  # 10

6.3 类型级编程

# 编译期计算斐波那契数列
struct Fibonacci{N} end

fib(::Fibonacci{0}) = 0
fib(::Fibonacci{1}) = 1
fib(::Fibonacci{N}) where N = fib(Fibonacci{N-1}()) + fib(Fibonacci{N-2}())

# 在编译期求值
const fib_10 = Fibonacci{10}()
# fib(fib_10) = 55,这个计算在编译期完成

7. 宏组合子

7.1 宏叠加

# 将多个宏组合为一个
macro compose(ex, macros...)
    result = ex
    for m in reverse(macros)
        result = :($m $result)
    end
    return esc(result)
end

# 使用
@compose :(x = rand()) @assert @inbounds
# 先应用 @inbounds,再应用 @assert

7.2 条件宏

# 根据条件选择不同的代码生成
macro optimize(flag, expr)
    if flag == :fast
        esc(:(@fastmath $expr))
    elseif flag == :simd
        esc(:(@simd $expr))
    else
        esc(expr)
    end
end

@optimize :fast sin(x) + cos(y)
@optimize :simd for i in 1:n
    a[i] = b[i] + c[i]
end

7.3 装饰器宏

# 类似 Python 装饰器的宏
macro memoize(ex)
    # 给函数添加缓存
    MacroTools.isdef(ex) || error("需要函数定义")
    
    fname = ex.args[1].args[1]
    fargs = ex.args[1].args[2:end]
    body = ex.args[2]
    
    cache_name = Symbol(:_cache_, fname)
    
    esc(quote
        const $cache_name = Dict()
        
        function $fname($(fargs...))
            key = ($(fargs...),)
            if haskey($cache_name, key)
                return $cache_name[key]
            end
            result = begin $body end
            $cache_name[key] = result
            return result
        end
    end)
end

@memoize function fibonacci(n::Int)
    if n <= 1
        return n
    end
    return fibonacci(n-1) + fibonacci(n-2)
end

fibonacci(100)  # 瞬间返回,带缓存

8. 宏测试策略

8.1 测试宏展开

using Test

macro my_macro(x)
    :($x + 1)
end

@testset "my_macro" begin
    # 测试展开结果
    ex = @macroexpand @my_macro 5
    @test ex == :(5 + 1)
    
    # 测试运行结果
    @test @my_macro(5) == 6
    @test @my_macro(0) == 1
    @test @my_macro(-1) == 0
end

8.2 测试生成的代码

macro make_struct(name, fields...)
    # 生成结构体定义
    field_defs = [:($(esc(f.args[1]))::$(esc(f.args[2]))) for f in fields]
    
    esc(quote
        struct $name
            $(field_defs...)
        end
    end)
end

@testset "make_struct" begin
    @macroexpand @make_struct Point x::Float64 y::Float64
    
    # 实际测试
    @make_struct TestPoint x::Float64 y::Float64
    p = TestPoint(1.0, 2.0)
    @test p.x == 1.0
    @test p.y == 2.0
end

8.3 测试边界条件

@testset "宏边界条件" begin
    # 空输入
    @test @my_macro(0) == 1
    
    # 类型测试
    @test @my_macro(1.0) == 2.0
    @test @my_macro(big"1") == big"2"
    
    # 错误输入
    @test_throws MethodError @my_macro("string")
end

9. 实际案例:构建 SQL 查询 DSL

module SQLDSL

export @sql, @select, @from, @where, @order, @limit

# 查询结构体
struct SQLQuery
    select::Vector{Symbol}
    from::Symbol
    where::Vector{String}
    order::Vector{Pair{Symbol,Symbol}}
    limit::Union{Int,Nothing}
end

SQLQuery(; select=Symbol[], from=:_, where=String[],
         order=Pair{Symbol,Symbol}[], limit=nothing) =
    SQLQuery(select, from, where, order, limit)

# @from 宏 — 指定表
macro from(source)
    :(SQLQuery(from=$(QuoteNode(source))))
end

# @select 宏 — 选择列
macro select(q, cols...)
    quote
        let query = $q
            SQLQuery(
                select=Symbol[$(QuoteNode.(cols)...)],
                from=query.from,
                where=query.where,
                order=query.order,
                limit=query.limit
            )
        end
    end |> esc
end

# @where 宏 — 添加条件
macro where(q, condition)
    quote
        let query = $q
            SQLQuery(
                select=query.select,
                from=query.from,
                where=vcat(query.where, string($(QuoteNode(condition)))),
                order=query.order,
                limit=query.limit
            )
        end
    end |> esc
end

# @order 宏 — 排序
macro order(q, col, dir=:asc)
    quote
        let query = $q
            SQLQuery(
                select=query.select,
                from=query.from,
                where=query.where,
                order=vcat(query.order, $(QuoteNode(col)) => $(QuoteNode(dir))),
                limit=query.limit
            )
        end
    end |> esc
end

# @limit 宏 — 限制行数
macro limit(q, n)
    quote
        let query = $q
            SQLQuery(
                select=query.select,
                from=query.from,
                where=query.where,
                order=query.order,
                limit=$n
            )
        end
    end |> esc
end

# 生成 SQL 字符串
function to_sql(q::SQLQuery)
    parts = String[]
    
    # SELECT
    if isempty(q.select)
        push!(parts, "SELECT *")
    else
        push!(parts, "SELECT " * join(string.(q.select), ", "))
    end
    
    # FROM
    push!(parts, "FROM $(q.from)")
    
    # WHERE
    if !isempty(q.where)
        push!(parts, "WHERE " * join(q.where, " AND "))
    end
    
    # ORDER BY
    if !isempty(q.order)
        order_clauses = ["$(col) $(dir)" for (col, dir) in q.order]
        push!(parts, "ORDER BY " * join(order_clauses, ", "))
    end
    
    # LIMIT
    if q.limit !== nothing
        push!(parts, "LIMIT $(q.limit)")
    end
    
    return join(parts, " ")
end

# @sql 宏 — 完整的链式查询
macro sql(ex)
    esc(ex)
end

end

# 使用
using .SQLDSL

q = @from users
q = @select q name age city
q = @where q age > 18
q = @where q city == "Beijing"
q = @order q age desc
q = @limit q 10

to_sql(q)
# SELECT name, age, city FROM users WHERE age > 18 AND city == Beijing ORDER BY age desc LIMIT 10

10. 宏最佳实践

实践说明
使用 esc()避免变量名冲突,除非性需要宏卫生
最小化宏作用域宏只做代码变换,复杂逻辑放在普通函数中
充分测试使用 @macroexpand 测试展开结果
提供清晰报错在宏中使用 error() 提供有意义的错误信息
文档化记录宏的输入输出和预期行为
避免过度宏化能用函数解决的问题不要用宏
# ❌ 错误:宏中没有使用 esc
macro bad_macro(name)
    :(println("Hello, ", $name))  # name 可能发生冲突
end

# ✅ 正确:使用 esc 保护用户代码
macro good_macro(name)
    esc(:(println("Hello, ", $name)))
end

业务场景

场景一:测试框架 DSL

构建自定义测试框架,使用宏提供类似 @test, @assert 的简洁语法。宏在编译期注入文件名和行号信息,用于定位测试失败。

场景二:ORM 查询生成

为数据库操作构建 SQL DSL,使用链式宏调用构建查询。编译期检查表名和列名的合法性。

场景三:配置管理

使用 DSL 宏定义应用配置,编译期验证配置项类型,生成类型安全的访问器。


总结

主题关键要点
宏展开@macroexpand 查看展开结果
表达式构造quote / :( ) / Expr()
符号操作MacroTools.jlpostwalk / @capture
DSL宏可以创建自然的领域特定语法
@generated基于类型信息编译期生成代码
宏测试使用 @macroexpand 和运行时测试
宏卫生使用 esc() 避免变量捕获

扩展阅读