11 - 代理与中间件
第 11 章:代理与中间件
11.1 MySQL 代理概述
MySQL 代理(Proxy)位于客户端与服务器之间,拦截并转发 MySQL 协议数据包。代理可以实现连接池、读写分离、负载均衡、SQL 审计等功能。
代理架构
┌──────────┐ MySQL 协议 ┌──────────┐ MySQL 协议 ┌──────────┐
│ Client │ ────────────────→ │ Proxy │ ────────────────→ │ MySQL │
│ │ ←──────────────── │ │ ←──────────────── │ Server │
└──────────┘ └──────────┘ └──────────┘
代理可以:
- 修改数据包
- 路由到不同后端
- 缓存结果
- 记录 SQL
主流 MySQL 代理/中间件
| 产品 | 语言 | 主要功能 |
|---|---|---|
| ProxySQL | C++ | 连接池、读写分离、查询缓存、故障转移 |
| MyCat | Java | 分库分表、读写分离、数据路由 |
| MySQL Router | C++ | 官方路由组件、InnoDB Cluster 集成 |
| MaxScale | C++ | MariaDB 代理、读写分离、过滤器 |
| Vitess | Go | YouTube 开源、水平扩展、分片 |
| Kingshard | Go | 轻量级分库分表代理 |
11.2 代理的基本工作原理
代理的两种模式
模式一:透传代理(Transparent Proxy)
代理不解析 MySQL 协议,仅转发 TCP 数据包。类似于 TCP 负载均衡器。
优点: 低延迟、实现简单
缺点: 无法理解 SQL、无法做读写分离
场景: 简单的故障转移、连接转发
模式二:协议解析代理(Protocol-Aware Proxy)
代理完全解析 MySQL 协议,理解每个数据包的含义。
优点: 可以做 SQL 路由、缓存、审计
缺点: 实现复杂、增加延迟
场景: 读写分离、分库分表、SQL 审计
11.3 协议解析代理实现
代理的核心流程
"""
mysql_proxy_basic.py
MySQL 代理的基本框架
"""
import socket
import struct
import select
import threading
class MySQLProxy:
"""MySQL 协议代理"""
def __init__(self, listen_host='0.0.0.0', listen_port=3307,
backend_host='127.0.0.1', backend_port=3306):
self.listen_host = listen_host
self.listen_port = listen_port
self.backend_host = backend_host
self.backend_port = backend_port
def start(self):
"""启动代理"""
server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_sock.bind((self.listen_host, self.listen_port))
server_sock.listen(128)
print(f"[+] 代理监听 {self.listen_host}:{self.listen_port}")
print(f"[+] 后端服务器 {self.backend_host}:{self.backend_port}")
while True:
client_sock, client_addr = server_sock.accept()
print(f"[+] 新连接: {client_addr}")
thread = threading.Thread(
target=self.handle_connection,
args=(client_sock, client_addr),
daemon=True
)
thread.start()
def handle_connection(self, client_sock, client_addr):
"""处理一个客户端连接"""
backend_sock = None
try:
# 连接后端
backend_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
backend_sock.connect((self.backend_host, self.backend_port))
# 代理握手过程
self.proxy_handshake(client_sock, backend_sock)
# 进入命令转发循环
self.relay_loop(client_sock, backend_sock)
except Exception as e:
print(f"[-] 连接 {client_addr} 错误: {e}")
finally:
client_sock.close()
if backend_sock:
backend_sock.close()
print(f"[-] 连接 {client_addr} 关闭")
def proxy_handshake(self, client_sock, backend_sock):
"""代理握手过程"""
# 1. 从后端读取 HandshakeV10
backend_handshake = self.read_packet(backend_sock)
print(f" [握手] 收到后端 HandshakeV10 ({len(backend_handshake)} 字节)")
# 2. 转发给客户端
self.send_packet(client_sock, backend_handshake, seq=0)
# 3. 从客户端读取 HandshakeResponse
client_response = self.read_packet(client_sock)
print(f" [握手] 收到客户端认证响应 ({len(client_response)} 字节)")
# 4. 可以在这里修改认证信息
# client_response = modify_auth(client_response)
# 5. 转发给后端
self.send_packet(backend_sock, client_response, seq=1)
# 6. 读取认证结果
auth_result = self.read_packet(backend_sock)
if auth_result[0] == 0x00:
print(" [握手] 认证成功 ✓")
elif auth_result[0] == 0xFF:
print(" [握手] 认证失败 ✗")
# 7. 转发给客户端
self.send_packet(client_sock, auth_result, seq=2)
def relay_loop(self, client_sock, backend_sock):
"""命令转发循环"""
while True:
# 等待客户端或后端的数据
readable, _, _ = select.select(
[client_sock, backend_sock], [], [], 30
)
if not readable:
print(" [超时] 连接空闲超时")
break
for sock in readable:
if sock is client_sock:
# 客户端 → 后端
packet = self.read_packet(client_sock)
if not packet:
return
# 解析命令
cmd_type = packet[0] if packet else 0
cmd_name = self.get_command_name(cmd_type)
if cmd_type == 0x03: # COM_QUERY
sql = packet[1:].decode('utf-8', errors='replace')
print(f" [查询] {sql[:80]}")
# 可以在这里做 SQL 路由、过滤、审计
# if should_route_to_slave(sql):
# backend_sock = slave_sock
elif cmd_type == 0x01: # COM_QUIT
print(" [断开] 客户端请求断开")
self.send_packet(backend_sock, packet, seq=0)
elif sock is backend_sock:
# 后端 → 客户端
packet = self.read_packet(backend_sock)
if not packet:
return
self.send_packet(client_sock, packet, seq=0)
def read_packet(self, sock) -> bytes:
"""读取一个完整的 MySQL 数据包"""
header = b''
while len(header) < 4:
chunk = sock.recv(4 - len(header))
if not chunk:
return b''
header += chunk
pkt_len = struct.unpack('<I', header[0:3] + b'\x00')[0]
payload = b''
while len(payload) < pkt_len:
chunk = sock.recv(pkt_len - len(payload))
if not chunk:
return b''
payload += chunk
return header + payload
def send_packet(self, sock, packet: bytes, seq: int = 0):
"""发送数据包(可修改序列号)"""
if len(packet) >= 4:
# 保持原始序列号或使用新序列号
new_header = packet[0:3] + struct.pack('B', seq & 0xFF)
sock.send(new_header + packet[4:])
else:
sock.send(packet)
def get_command_name(self, cmd_type: int) -> str:
"""获取命令名称"""
names = {
0x00: "COM_SLEEP", 0x01: "COM_QUIT", 0x02: "COM_INIT_DB",
0x03: "COM_QUERY", 0x04: "COM_FIELD_LIST", 0x07: "COM_REFRESH",
0x08: "COM_STATISTICS", 0x0E: "COM_PING", 0x11: "COM_CHANGE_USER",
0x16: "COM_STMT_PREPARE", 0x17: "COM_STMT_EXECUTE",
0x18: "COM_STMT_SEND_LONG_DATA", 0x19: "COM_STMT_CLOSE",
0x1A: "COM_STMT_RESET", 0x1C: "COM_STMT_FETCH",
0x1F: "COM_RESET_CONNECTION",
}
return names.get(cmd_type, f"UNKNOWN(0x{cmd_type:02X})")
if __name__ == '__main__':
proxy = MySQLProxy(
listen_port=3307,
backend_host='127.0.0.1',
backend_port=3306
)
proxy.start()
11.4 连接池实现
连接池的核心原理
┌──────────────────────────────────────────────┐
│ Connection Pool │
│ │
│ ┌────┐ ┌────┐ ┌────┐ ┌────┐ ┌────┐ │
│ │ C1 │ │ C2 │ │ C3 │ │ C4 │ │ C5 │ ← 空闲 │
│ └──┬─┘ └────┘ └────┘ └────┘ └────┘ │
│ │ │
│ ↓ 使用中 │
│ ┌────┐ │
│ │ C6 │ ← 被借出 │
│ └────┘ │
│ │
│ min_idle: 5 max_size: 20 │
└──────────────────────────────────────────────┘
Python 简易连接池
"""
mysql_connection_pool.py
MySQL 连接池实现
"""
import socket
import struct
import threading
import time
from queue import Queue, Empty
from contextlib import contextmanager
class PooledConnection:
"""池化的数据库连接"""
def __init__(self, sock, connection_id):
self.sock = sock
self.connection_id = connection_id
self.created_at = time.time()
self.last_used = time.time()
self.in_use = False
def is_alive(self) -> bool:
"""检查连接是否存活"""
try:
# 发送 COM_PING
payload = b'\x0E'
packet = struct.pack('<I', 1)[:3] + b'\x00' + payload
self.sock.send(packet)
# 读取响应
header = self.sock.recv(4)
if len(header) < 4:
return False
pkt_len = struct.unpack('<I', header[0:3] + b'\x00')[0]
response = self.sock.recv(pkt_len)
return response[0] == 0x00
except Exception:
return False
def reset(self):
"""重置连接状态"""
try:
payload = b'\x1F' # COM_RESET_CONNECTION
packet = struct.pack('<I', 1)[:3] + b'\x00' + payload
self.sock.send(packet)
header = self.sock.recv(4)
pkt_len = struct.unpack('<I', header[0:3] + b'\x00')[0]
self.sock.recv(pkt_len)
except Exception:
pass
class MySQLConnectionPool:
"""MySQL 连接池"""
def __init__(self, host, port, user, password, database,
min_idle=5, max_size=20, max_idle_time=300):
self.host = host
self.port = port
self.user = user
self.password = password
self.database = database
self.min_idle = min_idle
self.max_size = max_size
self.max_idle_time = max_idle_time
self._pool = Queue(maxsize=max_size)
self._all_connections = []
self._lock = threading.Lock()
self._size = 0
# 初始化最小连接
self._init_pool()
def _init_pool(self):
"""初始化连接池"""
for _ in range(self.min_idle):
conn = self._create_connection()
if conn:
self._pool.put(conn)
def _create_connection(self) -> PooledConnection:
"""创建新的数据库连接"""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((self.host, self.port))
# 读取握手包
header = sock.recv(4)
pkt_len = struct.unpack('<I', header[0:3] + b'\x00')[0]
handshake = sock.recv(pkt_len)
# 发送认证 (简化版)
auth = self._build_auth(handshake)
packet = struct.pack('<I', len(auth))[:3] + b'\x01' + auth
sock.send(packet)
# 读取认证结果
header = sock.recv(4)
pkt_len = struct.unpack('<I', header[0:3] + b'\x00')[0]
result = sock.recv(pkt_len)
if result[0] != 0x00:
raise Exception("认证失败")
conn_id = struct.unpack('<I', handshake[5:9])[0] if len(handshake) > 9 else 0
with self._lock:
self._size += 1
conn = PooledConnection(sock, conn_id)
self._all_connections.append(conn)
print(f" [池] 创建连接 #{conn_id} (总: {self._size})")
return conn
except Exception as e:
print(f" [池] 创建连接失败: {e}")
return None
def _build_auth(self, handshake):
"""构建认证响应(简化版)"""
# 实际实现参考前面章节
return b'\x00' * 32
@contextmanager
def get_connection(self):
"""获取连接(上下文管理器)"""
conn = self._pool.get(timeout=5)
conn.in_use = True
conn.last_used = time.time()
try:
yield conn
finally:
conn.in_use = False
# 归还前检查连接是否存活
if conn.is_alive():
self._pool.put(conn)
else:
# 连接已死,替换
with self._lock:
self._size -= 1
new_conn = self._create_connection()
if new_conn:
self._pool.put(new_conn)
def execute(self, sql: str):
"""使用池中的连接执行查询"""
with self.get_connection() as conn:
payload = b'\x03' + sql.encode('utf-8')
packet = struct.pack('<I', len(payload))[:3] + b'\x00' + payload
conn.sock.send(packet)
# 读取响应(简化版)
header = conn.sock.recv(4)
pkt_len = struct.unpack('<I', header[0:3] + b'\x00')[0]
response = conn.sock.recv(pkt_len)
return response
def close(self):
"""关闭连接池"""
while not self._pool.empty():
try:
conn = self._pool.get_nowait()
conn.sock.close()
except Empty:
break
print(f"[池] 连接池已关闭")
# 使用示例
def pool_example():
pool = MySQLConnectionPool(
host='127.0.0.1', port=3306,
user='app', password='password',
database='mydb', min_idle=5, max_size=20
)
# 使用连接
with pool.get_connection() as conn:
# 执行查询
pass
pool.close()
11.5 读写分离
读写分离原理
"""
mysql_read_write_split.py
读写分离路由实现
"""
import re
class ReadWriteRouter:
"""读写分离路由器"""
# 匹配只读语句的正则
READ_PATTERNS = [
re.compile(r'^\s*SELECT\b', re.IGNORECASE),
re.compile(r'^\s*SHOW\b', re.IGNORECASE),
re.compile(r'^\s*DESCRIBE\b', re.IGNORECASE),
re.compile(r'^\s*EXPLAIN\b', re.IGNORECASE),
]
# 匹配写语句的正则
WRITE_PATTERNS = [
re.compile(r'^\s*INSERT\b', re.IGNORECASE),
re.compile(r'^\s*UPDATE\b', re.IGNORECASE),
re.compile(r'^\s*DELETE\b', re.IGNORECASE),
re.compile(r'^\s*REPLACE\b', re.IGNORECASE),
re.compile(r'^\s*CREATE\b', re.IGNORECASE),
re.compile(r'^\s*ALTER\b', re.IGNORECASE),
re.compile(r'^\s*DROP\b', re.IGNORECASE),
re.compile(r'^\s*TRUNCATE\b', re.IGNORECASE),
re.compile(r'^\s*GRANT\b', re.IGNORECASE),
re.compile(r'^\s*REVOKE\b', re.IGNORECASE),
]
def __init__(self, master_pool, slave_pool):
self.master_pool = master_pool
self.slave_pool = slave_pool
def route_query(self, sql: str):
"""
根据 SQL 语句类型路由到主库或从库
返回: (pool, sql) 元组
"""
# 检查是否在事务中
# (实际实现需要跟踪会话状态)
# 写操作 → 主库
for pattern in self.WRITE_PATTERNS:
if pattern.match(sql):
return self.master_pool, sql
# 只读操作 → 从库
for pattern in self.READ_PATTERNS:
if pattern.match(sql):
return self.slave_pool, sql
# 默认 → 主库(安全起见)
return self.master_pool, sql
def execute(self, sql: str):
"""路由并执行查询"""
pool, routed_sql = self.route_query(sql)
backend = "主库" if pool is self.master_pool else "从库"
print(f" [路由] {backend}: {sql[:60]}")
return pool.execute(routed_sql)
# 使用示例
def rw_split_example():
master = MySQLConnectionPool(host='master.db', port=3306, ...)
slave = MySQLConnectionPool(host='slave.db', port=3306, ...)
router = ReadWriteRouter(master, slave)
# 自动路由
router.execute("SELECT * FROM users") # → 从库
router.execute("INSERT INTO orders VALUES()") # → 主库
router.execute("UPDATE users SET name='x'") # → 主库
11.6 SQL 审计与过滤
class SQLAuditor:
"""SQL 审计中间件"""
def __init__(self):
self.query_log = []
self.blocked_patterns = [
re.compile(r'DROP\s+TABLE', re.IGNORECASE),
re.compile(r'TRUNCATE', re.IGNORECASE),
]
def audit(self, sql: str, user: str, client_addr: tuple) -> tuple:
"""
审计 SQL 语句
返回: (allowed: bool, reason: str)
"""
# 记录日志
entry = {
'timestamp': time.time(),
'user': user,
'client': client_addr,
'sql': sql,
}
self.query_log.append(entry)
# 检查是否在黑名单中
for pattern in self.blocked_patterns:
if pattern.search(sql):
return False, f"语句被安全策略阻止: {pattern.pattern}"
# 检查慢查询(可根据执行时间后续标记)
return True, "OK"
11.7 查询缓存
class QueryCache:
"""简单的查询缓存"""
def __init__(self, max_size=1000, ttl=60):
self.cache = {}
self.max_size = max_size
self.ttl = ttl
self.hits = 0
self.misses = 0
def get(self, sql: str):
"""获取缓存结果"""
if sql in self.cache:
entry = self.cache[sql]
if time.time() - entry['time'] < self.ttl:
self.hits += 1
return entry['result']
else:
del self.cache[sql]
self.misses += 1
return None
def put(self, sql: str, result):
"""存储查询结果"""
if len(self.cache) >= self.max_size:
# 简单的 LRU: 删除最旧的
oldest_key = min(self.cache, key=lambda k: self.cache[k]['time'])
del self.cache[oldest_key]
self.cache[sql] = {
'result': result,
'time': time.time(),
}
def invalidate(self, table: str):
"""使相关表的缓存失效"""
keys_to_delete = []
for sql in self.cache:
if table.lower() in sql.lower():
keys_to_delete.append(sql)
for key in keys_to_delete:
del self.cache[key]
11.8 MyCat 风格的分库分表
class ShardingRouter:
"""分库分表路由器"""
def __init__(self, shards: dict):
"""
shards: {
'shard_0': {'host': '10.0.0.1', 'port': 3306},
'shard_1': {'host': '10.0.0.2', 'port': 3306},
'shard_2': {'host': '10.0.0.3', 'port': 3306},
}
"""
self.shards = shards
self.shard_count = len(shards)
self.shard_keys = {
'orders': 'user_id',
'users': 'id',
}
def route(self, table: str, shard_key_value=None) -> str:
"""根据分片键路由到对应的分片"""
if shard_key_value is None:
# 全表查询 → 广播到所有分片
return list(self.shards.keys())
# 哈希取模分片
shard_index = hash(shard_key_value) % self.shard_count
shard_name = f'shard_{shard_index}'
return [shard_name]
def rewrite_sql(self, sql: str, target_shard: str) -> str:
"""改写 SQL(如添加分片条件)"""
# 实际实现需要 SQL 解析器
return sql
11.9 协议解析的挑战
挑战一:分包重组
MySQL 消息可能跨多个 TCP 包或包含多个 MySQL 包。代理必须正确重组:
class PacketBuffer:
"""MySQL 数据包缓冲区"""
def __init__(self):
self.buffer = b''
self.current_message = b''
self.message_complete = False
def feed(self, data: bytes):
"""添加数据到缓冲区"""
self.buffer += data
self._try_parse()
def _try_parse(self):
"""尝试从缓冲区解析完整消息"""
while len(self.buffer) >= 4:
pkt_len = struct.unpack('<I', self.buffer[0:3] + b'\x00')[0]
if len(self.buffer) < 4 + pkt_len:
break # 数据不完整
# 提取一个包
packet = self.buffer[:4 + pkt_len]
self.buffer = self.buffer[4 + pkt_len:]
self.current_message += packet
# 如果包长度 < 最大值,消息结束
if pkt_len < 0xFFFFFF:
self.message_complete = True
break
def get_message(self) -> bytes:
"""获取完整消息"""
if self.message_complete:
msg = self.current_message
self.current_message = b''
self.message_complete = False
return msg
return b''
挑战二:序列号管理
代理在转发数据包时可能需要修改序列号:
class SequenceManager:
"""序列号管理器"""
def __init__(self):
self.client_seq = 0
self.backend_seq = 0
def reset(self):
"""新命令开始时重置"""
self.client_seq = 0
self.backend_seq = 0
def rewrite_packet(self, packet: bytes, direction: str) -> bytes:
"""重写数据包的序列号"""
if len(packet) < 4:
return packet
if direction == 'to_backend':
seq = self.backend_seq
self.backend_seq = (self.backend_seq + 1) & 0xFF
else:
seq = self.client_seq
self.client_seq = (self.client_seq + 1) & 0xFF
return packet[0:3] + struct.pack('B', seq) + packet[4:]
挑战三:多结果集处理
代理必须正确处理 SERVER_MORE_RESULTS_EXISTS 标志,将所有结果集转发完毕后再等待下一个命令。
11.10 注意事项
重要提醒
连接状态跟踪:代理必须跟踪每个连接的状态(是否在事务中、当前数据库等),以便做出正确的路由决策。
预处理语句路由:预处理语句的句柄只在创建它的连接上有效。代理不能将
COM_STMT_EXECUTE路由到不同的后端。会话变量:
SET语句设置的会话变量只对当前连接生效。代理需要确保路由一致性。错误处理:后端返回错误时,代理需要正确转发错误包,并清理内部状态。
超时管理:代理需要管理客户端和后端的超时,避免资源泄漏。
SSL/TLS:如果客户端使用 SSL,代理需要先完成 TLS 升级再解析 MySQL 协议。
11.11 业务场景
场景一:应用层读写分离
使用 ProxySQL 实现自动读写分离:
-- ProxySQL 配置示例
-- 添加后端 MySQL 服务器
INSERT INTO mysql_servers (hostgroup_id, hostname, port)
VALUES (10, 'master.db', 3306); -- 写组
INSERT INTO mysql_servers (hostgroup_id, hostname, port)
VALUES (20, 'slave1.db', 3306); -- 读组
INSERT INTO mysql_servers (hostgroup_id, hostname, port)
VALUES (20, 'slave2.db', 3306); -- 读组
-- 配置路由规则
INSERT INTO mysql_query_rules (rule_id, match_pattern, destination_hostgroup)
VALUES (1, '^SELECT', 20); -- SELECT → 读组
INSERT INTO mysql_query_rules (rule_id, match_pattern, destination_hostgroup)
VALUES (2, '.*', 10); -- 其他 → 写组
场景二:SQL 审计平台
通过代理拦截所有 SQL,写入审计日志:
Client → Proxy → MySQL
↓
审计日志 (ES/Kafka)
场景三:灰度发布
在代理层根据用户 ID 或请求特征,将部分流量路由到新版本数据库。
11.12 扩展阅读
- ProxySQL Documentation
- MyCat Documentation
- Vitess Documentation
- MySQL Router Documentation
- MaxScale Documentation