Julia 教程 / Julia 迭代器与生成器
13. 迭代器与生成器
迭代器是 Julia 中惰性处理序列数据的基础机制,能显著降低内存使用并提升代码表达力。
13.1 迭代器协议
Julia 的迭代器协议基于 iterate 函数:
# 迭代器协议:实现 iterate 函数
# iterate(iter) → (first_item, state) 或 nothing
# iterate(iter, state) → (next_item, new_state) 或 nothing
# 最简单的例子:计数器
struct Counter
max::Int
end
function Base.iterate(c::Counter)
c.max < 1 ? nothing : (1, 1)
end
function Base.iterate(c::Counter, state)
state >= c.max ? nothing : (state + 1, state + 1)
end
# 使用
for i in Counter(5)
print(i, " ")
end
# 1 2 3 4 5
# 支持所有依赖 iterate 的函数
collect(Counter(3)) # [1, 2, 3]
sum(Counter(100)) # 5050
first(Counter(10)) # 1
实现 Base.length(可选)
Base.length(c::Counter) = c.max
Base.eltype(::Type{Counter}) = Int
# 现在可以用于更多场景
println(length(Counter(10))) # 10
13.2 生成器表达式
生成器是匿名的惰性迭代器,语法类似列表推导但不分配内存。
# 列表推导(立即分配内存)
squares_list = [x^2 for x in 1:10]
# 生成器表达式(惰性求值)
squares_gen = (x^2 for x in 1:10)
# 生成器不分配内存,按需计算
println(sum(x^2 for x in 1:1_000_000)) # 333333833333500000
# 带条件过滤
even_squares = (x^2 for x in 1:20 if x % 2 == 0)
println(collect(even_squares)) # [4, 16, 36, 64, 100, 144, 196, 256, 324, 400]
# 嵌套生成器
pairs = ((i, j) for i in 1:3 for j in 1:3 if i < j)
println(collect(pairs)) # [(1,2), (1,3), (2,3)]
内存对比
# 列表推导:O(n) 内存
list = [sin(x) for x in 1:10_000_000]
println(Base.summarysize(list)) # ~80 MB
# 生成器:O(1) 内存
gen = (sin(x) for x in 1:10_000_000)
println(Base.summarysize(gen)) # ~16 bytes (固定)
13.3 enumerate 与 zip 迭代
enumerate:同时获取索引和值
fruits = ["苹果", "香蕉", "橘子"]
for (i, fruit) in enumerate(fruits)
println("$i. $fruit")
end
# 1. 苹果
# 2. 香蕉
# 3. 橘子
# 使用 collect 获取元组数组
enumerated = collect(enumerate(fruits))
println(enumerated)
# [(1, "苹果"), (2, "香蕉"), (3, "橘子")]
zip:并行迭代多个序列
names = ["Alice", "Bob", "Charlie"]
ages = [25, 30, 35]
scores = [95, 87, 92]
# 同时迭代三个数组
for (name, age, score) in zip(names, ages, scores)
println("$name (年龄$age): $score 分")
end
# Alice (年龄25): 95 分
# Bob (年龄30): 87 分
# Charlie (年龄35): 92 分
# zip 停止在最短的序列
short = [1, 2]
long = [10, 20, 30, 40]
println(collect(zip(short, long))) # [(1, 10), (2, 20)]
组合使用
data = [3.5, 2.1, 4.8, 1.2, 5.0]
# 找到最大值及其索引
max_idx, max_val = reduce(enumerate(data); init=(0, -Inf)) do (mi, mv), (i, v)
v > mv ? (i, v) : (mi, mv)
end
println("最大值在索引 $max_idx: $max_val")
13.4 自定义迭代器
斐波那契迭代器
struct Fibonacci
max_value::Int
end
function Base.iterate(::Fibonacci)
return (1, (1, 1))
end
function Base.iterate(f::Fibonacci, (a, b))
b > f.max_value ? nothing : (b, (b, a + b))
end
Base.eltype(::Type{Fibonacci}) = Int
# 使用
fib = Fibonacci(100)
println(collect(fib))
# [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89]
# 直接在循环中使用
for (i, f) in enumerate(Fibonacci(1000))
i > 10 && break
print("$f ")
end
# 1 1 2 3 5 8 13 21 34 55
滑动窗口迭代器
struct SlidingWindow{T}
data::Vector{T}
size::Int
end
function Base.iterate(sw::SlidingWindow, start=1)
start + sw.size - 1 > length(sw.data) && return nothing
return (sw.data[start:start+sw.size-1], start + 1)
end
Base.length(sw::SlidingWindow) = max(0, length(sw.data) - sw.size + 1)
Base.eltype(::Type{SlidingWindow{T}}) where T = SubArray{T,1,Vector{T},Tuple{UnitRange{Int}},true}
# 使用
data = [1, 2, 3, 4, 5, 6, 7]
sw = SlidingWindow(data, 3)
for window in sw
println(window)
end
# [1, 2, 3]
# [2, 3, 4]
# [3, 4, 5]
# [4, 5, 6]
# [5, 6, 7]
日期范围迭代器
using Dates
struct DateRange
start::Date
stop::Date
step::Day
end
DateRange(start::Date, stop::Date) = DateRange(start, stop, Day(1))
function Base.iterate(dr::DateRange, current=dr.start)
current > dr.stop && return nothing
return (current, current + dr.step)
end
Base.length(dr::DateRange) = fld(dr.stop - dr.start, dr.step) + 1
# 使用
for d in DateRange(Date(2026, 1, 1), Date(2026, 1, 5))
println(dayname(d), ": ", d)
end
13.5 惰性求值
Julia 的许多迭代器函数是惰性的,不会立即分配内存。
# 惰性 map
lazy_squares = Iterators.map(x -> x^2, 1:1000000)
# 不分配内存!只在需要时计算
# 惰性 filter
lazy_evens = Iterators.filter(iseven, 1:1000000)
# 链式操作(全部惰性)
result = Iterators.map(x -> x^2,
Iterators.filter(iseven,
Iterators.map(x -> x + 1, 1:1000000)
)
)
# 只在 collect 时才真正计算
println(sum(result)) # 需要结果时才求值
💡 惰性求值原则:尽可能使用惰性迭代器,只在需要全部结果时才
collect。
13.6 Channel:生产者-消费者
Channel 是 Julia 中用于协程间通信的管道。
# 生产者函数
function producer(ch::Channel)
for i in 1:10
put!(ch, i^2)
sleep(0.1) # 模拟耗时操作
end
end
# 消费者
ch = Channel(producer)
for val in ch
println("收到: $val")
end
有限 Channel
# 带缓冲的 Channel
ch = Channel{Int}(5) # 缓冲大小为 5
# 生产者任务
@async for i in 1:20
put!(ch, i)
println("生产: $i")
end
# 消费
for i in 1:20
val = take!(ch)
println("消费: $val")
end
Channel 作为迭代器
function fibonacci_channel(ch::Channel)
a, b = 0, 1
while true
put!(ch, a)
a, b = b, a + b
end
end
ch = Channel(fibonacci_channel)
for (i, fib) in enumerate(ch)
i > 15 && break
println("F($i) = $fib")
end
13.7 迭代器组合器
take / drop
# take:取前 n 个
first_five = Iterators.take(1:100, 5)
println(collect(first_five)) # [1, 2, 3, 4, 5]
# drop:跳过前 n 个
after_ten = Iterators.drop(1:20, 10)
println(collect(after_ten)) # [11, 12, ..., 20]
# 组合
first_five_after_ten = Iterators.take(Iterators.drop(1:100, 10), 5)
println(collect(first_five_after_ten)) # [11, 12, 13, 14, 15]
filter 与 map
# 惰性 filter
evens = Iterators.filter(iseven, 1:20)
println(collect(evens)) # [2, 4, 6, ..., 20]
# 惰性 map
doubled = Iterators.map(x -> 2x, 1:10)
println(collect(doubled)) # [2, 4, 6, ..., 20]
flatten 与 partition
# flatten:展平嵌套迭代器
nested = [[1, 2], [3, 4], [5, 6]]
flat = Iterators.flatten(nested)
println(collect(flat)) # [1, 2, 3, 4, 5, 6]
# partition:分组(需要 IterTools.jl)
using IterTools
parts = partition(1:10, 3)
for part in parts
println(collect(part))
end
# [1, 2, 3]
# [4, 5, 6]
# [7, 8, 9]
cycle / repeated / product
# cycle:无限循环
cycled = Iterators.cycle([1, 2, 3])
println(collect(Iterators.take(cycled, 10)))
# [1, 2, 3, 1, 2, 3, 1, 2, 3, 1]
# repeated:重复值
reps = Iterators.repeated(π, 5)
println(collect(reps)) # [π, π, π, π, π]
# product:笛卡尔积
coords = Iterators.product(1:3, 1:3)
for (x, y) in coords
print("($x,$y) ")
end
# (1,1) (2,1) (3,1) (1,2) (2,2) (3,2) (1,3) (2,3) (3,3)
13.8 迭代器性能
using BenchmarkTools
data = rand(1_000_000)
# 手动循环
function sum_manual(data)
s = 0.0
for x in data
s += x
end
return s
end
# 使用生成器
sum_gen(data) = sum(x for x in data)
# 使用 sum 直接调用
sum_direct(data) = sum(data)
@btime sum_manual($data) # 最快
@btime sum_direct($data) # 很快(内部也是迭代)
@btime sum_gen($data) # 稍慢(生成器开销)
💡 性能提示:对于简单操作,直接使用内置函数(如
sum)通常最快。生成器在复杂链式操作中优势明显。
13.9 内存效率对比
# 任务:计算前 N 个素数的和
N = 100_000
# 方法1:列表推导(高内存)
isprime(n) = n > 1 && all(n % i != 0 for i in 2:isqrt(n))
primes_list = [x for x in 1:N if isprime(x)]
sum_list = sum(primes_list)
# 方法2:生成器(低内存)
sum_gen = sum(x for x in 1:N if isprime(x))
# 结果相同
println(sum_list == sum_gen) # true
# 内存对比
println("列表内存: ", Base.summarysize(primes_list), " bytes")
println("生成器内存: 几乎为零")
实际场景:大文件处理
# 处理大文件时,避免一次性加载所有内容
function process_file(filename)
# 惰性逐行处理
lines = eachline(filename)
numbers = Iterators.map(line -> parse(Float64, line), lines)
positive = Iterators.filter(x -> x > 0, numbers)
return sum(positive)
end
# 内存使用恒定,无论文件多大
13.10 实际业务场景
场景一:数据清洗管道
function clean_data(raw_lines)
lines = Iterators.filter(!isempty, raw_lines)
trimmed = Iterators.map(strip, lines)
non_comment = Iterators.filter(l -> !startswith(l, "#"), trimmed)
records = Iterators.map(line -> split(line, ","), non_comment)
return records
end
raw = [
"# 这是注释",
"",
"Alice,25,95",
"Bob,30,87",
"# 另一条注释",
"Charlie,35,92"
]
for record in clean_data(raw)
println(record)
end
场景二:无限序列与近似计算
# 莱布尼茨公式计算 π
function leibniz_pi()
return (4 * ((-1)^k / (2k + 1) for k in 0:Inf))
end
# 部分和近似
pi_approx = 4 * sum((-1)^k / (2k + 1) for k in 0:1_000_000)
println("π ≈ $pi_approx (误差: $(abs(pi - pi_approx)))")
场景三:流式聚合
# 滑动平均
function moving_average(data, window_size)
windows = SlidingWindow(collect(data), window_size)
return Iterators.map(w -> sum(w) / length(w), windows)
end
data = sin.(0:0.1:10π)
ma = collect(moving_average(data, 20))
println("平滑后的长度: $(length(ma))")
13.11 扩展阅读
| 资源 | 链接 |
|---|---|
| Julia 官方文档 - Iteration | https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-iteration |
| IterTools.jl | https://github.com/JuliaCollections/IterTools.jl |
| Iterators.jl | https://github.com/JuliaCollections/Iterators.jl |
| Transducers.jl | https://github.com/JuliaFolds/Transducers.jl |
| 惰性求值最佳实践 | https://discourse.julialang.org/t/lazy-evaluation-patterns/ |
13.12 本章小结
| 主题 | 要点 |
|---|---|
| iterate 协议 | 实现 iterate(iter) 和 iterate(iter, state) |
| 生成器 | (expr for x in iter) 惰性求值 |
| enumerate/zip | 同时获取索引、并行迭代 |
| 自定义迭代器 | struct + iterate 实现惰性序列 |
| Channel | 生产者-消费者通信 |
| 组合器 | take/drop/filter/map/flatten |
| 内存效率 | 惰性迭代器 O(1) 内存 |