OCaml 教程 / OCaml 函子(Functors)
OCaml 函子(Functors)
函子(Functor)是 OCaml 中以模块为参数的"函数",是模块系统中最强大也最独特的特性。它允许我们编写参数化的模块,实现高度可复用的泛型组件。
什么是函子
在 OCaml 中,普通函数以值为参数,而函子以模块为参数并返回一个新模块。
(* 定义一个接受模块作为参数的函子 *)
module type SHOWABLE = sig
type t
val to_string : t -> string
end
module Printer (S : SHOWABLE) = struct
let print x = print_endline (S.to_string x)
let print_list xs = List.iter (fun x -> print_endline (S.to_string x)) xs
end
函子的应用
module IntShowable : SHOWABLE with type t = int = struct
type t = int
let to_string = string_of_int
end
module StringShowable : SHOWABLE with type t = string = struct
type t = string
let to_string s = Printf.sprintf "\"%s\"" s
end
(* 应用函子,生成具体模块 *)
module IntPrinter = Printer(IntShowable)
module StringPrinter = Printer(StringShowable)
let () =
IntPrinter.print 42;
StringPrinter.print "hello";
IntPrinter.print_list [1; 2; 3; 4; 5]
输出:
42
"hello"
1
2
3
4
5
Set.Make 函子
Set.Make 是标准库中最常用的函子之一,它接受一个可比较的模块,返回一个集合模块。
module IntSet = Set.Make(struct
type t = int
let compare = Stdlib.compare
end)
module StringSet = Set.Make(struct
type t = string
let compare = Stdlib.compare
end)
let () =
let s = IntSet.(empty |> add 1 |> add 2 |> add 3 |> add 2) in
Printf.printf "Size: %d\n" (IntSet.cardinal s); (* 3 *)
Printf.printf "Has 2: %b\n" (IntSet.mem 2 s); (* true *)
Printf.printf "Has 5: %b\n" (IntSet.mem 5 s) (* false *)
let () =
let words = StringSet.of_list ["hello"; "world"; "hello"] in
StringSet.iter print_endline words
(* 输出: hello world(去重且排序) *)
Set.Make 的签名要求
| 要求的签名 | 说明 |
|---|---|
type t | 元素类型 |
val compare : t -> t -> int | 比较函数,用于内部排序 |
Map.Make 函子
Map.Make 创建按键有序的不可变映射。
module IntMap = Map.Make(Int)
let demo_map =
let m = IntMap.empty in
let m = IntMap.add 1 "one" m in
let m = IntMap.add 2 "two" m in
let m = IntMap.add 3 "three" m in
m
let () =
IntMap.iter (fun k v ->
Printf.printf "%d -> %s\n" k v
) demo_map;
(* 更新操作 *)
let m = IntMap.update 2 (fun opt ->
match opt with
| None -> Some "TWO"
| Some _ -> Some "TWO_UPDATED"
) demo_map in
Printf.printf "Updated: %s\n" (IntMap.find 2 m)
自定义有序键
module CaseInsensitive = struct
type t = string
let compare a b =
String.compare (String.lowercase_ascii a) (String.lowercase_ascii b)
end
module CIMap = Map.Make(CaseInsensitive)
let () =
let m = CIMap.(empty
|> add "Hello" 1
|> add "hello" 2
|> add "WORLD" 3) in
Printf.printf "Size: %d\n" (CIMap.cardinal m); (* 2 *)
Printf.printf "hello: %d\n" (CIMap.find "HELLO" m) (* 2 *)
⚠️ 注意:
CIMap中"Hello"和"hello"被视为同一个键,后插入的值会覆盖前者。
MakeComparable 模式
定义一个标准化的可比较接口,用于各种数据结构。
module type COMPARABLE = sig
type t
val compare : t -> t -> int
end
(* 泛型集合工厂 *)
module MakeCollection (C : COMPARABLE) = struct
module Set = Set.Make(C)
module Map = Map.Make(C)
type 'a collection = {
set : Set.t;
data : 'a Map.t;
}
let empty = { set = Set.empty; data = Map.empty }
let add key value c = {
set = Set.add key c.set;
data = Map.add key value c.data;
}
let find key c = Map.find_opt key c.data
let mem key c = Set.mem key c.set
let size c = Set.cardinal c.set
end
(* 使用示例 *)
module IntCollection = MakeCollection(Int)
let () =
let c = IntCollection.(empty
|> add 1 "Alice"
|> add 2 "Bob"
|> add 3 "Charlie") in
Printf.printf "Size: %d\n" (IntCollection.size c);
match IntCollection.find 2 c with
| Some name -> Printf.printf "Found: %s\n" name
| None -> print_endline "Not found"
函子与依赖注入
函子天然支持依赖注入模式,模块的依赖可以在应用函子时指定。
module type LOGGER = sig
val info : string -> unit
val error : string -> unit
end
module type DATABASE = sig
type t
val connect : string -> t
val query : t -> string -> string list
val close : t -> unit
end
(* 业务模块通过函子接收依赖 *)
module UserService (Log : LOGGER) (DB : DATABASE) = struct
let find_user db user_id =
Log.info (Printf.sprintf "Looking up user %d" user_id);
let results = DB.query db
(Printf.sprintf "SELECT name FROM users WHERE id = %d" user_id) in
match results with
| [] ->
Log.error (Printf.sprintf "User %d not found" user_id);
None
| name :: _ ->
Log.info (Printf.sprintf "Found user: %s" name);
Some name
end
(* 生产环境注入 *)
module ProdLogger : LOGGER = struct
let info msg = Printf.printf "[INFO] %s\n" msg
let error msg = Printf.eprintf "[ERROR] %s\n" msg
end
(* 测试环境注入 *)
module TestLogger : LOGGER = struct
let logs = ref []
let info msg = logs := ("INFO", msg) :: !logs
let error msg = logs := ("ERROR", msg) :: !logs
let get_logs () = List.rev !logs
end
💡 提示:函子实现的依赖注入是编译时确定的,没有运行时开销,比接口注入更高效。
嵌套函子
module type ORDERED = sig
type t
val compare : t -> t -> int
end
module MakePairSet (A : ORDERED) (B : ORDERED) : sig
type t
val empty : t
val add : A.t -> B.t -> t -> t
val mem : A.t -> B.t -> t -> bool
end = struct
module Pair = struct
type t = A.t * B.t
let compare (a1, b1) (a2, b2) =
let c = A.compare a1 a2 in
if c <> 0 then c else B.compare b1 b2
end
module S = Set.Make(Pair)
type t = S.t
let empty = S.empty
let add a b s = S.add (a, b) s
let mem a b s = S.mem (a, b) s
end
module IntPairSet = MakePairSet(Int)(String)
let () =
let s = IntPairSet.(empty
|> add 1 "hello"
|> add 2 "world"
|> add 1 "hello") in
Printf.printf "Mem: %b\n" (IntPairSet.mem 1 "hello" s) (* true *)
标准库中的函子汇总
| 函子 | 参数签名 | 产出 | 用途 |
|---|---|---|---|
Set.Make | OrderedType | 集合模块 | 有序集合 |
Map.Make | OrderedType | 映射模块 | 有序键值映射 |
Hashtbl.Make | HashedType | 哈希表模块 | 自定义哈希 |
Weak.Make | HashedType | 弱引用模块 | 缓存 |
MoreLabels.Set.Make | OrderedType | 标签化集合 | 带标签参数 |
MoreLabels.Map.Make | OrderedType | 标签化映射 | 带标签参数 |
实际业务场景:插件系统
module type PLUGIN = sig
val name : string
val version : string
val execute : string -> string
end
module PluginRegistry = struct
let registry : (string, (module PLUGIN)) Hashtbl.t =
Hashtbl.create 16
let register (module P : PLUGIN) =
Hashtbl.add registry P.name (module P : PLUGIN)
let execute plugin_name input =
match Hashtbl.find_opt registry plugin_name with
| None -> Error (Printf.sprintf "Plugin %s not found" plugin_name)
| Some (module P : PLUGIN) -> Ok (P.execute input)
end
module UpperPlugin : PLUGIN = struct
let name = "upper"
let version = "1.0"
let execute = String.uppercase_ascii
end
module ReversePlugin : PLUGIN = struct
let name = "reverse"
let version = "1.0"
let execute s =
let len = String.length s in
String.init len (fun i -> s.[len - 1 - i])
end
let () =
PluginRegistry.register (module UpperPlugin);
PluginRegistry.register (module ReversePlugin);
match PluginRegistry.execute "upper" "hello" with
| Ok result -> Printf.printf "Result: %s\n" result
| Error msg -> Printf.printf "Error: %s\n" msg