强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

OCaml 教程 / OCaml 高阶函数

OCaml 高阶函数

高阶函数(Higher-Order Function)是函数式编程的核心:函数可以作为参数传递,也可以作为返回值。OCaml 对高阶函数提供了原生支持,使其成为编写简洁、可复用代码的利器。

函数作为参数

最简单的高阶函数——接受一个函数作为参数:

(* 对列表中每个元素应用函数并打印 *)
let apply_and_print f xs =
  List.iter (fun x -> Printf.printf "%d " (f x)) xs;
  print_newline ()

let () =
  apply_and_print (fun x -> x * 2) [1; 2; 3; 4; 5];
  (* 输出: 2 4 6 8 10 *)

  apply_and_print (fun x -> x * x) [1; 2; 3; 4; 5]
  (* 输出: 1 4 9 16 25 *)

判定函数作为参数

let count_if pred xs =
  List.fold_left (fun acc x -> if pred x then acc + 1 else acc) 0 xs

let () =
  let xs = [1; 2; 3; 4; 5; 6; 7; 8; 9; 10] in
  Printf.printf "Even: %d\n" (count_if (fun x -> x mod 2 = 0) xs);
  Printf.printf "Large: %d\n" (count_if (fun x -> x > 5) xs)

函数作为返回值

(* 返回一个乘以 n 的函数 *)
let multiplier n =
  fun x -> x * n

(* 返回一个检查是否大于 n 的函数 *)
let greater_than n =
  fun x -> x > n

let () =
  let double = multiplier 2 in
  let triple = multiplier 3 in
  Printf.printf "Double 5: %d\n" (double 5);    (* 10 *)
  Printf.printf "Triple 5: %d\n" (triple 5);    (* 15 *)

  let is_positive = greater_than 0 in
  Printf.printf "3 > 0: %b\n" (is_positive 3);   (* true *)
  Printf.printf "-1 > 0: %b\n" (is_positive (-1)) (* false *)

工厂模式

let make_validator ~min ~max =
  fun x -> x >= min && x <= max

let make_range_checker ?(step=1) ~start ~stop =
  fun () ->
    let current = ref start in
    fun () ->
      if !current > stop then None
      else begin
        let v = !current in
        current := !current + step;
        Some v
      end

let () =
  let valid_age = make_validator ~min:0 ~max:150 in
  Printf.printf "Age 25 valid: %b\n" (valid_age 25);   (* true *)
  Printf.printf "Age -1 valid: %b\n" (valid_age (-1))   (* false *)

List.map 深入

List.map 将函数应用于列表的每个元素,返回新列表。

(* 基础用法 *)
let () =
  let nums = [1; 2; 3; 4; 5] in
  let squares = List.map (fun x -> x * x) nums in
  List.iter (Printf.printf "%d ") squares;
  print_newline ()
  (* 输出: 1 4 9 16 25 *)

(* 字符串处理 *)
let () =
  let names = ["alice"; "bob"; "charlie"] in
  let capitalized = List.map String.capitalize_ascii names in
  List.iter (Printf.printf "%s ") capitalized;
  print_newline ()
  (* 输出: Alice Bob Charlie *)

(* 嵌套 map *)
let matrix = [[1; 2]; [3; 4]; [5; 6]]
let doubled = List.map (List.map (fun x -> x * 2)) matrix
(* [[2; 4]; [6; 8]; [10; 12]] *)

List.filter 深入

(* 过滤偶数 *)
let evens = List.filter (fun x -> x mod 2 = 0) [1; 2; 3; 4; 5; 6]
(* [2; 4; 6] *)

(* 过滤非空字符串 *)
let non_empty = List.filter (fun s -> String.length s > 0) [""; "hello"; ""; "world"]
(* ["hello"; "world"] *)

(* 组合 filter 和 map *)
let even_squares =
  [1; 2; 3; 4; 5; 6; 7; 8; 9; 10]
  |> List.filter (fun x -> x mod 2 = 0)
  |> List.map (fun x -> x * x)
(* [4; 16; 36; 64; 100] *)

List.fold_left / fold_right 深入

fold 是列表操作的"瑞士军刀",几乎所有列表操作都可以用 fold 实现。

(* 求和 *)
let sum = List.fold_left (+) 0 [1; 2; 3; 4; 5]
(* 15 *)

(* 求积 *)
let product = List.fold_left ( * ) 1 [1; 2; 3; 4; 5]
(* 120 *)

(* 找最大值 *)
let max_list xs =
  List.fold_left max (List.hd xs) (List.tl xs)

(* 字符串拼接 *)
let join sep xs =
  List.fold_left
    (fun acc s -> if acc = "" then s else acc ^ sep ^ s)
    "" xs

(* 统计字符频率 *)
let char_freq s =
  let tbl = Hashtbl.create 16 in
  String.iter (fun c ->
    let count = Hashtbl.find_opt tbl c |> Option.value ~default:0 in
    Hashtbl.replace tbl c (count + 1)
  ) s;
  tbl

(* fold 实现 map *)
let map_via_fold f xs =
  List.fold_right (fun x acc -> f x :: acc) xs []

(* fold 实现 filter *)
let filter_via_fold pred xs =
  List.fold_right (fun x acc ->
    if pred x then x :: acc else acc
  ) xs []

let () =
  Printf.printf "Sum: %d\n" sum;
  Printf.printf "Join: %s\n" (join ", " ["a"; "b"; "c"]);
  List.iter (Printf.printf "%d ") (map_via_fold (fun x -> x * 2) [1; 2; 3]);
  print_newline ()

管道操作符 |>

管道操作符 |> 将左侧的值作为右侧函数的最后一个参数,让代码读起来像数据流。

(* 不使用管道 *)
let result1 =
  String.uppercase_ascii
    (String.trim
      (String.concat " " ["hello"; "  "; "world"; "  "]))

(* 使用管道 *)
let result2 =
  ["hello"; "  "; "world"; "  "]
  |> String.concat " "
  |> String.trim
  |> String.uppercase_ascii

(* 管道让数据流清晰可见 *)
let process_text text =
  text
  |> String.split_on_char '\n'
  |> List.map String.trim
  |> List.filter (fun s -> s <> "")
  |> List.map String.capitalize_ascii
  |> String.concat "\n"

let () =
  let input = "  hello\n  \n  world\n  ocaml  " in
  print_endline (process_text input)

💡 提示:管道操作符 |> 是标准库中最常用的操作符之一。它的定义非常简单:let (|>) x f = f x。核心价值在于让代码从左到右、从上到下阅读,而非嵌套。

函数组合

(* 函数组合操作符 *)
let (>>) f g x = g (f x)
let (<<) f g x = f (g x)

let is_even_and_positive =
  (fun x -> x mod 2 = 0) >> (fun b -> b && true)

(* 更实用的组合 *)
let process =
  String.trim
  >> String.lowercase_ascii
  >> String.split_on_char ' '
  >> List.filter (fun s -> String.length s > 2)

let () =
  let result = process "  Hello World Of OCaml  " in
  List.iter (Printf.printf "%s\n") result

策略模式实现

type sort_strategy = int list -> int list

let bubble_sort xs =
  (* 简化实现 *)
  let arr = Array.of_list xs in
  let n = Array.length arr in
  for i = 0 to n - 2 do
    for j = 0 to n - i - 2 do
      if arr.(j) > arr.(j + 1) then begin
        let tmp = arr.(j) in
        arr.(j) <- arr.(j + 1);
        arr.(j + 1) <- tmp
      end
    done
  done;
  Array.to_list arr

let quick_sort xs =
  let rec sort = function
    | [] -> []
    | pivot :: rest ->
      let left = List.filter (fun x -> x < pivot) rest in
      let right = List.filter (fun x -> x >= pivot) rest in
      sort left @ [pivot] @ sort right
  in
  sort xs

let stdlib_sort xs = List.sort Stdlib.compare xs

(* 使用策略 *)
let sort_with_strategy (strategy : sort_strategy) xs =
  strategy xs

let () =
  let data = [5; 3; 8; 1; 9; 2; 7; 4; 6] in
  Printf.printf "Bubble: ";
  List.iter (Printf.printf "%d ") (sort_with_strategy bubble_sort data);
  print_newline ();
  Printf.printf "Quick:  ";
  List.iter (Printf.printf "%d ") (sort_with_strategy quick_sort data);
  print_newline ()

事件处理系统

module EventBus = struct
  type 'a handler = 'a -> unit

  type 'a t = {
    mutable handlers : 'a handler list;
  }

  let create () = { handlers = [] }

  let subscribe bus handler =
    bus.handlers <- handler :: bus.handlers

  let emit bus event =
    List.iter (fun h -> h event) bus.handlers

  let clear bus =
    bus.handlers <- []
end

type event =
  | Login of string
  | Logout of string
  | Message of string * string

let () =
  let bus = EventBus.create () in

  (* 订阅事件 *)
  EventBus.subscribe bus (function
    | Login user -> Printf.printf "[LOG] %s logged in\n" user
    | Logout user -> Printf.printf "[LOG] %s logged out\n" user
    | Message (from, msg) -> Printf.printf "[LOG] %s: %s\n" from msg
  );

  EventBus.subscribe bus (function
    | Login user -> Printf.printf "[WELCOME] Welcome, %s!\n" user
    | _ -> ()
  );

  (* 触发事件 *)
  EventBus.emit bus (Login "Alice");
  EventBus.emit bus (Message ("Alice", "Hello everyone!"));
  EventBus.emit bus (Logout "Alice")

高阶函数与泛型编程

函数类型签名用途
List.map('a -> 'b) -> 'a list -> 'b list变换每个元素
List.filter('a -> bool) -> 'a list -> 'a list过滤元素
List.fold_left('b -> 'a -> 'b) -> 'b -> 'a list -> 'b累积计算
List.exists('a -> bool) -> 'a list -> bool存在性检查
List.for_all('a -> bool) -> 'a list -> bool全称检查
List.find('a -> bool) -> 'a list -> 'a查找第一个
List.partition('a -> bool) -> 'a list -> 'a list * 'a list分组
List.sort('a -> 'a -> int) -> 'a list -> 'a list排序

扩展阅读