强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

Julia 教程 / Julia 迭代器与生成器

13. 迭代器与生成器

迭代器是 Julia 中惰性处理序列数据的基础机制,能显著降低内存使用并提升代码表达力。

13.1 迭代器协议

Julia 的迭代器协议基于 iterate 函数:

# 迭代器协议:实现 iterate 函数
# iterate(iter)         → (first_item, state) 或 nothing
# iterate(iter, state)  → (next_item, new_state) 或 nothing

# 最简单的例子:计数器
struct Counter
    max::Int
end

function Base.iterate(c::Counter)
    c.max < 1 ? nothing : (1, 1)
end

function Base.iterate(c::Counter, state)
    state >= c.max ? nothing : (state + 1, state + 1)
end

# 使用
for i in Counter(5)
    print(i, " ")
end
# 1 2 3 4 5

# 支持所有依赖 iterate 的函数
collect(Counter(3))       # [1, 2, 3]
sum(Counter(100))         # 5050
first(Counter(10))        # 1

实现 Base.length(可选)

Base.length(c::Counter) = c.max
Base.eltype(::Type{Counter}) = Int

# 现在可以用于更多场景
println(length(Counter(10)))  # 10

13.2 生成器表达式

生成器是匿名的惰性迭代器,语法类似列表推导但不分配内存。

# 列表推导(立即分配内存)
squares_list = [x^2 for x in 1:10]

# 生成器表达式(惰性求值)
squares_gen = (x^2 for x in 1:10)

# 生成器不分配内存,按需计算
println(sum(x^2 for x in 1:1_000_000))  # 333333833333500000

# 带条件过滤
even_squares = (x^2 for x in 1:20 if x % 2 == 0)
println(collect(even_squares))  # [4, 16, 36, 64, 100, 144, 196, 256, 324, 400]

# 嵌套生成器
pairs = ((i, j) for i in 1:3 for j in 1:3 if i < j)
println(collect(pairs))  # [(1,2), (1,3), (2,3)]

内存对比

# 列表推导:O(n) 内存
list = [sin(x) for x in 1:10_000_000]
println(Base.summarysize(list))  # ~80 MB

# 生成器:O(1) 内存
gen = (sin(x) for x in 1:10_000_000)
println(Base.summarysize(gen))   # ~16 bytes (固定)

13.3 enumerate 与 zip 迭代

enumerate:同时获取索引和值

fruits = ["苹果", "香蕉", "橘子"]

for (i, fruit) in enumerate(fruits)
    println("$i. $fruit")
end
# 1. 苹果
# 2. 香蕉
# 3. 橘子

# 使用 collect 获取元组数组
enumerated = collect(enumerate(fruits))
println(enumerated)
# [(1, "苹果"), (2, "香蕉"), (3, "橘子")]

zip:并行迭代多个序列

names = ["Alice", "Bob", "Charlie"]
ages = [25, 30, 35]
scores = [95, 87, 92]

# 同时迭代三个数组
for (name, age, score) in zip(names, ages, scores)
    println("$name (年龄$age): $score 分")
end
# Alice (年龄25): 95 分
# Bob (年龄30): 87 分
# Charlie (年龄35): 92 分

# zip 停止在最短的序列
short = [1, 2]
long = [10, 20, 30, 40]
println(collect(zip(short, long)))  # [(1, 10), (2, 20)]

组合使用

data = [3.5, 2.1, 4.8, 1.2, 5.0]

# 找到最大值及其索引
max_idx, max_val = reduce(enumerate(data); init=(0, -Inf)) do (mi, mv), (i, v)
    v > mv ? (i, v) : (mi, mv)
end

println("最大值在索引 $max_idx: $max_val")

13.4 自定义迭代器

斐波那契迭代器

struct Fibonacci
    max_value::Int
end

function Base.iterate(::Fibonacci)
    return (1, (1, 1))
end

function Base.iterate(f::Fibonacci, (a, b))
    b > f.max_value ? nothing : (b, (b, a + b))
end

Base.eltype(::Type{Fibonacci}) = Int

# 使用
fib = Fibonacci(100)
println(collect(fib))
# [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89]

# 直接在循环中使用
for (i, f) in enumerate(Fibonacci(1000))
    i > 10 && break
    print("$f ")
end
# 1 1 2 3 5 8 13 21 34 55

滑动窗口迭代器

struct SlidingWindow{T}
    data::Vector{T}
    size::Int
end

function Base.iterate(sw::SlidingWindow, start=1)
    start + sw.size - 1 > length(sw.data) && return nothing
    return (sw.data[start:start+sw.size-1], start + 1)
end

Base.length(sw::SlidingWindow) = max(0, length(sw.data) - sw.size + 1)
Base.eltype(::Type{SlidingWindow{T}}) where T = SubArray{T,1,Vector{T},Tuple{UnitRange{Int}},true}

# 使用
data = [1, 2, 3, 4, 5, 6, 7]
sw = SlidingWindow(data, 3)
for window in sw
    println(window)
end
# [1, 2, 3]
# [2, 3, 4]
# [3, 4, 5]
# [4, 5, 6]
# [5, 6, 7]

日期范围迭代器

using Dates

struct DateRange
    start::Date
    stop::Date
    step::Day
end

DateRange(start::Date, stop::Date) = DateRange(start, stop, Day(1))

function Base.iterate(dr::DateRange, current=dr.start)
    current > dr.stop && return nothing
    return (current, current + dr.step)
end

Base.length(dr::DateRange) = fld(dr.stop - dr.start, dr.step) + 1

# 使用
for d in DateRange(Date(2026, 1, 1), Date(2026, 1, 5))
    println(dayname(d), ": ", d)
end

13.5 惰性求值

Julia 的许多迭代器函数是惰性的,不会立即分配内存。

# 惰性 map
lazy_squares = Iterators.map(x -> x^2, 1:1000000)
# 不分配内存!只在需要时计算

# 惰性 filter
lazy_evens = Iterators.filter(iseven, 1:1000000)

# 链式操作(全部惰性)
result = Iterators.map(x -> x^2,
    Iterators.filter(iseven,
        Iterators.map(x -> x + 1, 1:1000000)
    )
)

# 只在 collect 时才真正计算
println(sum(result))  # 需要结果时才求值

💡 惰性求值原则:尽可能使用惰性迭代器,只在需要全部结果时才 collect

13.6 Channel:生产者-消费者

Channel 是 Julia 中用于协程间通信的管道。

# 生产者函数
function producer(ch::Channel)
    for i in 1:10
        put!(ch, i^2)
        sleep(0.1)  # 模拟耗时操作
    end
end

# 消费者
ch = Channel(producer)
for val in ch
    println("收到: $val")
end

有限 Channel

# 带缓冲的 Channel
ch = Channel{Int}(5)  # 缓冲大小为 5

# 生产者任务
@async for i in 1:20
    put!(ch, i)
    println("生产: $i")
end

# 消费
for i in 1:20
    val = take!(ch)
    println("消费: $val")
end

Channel 作为迭代器

function fibonacci_channel(ch::Channel)
    a, b = 0, 1
    while true
        put!(ch, a)
        a, b = b, a + b
    end
end

ch = Channel(fibonacci_channel)
for (i, fib) in enumerate(ch)
    i > 15 && break
    println("F($i) = $fib")
end

13.7 迭代器组合器

take / drop

# take:取前 n 个
first_five = Iterators.take(1:100, 5)
println(collect(first_five))  # [1, 2, 3, 4, 5]

# drop:跳过前 n 个
after_ten = Iterators.drop(1:20, 10)
println(collect(after_ten))  # [11, 12, ..., 20]

# 组合
first_five_after_ten = Iterators.take(Iterators.drop(1:100, 10), 5)
println(collect(first_five_after_ten))  # [11, 12, 13, 14, 15]

filter 与 map

# 惰性 filter
evens = Iterators.filter(iseven, 1:20)
println(collect(evens))  # [2, 4, 6, ..., 20]

# 惰性 map
doubled = Iterators.map(x -> 2x, 1:10)
println(collect(doubled))  # [2, 4, 6, ..., 20]

flatten 与 partition

# flatten:展平嵌套迭代器
nested = [[1, 2], [3, 4], [5, 6]]
flat = Iterators.flatten(nested)
println(collect(flat))  # [1, 2, 3, 4, 5, 6]

# partition:分组(需要 IterTools.jl)
using IterTools
parts = partition(1:10, 3)
for part in parts
    println(collect(part))
end
# [1, 2, 3]
# [4, 5, 6]
# [7, 8, 9]

cycle / repeated / product

# cycle:无限循环
cycled = Iterators.cycle([1, 2, 3])
println(collect(Iterators.take(cycled, 10)))
# [1, 2, 3, 1, 2, 3, 1, 2, 3, 1]

# repeated:重复值
reps = Iterators.repeated(π, 5)
println(collect(reps))  # [π, π, π, π, π]

# product:笛卡尔积
coords = Iterators.product(1:3, 1:3)
for (x, y) in coords
    print("($x,$y) ")
end
# (1,1) (2,1) (3,1) (1,2) (2,2) (3,2) (1,3) (2,3) (3,3)

13.8 迭代器性能

using BenchmarkTools

data = rand(1_000_000)

# 手动循环
function sum_manual(data)
    s = 0.0
    for x in data
        s += x
    end
    return s
end

# 使用生成器
sum_gen(data) = sum(x for x in data)

# 使用 sum 直接调用
sum_direct(data) = sum(data)

@btime sum_manual($data)   # 最快
@btime sum_direct($data)   # 很快(内部也是迭代)
@btime sum_gen($data)      # 稍慢(生成器开销)

💡 性能提示:对于简单操作,直接使用内置函数(如 sum)通常最快。生成器在复杂链式操作中优势明显。

13.9 内存效率对比

# 任务:计算前 N 个素数的和
N = 100_000

# 方法1:列表推导(高内存)
isprime(n) = n > 1 && all(n % i != 0 for i in 2:isqrt(n))
primes_list = [x for x in 1:N if isprime(x)]
sum_list = sum(primes_list)

# 方法2:生成器(低内存)
sum_gen = sum(x for x in 1:N if isprime(x))

# 结果相同
println(sum_list == sum_gen)  # true

# 内存对比
println("列表内存: ", Base.summarysize(primes_list), " bytes")
println("生成器内存: 几乎为零")

实际场景:大文件处理

# 处理大文件时,避免一次性加载所有内容
function process_file(filename)
    # 惰性逐行处理
    lines = eachline(filename)
    numbers = Iterators.map(line -> parse(Float64, line), lines)
    positive = Iterators.filter(x -> x > 0, numbers)
    return sum(positive)
end

# 内存使用恒定,无论文件多大

13.10 实际业务场景

场景一:数据清洗管道

function clean_data(raw_lines)
    lines = Iterators.filter(!isempty, raw_lines)
    trimmed = Iterators.map(strip, lines)
    non_comment = Iterators.filter(l -> !startswith(l, "#"), trimmed)
    records = Iterators.map(line -> split(line, ","), non_comment)
    return records
end

raw = [
    "# 这是注释",
    "",
    "Alice,25,95",
    "Bob,30,87",
    "# 另一条注释",
    "Charlie,35,92"
]

for record in clean_data(raw)
    println(record)
end

场景二:无限序列与近似计算

# 莱布尼茨公式计算 π
function leibniz_pi()
    return (4 * ((-1)^k / (2k + 1) for k in 0:Inf))
end

# 部分和近似
pi_approx = 4 * sum((-1)^k / (2k + 1) for k in 0:1_000_000)
println("π ≈ $pi_approx (误差: $(abs(pi - pi_approx)))")

场景三:流式聚合

# 滑动平均
function moving_average(data, window_size)
    windows = SlidingWindow(collect(data), window_size)
    return Iterators.map(w -> sum(w) / length(w), windows)
end

data = sin.(0:0.1:10π)
ma = collect(moving_average(data, 20))
println("平滑后的长度: $(length(ma))")

13.11 扩展阅读

资源链接
Julia 官方文档 - Iterationhttps://docs.julialang.org/en/v1/manual/interfaces/#man-interface-iteration
IterTools.jlhttps://github.com/JuliaCollections/IterTools.jl
Iterators.jlhttps://github.com/JuliaCollections/Iterators.jl
Transducers.jlhttps://github.com/JuliaFolds/Transducers.jl
惰性求值最佳实践https://discourse.julialang.org/t/lazy-evaluation-patterns/

13.12 本章小结

主题要点
iterate 协议实现 iterate(iter)iterate(iter, state)
生成器(expr for x in iter) 惰性求值
enumerate/zip同时获取索引、并行迭代
自定义迭代器struct + iterate 实现惰性序列
Channel生产者-消费者通信
组合器take/drop/filter/map/flatten
内存效率惰性迭代器 O(1) 内存