强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

OCaml 教程 / OCaml 函子(Functors)

OCaml 函子(Functors)

函子(Functor)是 OCaml 中以模块为参数的"函数",是模块系统中最强大也最独特的特性。它允许我们编写参数化的模块,实现高度可复用的泛型组件。

什么是函子

在 OCaml 中,普通函数以值为参数,而函子以模块为参数并返回一个新模块。

(* 定义一个接受模块作为参数的函子 *)
module type SHOWABLE = sig
  type t
  val to_string : t -> string
end

module Printer (S : SHOWABLE) = struct
  let print x = print_endline (S.to_string x)
  let print_list xs = List.iter (fun x -> print_endline (S.to_string x)) xs
end

函子的应用

module IntShowable : SHOWABLE with type t = int = struct
  type t = int
  let to_string = string_of_int
end

module StringShowable : SHOWABLE with type t = string = struct
  type t = string
  let to_string s = Printf.sprintf "\"%s\"" s
end

(* 应用函子,生成具体模块 *)
module IntPrinter = Printer(IntShowable)
module StringPrinter = Printer(StringShowable)

let () =
  IntPrinter.print 42;
  StringPrinter.print "hello";
  IntPrinter.print_list [1; 2; 3; 4; 5]

输出:

42
"hello"
1
2
3
4
5

Set.Make 函子

Set.Make 是标准库中最常用的函子之一,它接受一个可比较的模块,返回一个集合模块。

module IntSet = Set.Make(struct
  type t = int
  let compare = Stdlib.compare
end)

module StringSet = Set.Make(struct
  type t = string
  let compare = Stdlib.compare
end)

let () =
  let s = IntSet.(empty |> add 1 |> add 2 |> add 3 |> add 2) in
  Printf.printf "Size: %d\n" (IntSet.cardinal s);  (* 3 *)
  Printf.printf "Has 2: %b\n" (IntSet.mem 2 s);     (* true *)
  Printf.printf "Has 5: %b\n" (IntSet.mem 5 s)      (* false *)

let () =
  let words = StringSet.of_list ["hello"; "world"; "hello"] in
  StringSet.iter print_endline words
  (* 输出: hello world(去重且排序) *)

Set.Make 的签名要求

要求的签名说明
type t元素类型
val compare : t -> t -> int比较函数,用于内部排序

Map.Make 函子

Map.Make 创建按键有序的不可变映射。

module IntMap = Map.Make(Int)

let demo_map =
  let m = IntMap.empty in
  let m = IntMap.add 1 "one" m in
  let m = IntMap.add 2 "two" m in
  let m = IntMap.add 3 "three" m in
  m

let () =
  IntMap.iter (fun k v ->
    Printf.printf "%d -> %s\n" k v
  ) demo_map;

  (* 更新操作 *)
  let m = IntMap.update 2 (fun opt ->
    match opt with
    | None -> Some "TWO"
    | Some _ -> Some "TWO_UPDATED"
  ) demo_map in
  Printf.printf "Updated: %s\n" (IntMap.find 2 m)

自定义有序键

module CaseInsensitive = struct
  type t = string
  let compare a b =
    String.compare (String.lowercase_ascii a) (String.lowercase_ascii b)
end

module CIMap = Map.Make(CaseInsensitive)

let () =
  let m = CIMap.(empty
    |> add "Hello" 1
    |> add "hello" 2
    |> add "WORLD" 3) in
  Printf.printf "Size: %d\n" (CIMap.cardinal m);     (* 2 *)
  Printf.printf "hello: %d\n" (CIMap.find "HELLO" m)  (* 2 *)

⚠️ 注意CIMap"Hello""hello" 被视为同一个键,后插入的值会覆盖前者。

MakeComparable 模式

定义一个标准化的可比较接口,用于各种数据结构。

module type COMPARABLE = sig
  type t
  val compare : t -> t -> int
end

(* 泛型集合工厂 *)
module MakeCollection (C : COMPARABLE) = struct
  module Set = Set.Make(C)
  module Map = Map.Make(C)

  type 'a collection = {
    set : Set.t;
    data : 'a Map.t;
  }

  let empty = { set = Set.empty; data = Map.empty }

  let add key value c = {
    set = Set.add key c.set;
    data = Map.add key value c.data;
  }

  let find key c = Map.find_opt key c.data
  let mem key c = Set.mem key c.set
  let size c = Set.cardinal c.set
end

(* 使用示例 *)
module IntCollection = MakeCollection(Int)

let () =
  let c = IntCollection.(empty
    |> add 1 "Alice"
    |> add 2 "Bob"
    |> add 3 "Charlie") in
  Printf.printf "Size: %d\n" (IntCollection.size c);
  match IntCollection.find 2 c with
  | Some name -> Printf.printf "Found: %s\n" name
  | None -> print_endline "Not found"

函子与依赖注入

函子天然支持依赖注入模式,模块的依赖可以在应用函子时指定。

module type LOGGER = sig
  val info : string -> unit
  val error : string -> unit
end

module type DATABASE = sig
  type t
  val connect : string -> t
  val query : t -> string -> string list
  val close : t -> unit
end

(* 业务模块通过函子接收依赖 *)
module UserService (Log : LOGGER) (DB : DATABASE) = struct
  let find_user db user_id =
    Log.info (Printf.sprintf "Looking up user %d" user_id);
    let results = DB.query db
      (Printf.sprintf "SELECT name FROM users WHERE id = %d" user_id) in
    match results with
    | [] ->
      Log.error (Printf.sprintf "User %d not found" user_id);
      None
    | name :: _ ->
      Log.info (Printf.sprintf "Found user: %s" name);
      Some name
end

(* 生产环境注入 *)
module ProdLogger : LOGGER = struct
  let info msg = Printf.printf "[INFO] %s\n" msg
  let error msg = Printf.eprintf "[ERROR] %s\n" msg
end

(* 测试环境注入 *)
module TestLogger : LOGGER = struct
  let logs = ref []
  let info msg = logs := ("INFO", msg) :: !logs
  let error msg = logs := ("ERROR", msg) :: !logs
  let get_logs () = List.rev !logs
end

💡 提示:函子实现的依赖注入是编译时确定的,没有运行时开销,比接口注入更高效。

嵌套函子

module type ORDERED = sig
  type t
  val compare : t -> t -> int
end

module MakePairSet (A : ORDERED) (B : ORDERED) : sig
  type t
  val empty : t
  val add : A.t -> B.t -> t -> t
  val mem : A.t -> B.t -> t -> bool
end = struct
  module Pair = struct
    type t = A.t * B.t
    let compare (a1, b1) (a2, b2) =
      let c = A.compare a1 a2 in
      if c <> 0 then c else B.compare b1 b2
  end

  module S = Set.Make(Pair)

  type t = S.t
  let empty = S.empty
  let add a b s = S.add (a, b) s
  let mem a b s = S.mem (a, b) s
end

module IntPairSet = MakePairSet(Int)(String)

let () =
  let s = IntPairSet.(empty
    |> add 1 "hello"
    |> add 2 "world"
    |> add 1 "hello") in
  Printf.printf "Mem: %b\n" (IntPairSet.mem 1 "hello" s)  (* true *)

标准库中的函子汇总

函子参数签名产出用途
Set.MakeOrderedType集合模块有序集合
Map.MakeOrderedType映射模块有序键值映射
Hashtbl.MakeHashedType哈希表模块自定义哈希
Weak.MakeHashedType弱引用模块缓存
MoreLabels.Set.MakeOrderedType标签化集合带标签参数
MoreLabels.Map.MakeOrderedType标签化映射带标签参数

实际业务场景:插件系统

module type PLUGIN = sig
  val name : string
  val version : string
  val execute : string -> string
end

module PluginRegistry = struct
  let registry : (string, (module PLUGIN)) Hashtbl.t =
    Hashtbl.create 16

  let register (module P : PLUGIN) =
    Hashtbl.add registry P.name (module P : PLUGIN)

  let execute plugin_name input =
    match Hashtbl.find_opt registry plugin_name with
    | None -> Error (Printf.sprintf "Plugin %s not found" plugin_name)
    | Some (module P : PLUGIN) -> Ok (P.execute input)
end

module UpperPlugin : PLUGIN = struct
  let name = "upper"
  let version = "1.0"
  let execute = String.uppercase_ascii
end

module ReversePlugin : PLUGIN = struct
  let name = "reverse"
  let version = "1.0"
  let execute s =
    let len = String.length s in
    String.init len (fun i -> s.[len - 1 - i])
end

let () =
  PluginRegistry.register (module UpperPlugin);
  PluginRegistry.register (module ReversePlugin);
  match PluginRegistry.execute "upper" "hello" with
  | Ok result -> Printf.printf "Result: %s\n" result
  | Error msg -> Printf.printf "Error: %s\n" msg

扩展阅读