第09章 代理与分布式
第09章 代理与分布式
当单台 Memcached 无法满足需求时,分布式架构和代理层成为必然选择。本章深入讲解分片策略、连接池、一致性哈希和负载均衡的实现。
9.1 分布式架构概述
为什么需要分布式
| 问题 | 说明 |
|---|---|
| 内存限制 | 单机内存有限,无法缓存所有数据 |
| 连接限制 | 单机连接数有上限 |
| 带宽瓶颈 | 单机网络带宽有限 |
| 高可用 | 单点故障会导致缓存雪崩 |
分布式架构模式
模式一:客户端分片(推荐)
┌────────┐
│ Client │──hash(key)──▶┌──────────┐
│ │ │ Server 1 │
└────────┘ ├──────────┤
│ Server 2 │
├──────────┤
│ Server 3 │
└──────────┘
模式二:代理层分片
┌────────┐ ┌────────┐ ┌──────────┐
│ Client │───▶│ Proxy │───▶│ Server 1 │
│ │ │ │ ─┤ Server 2 │
└────────┘ └────────┘ ──┤ Server 3 │
└──────────┘
模式三:服务端集群(如 memcached集群配合ketama)
┌────────┐ ┌──────────┐ ┌──────────┐
│ Client │───▶│ Server 1 │◀──▶│ Server 2 │
└────────┘ └──────────┘ └──────────┘
▲ ▲
│ 同步 │
▼ ▼
┌──────────┐
│ Server 3 │
└──────────┘
9.2 一致性哈希
哈希取模的问题
最简单的分片方式是 hash(key) % N(N 为服务器数量),但存在严重问题:
当 N=3 时:
key="user:1001" → hash=7 → 7%3=1 → Server 2
当新增一台服务器(N=4)时:
key="user:1001" → hash=7 → 7%4=3 → Server 4
结论:几乎所有 key 的映射都会改变!
一致性哈希原理
一致性哈希将所有节点映射到一个虚拟的哈希环上:
0
│
┌────┴────┐
│ Node A │
│ (100) │
└────┬────┘
│
120° │ 240°
┌────┴────┐
│ Node B │
│ (200) │
└────┬────┘
│
│
┌────┴────┐
│ Node C │
│ (300) │
└────┬────┘
│
360°
Key "user:1001" → hash=150 → 落在 Node A(100) 和 Node B(200) 之间
→ 顺时针找到 Node B
Python 实现
#!/usr/bin/env python3
"""consistent_hash.py — 一致性哈希实现"""
import hashlib
from bisect import bisect_right
from typing import List, Optional
class ConsistentHash:
def __init__(self, nodes: List[str], virtual_nodes: int = 150):
"""
nodes: 物理节点列表(如 ['server1:11211', 'server2:11211'])
virtual_nodes: 每个物理节点的虚拟节点数量
"""
self.virtual_nodes = virtual_nodes
self.ring = {} # hash → node
self.sorted_keys = [] # 排序的哈希值
for node in nodes:
self.add_node(node)
def _hash(self, key: str) -> int:
"""计算哈希值"""
return int(hashlib.md5(key.encode()).hexdigest(), 16)
def add_node(self, node: str):
"""添加节点"""
for i in range(self.virtual_nodes):
virtual_key = f"{node}#{i}"
h = self._hash(virtual_key)
self.ring[h] = node
self.sorted_keys.append(h)
self.sorted_keys.sort()
def remove_node(self, node: str):
"""移除节点"""
for i in range(self.virtual_nodes):
virtual_key = f"{node}#{i}"
h = self._hash(virtual_key)
if h in self.ring:
del self.ring[h]
self.sorted_keys.remove(h)
def get_node(self, key: str) -> Optional[str]:
"""获取 key 对应的节点"""
if not self.ring:
return None
h = self._hash(key)
idx = bisect_right(self.sorted_keys, h)
if idx == len(self.sorted_keys):
idx = 0 # 绕回到环的起点
return self.ring[self.sorted_keys[idx]]
def get_nodes(self, key: str, count: int = 1) -> List[str]:
"""获取 key 对应的多个节点(用于副本)"""
if not self.ring:
return []
h = self._hash(key)
nodes = []
seen = set()
idx = bisect_right(self.sorted_keys, h)
for _ in range(len(self.sorted_keys)):
if idx >= len(self.sorted_keys):
idx = 0
node = self.ring[self.sorted_keys[idx]]
if node not in seen:
nodes.append(node)
seen.add(node)
if len(nodes) >= count:
break
idx += 1
return nodes
# 测试
ring = ConsistentHash([
"server1:11211",
"server2:11211",
"server3:11211"
])
# 查看分布
distribution = {}
for i in range(1000):
key = f"user:{i}"
node = ring.get_node(key)
distribution[node] = distribution.get(node, 0) + 1
print("数据分布:")
for node, count in sorted(distribution.items()):
print(f" {node}: {count} keys ({count/10:.1f}%)")
# 添加新节点后查看变化
ring.add_node("server4:11211")
moved = 0
for i in range(1000):
key = f"user:{i}"
new_node = ring.get_node(key)
# 比较是否改变了节点(需要预先保存旧映射)
print(f"\n添加新节点后约 {1/4*100:.0f}% 的 key 会迁移")
虚拟节点数量选择
| 虚拟节点数 | 均匀性 | 内存开销 | 建议场景 |
|---|---|---|---|
| 50 | 一般 | 低 | 节点数 > 100 |
| 150 | 良好 | 中 | 通用场景(推荐) |
| 250 | 优秀 | 较高 | 节点数 < 10 |
| 500 | 极好 | 高 | 测试/基准 |
9.3 连接池
为什么需要连接池
| 问题 | 说明 |
|---|---|
| TCP 连接开销 | 每次建立连接需要 3 次握手 |
| 文件描述符限制 | 每个连接占用一个 fd |
| 连接数限制 | 服务端有最大连接数限制 |
| 复用效率 | 连接池可显著降低延迟 |
连接池实现
#!/usr/bin/env python3
"""connection_pool.py — Memcached 连接池"""
import socket
import threading
import time
from queue import Queue, Empty
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional
@dataclass
class PooledConnection:
socket: socket.socket
created_at: float
last_used: float
in_use: bool = False
class MemcachedConnectionPool:
def __init__(self, host: str = '127.0.0.1', port: int = 11211,
min_size: int = 5, max_size: int = 20,
max_idle_time: int = 300):
self.host = host
self.port = port
self.min_size = min_size
self.max_size = max_size
self.max_idle_time = max_idle_time
self._pool = Queue(maxsize=max_size)
self._all_connections = []
self._lock = threading.Lock()
self._current_size = 0
# 预创建最小连接数
for _ in range(min_size):
conn = self._create_connection()
if conn:
self._pool.put(conn)
def _create_connection(self) -> Optional[PooledConnection]:
"""创建新连接"""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((self.host, self.port))
conn = PooledConnection(
socket=sock,
created_at=time.time(),
last_used=time.time()
)
with self._lock:
self._all_connections.append(conn)
self._current_size += 1
return conn
except Exception as e:
print(f"连接创建失败: {e}")
return None
@contextmanager
def acquire(self):
"""获取连接(上下文管理器)"""
conn = None
try:
# 尝试从池中获取
try:
conn = self._pool.get_nowait()
# 检查连接是否过期
if time.time() - conn.last_used > self.max_idle_time:
self._close_connection(conn)
conn = None
except Empty:
pass
# 池中无可用连接,尝试创建新连接
if conn is None:
with self._lock:
if self._current_size < self.max_size:
conn = self._create_connection()
# 等待可用连接
if conn is None:
try:
conn = self._pool.get(timeout=5)
except Empty:
raise TimeoutError("获取连接超时")
conn.in_use = True
conn.last_used = time.time()
yield conn.socket
finally:
if conn:
conn.in_use = False
conn.last_used = time.time()
# 验证连接是否仍然可用
try:
conn.socket.sendall(b"version\r\n")
resp = conn.socket.recv(1024)
if b"VERSION" in resp:
self._pool.put(conn)
else:
self._close_connection(conn)
except Exception:
self._close_connection(conn)
def _close_connection(self, conn: PooledConnection):
"""关闭连接"""
try:
conn.socket.sendall(b"quit\r\n")
conn.socket.close()
except Exception:
pass
finally:
with self._lock:
if conn in self._all_connections:
self._all_connections.remove(conn)
self._current_size -= 1
def close_all(self):
"""关闭所有连接"""
while not self._pool.empty():
try:
conn = self._pool.get_nowait()
self._close_connection(conn)
except Empty:
break
@property
def stats(self) -> dict:
"""连接池统计"""
return {
'total': self._current_size,
'available': self._pool.qsize(),
'in_use': self._current_size - self._pool.qsize()
}
# 使用示例
pool = MemcachedConnectionPool(
host='127.0.0.1', port=11211,
min_size=5, max_size=20
)
# 使用连接
with pool.acquire() as sock:
sock.sendall(b"set pool:test 0 60 5\r\nhello\r\n")
print(sock.recv(1024).decode())
# 连接自动归还到池中
print(f"连接池状态: {pool.stats}")
pool.close_all()
9.4 带一致性哈希的分布式客户端
#!/usr/bin/env python3
"""distributed_client.py — 带一致性哈希的分布式 Memcached 客户端"""
import socket
import json
from typing import List, Optional
class DistributedMemcached:
def __init__(self, servers: List[str], virtual_nodes: int = 150):
"""
servers: ['host1:port1', 'host2:port2', ...]
"""
self.servers = servers
self.hash_ring = ConsistentHash(servers, virtual_nodes)
self.connections = {} # server → socket
# 建立连接
for server in servers:
host, port = server.split(':')
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((host, int(port)))
self.connections[server] = sock
def _get_connection(self, key: str) -> tuple[socket.socket, str]:
server = self.hash_ring.get_node(key)
return self.connections[server], server
def set(self, key: str, value: bytes, flags: int = 0,
exptime: int = 0) -> bool:
conn, _ = self._get_connection(key)
data = value if isinstance(value, bytes) else value.encode()
cmd = f"set {key} {flags} {exptime} {len(data)}\r\n"
conn.sendall(cmd.encode() + data + b"\r\n")
return b"STORED" in conn.recv(1024)
def get(self, key: str) -> Optional[bytes]:
conn, _ = self._get_connection(key)
conn.sendall(f"get {key}\r\n".encode())
buffer = b""
while True:
chunk = conn.recv(65536)
buffer += chunk
if b"END\r\n" in buffer:
break
lines = buffer.split(b"\r\n")
if lines[0].startswith(b"VALUE"):
return lines[1]
return None
def get_multi(self, keys: List[str]) -> dict[str, bytes]:
"""批量获取(自动按服务器分组)"""
# 按服务器分组
server_keys = {}
for key in keys:
server = self.hash_ring.get_node(key)
if server not in server_keys:
server_keys[server] = []
server_keys[server].append(key)
# 并行查询各服务器
result = {}
for server, skeys in server_keys.items():
conn = self.connections[server]
cmd = f"get {' '.join(skeys)}\r\n"
conn.sendall(cmd.encode())
buffer = b""
while True:
chunk = conn.recv(65536)
buffer += chunk
if b"END\r\n" in buffer:
break
lines = buffer.split(b"\r\n")
i = 0
while i < len(lines):
line = lines[i]
if line == b"END":
break
if line.startswith(b"VALUE "):
key = line.decode().split()[1]
i += 1
if i < len(lines):
result[key] = lines[i]
i += 1
return result
def delete(self, key: str) -> bool:
conn, _ = self._get_connection(key)
conn.sendall(f"delete {key}\r\n".encode())
return b"DELETED" in conn.recv(1024)
def stats(self) -> dict[str, dict]:
"""获取所有服务器的统计"""
result = {}
for server, conn in self.connections.items():
conn.sendall(b"stats\r\n")
buffer = b""
while True:
chunk = conn.recv(8192)
buffer += chunk
if b"END\r\n" in buffer:
break
stats = {}
for line in buffer.decode().split("\r\n"):
if line.startswith("STAT"):
parts = line.split()
stats[parts[1]] = parts[2]
result[server] = stats
return result
def close(self):
for conn in self.connections.values():
try:
conn.sendall(b"quit\r\n")
conn.close()
except Exception:
pass
# 使用示例
client = DistributedMemcached([
'127.0.0.1:11211',
# '127.0.0.1:11212',
# '127.0.0.1:11213',
])
# 写入数据
client.set("user:1001", json.dumps({"name": "Bob"}).encode())
client.set("user:1002", json.dumps({"name": "Alice"}).encode())
# 读取数据
data = client.get("user:1001")
if data:
print(f"user:1001 = {json.loads(data)}")
# 批量读取
results = client.get_multi(["user:1001", "user:1002"])
for key, val in results.items():
print(f"{key} = {json.loads(val)}")
# 查看统计
stats = client.stats()
for server, s in stats.items():
print(f"{server}: items={s.get('curr_items', 'N/A')}, "
f"hit_rate={int(s.get('get_hits', 0)) / max(1, int(s.get('get_hits', 0)) + int(s.get('get_misses', 0))) * 100:.1f}%")
client.close()
9.5 Memcached 代理(mcrouter)
mcrouter 简介
mcrouter 是 Facebook 开发的 Memcached 代理,支持:
- 连接池和多路复用
- 一致性哈希
- 复制和故障转移
- 流量镜像
- 冷热分离
mcrouter 配置示例
{
"pools": {
"A": {
"servers": [
"127.0.0.1:11211",
"127.0.0.1:11212"
]
},
"B": {
"servers": [
"127.0.0.1:11213",
"127.0.0.1:11214"
]
}
},
"route": {
"type": "OperationSelectorRoute",
"default_policy": {
"type": "PoolRoute",
"pool": "A"
},
"operation_policies": {
"get": {
"type": "FailoverRoute",
"children": [
{"type": "PoolRoute", "pool": "A"},
{"type": "PoolRoute", "pool": "B"}
]
}
}
}
}
启动 mcrouter
# 安装
sudo apt install mcrouter
# 启动
mcrouter --config-file=/etc/mcrouter/config.json \
--port=11210 \
--listen-addr=127.0.0.1
# 连接到代理(像连接普通 Memcached 一样使用)
echo "version" | nc 127.0.0.1 11210
9.6 负载均衡策略
策略对比
| 策略 | 说明 | 适用场景 |
|---|---|---|
| 一致性哈希 | 相同 key 总是路由到同一节点 | 通用缓存(推荐) |
| 取模哈希 | hash(key) % N | 节点数固定 |
| 范围分片 | 按 key 前缀分配节点 | 多租户系统 |
| 随机 | 随机选择节点 | 无状态数据 |
| 轮询 | Round-Robin | 负载均匀 |
范围分片实现
class RangeSharding:
"""按 key 前缀分片"""
def __init__(self, servers: List[str]):
self.servers = servers
self.prefix_map = {}
def add_rule(self, prefix: str, server: str):
self.prefix_map[prefix] = server
def get_server(self, key: str) -> str:
for prefix, server in self.prefix_map.items():
if key.startswith(prefix):
return server
# 默认使用一致性哈希
return self.servers[hash(key) % len(self.servers)]
# 使用
sharding = RangeSharding(['s1:11211', 's2:11211', 's3:11211'])
sharding.add_rule("user:", "s1:11211")
sharding.add_rule("session:", "s2:11211")
sharding.add_rule("product:", "s3:11211")
print(sharding.get_server("user:1001")) # s1:11211
print(sharding.get_server("session:abc")) # s2:11211
print(sharding.get_server("product:2001")) # s3:11211
9.7 故障转移
自动故障检测
import socket
import time
from typing import Dict
class HealthChecker:
def __init__(self, servers: List[str], check_interval: int = 5):
self.servers = servers
self.check_interval = check_interval
self.healthy = {s: True for s in servers}
self.last_check = {s: 0 for s in servers}
def is_healthy(self, server: str) -> bool:
return self.healthy.get(server, False)
def check(self, server: str) -> bool:
try:
host, port = server.split(':')
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(2)
sock.connect((host, int(port)))
sock.sendall(b"version\r\n")
resp = sock.recv(1024)
sock.close()
is_ok = b"VERSION" in resp
self.healthy[server] = is_ok
self.last_check[server] = time.time()
return is_ok
except Exception:
self.healthy[server] = False
self.last_check[server] = time.time()
return False
def check_all(self) -> Dict[str, bool]:
results = {}
for server in self.servers:
results[server] = self.check(server)
return results
def get_healthy_servers(self) -> List[str]:
return [s for s in self.servers if self.healthy[s]]
# 故障转移客户端
class FailoverClient:
def __init__(self, servers: List[str]):
self.hash_ring = ConsistentHash(servers)
self.connections: Dict[str, socket.socket] = {}
self.health_checker = HealthChecker(servers)
for server in servers:
self._connect(server)
def _connect(self, server: str):
try:
host, port = server.split(':')
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((host, int(port)))
self.connections[server] = sock
except Exception as e:
print(f"连接失败 {server}: {e}")
def _get_connection(self, key: str, exclude: set = None):
"""获取连接,支持排除故障节点"""
if exclude is None:
exclude = set()
# 获取所有候选节点
candidates = self.hash_ring.get_nodes(key, count=len(self.servers))
for server in candidates:
if server in exclude:
continue
if not self.health_checker.is_healthy(server):
exclude.add(server)
continue
try:
conn = self.connections.get(server)
if conn:
return conn, server
except Exception:
exclude.add(server)
raise ConnectionError(f"所有候选节点不可用: {key}")
def get(self, key: str) -> Optional[bytes]:
excluded = set()
last_error = None
for _ in range(min(3, len(self.servers))):
try:
conn, server = self._get_connection(key, excluded)
conn.sendall(f"get {key}\r\n".encode())
buffer = b""
conn.settimeout(5)
while True:
chunk = conn.recv(65536)
buffer += chunk
if b"END\r\n" in buffer:
break
lines = buffer.split(b"\r\n")
if lines[0].startswith(b"VALUE"):
return lines[1]
return None
except Exception as e:
last_error = e
continue
raise ConnectionError(f"故障转移失败: {last_error}")
9.8 业务场景
场景一:电商缓存架构
class EcommerceCache:
def __init__(self):
# 不同业务使用不同的缓存集群
self.product_cache = DistributedMemcached([
'cache-product-1:11211',
'cache-product-2:11211',
])
self.session_cache = DistributedMemcached([
'cache-session-1:11211',
'cache-session-2:11211',
])
self.user_cache = DistributedMemcached([
'cache-user-1:11211',
'cache-user-2:11211',
])
def get_product(self, product_id: int) -> dict:
data = self.product_cache.get(f"product:{product_id}")
return json.loads(data) if data else None
def cache_product(self, product_id: int, data: dict, ttl: int = 600):
self.product_cache.set(
f"product:{product_id}",
json.dumps(data).encode(),
exptime=ttl
)
def get_user_session(self, session_id: str) -> dict:
data = self.session_cache.get(f"session:{session_id}")
return json.loads(data) if data else None
场景二:多级缓存架构
class MultiLevelCache:
"""L1(本地) + L2(Memcached) 多级缓存"""
def __init__(self, servers: List[str]):
self.l1_cache = {} # 本地内存缓存
self.l1_ttl = {}
self.l2_client = DistributedMemcached(servers)
self.l1_max_size = 1000
self.l1_ttl_seconds = 60
def get(self, key: str) -> Optional[bytes]:
# L1 命中
if key in self.l1_cache:
if self.l1_ttl.get(key, 0) > time.time():
return self.l1_cache[key]
else:
del self.l1_cache[key]
del self.l1_ttl[key]
# L2 查询
data = self.l2_client.get(key)
if data:
# 回填 L1
if len(self.l1_cache) >= self.l1_max_size:
self._evict_l1()
self.l1_cache[key] = data
self.l1_ttl[key] = time.time() + self.l1_ttl_seconds
return data
def set(self, key: str, value: bytes, ttl: int = 300):
# 写入 L2
self.l2_client.set(key, value, exptime=ttl)
# 更新 L1
self.l1_cache[key] = value
self.l1_ttl[key] = time.time() + self.l1_ttl_seconds
def _evict_l1(self):
"""LRU 淘汰"""
now = time.time()
expired = [k for k, t in self.l1_ttl.items() if t < now]
for k in expired:
del self.l1_cache[k]
del self.l1_ttl[k]
# 如果仍然过大,随机淘汰
while len(self.l1_cache) >= self.l1_max_size:
self.l1_cache.pop(next(iter(self.l1_cache)))
9.9 注意事项
| 编号 | 注意事项 | 说明 |
|---|---|---|
| 1 | 节点增减时的数据迁移 | 一致性哈希只影响 1/N 的 key |
| 2 | 连接池大小设置 | 根据并发量设置,避免连接过多或过少 |
| 3 | 故障检测间隔 | 太频繁影响性能,太稀疏影响可用性 |
| 4 | 热点 key 问题 | 某些 key 访问量极高,考虑本地缓存 |
| 5 | 批量操作的路由 | get_multi 需要按服务器分组 |
| 6 | 时钟同步 | CAS unique 依赖服务端时钟 |
9.10 扩展阅读
上一章: 第08章 Meta 协议 下一章: 第10章 最佳实践 — 协议选择、客户端开发、性能优化与安全加固。