强曰为道

与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

11 - 代理与中间件

第 11 章:代理与中间件

11.1 MySQL 代理概述

MySQL 代理(Proxy)位于客户端与服务器之间,拦截并转发 MySQL 协议数据包。代理可以实现连接池、读写分离、负载均衡、SQL 审计等功能。

代理架构

┌──────────┐     MySQL 协议      ┌──────────┐     MySQL 协议      ┌──────────┐
│  Client  │ ────────────────→  │   Proxy  │ ────────────────→  │  MySQL   │
│          │ ←────────────────  │          │ ←────────────────  │  Server  │
└──────────┘                    └──────────┘                    └──────────┘
                                代理可以:
                                - 修改数据包
                                - 路由到不同后端
                                - 缓存结果
                                - 记录 SQL

主流 MySQL 代理/中间件

产品语言主要功能
ProxySQLC++连接池、读写分离、查询缓存、故障转移
MyCatJava分库分表、读写分离、数据路由
MySQL RouterC++官方路由组件、InnoDB Cluster 集成
MaxScaleC++MariaDB 代理、读写分离、过滤器
VitessGoYouTube 开源、水平扩展、分片
KingshardGo轻量级分库分表代理

11.2 代理的基本工作原理

代理的两种模式

模式一:透传代理(Transparent Proxy)

代理不解析 MySQL 协议,仅转发 TCP 数据包。类似于 TCP 负载均衡器。

优点: 低延迟、实现简单
缺点: 无法理解 SQL、无法做读写分离
场景: 简单的故障转移、连接转发

模式二:协议解析代理(Protocol-Aware Proxy)

代理完全解析 MySQL 协议,理解每个数据包的含义。

优点: 可以做 SQL 路由、缓存、审计
缺点: 实现复杂、增加延迟
场景: 读写分离、分库分表、SQL 审计

11.3 协议解析代理实现

代理的核心流程

"""
mysql_proxy_basic.py
MySQL 代理的基本框架
"""
import socket
import struct
import select
import threading


class MySQLProxy:
    """MySQL 协议代理"""

    def __init__(self, listen_host='0.0.0.0', listen_port=3307,
                 backend_host='127.0.0.1', backend_port=3306):
        self.listen_host = listen_host
        self.listen_port = listen_port
        self.backend_host = backend_host
        self.backend_port = backend_port

    def start(self):
        """启动代理"""
        server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        server_sock.bind((self.listen_host, self.listen_port))
        server_sock.listen(128)

        print(f"[+] 代理监听 {self.listen_host}:{self.listen_port}")
        print(f"[+] 后端服务器 {self.backend_host}:{self.backend_port}")

        while True:
            client_sock, client_addr = server_sock.accept()
            print(f"[+] 新连接: {client_addr}")
            thread = threading.Thread(
                target=self.handle_connection,
                args=(client_sock, client_addr),
                daemon=True
            )
            thread.start()

    def handle_connection(self, client_sock, client_addr):
        """处理一个客户端连接"""
        backend_sock = None
        try:
            # 连接后端
            backend_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            backend_sock.connect((self.backend_host, self.backend_port))

            # 代理握手过程
            self.proxy_handshake(client_sock, backend_sock)

            # 进入命令转发循环
            self.relay_loop(client_sock, backend_sock)

        except Exception as e:
            print(f"[-] 连接 {client_addr} 错误: {e}")
        finally:
            client_sock.close()
            if backend_sock:
                backend_sock.close()
            print(f"[-] 连接 {client_addr} 关闭")

    def proxy_handshake(self, client_sock, backend_sock):
        """代理握手过程"""
        # 1. 从后端读取 HandshakeV10
        backend_handshake = self.read_packet(backend_sock)
        print(f"  [握手] 收到后端 HandshakeV10 ({len(backend_handshake)} 字节)")

        # 2. 转发给客户端
        self.send_packet(client_sock, backend_handshake, seq=0)

        # 3. 从客户端读取 HandshakeResponse
        client_response = self.read_packet(client_sock)
        print(f"  [握手] 收到客户端认证响应 ({len(client_response)} 字节)")

        # 4. 可以在这里修改认证信息
        # client_response = modify_auth(client_response)

        # 5. 转发给后端
        self.send_packet(backend_sock, client_response, seq=1)

        # 6. 读取认证结果
        auth_result = self.read_packet(backend_sock)
        if auth_result[0] == 0x00:
            print("  [握手] 认证成功 ✓")
        elif auth_result[0] == 0xFF:
            print("  [握手] 认证失败 ✗")

        # 7. 转发给客户端
        self.send_packet(client_sock, auth_result, seq=2)

    def relay_loop(self, client_sock, backend_sock):
        """命令转发循环"""
        while True:
            # 等待客户端或后端的数据
            readable, _, _ = select.select(
                [client_sock, backend_sock], [], [], 30
            )

            if not readable:
                print("  [超时] 连接空闲超时")
                break

            for sock in readable:
                if sock is client_sock:
                    # 客户端 → 后端
                    packet = self.read_packet(client_sock)
                    if not packet:
                        return

                    # 解析命令
                    cmd_type = packet[0] if packet else 0
                    cmd_name = self.get_command_name(cmd_type)

                    if cmd_type == 0x03:  # COM_QUERY
                        sql = packet[1:].decode('utf-8', errors='replace')
                        print(f"  [查询] {sql[:80]}")

                        # 可以在这里做 SQL 路由、过滤、审计
                        # if should_route_to_slave(sql):
                        #     backend_sock = slave_sock

                    elif cmd_type == 0x01:  # COM_QUIT
                        print("  [断开] 客户端请求断开")

                    self.send_packet(backend_sock, packet, seq=0)

                elif sock is backend_sock:
                    # 后端 → 客户端
                    packet = self.read_packet(backend_sock)
                    if not packet:
                        return
                    self.send_packet(client_sock, packet, seq=0)

    def read_packet(self, sock) -> bytes:
        """读取一个完整的 MySQL 数据包"""
        header = b''
        while len(header) < 4:
            chunk = sock.recv(4 - len(header))
            if not chunk:
                return b''
            header += chunk

        pkt_len = struct.unpack('<I', header[0:3] + b'\x00')[0]
        payload = b''
        while len(payload) < pkt_len:
            chunk = sock.recv(pkt_len - len(payload))
            if not chunk:
                return b''
            payload += chunk

        return header + payload

    def send_packet(self, sock, packet: bytes, seq: int = 0):
        """发送数据包(可修改序列号)"""
        if len(packet) >= 4:
            # 保持原始序列号或使用新序列号
            new_header = packet[0:3] + struct.pack('B', seq & 0xFF)
            sock.send(new_header + packet[4:])
        else:
            sock.send(packet)

    def get_command_name(self, cmd_type: int) -> str:
        """获取命令名称"""
        names = {
            0x00: "COM_SLEEP", 0x01: "COM_QUIT", 0x02: "COM_INIT_DB",
            0x03: "COM_QUERY", 0x04: "COM_FIELD_LIST", 0x07: "COM_REFRESH",
            0x08: "COM_STATISTICS", 0x0E: "COM_PING", 0x11: "COM_CHANGE_USER",
            0x16: "COM_STMT_PREPARE", 0x17: "COM_STMT_EXECUTE",
            0x18: "COM_STMT_SEND_LONG_DATA", 0x19: "COM_STMT_CLOSE",
            0x1A: "COM_STMT_RESET", 0x1C: "COM_STMT_FETCH",
            0x1F: "COM_RESET_CONNECTION",
        }
        return names.get(cmd_type, f"UNKNOWN(0x{cmd_type:02X})")


if __name__ == '__main__':
    proxy = MySQLProxy(
        listen_port=3307,
        backend_host='127.0.0.1',
        backend_port=3306
    )
    proxy.start()

11.4 连接池实现

连接池的核心原理

┌──────────────────────────────────────────────┐
│                 Connection Pool              │
│                                              │
│  ┌────┐ ┌────┐ ┌────┐ ┌────┐ ┌────┐        │
│  │ C1 │ │ C2 │ │ C3 │ │ C4 │ │ C5 │ ← 空闲 │
│  └──┬─┘ └────┘ └────┘ └────┘ └────┘        │
│     │                                        │
│     ↓ 使用中                                  │
│  ┌────┐                                      │
│  │ C6 │ ← 被借出                              │
│  └────┘                                      │
│                                              │
│  min_idle: 5  max_size: 20                   │
└──────────────────────────────────────────────┘

Python 简易连接池

"""
mysql_connection_pool.py
MySQL 连接池实现
"""
import socket
import struct
import threading
import time
from queue import Queue, Empty
from contextlib import contextmanager


class PooledConnection:
    """池化的数据库连接"""
    def __init__(self, sock, connection_id):
        self.sock = sock
        self.connection_id = connection_id
        self.created_at = time.time()
        self.last_used = time.time()
        self.in_use = False

    def is_alive(self) -> bool:
        """检查连接是否存活"""
        try:
            # 发送 COM_PING
            payload = b'\x0E'
            packet = struct.pack('<I', 1)[:3] + b'\x00' + payload
            self.sock.send(packet)

            # 读取响应
            header = self.sock.recv(4)
            if len(header) < 4:
                return False
            pkt_len = struct.unpack('<I', header[0:3] + b'\x00')[0]
            response = self.sock.recv(pkt_len)
            return response[0] == 0x00
        except Exception:
            return False

    def reset(self):
        """重置连接状态"""
        try:
            payload = b'\x1F'  # COM_RESET_CONNECTION
            packet = struct.pack('<I', 1)[:3] + b'\x00' + payload
            self.sock.send(packet)

            header = self.sock.recv(4)
            pkt_len = struct.unpack('<I', header[0:3] + b'\x00')[0]
            self.sock.recv(pkt_len)
        except Exception:
            pass


class MySQLConnectionPool:
    """MySQL 连接池"""

    def __init__(self, host, port, user, password, database,
                 min_idle=5, max_size=20, max_idle_time=300):
        self.host = host
        self.port = port
        self.user = user
        self.password = password
        self.database = database
        self.min_idle = min_idle
        self.max_size = max_size
        self.max_idle_time = max_idle_time

        self._pool = Queue(maxsize=max_size)
        self._all_connections = []
        self._lock = threading.Lock()
        self._size = 0

        # 初始化最小连接
        self._init_pool()

    def _init_pool(self):
        """初始化连接池"""
        for _ in range(self.min_idle):
            conn = self._create_connection()
            if conn:
                self._pool.put(conn)

    def _create_connection(self) -> PooledConnection:
        """创建新的数据库连接"""
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect((self.host, self.port))

            # 读取握手包
            header = sock.recv(4)
            pkt_len = struct.unpack('<I', header[0:3] + b'\x00')[0]
            handshake = sock.recv(pkt_len)

            # 发送认证 (简化版)
            auth = self._build_auth(handshake)
            packet = struct.pack('<I', len(auth))[:3] + b'\x01' + auth
            sock.send(packet)

            # 读取认证结果
            header = sock.recv(4)
            pkt_len = struct.unpack('<I', header[0:3] + b'\x00')[0]
            result = sock.recv(pkt_len)

            if result[0] != 0x00:
                raise Exception("认证失败")

            conn_id = struct.unpack('<I', handshake[5:9])[0] if len(handshake) > 9 else 0

            with self._lock:
                self._size += 1

            conn = PooledConnection(sock, conn_id)
            self._all_connections.append(conn)
            print(f"  [池] 创建连接 #{conn_id} (总: {self._size})")
            return conn

        except Exception as e:
            print(f"  [池] 创建连接失败: {e}")
            return None

    def _build_auth(self, handshake):
        """构建认证响应(简化版)"""
        # 实际实现参考前面章节
        return b'\x00' * 32

    @contextmanager
    def get_connection(self):
        """获取连接(上下文管理器)"""
        conn = self._pool.get(timeout=5)
        conn.in_use = True
        conn.last_used = time.time()
        try:
            yield conn
        finally:
            conn.in_use = False
            # 归还前检查连接是否存活
            if conn.is_alive():
                self._pool.put(conn)
            else:
                # 连接已死,替换
                with self._lock:
                    self._size -= 1
                new_conn = self._create_connection()
                if new_conn:
                    self._pool.put(new_conn)

    def execute(self, sql: str):
        """使用池中的连接执行查询"""
        with self.get_connection() as conn:
            payload = b'\x03' + sql.encode('utf-8')
            packet = struct.pack('<I', len(payload))[:3] + b'\x00' + payload
            conn.sock.send(packet)

            # 读取响应(简化版)
            header = conn.sock.recv(4)
            pkt_len = struct.unpack('<I', header[0:3] + b'\x00')[0]
            response = conn.sock.recv(pkt_len)
            return response

    def close(self):
        """关闭连接池"""
        while not self._pool.empty():
            try:
                conn = self._pool.get_nowait()
                conn.sock.close()
            except Empty:
                break
        print(f"[池] 连接池已关闭")


# 使用示例
def pool_example():
    pool = MySQLConnectionPool(
        host='127.0.0.1', port=3306,
        user='app', password='password',
        database='mydb', min_idle=5, max_size=20
    )

    # 使用连接
    with pool.get_connection() as conn:
        # 执行查询
        pass

    pool.close()

11.5 读写分离

读写分离原理

"""
mysql_read_write_split.py
读写分离路由实现
"""
import re


class ReadWriteRouter:
    """读写分离路由器"""

    # 匹配只读语句的正则
    READ_PATTERNS = [
        re.compile(r'^\s*SELECT\b', re.IGNORECASE),
        re.compile(r'^\s*SHOW\b', re.IGNORECASE),
        re.compile(r'^\s*DESCRIBE\b', re.IGNORECASE),
        re.compile(r'^\s*EXPLAIN\b', re.IGNORECASE),
    ]

    # 匹配写语句的正则
    WRITE_PATTERNS = [
        re.compile(r'^\s*INSERT\b', re.IGNORECASE),
        re.compile(r'^\s*UPDATE\b', re.IGNORECASE),
        re.compile(r'^\s*DELETE\b', re.IGNORECASE),
        re.compile(r'^\s*REPLACE\b', re.IGNORECASE),
        re.compile(r'^\s*CREATE\b', re.IGNORECASE),
        re.compile(r'^\s*ALTER\b', re.IGNORECASE),
        re.compile(r'^\s*DROP\b', re.IGNORECASE),
        re.compile(r'^\s*TRUNCATE\b', re.IGNORECASE),
        re.compile(r'^\s*GRANT\b', re.IGNORECASE),
        re.compile(r'^\s*REVOKE\b', re.IGNORECASE),
    ]

    def __init__(self, master_pool, slave_pool):
        self.master_pool = master_pool
        self.slave_pool = slave_pool

    def route_query(self, sql: str):
        """
        根据 SQL 语句类型路由到主库或从库

        返回: (pool, sql) 元组
        """
        # 检查是否在事务中
        # (实际实现需要跟踪会话状态)

        # 写操作 → 主库
        for pattern in self.WRITE_PATTERNS:
            if pattern.match(sql):
                return self.master_pool, sql

        # 只读操作 → 从库
        for pattern in self.READ_PATTERNS:
            if pattern.match(sql):
                return self.slave_pool, sql

        # 默认 → 主库(安全起见)
        return self.master_pool, sql

    def execute(self, sql: str):
        """路由并执行查询"""
        pool, routed_sql = self.route_query(sql)
        backend = "主库" if pool is self.master_pool else "从库"
        print(f"  [路由] {backend}: {sql[:60]}")
        return pool.execute(routed_sql)


# 使用示例
def rw_split_example():
    master = MySQLConnectionPool(host='master.db', port=3306, ...)
    slave = MySQLConnectionPool(host='slave.db', port=3306, ...)

    router = ReadWriteRouter(master, slave)

    # 自动路由
    router.execute("SELECT * FROM users")          # → 从库
    router.execute("INSERT INTO orders VALUES()")   # → 主库
    router.execute("UPDATE users SET name='x'")    # → 主库

11.6 SQL 审计与过滤

class SQLAuditor:
    """SQL 审计中间件"""

    def __init__(self):
        self.query_log = []
        self.blocked_patterns = [
            re.compile(r'DROP\s+TABLE', re.IGNORECASE),
            re.compile(r'TRUNCATE', re.IGNORECASE),
        ]

    def audit(self, sql: str, user: str, client_addr: tuple) -> tuple:
        """
        审计 SQL 语句

        返回: (allowed: bool, reason: str)
        """
        # 记录日志
        entry = {
            'timestamp': time.time(),
            'user': user,
            'client': client_addr,
            'sql': sql,
        }
        self.query_log.append(entry)

        # 检查是否在黑名单中
        for pattern in self.blocked_patterns:
            if pattern.search(sql):
                return False, f"语句被安全策略阻止: {pattern.pattern}"

        # 检查慢查询(可根据执行时间后续标记)

        return True, "OK"

11.7 查询缓存

class QueryCache:
    """简单的查询缓存"""

    def __init__(self, max_size=1000, ttl=60):
        self.cache = {}
        self.max_size = max_size
        self.ttl = ttl
        self.hits = 0
        self.misses = 0

    def get(self, sql: str):
        """获取缓存结果"""
        if sql in self.cache:
            entry = self.cache[sql]
            if time.time() - entry['time'] < self.ttl:
                self.hits += 1
                return entry['result']
            else:
                del self.cache[sql]
        self.misses += 1
        return None

    def put(self, sql: str, result):
        """存储查询结果"""
        if len(self.cache) >= self.max_size:
            # 简单的 LRU: 删除最旧的
            oldest_key = min(self.cache, key=lambda k: self.cache[k]['time'])
            del self.cache[oldest_key]

        self.cache[sql] = {
            'result': result,
            'time': time.time(),
        }

    def invalidate(self, table: str):
        """使相关表的缓存失效"""
        keys_to_delete = []
        for sql in self.cache:
            if table.lower() in sql.lower():
                keys_to_delete.append(sql)
        for key in keys_to_delete:
            del self.cache[key]

11.8 MyCat 风格的分库分表

class ShardingRouter:
    """分库分表路由器"""

    def __init__(self, shards: dict):
        """
        shards: {
            'shard_0': {'host': '10.0.0.1', 'port': 3306},
            'shard_1': {'host': '10.0.0.2', 'port': 3306},
            'shard_2': {'host': '10.0.0.3', 'port': 3306},
        }
        """
        self.shards = shards
        self.shard_count = len(shards)
        self.shard_keys = {
            'orders': 'user_id',
            'users': 'id',
        }

    def route(self, table: str, shard_key_value=None) -> str:
        """根据分片键路由到对应的分片"""
        if shard_key_value is None:
            # 全表查询 → 广播到所有分片
            return list(self.shards.keys())

        # 哈希取模分片
        shard_index = hash(shard_key_value) % self.shard_count
        shard_name = f'shard_{shard_index}'
        return [shard_name]

    def rewrite_sql(self, sql: str, target_shard: str) -> str:
        """改写 SQL(如添加分片条件)"""
        # 实际实现需要 SQL 解析器
        return sql

11.9 协议解析的挑战

挑战一:分包重组

MySQL 消息可能跨多个 TCP 包或包含多个 MySQL 包。代理必须正确重组:

class PacketBuffer:
    """MySQL 数据包缓冲区"""

    def __init__(self):
        self.buffer = b''
        self.current_message = b''
        self.message_complete = False

    def feed(self, data: bytes):
        """添加数据到缓冲区"""
        self.buffer += data
        self._try_parse()

    def _try_parse(self):
        """尝试从缓冲区解析完整消息"""
        while len(self.buffer) >= 4:
            pkt_len = struct.unpack('<I', self.buffer[0:3] + b'\x00')[0]

            if len(self.buffer) < 4 + pkt_len:
                break  # 数据不完整

            # 提取一个包
            packet = self.buffer[:4 + pkt_len]
            self.buffer = self.buffer[4 + pkt_len:]

            self.current_message += packet

            # 如果包长度 < 最大值,消息结束
            if pkt_len < 0xFFFFFF:
                self.message_complete = True
                break

    def get_message(self) -> bytes:
        """获取完整消息"""
        if self.message_complete:
            msg = self.current_message
            self.current_message = b''
            self.message_complete = False
            return msg
        return b''

挑战二:序列号管理

代理在转发数据包时可能需要修改序列号:

class SequenceManager:
    """序列号管理器"""

    def __init__(self):
        self.client_seq = 0
        self.backend_seq = 0

    def reset(self):
        """新命令开始时重置"""
        self.client_seq = 0
        self.backend_seq = 0

    def rewrite_packet(self, packet: bytes, direction: str) -> bytes:
        """重写数据包的序列号"""
        if len(packet) < 4:
            return packet

        if direction == 'to_backend':
            seq = self.backend_seq
            self.backend_seq = (self.backend_seq + 1) & 0xFF
        else:
            seq = self.client_seq
            self.client_seq = (self.client_seq + 1) & 0xFF

        return packet[0:3] + struct.pack('B', seq) + packet[4:]

挑战三:多结果集处理

代理必须正确处理 SERVER_MORE_RESULTS_EXISTS 标志,将所有结果集转发完毕后再等待下一个命令。


11.10 注意事项

重要提醒

  1. 连接状态跟踪:代理必须跟踪每个连接的状态(是否在事务中、当前数据库等),以便做出正确的路由决策。

  2. 预处理语句路由:预处理语句的句柄只在创建它的连接上有效。代理不能将 COM_STMT_EXECUTE 路由到不同的后端。

  3. 会话变量SET 语句设置的会话变量只对当前连接生效。代理需要确保路由一致性。

  4. 错误处理:后端返回错误时,代理需要正确转发错误包,并清理内部状态。

  5. 超时管理:代理需要管理客户端和后端的超时,避免资源泄漏。

  6. SSL/TLS:如果客户端使用 SSL,代理需要先完成 TLS 升级再解析 MySQL 协议。


11.11 业务场景

场景一:应用层读写分离

使用 ProxySQL 实现自动读写分离:

-- ProxySQL 配置示例
-- 添加后端 MySQL 服务器
INSERT INTO mysql_servers (hostgroup_id, hostname, port)
VALUES (10, 'master.db', 3306);   -- 写组
INSERT INTO mysql_servers (hostgroup_id, hostname, port)
VALUES (20, 'slave1.db', 3306);   -- 读组
INSERT INTO mysql_servers (hostgroup_id, hostname, port)
VALUES (20, 'slave2.db', 3306);   -- 读组

-- 配置路由规则
INSERT INTO mysql_query_rules (rule_id, match_pattern, destination_hostgroup)
VALUES (1, '^SELECT', 20);        -- SELECT → 读组
INSERT INTO mysql_query_rules (rule_id, match_pattern, destination_hostgroup)
VALUES (2, '.*', 10);             -- 其他 → 写组

场景二:SQL 审计平台

通过代理拦截所有 SQL,写入审计日志:

Client → Proxy → MySQL
           ↓
     审计日志 (ES/Kafka)

场景三:灰度发布

在代理层根据用户 ID 或请求特征,将部分流量路由到新版本数据库。


11.12 扩展阅读


上一章10 - 复制协议 下一章12 - 最佳实践 —— 总结 MySQL 协议开发的最佳实践。