OCaml 教程 / OCaml 高阶函数
OCaml 高阶函数
高阶函数(Higher-Order Function)是函数式编程的核心:函数可以作为参数传递,也可以作为返回值。OCaml 对高阶函数提供了原生支持,使其成为编写简洁、可复用代码的利器。
函数作为参数
最简单的高阶函数——接受一个函数作为参数:
(* 对列表中每个元素应用函数并打印 *)
let apply_and_print f xs =
List.iter (fun x -> Printf.printf "%d " (f x)) xs;
print_newline ()
let () =
apply_and_print (fun x -> x * 2) [1; 2; 3; 4; 5];
(* 输出: 2 4 6 8 10 *)
apply_and_print (fun x -> x * x) [1; 2; 3; 4; 5]
(* 输出: 1 4 9 16 25 *)
判定函数作为参数
let count_if pred xs =
List.fold_left (fun acc x -> if pred x then acc + 1 else acc) 0 xs
let () =
let xs = [1; 2; 3; 4; 5; 6; 7; 8; 9; 10] in
Printf.printf "Even: %d\n" (count_if (fun x -> x mod 2 = 0) xs);
Printf.printf "Large: %d\n" (count_if (fun x -> x > 5) xs)
函数作为返回值
(* 返回一个乘以 n 的函数 *)
let multiplier n =
fun x -> x * n
(* 返回一个检查是否大于 n 的函数 *)
let greater_than n =
fun x -> x > n
let () =
let double = multiplier 2 in
let triple = multiplier 3 in
Printf.printf "Double 5: %d\n" (double 5); (* 10 *)
Printf.printf "Triple 5: %d\n" (triple 5); (* 15 *)
let is_positive = greater_than 0 in
Printf.printf "3 > 0: %b\n" (is_positive 3); (* true *)
Printf.printf "-1 > 0: %b\n" (is_positive (-1)) (* false *)
工厂模式
let make_validator ~min ~max =
fun x -> x >= min && x <= max
let make_range_checker ?(step=1) ~start ~stop =
fun () ->
let current = ref start in
fun () ->
if !current > stop then None
else begin
let v = !current in
current := !current + step;
Some v
end
let () =
let valid_age = make_validator ~min:0 ~max:150 in
Printf.printf "Age 25 valid: %b\n" (valid_age 25); (* true *)
Printf.printf "Age -1 valid: %b\n" (valid_age (-1)) (* false *)
List.map 深入
List.map 将函数应用于列表的每个元素,返回新列表。
(* 基础用法 *)
let () =
let nums = [1; 2; 3; 4; 5] in
let squares = List.map (fun x -> x * x) nums in
List.iter (Printf.printf "%d ") squares;
print_newline ()
(* 输出: 1 4 9 16 25 *)
(* 字符串处理 *)
let () =
let names = ["alice"; "bob"; "charlie"] in
let capitalized = List.map String.capitalize_ascii names in
List.iter (Printf.printf "%s ") capitalized;
print_newline ()
(* 输出: Alice Bob Charlie *)
(* 嵌套 map *)
let matrix = [[1; 2]; [3; 4]; [5; 6]]
let doubled = List.map (List.map (fun x -> x * 2)) matrix
(* [[2; 4]; [6; 8]; [10; 12]] *)
List.filter 深入
(* 过滤偶数 *)
let evens = List.filter (fun x -> x mod 2 = 0) [1; 2; 3; 4; 5; 6]
(* [2; 4; 6] *)
(* 过滤非空字符串 *)
let non_empty = List.filter (fun s -> String.length s > 0) [""; "hello"; ""; "world"]
(* ["hello"; "world"] *)
(* 组合 filter 和 map *)
let even_squares =
[1; 2; 3; 4; 5; 6; 7; 8; 9; 10]
|> List.filter (fun x -> x mod 2 = 0)
|> List.map (fun x -> x * x)
(* [4; 16; 36; 64; 100] *)
List.fold_left / fold_right 深入
fold 是列表操作的"瑞士军刀",几乎所有列表操作都可以用 fold 实现。
(* 求和 *)
let sum = List.fold_left (+) 0 [1; 2; 3; 4; 5]
(* 15 *)
(* 求积 *)
let product = List.fold_left ( * ) 1 [1; 2; 3; 4; 5]
(* 120 *)
(* 找最大值 *)
let max_list xs =
List.fold_left max (List.hd xs) (List.tl xs)
(* 字符串拼接 *)
let join sep xs =
List.fold_left
(fun acc s -> if acc = "" then s else acc ^ sep ^ s)
"" xs
(* 统计字符频率 *)
let char_freq s =
let tbl = Hashtbl.create 16 in
String.iter (fun c ->
let count = Hashtbl.find_opt tbl c |> Option.value ~default:0 in
Hashtbl.replace tbl c (count + 1)
) s;
tbl
(* fold 实现 map *)
let map_via_fold f xs =
List.fold_right (fun x acc -> f x :: acc) xs []
(* fold 实现 filter *)
let filter_via_fold pred xs =
List.fold_right (fun x acc ->
if pred x then x :: acc else acc
) xs []
let () =
Printf.printf "Sum: %d\n" sum;
Printf.printf "Join: %s\n" (join ", " ["a"; "b"; "c"]);
List.iter (Printf.printf "%d ") (map_via_fold (fun x -> x * 2) [1; 2; 3]);
print_newline ()
管道操作符 |>
管道操作符 |> 将左侧的值作为右侧函数的最后一个参数,让代码读起来像数据流。
(* 不使用管道 *)
let result1 =
String.uppercase_ascii
(String.trim
(String.concat " " ["hello"; " "; "world"; " "]))
(* 使用管道 *)
let result2 =
["hello"; " "; "world"; " "]
|> String.concat " "
|> String.trim
|> String.uppercase_ascii
(* 管道让数据流清晰可见 *)
let process_text text =
text
|> String.split_on_char '\n'
|> List.map String.trim
|> List.filter (fun s -> s <> "")
|> List.map String.capitalize_ascii
|> String.concat "\n"
let () =
let input = " hello\n \n world\n ocaml " in
print_endline (process_text input)
💡 提示:管道操作符
|>是标准库中最常用的操作符之一。它的定义非常简单:let (|>) x f = f x。核心价值在于让代码从左到右、从上到下阅读,而非嵌套。
函数组合
(* 函数组合操作符 *)
let (>>) f g x = g (f x)
let (<<) f g x = f (g x)
let is_even_and_positive =
(fun x -> x mod 2 = 0) >> (fun b -> b && true)
(* 更实用的组合 *)
let process =
String.trim
>> String.lowercase_ascii
>> String.split_on_char ' '
>> List.filter (fun s -> String.length s > 2)
let () =
let result = process " Hello World Of OCaml " in
List.iter (Printf.printf "%s\n") result
策略模式实现
type sort_strategy = int list -> int list
let bubble_sort xs =
(* 简化实现 *)
let arr = Array.of_list xs in
let n = Array.length arr in
for i = 0 to n - 2 do
for j = 0 to n - i - 2 do
if arr.(j) > arr.(j + 1) then begin
let tmp = arr.(j) in
arr.(j) <- arr.(j + 1);
arr.(j + 1) <- tmp
end
done
done;
Array.to_list arr
let quick_sort xs =
let rec sort = function
| [] -> []
| pivot :: rest ->
let left = List.filter (fun x -> x < pivot) rest in
let right = List.filter (fun x -> x >= pivot) rest in
sort left @ [pivot] @ sort right
in
sort xs
let stdlib_sort xs = List.sort Stdlib.compare xs
(* 使用策略 *)
let sort_with_strategy (strategy : sort_strategy) xs =
strategy xs
let () =
let data = [5; 3; 8; 1; 9; 2; 7; 4; 6] in
Printf.printf "Bubble: ";
List.iter (Printf.printf "%d ") (sort_with_strategy bubble_sort data);
print_newline ();
Printf.printf "Quick: ";
List.iter (Printf.printf "%d ") (sort_with_strategy quick_sort data);
print_newline ()
事件处理系统
module EventBus = struct
type 'a handler = 'a -> unit
type 'a t = {
mutable handlers : 'a handler list;
}
let create () = { handlers = [] }
let subscribe bus handler =
bus.handlers <- handler :: bus.handlers
let emit bus event =
List.iter (fun h -> h event) bus.handlers
let clear bus =
bus.handlers <- []
end
type event =
| Login of string
| Logout of string
| Message of string * string
let () =
let bus = EventBus.create () in
(* 订阅事件 *)
EventBus.subscribe bus (function
| Login user -> Printf.printf "[LOG] %s logged in\n" user
| Logout user -> Printf.printf "[LOG] %s logged out\n" user
| Message (from, msg) -> Printf.printf "[LOG] %s: %s\n" from msg
);
EventBus.subscribe bus (function
| Login user -> Printf.printf "[WELCOME] Welcome, %s!\n" user
| _ -> ()
);
(* 触发事件 *)
EventBus.emit bus (Login "Alice");
EventBus.emit bus (Message ("Alice", "Hello everyone!"));
EventBus.emit bus (Logout "Alice")
高阶函数与泛型编程
| 函数 | 类型签名 | 用途 |
|---|---|---|
List.map | ('a -> 'b) -> 'a list -> 'b list | 变换每个元素 |
List.filter | ('a -> bool) -> 'a list -> 'a list | 过滤元素 |
List.fold_left | ('b -> 'a -> 'b) -> 'b -> 'a list -> 'b | 累积计算 |
List.exists | ('a -> bool) -> 'a list -> bool | 存在性检查 |
List.for_all | ('a -> bool) -> 'a list -> bool | 全称检查 |
List.find | ('a -> bool) -> 'a list -> 'a | 查找第一个 |
List.partition | ('a -> bool) -> 'a list -> 'a list * 'a list | 分组 |
List.sort | ('a -> 'a -> int) -> 'a list -> 'a list | 排序 |