OCaml 教程 / OCaml 闭包与柯里化
OCaml 闭包与柯里化
闭包(Closure)和柯里化(Currying)是函数式编程的核心概念。OCaml 原生支持柯里化和闭包,理解它们的原理对于编写高效的函数式代码至关重要。
闭包原理
闭包是一个函数与其创建时环境的组合。当函数引用了外部变量,这些变量就被"捕获"在闭包中。
(* 最简单的闭包 *)
let make_adder n =
fun x -> x + n (* n 被捕获到闭包中 *)
let add5 = make_adder 5
let add10 = make_adder 10
let () =
Printf.printf "add5(3) = %d\n" (add5 3); (* 8 *)
Printf.printf "add10(3) = %d\n" (add10 3) (* 13 *)
闭包捕获可变状态
let make_counter () =
let count = ref 0 in
fun () ->
incr count;
!count
let () =
let counter1 = make_counter () in
let counter2 = make_counter () in
Printf.printf "Counter1: %d\n" (counter1 ()); (* 1 *)
Printf.printf "Counter1: %d\n" (counter1 ()); (* 2 *)
Printf.printf "Counter2: %d\n" (counter2 ()); (* 1 -- 独立的闭包 *)
Printf.printf "Counter1: %d\n" (counter1 ()) (* 3 *)
⚠️ 注意:闭包捕获的是变量的引用,不是值的副本。如果捕获的是
ref,闭包共享对同一可变单元的访问。
词法作用域
OCaml 使用词法作用域(Lexical Scoping),闭包在定义时确定变量绑定,而非调用时。
let x = 10
let f () =
let x = 20 in
fun () -> x (* 捕获的是外层的 x = 20,不是全局的 x = 10 *)
let () =
let g = f () in
Printf.printf "g() = %d\n" (g ()) (* 20 *)
词法作用域 vs 动态作用域
let scope = "global"
let outer () =
let scope = "outer" in
fun () -> scope (* 词法作用域:返回 "outer" *)
let () =
let inner = outer () in
Printf.printf "scope: %s\n" (inner ())
(* 词法作用域输出: "outer" *)
(* 动态作用域会输出: "global" -- 但 OCaml 不是动态作用域 *)
柯里化 vs 非柯里化
OCaml 中所有多参数函数默认都是柯里化的——每个参数产生一个新函数。
(* 柯里化函数(OCaml 默认) *)
let add x y = x + y
(* 类型: int -> int -> int *)
(* 实际上是: int -> (int -> int) *)
(* 非柯里化函数(接受元组参数) *)
let add_uncurried (x, y) = x + y
(* 类型: int * int -> int *)
let () =
(* 柯里化:可以部分应用 *)
let add5 = add 5 in
Printf.printf "add5(3) = %d\n" (add5 3);
(* 非柯里化:必须一次提供所有参数 *)
Printf.printf "add(3,4) = %d\n" (add_uncurried (3, 4))
(* 柯里化等价转换 *)
(* add 的定义等价于: *)
let add_explicit =
fun x -> (fun y -> x + y)
柯里化与非柯里化对比
| 特性 | 柯里化 f x y | 非柯里化 f (x, y) |
|---|---|---|
| 部分应用 | ✅ 天然支持 | ❌ 需要包装 |
| 类型 | int -> int -> int | int * int -> int |
| FFI 互操作 | ❌ | ✅ 更接近 C 调用约定 |
| 性能 | 每次应用一次函数调用 | 单次函数调用 |
| 模式匹配 | 只匹配第一个参数 | 可匹配所有参数 |
部分应用(Partial Application)
部分应用是柯里化最直接的好处:只提供部分参数,得到等待剩余参数的新函数。
(* 列表操作的部分应用 *)
let double_each = List.map (fun x -> x * 2)
let filter_positives = List.filter (fun x -> x > 0)
let sum = List.fold_left (+) 0
let () =
let data = [-1; 2; -3; 4; 5; -6] in
let result =
data
|> filter_positives (* [2; 4; 5] *)
|> double_each (* [4; 8; 10] *)
|> sum (* 22 *)
in
Printf.printf "Result: %d\n" result
(* printf 的部分应用 *)
let debug = Printf.printf "[DEBUG] %s\n"
let error = Printf.printf "[ERROR] %s\n"
let () =
debug "Application started";
error "Connection failed";
debug "Retrying..."
配置化函数
(* 创建一个格式化日期的函数 *)
let format_date sep year month day =
Printf.sprintf "%d%s%02d%s%02d" year sep month sep day
(* 部分应用产生特定格式 *)
let format_iso = format_date "-" (* 2026-05-11 *)
let format_us = format_date "/" (* 2026/05/11 *)
let format_compact = format_date "" (* 20260511 *)
let () =
Printf.printf "ISO: %s\n" (format_iso 2026 5 11);
Printf.printf "US: %s\n" (format_us 2026 5 11);
Printf.printf "Compact: %s\n" (format_compact 2026 5 11)
闭包与回调
闭包最常见的用途之一是作为回调函数。
let with_retry ~max_retries ~on_error f =
let rec attempt n =
try f () with
| exn ->
if n >= max_retries then begin
on_error exn n;
raise exn
end else begin
Printf.printf "Retry %d/%d...\n" (n + 1) max_retries;
attempt (n + 1)
end
in
attempt 0
let () =
let counter = ref 0 in
let result = with_retry
~max_retries:3
~on_error:(fun exn n ->
Printf.printf "Failed after %d retries: %s\n"
n (Printexc.to_string exn))
(fun () ->
incr counter;
if !counter < 3 then
failwith "temporary failure"
else
"success")
in
Printf.printf "Result: %s\n" result
闭包内存模型
理解闭包的内存布局有助于优化性能。
(* 闭包捕获整个环境——可能不是你想要的 *)
let bad_example data =
let total = List.fold_left (+) 0 data in
fun x ->
(* 这个闭包捕获了 data 和 total *)
(* 即使只用了 total,data 也不会被 GC 回收 *)
x + total
(* 好的做法:只捕获需要的值 *)
let good_example data =
let total = List.fold_left (+) 0 data in
let _ignore = data in (* 允许 GC 回收 data *)
fun x -> x + total
(* 更好的做法 *)
let best_example data =
let total = List.fold_left (+) 0 data in
(* total 是 int,直接复制到闭包中 *)
fun x -> x + total
⚠️ 注意:闭包会保持对所有捕获变量的引用。如果闭包长期存活,大数组等资源可能无法被 GC 回收。在性能敏感代码中,应只捕获必要的值。
uncurry 操作符
有时需要在柯里化和非柯里化之间转换。
let curry f x y = f (x, y)
let uncurry f (x, y) = f x y
(* 使用示例 *)
let add_curried x y = x + y
let add_uncurried = uncurry add_curried
let pair_apply f (x, y) = (f x, f y)
let () =
Printf.printf "curry: %d\n" (add_curried 3 4);
Printf.printf "uncurry: %d\n" (add_uncurried (3, 4));
let (a, b) = pair_apply (fun x -> x * 2) (3, 4) in
Printf.printf "pair_apply: (%d, %d)\n" a b
Tuple 和 curry 的转换表
| 原始类型 | 柯里化类型 | 转换函数 |
|---|---|---|
'a * 'b -> 'c | 'a -> 'b -> 'c | curry |
'a -> 'b -> 'c | 'a * 'b -> 'c | uncurry |
'a * 'b * 'c -> 'd | 'a -> 'b -> 'c -> 'd | curry3 |
'a -> 'b -> 'c -> 'd | 'a * 'b * 'c -> 'd | uncurry3 |
实际应用:中间件与过滤器链
type request = {
path : string;
headers : (string * string) list;
body : string;
}
type response = {
status : int;
body : string;
}
type middleware = request -> (request -> response) -> response
(* 日志中间件 *)
let logging_middleware : middleware = fun req next ->
Printf.printf "[LOG] %s %s\n" "GET" req.path;
let resp = next req in
Printf.printf "[LOG] Response: %d\n" resp.status;
resp
(* 认证中间件 *)
let auth_middleware : middleware = fun req next ->
let has_token = List.exists
(fun (k, _) -> k = "Authorization") req.headers in
if has_token then
next req
else
{ status = 401; body = "Unauthorized" }
(* CORS 中间件 *)
let cors_middleware : middleware = fun req next ->
let resp = next req in
{ resp with
body = resp.body
}
(* 组合中间件 *)
let compose_middlewares middlewares handler =
List.fold_right
(fun mw acc -> fun req -> mw req acc)
middlewares
handler
(* 最终处理器 *)
let final_handler : request -> response = fun req ->
{ status = 200; body = Printf.sprintf "Hello from %s" req.path }
let () =
let app = compose_middlewares
[logging_middleware; auth_middleware; cors_middleware]
final_handler
in
let req1 = {
path = "/api/users";
headers = [("Authorization", "Bearer token123")];
body = "";
} in
let resp1 = app req1 in
Printf.printf "Status: %d, Body: %s\n" resp1.status resp1.body;
let req2 = {
path = "/api/users";
headers = [];
body = "";
} in
let resp2 = app req2 in
Printf.printf "Status: %d, Body: %s\n" resp2.status resp2.body