15 - 扩展
15 - 扩展:加载扩展、自定义函数与自定义聚合
15.1 SQLite 扩展概述
SQLite 支持通过**扩展(Extension)**机制添加新功能:自定义 SQL 函数、聚合、排序规则,甚至虚拟表。
| 扩展类型 | 说明 | 示例 |
|---|---|---|
| 标量函数 | 每行返回一个值 | my_func(col) |
| 聚合函数 | 多行返回一个值 | my_agg(col) |
| 排序规则 | 自定义排序逻辑 | COLLATE my_collation |
| 虚拟表 | 自定义表引擎 | CREATE VIRTUAL TABLE |
| 预编译扩展 | C 编写的共享库 | .load ./ext.so |
15.2 加载预编译扩展
15.2.1 CLI 加载
# 启动时加载
sqlite3 -cmd ".load ./my_extension" mydb.db
# 在 Shell 中加载
sqlite3 mydb.db
.load ./my_extension
# 启用扩展加载(需要编译支持)
# sqlite3 -cmd "PRAGMA enable_load_extension = ON" mydb.db
15.2.2 Python 加载
import sqlite3
conn = sqlite3.connect(':memory:')
# 启用扩展加载
conn.enable_load_extension(True)
# 加载扩展
conn.load_extension('./my_extension')
# 或使用 SQL
conn.execute("SELECT load_extension('./my_extension')")
15.2.3 常用开源扩展
| 扩展 | 说明 | 地址 |
|---|---|---|
| sqlite-vec | 向量搜索(AI) | github.com/asg017/sqlite-vec |
| sqlite-http | HTTP 请求 | github.com/asg017/sqlite-http |
| sqlite-html | HTML 解析 | github.com/asg017/sqlite-html |
| sqlite-regex | 正则表达式 | github.com/asg017/sqlite-regex |
| cr-sqlite | CRDT 复制 | github.com/nicholasgasior/cr-sqlite |
| sqlean | 工具函数集合 | github.com/nicholasgasior/sqlean |
| sqlite-lines | 按行读文件 | github.com/asg017/sqlite-lines |
15.3 自定义标量函数
15.3.1 Python 实现
import sqlite3
import hashlib
import json
import re
from datetime import datetime
conn = sqlite3.connect(':memory:')
# 注册自定义标量函数
conn.create_function('md5', 1, lambda s: hashlib.md5(s.encode()).hexdigest() if s else None)
conn.create_function('sha256', 1, lambda s: hashlib.sha256(s.encode()).hexdigest() if s else None)
conn.create_function('regexp', 2, lambda pattern, text: 1 if re.search(pattern, text or '') else 0)
conn.create_function('json_extract_path', 2, lambda data, path: json.loads(data).get(path) if data else None)
conn.create_function('now', 0, lambda: datetime.now().isoformat())
conn.create_function('coalesce_many', -1, lambda *args: next((a for a in args if a is not None), None))
# 使用
conn.execute("SELECT md5('hello')") # 5d41402abc4b2a76b9719d911017c592
conn.execute("SELECT regexp('\\d+', 'abc123def')") # 1
conn.execute("SELECT now()") # 2026-05-10T12:00:00.000000
# 创建带上下文的函数(使用类)
class AggConcat:
def __init__(self):
self.items = []
def step(self, value, sep=','):
if value is not None:
self.items.append(str(value))
self._sep = sep
def finalize(self):
return self._sep.join(self.items)
conn.create_aggregate('agg_concat', 2, AggConcat)
15.3.2 Go 实现
// 使用 github.com/mattn/go-sqlite3
package main
import (
"crypto/md5"
"database/sql"
"fmt"
_ "github.com/mattn/go-sqlite3"
)
func main() {
sql.Register("sqlite3_custom", &sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
// 注册自定义函数
conn.RegisterFunc("md5", func(s string) string {
hash := md5.Sum([]byte(s))
return fmt.Sprintf("%x", hash)
}, true)
return nil
},
})
db, _ := sql.Open("sqlite3_custom", ":memory:")
var result string
db.QueryRow("SELECT md5('hello')").Scan(&result)
fmt.Println(result)
}
15.3.3 Node.js (better-sqlite3) 实现
const Database = require('better-sqlite3');
const crypto = require('crypto');
const db = new Database(':memory:');
// 注册自定义函数
db.function('md5', (text) => {
if (text === null) return null;
return crypto.createHash('md5').update(text).digest('hex');
});
db.function('regexp', (pattern, text) => {
if (!pattern || !text) return 0;
return new RegExp(pattern).test(text) ? 1 : 0;
});
// 使用
const row = db.prepare("SELECT md5('hello') AS hash").get();
console.log(row.hash); // 5d41402abc4b2a76b9719d911017c592
15.3.4 Rust (rusqlite) 实现
use rusqlite::{Connection, Result, functions::FunctionFlags};
use md5;
fn main() -> Result<()> {
let conn = Connection::open_in_memory()?;
conn.create_scalar_function("md5", 1, FunctionFlags::SQLITE_UTF8, |ctx| {
let text = ctx.get::<String>(0)?;
Ok(format!("{:x}", md5::compute(text)))
})?;
let hash: String = conn.query_row("SELECT md5('hello')", [], |row| row.get(0))?;
println!("{}", hash);
Ok(())
}
15.4 自定义聚合函数
15.4.1 Python 聚合函数
import sqlite3
import statistics
conn = sqlite3.connect(':memory:')
# 自定义聚合:标准差
class StdDev:
def __init__(self):
self.values = []
def step(self, value):
if value is not None:
self.values.append(float(value))
def finalize(self):
if len(self.values) < 2:
return None
return statistics.stdev(self.values)
# 自定义聚合:中位数
class Median:
def __init__(self):
self.values = []
def step(self, value):
if value is not None:
self.values.append(float(value))
def finalize(self):
if not self.values:
return None
return statistics.median(self.values)
# 自定义聚合:JSON 数组收集
class JsonArrayAgg:
def __init__(self):
self.items = []
def step(self, value):
self.items.append(value)
def finalize(self):
import json
return json.dumps(self.items)
conn.create_aggregate('stddev', 1, StdDev)
conn.create_aggregate('median', 1, Median)
conn.create_aggregate('json_array_agg', 1, JsonArrayAgg)
# 使用
conn.execute("CREATE TABLE scores (id INTEGER, score REAL)")
conn.executemany("INSERT INTO scores VALUES (?, ?)", [(1, 85), (2, 92), (3, 78), (4, 95), (5, 88)])
row = conn.execute("SELECT stddev(score), median(score) FROM scores").fetchone()
print(f"标准差: {row[0]:.2f}, 中位数: {row[1]}")
row = conn.execute("SELECT json_array_agg(score) FROM scores").fetchone()
print(f"JSON 数组: {row[0]}")
15.4.2 自定义排序规则
import sqlite3
conn = sqlite3.connect(':memory:')
# 自定义排序:中文拼音排序(需要 pypinyin 库)
def pinyin_collation(s1, s2):
try:
from pypinyin import lazy_pinyin
p1 = ''.join(lazy_pinyin(s1))
p2 = ''.join(lazy_pinyin(s2))
return (p1 > p2) - (p1 < p2)
except ImportError:
return (s1 > s2) - (s1 < s2)
conn.create_collation('pinyin', pinyin_collation)
# 使用
conn.execute("CREATE TABLE names (name TEXT)")
conn.executemany("INSERT INTO names VALUES (?)", [('张三',), ('李四',), ('王五',), ('赵六',)])
rows = conn.execute("SELECT * FROM names ORDER BY name COLLATE pinyin").fetchall()
print([r[0] for r in rows])
15.5 虚拟表(概述)
15.5.1 什么是虚拟表
虚拟表是一种不存储实际数据的表——数据由自定义的 C 代码或扩展提供。
-- 虚拟表的使用方式和普通表相同
CREATE VIRTUAL TABLE my_vtable USING my_module(args);
SELECT * FROM my_vtable WHERE ...;
15.5.2 内置虚拟表
| 虚拟表 | 说明 |
|---|---|
fts5 | 全文搜索 |
fts3/fts4 | 旧版全文搜索 |
rtree | 空间索引 |
json_each/json_tree | JSON 展开 |
generate_series | 数字序列生成 |
-- generate_series 生成数字序列
SELECT * FROM generate_series(1, 10);
-- 配合其他查询
SELECT value, value * value AS square FROM generate_series(1, 10);
-- rtree 空间索引
CREATE VIRTUAL TABLE locations USING rtree(
id,
min_x, max_x,
min_y, max_y
);
INSERT INTO locations VALUES (1, 100, 200, 300, 400);
INSERT INTO locations VALUES (2, 150, 250, 350, 450);
-- 查询包含某点的区域
SELECT * FROM locations
WHERE min_x <= 180 AND max_x >= 180
AND min_y <= 350 AND max_y >= 350;
15.5.3 Python 虚拟表(使用 sqlite-vtab)
pip install sqlite-vtab
# Python 虚拟表示例框架
# 需要使用 apsw 或类似库来实现完整的虚拟表
# 简单场景建议使用内置的 generate_series 或 FTS5
15.6 C 语言扩展开发
15.6.1 基本扩展框架
// my_extension.c
#include <sqlite3ext.h>
SQLITE_EXTENSION_INIT1
// 自定义标量函数
static void my_upper_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
const unsigned char *text = sqlite3_value_text(argv[0]);
if (!text) {
sqlite3_result_null(ctx);
return;
}
int len = sqlite3_value_bytes(argv[0]);
char *result = sqlite3_malloc(len + 1);
for (int i = 0; i <= len; i++) {
result[i] = (text[i] >= 'a' && text[i] <= 'z') ? text[i] - 32 : text[i];
}
sqlite3_result_text(ctx, result, len, sqlite3_free);
}
// 扩展入口
#ifdef _WIN32
__declspec(dllexport)
#endif
int sqlite3_myextension_init(
sqlite3 *db,
char **pzErrMsg,
const sqlite3_api_routines *pApi
) {
SQLITE_EXTENSION_INIT2(pApi);
sqlite3_create_function(db, "my_upper", 1, SQLITE_UTF8 | SQLITE_DETERMINISTIC,
NULL, my_upper_func, NULL, NULL);
return SQLITE_OK;
}
15.6.2 编译扩展
# Linux
gcc -shared -fPIC -o my_extension.so my_extension.c -lsqlite3 -I/usr/include
# macOS
gcc -shared -fPIC -o my_extension.dylib my_extension.c -lsqlite3
# Windows (MinGW)
gcc -shared -o my_extension.dll my_extension.c -lsqlite3
⚠️ 注意事项
- 扩展加载默认禁用——CLI 和程序中都需要显式启用
- 自定义函数应标记
DETERMINISTIC——允许优化器缓存结果 - 注意函数中的 NULL 处理——每个参数都可能是 NULL
- 聚合函数的 finalize 必须返回值——即使没有数据也要返回 NULL
- 扩展的二进制兼容性——需要匹配 SQLite 版本和平台
enable_load_extension有安全风险——不要暴露给不信任的输入
💡 技巧
- Python 的
create_function是最简单的扩展方式 - 使用
deterministic=True提升确定性函数的性能 create_aggregate适合实现统计类函数create_collation可以实现自定义排序(如中文拼音)- sqlean 扩展包 提供了大量实用函数——避免重复造轮子
📌 业务场景
场景一:数据脱敏函数
conn.create_function('mask_phone', 1, lambda p: p[:3] + '****' + p[-4:] if p and len(p) >= 7 else p)
conn.create_function('mask_email', 1, lambda e: e[:2] + '***@' + e.split('@')[1] if e and '@' in e else e)
SELECT mask_phone(phone), mask_email(email) FROM users;
场景二:地理距离计算
import math
def haversine(lat1, lon1, lat2, lon2):
R = 6371
dlat = math.radians(lat2 - lat1)
dlon = math.radians(lon2 - lon1)
a = math.sin(dlat/2)**2 + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon/2)**2
return R * 2 * math.asin(math.sqrt(a))
conn.create_function('distance_km', 4, haversine)
SELECT name, distance_km(lat, lon, 39.9, 116.4) AS dist_km FROM stores ORDER BY dist_km LIMIT 10;
🔗 扩展阅读
📖 下一章:16 - 性能调优 —— WAL 优化、批量操作、内存映射