Julia 教程 / 宏实战:DSL 构建
宏实战:DSL 构建
宏(Macro)是 Julia 元编程的核心工具。本文深入探讨宏的高级用法,展示如何利用宏构建领域特定语言(DSL),实现编译期代码生成和优化。
1. 宏展开与 @macroexpand
1.1 理解宏展开
# 宏在编译期将代码转换为另一种代码
macro sayhello(name)
return :(println("Hello, ", $name, "!"))
end
@sayhello "World"
# 展开后等价于: println("Hello, ", "World", "!")
1.2 使用 @macroexpand 查看展开结果
# 查看宏展开后的代码
ex = @macroexpand @sayhello "Julia"
println(ex)
# :((println)("Hello, ", "Julia", "!"))
# 更复杂的宏
macro timed(expr)
return quote
local t0 = time()
local val = $expr
local t1 = time()
println("耗时: ", t1 - t0, " 秒")
val
end
end
ex = @macroexpand @timed sum(1:1000)
# 展开后会看到完整的 time() 调用和计时逻辑
1.3 宏展开步骤
源代码 → 解析为 AST → 宏展开 → Lowering → 类型推断 → 代码生成
↑
宏在此阶段工作
2. 表达式构造
2.1 quote 与 :( )
# quote 块创建带行号的表达式
ex = quote
x = 1
y = x + 2
println(y)
end
# :( ) 创建单行表达式(无行号)
ex = :(x + y * z)
# 等价于
ex = Expr(:call, :+, :x, Expr(:call, :*, :y, :z))
2.2 Expr 对象详解
# 手动构造 Expr
ex = Expr(:call, :sin, :x)
# 等价于 :(sin(x))
# 函数定义
ex = Expr(:function,
Expr(:call, :foo, :x), # 函数签名
Expr(:block, # 函数体
Expr(:(=), :y, Expr(:call, :*, :x, :x)),
:y
)
)
# 等价于 :(function foo(x); y = x * x; y; end)
# 条件表达式
ex = Expr(:if, :(x > 0), :(x), :(-x))
# 等价于 :(if x > 0; x; else; -x; end)
2.3 在表达式中插值
# 在 quote 中使用 $ 插值
function make_assignment(name, value)
:($name = $value)
end
ex = make_assignment(:x, 42)
# :(x = 42)
# 插值表达式
body = :(x^2 + 1)
ex = :(function f(x); return $body; end)
2.4 splatting 展开
# 在宏中使用 ... 展开参数
macro call_many(f, args...)
return :($f($(args...)))
end
@call_many sin 1.0 2.0 3.0
# :(sin(1.0, 2.0, 3.0))
3. 符号操作与遍历
3.1 遍历表达式树
# 递归遍历 AST
function walk_expr(f, ex::Expr)
f(ex) # 对当前节点调用 f
for arg in ex.args
if arg isa Expr
walk_expr(f, arg)
end
end
end
# 收集所有符号
function collect_symbols(ex::Expr)
symbols = Symbol[]
walk_expr(ex) do node
if node isa Symbol
push!(symbols, node)
elseif node isa Expr && node.head == :(::)
# 类型标注
push!(symbols, node.args[1])
end
end
return unique(symbols)
end
ex = :(function foo(x::Int, y::Float64)
z = x + y
return z * 2
end)
collect_symbols(ex) # [:foo, :x, :y, :z]
3.2 MacroTools.jl
using MacroTools
# 模式匹配
ex = :(f(x) = x + 1)
@capture(ex, f_(args__) = body_)
# f = :f, args = [:x], body = :(x + 1)
# 表达式替换
postwalk(ex) do x
if x === :x
return :y # 把所有 x 替换为 y
end
return x
end
# :(f(y) = y + 1)
# rmlines 移除行号节点
clean_ex = rmlines(:(begin
x = 1
y = 2
end))
3.3 自定义遍历函数
# prewalk: 先处理父节点,再处理子节点
# postwalk: 先处理子节点,再处理父节点
function replace_op(ex::Expr, old_op::Symbol, new_op::Symbol)
MacroTools.postwalk(ex) do x
if x isa Expr && x.head == :call && x.args[1] == old_op
x.args[1] = new_op
end
return x
end
end
# 把 + 替换为 *
replace_op(:(a + b * c), :+, :*)
# :(a * b * c)
4. 自定义 DSL
4.1 查询 DSL
# SQL-like 查询 DSL
macro query(source, condition)
esc(quote
filter(row -> $condition, $source)
end)
end
# 使用
data = [
(name="Alice", age=25, city="Beijing"),
(name="Bob", age=30, city="Shanghai"),
(name="Charlie", age=35, city="Beijing"),
(name="Diana", age=28, city="Guangzhou"),
]
result = @query data begin
row.age > 25 && row.city == "Beijing"
end
# [(name="Charlie", age=35, city="Beijing")]
4.2 配置 DSL
# 配置文件 DSL
macro config(ex)
# 将配置块转换为 Dict
items = Pair{String,Any}[]
for arg in ex.args
if arg isa Expr && arg.head == :(=)
key = string(arg.args[1])
val = arg.args[2]
push!(items, key => val)
end
end
esc(:(Dict($([:($k => $v) for (k, v) in items]...))))
end
cfg = @config begin
host = "localhost"
port = 8080
debug = true
max_connections = 100
end
# Dict("host" => "localhost", "port" => 8080, ...)
4.3 数值计算 DSL
# 向量化运算 DSL
macro vec(expr)
# 自动将标量运算转换为向量化
esc(MacroTools.postwalk(expr) do x
if x isa Expr && x.head == :call
func = x.args[1]
# 用广播替换普通调用
return :($func.($(x.args[2:end]...)))
end
return x
end)
end
x = [1.0, 2.0, 3.0]
y = [4.0, 5.0, 6.0]
result = @vec sin(x) + cos(y) * 2
# :(sin.(x) .+ cos.(y) .* 2)
5. @generated 函数
5.1 基本用法
# @generated 在编译期根据参数类型生成代码
@generated function type_size(::T) where T
if isbitstype(T)
return :(sizeof(T))
else
return :(sizeof(Ptr{Nothing}))
end
end
type_size(Int64) # 8
type_size(Float32) # 4
type_size(String) # 8(指针大小)
5.2 生成特化代码
@generated function dot_product(a::NTuple{N,T}, b::NTuple{N,T}) where {N,T}
# 编译期展开循环
expr = :(+)
for i in 1:N
expr = :($expr + a[$i] * b[$i])
end
return expr
end
dot_product((1, 2, 3), (4, 5, 6))
# 编译期展开为: 1*4 + 2*5 + 3*6 = 32
5.3 @generated 与普通宏的区别
| 特性 | 宏 | @generated |
|---|---|---|
| 执行时机 | 解析时 | 编译时(类型确定后) |
| 输入 | AST | 类型信息 |
| 输出 | AST | AST |
| 特化 | 否 | 是(每种类型组合一次) |
| 用途 | 代码变换 | 类型驱动的代码生成 |
6. 编译期计算
6.1 Val{} 模式
# 使用 Val 将值提升到类型层面
function process(::Val{:fast}, data)
# 快速路径
return sum(data)
end
function process(::Val{:accurate}, data)
# 精确路径
return sum(big.(data))
end
process(Val(:fast), [1, 2, 3]) # 使用快速路径
process(Val(:accurate), [1, 2, 3]) # 使用精确路径
6.2 编译期消除分支
# 条件编译
@generated function compute(::Val{mode}, x) where mode
if mode == :add
return :(x + 1)
elseif mode == :mul
return :(x * 2)
else
return :(x)
end
end
# 运行时无分支开销
compute(Val(:add), 5) # 6
compute(Val(:mul), 5) # 10
6.3 类型级编程
# 编译期计算斐波那契数列
struct Fibonacci{N} end
fib(::Fibonacci{0}) = 0
fib(::Fibonacci{1}) = 1
fib(::Fibonacci{N}) where N = fib(Fibonacci{N-1}()) + fib(Fibonacci{N-2}())
# 在编译期求值
const fib_10 = Fibonacci{10}()
# fib(fib_10) = 55,这个计算在编译期完成
7. 宏组合子
7.1 宏叠加
# 将多个宏组合为一个
macro compose(ex, macros...)
result = ex
for m in reverse(macros)
result = :($m $result)
end
return esc(result)
end
# 使用
@compose :(x = rand()) @assert @inbounds
# 先应用 @inbounds,再应用 @assert
7.2 条件宏
# 根据条件选择不同的代码生成
macro optimize(flag, expr)
if flag == :fast
esc(:(@fastmath $expr))
elseif flag == :simd
esc(:(@simd $expr))
else
esc(expr)
end
end
@optimize :fast sin(x) + cos(y)
@optimize :simd for i in 1:n
a[i] = b[i] + c[i]
end
7.3 装饰器宏
# 类似 Python 装饰器的宏
macro memoize(ex)
# 给函数添加缓存
MacroTools.isdef(ex) || error("需要函数定义")
fname = ex.args[1].args[1]
fargs = ex.args[1].args[2:end]
body = ex.args[2]
cache_name = Symbol(:_cache_, fname)
esc(quote
const $cache_name = Dict()
function $fname($(fargs...))
key = ($(fargs...),)
if haskey($cache_name, key)
return $cache_name[key]
end
result = begin $body end
$cache_name[key] = result
return result
end
end)
end
@memoize function fibonacci(n::Int)
if n <= 1
return n
end
return fibonacci(n-1) + fibonacci(n-2)
end
fibonacci(100) # 瞬间返回,带缓存
8. 宏测试策略
8.1 测试宏展开
using Test
macro my_macro(x)
:($x + 1)
end
@testset "my_macro" begin
# 测试展开结果
ex = @macroexpand @my_macro 5
@test ex == :(5 + 1)
# 测试运行结果
@test @my_macro(5) == 6
@test @my_macro(0) == 1
@test @my_macro(-1) == 0
end
8.2 测试生成的代码
macro make_struct(name, fields...)
# 生成结构体定义
field_defs = [:($(esc(f.args[1]))::$(esc(f.args[2]))) for f in fields]
esc(quote
struct $name
$(field_defs...)
end
end)
end
@testset "make_struct" begin
@macroexpand @make_struct Point x::Float64 y::Float64
# 实际测试
@make_struct TestPoint x::Float64 y::Float64
p = TestPoint(1.0, 2.0)
@test p.x == 1.0
@test p.y == 2.0
end
8.3 测试边界条件
@testset "宏边界条件" begin
# 空输入
@test @my_macro(0) == 1
# 类型测试
@test @my_macro(1.0) == 2.0
@test @my_macro(big"1") == big"2"
# 错误输入
@test_throws MethodError @my_macro("string")
end
9. 实际案例:构建 SQL 查询 DSL
module SQLDSL
export @sql, @select, @from, @where, @order, @limit
# 查询结构体
struct SQLQuery
select::Vector{Symbol}
from::Symbol
where::Vector{String}
order::Vector{Pair{Symbol,Symbol}}
limit::Union{Int,Nothing}
end
SQLQuery(; select=Symbol[], from=:_, where=String[],
order=Pair{Symbol,Symbol}[], limit=nothing) =
SQLQuery(select, from, where, order, limit)
# @from 宏 — 指定表
macro from(source)
:(SQLQuery(from=$(QuoteNode(source))))
end
# @select 宏 — 选择列
macro select(q, cols...)
quote
let query = $q
SQLQuery(
select=Symbol[$(QuoteNode.(cols)...)],
from=query.from,
where=query.where,
order=query.order,
limit=query.limit
)
end
end |> esc
end
# @where 宏 — 添加条件
macro where(q, condition)
quote
let query = $q
SQLQuery(
select=query.select,
from=query.from,
where=vcat(query.where, string($(QuoteNode(condition)))),
order=query.order,
limit=query.limit
)
end
end |> esc
end
# @order 宏 — 排序
macro order(q, col, dir=:asc)
quote
let query = $q
SQLQuery(
select=query.select,
from=query.from,
where=query.where,
order=vcat(query.order, $(QuoteNode(col)) => $(QuoteNode(dir))),
limit=query.limit
)
end
end |> esc
end
# @limit 宏 — 限制行数
macro limit(q, n)
quote
let query = $q
SQLQuery(
select=query.select,
from=query.from,
where=query.where,
order=query.order,
limit=$n
)
end
end |> esc
end
# 生成 SQL 字符串
function to_sql(q::SQLQuery)
parts = String[]
# SELECT
if isempty(q.select)
push!(parts, "SELECT *")
else
push!(parts, "SELECT " * join(string.(q.select), ", "))
end
# FROM
push!(parts, "FROM $(q.from)")
# WHERE
if !isempty(q.where)
push!(parts, "WHERE " * join(q.where, " AND "))
end
# ORDER BY
if !isempty(q.order)
order_clauses = ["$(col) $(dir)" for (col, dir) in q.order]
push!(parts, "ORDER BY " * join(order_clauses, ", "))
end
# LIMIT
if q.limit !== nothing
push!(parts, "LIMIT $(q.limit)")
end
return join(parts, " ")
end
# @sql 宏 — 完整的链式查询
macro sql(ex)
esc(ex)
end
end
# 使用
using .SQLDSL
q = @from users
q = @select q name age city
q = @where q age > 18
q = @where q city == "Beijing"
q = @order q age desc
q = @limit q 10
to_sql(q)
# SELECT name, age, city FROM users WHERE age > 18 AND city == Beijing ORDER BY age desc LIMIT 10
10. 宏最佳实践
| 实践 | 说明 |
|---|---|
使用 esc() | 避免变量名冲突,除非性需要宏卫生 |
| 最小化宏作用域 | 宏只做代码变换,复杂逻辑放在普通函数中 |
| 充分测试 | 使用 @macroexpand 测试展开结果 |
| 提供清晰报错 | 在宏中使用 error() 提供有意义的错误信息 |
| 文档化 | 记录宏的输入输出和预期行为 |
| 避免过度宏化 | 能用函数解决的问题不要用宏 |
# ❌ 错误:宏中没有使用 esc
macro bad_macro(name)
:(println("Hello, ", $name)) # name 可能发生冲突
end
# ✅ 正确:使用 esc 保护用户代码
macro good_macro(name)
esc(:(println("Hello, ", $name)))
end
业务场景
场景一:测试框架 DSL
构建自定义测试框架,使用宏提供类似 @test, @assert 的简洁语法。宏在编译期注入文件名和行号信息,用于定位测试失败。
场景二:ORM 查询生成
为数据库操作构建 SQL DSL,使用链式宏调用构建查询。编译期检查表名和列名的合法性。
场景三:配置管理
使用 DSL 宏定义应用配置,编译期验证配置项类型,生成类型安全的访问器。
总结
| 主题 | 关键要点 |
|---|---|
| 宏展开 | @macroexpand 查看展开结果 |
| 表达式构造 | quote / :( ) / Expr() |
| 符号操作 | MacroTools.jl 的 postwalk / @capture |
| DSL | 宏可以创建自然的领域特定语法 |
| @generated | 基于类型信息编译期生成代码 |
| 宏测试 | 使用 @macroexpand 和运行时测试 |
| 宏卫生 | 使用 esc() 避免变量捕获 |