TCP/UDP 网络协议教程 / 11-TCP 编程实战
11 - TCP 编程实战
11.1 TCP 服务器基础架构
"""完整的 TCP 回显服务器"""
import socket
import threading
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TCPServer:
def __init__(self, host='0.0.0.0', port=8080):
self.host = host
self.port = port
self.server_sock = None
self.running = False
def start(self):
"""启动服务器"""
self.server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.server_sock.bind((self.host, self.port))
self.server_sock.listen(128)
self.running = True
logger.info(f"服务器启动: {self.host}:{self.port}")
while self.running:
try:
client_sock, addr = self.server_sock.accept()
logger.info(f"新连接: {addr}")
t = threading.Thread(
target=self.handle_client,
args=(client_sock, addr),
daemon=True
)
t.start()
except OSError:
break
def handle_client(self, client_sock, addr):
"""处理客户端连接"""
try:
while True:
data = client_sock.recv(4096)
if not data:
break
client_sock.sendall(data)
except ConnectionResetError:
logger.warning(f"连接重置: {addr}")
except Exception as e:
logger.error(f"错误: {e}")
finally:
client_sock.close()
logger.info(f"连接关闭: {addr}")
def stop(self):
"""停止服务器"""
self.running = False
if self.server_sock:
self.server_sock.close()
11.2 多线程 vs 异步模型
多线程模型
"""线程池模型"""
from concurrent.futures import ThreadPoolExecutor
import socket
class ThreadPoolTCPServer:
def __init__(self, host, port, max_workers=100):
self.host = host
self.port = port
self.pool = ThreadPoolExecutor(max_workers=max_workers)
def start(self):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind((self.host, self.port))
server.listen(128)
while True:
client, addr = server.accept()
self.pool.submit(self.handle_client, client, addr)
def handle_client(self, client_sock, addr):
with client_sock:
try:
while True:
data = client_sock.recv(4096)
if not data:
break
client_sock.sendall(data)
except ConnectionResetError:
pass
asyncio 异步模型
"""asyncio TCP 服务器"""
import asyncio
async def handle_client(reader, writer):
"""处理客户端连接"""
addr = writer.get_extra_info('peername')
print(f"新连接: {addr}")
try:
while True:
data = await reader.read(4096)
if not data:
break
writer.write(data)
await writer.drain()
except ConnectionResetError:
pass
finally:
writer.close()
print(f"连接关闭: {addr}")
async def main():
server = await asyncio.start_server(handle_client, '0.0.0.0', 8080)
async with server:
await server.serve_forever()
# asyncio.run(main())
11.3 粘包问题
TCP 是字节流,没有消息边界
发送方:
send(b"Hello")
send(b"World")
接收方可能:
recv() → b"HelloWorld" (粘包)
recv() → b"Hello" (正常)
recv() → b"World" (正常)
或者:
recv() → b"Hel" (拆包)
recv() → b"loWorl" (拆包)
recv() → b"d" (拆包)
方案一:固定长度消息
"""固定长度消息协议"""
import socket
FIXED_SIZE = 1024
def send_fixed(sock, data: bytes):
"""发送固定长度消息"""
if len(data) > FIXED_SIZE:
raise ValueError("数据过大")
# 填充到固定长度
padded = data.ljust(FIXED_SIZE, b'\x00')
sock.sendall(padded)
def recv_fixed(sock) -> bytes:
"""接收固定长度消息"""
chunks = []
received = 0
while received < FIXED_SIZE:
chunk = sock.recv(FIXED_SIZE - received)
if not chunk:
raise ConnectionError("连接关闭")
chunks.append(chunk)
received += len(chunk)
return b''.join(chunks).rstrip(b'\x00')
# 使用
# send_fixed(sock, b"Hello")
# data = recv_fixed(sock)
方案二:分隔符
"""分隔符消息协议"""
import socket
DELIMITER = b'\r\n'
def send_delimited(sock, data: bytes):
"""发送分隔符消息"""
# 转义数据中的分隔符
escaped = data.replace(b'\\', b'\\\\').replace(DELIMITER, b'\\r\\n')
sock.sendall(escaped + DELIMITER)
class DelimitedReceiver:
def __init__(self, sock):
self.sock = sock
self.buffer = b''
def recv_message(self) -> bytes:
"""接收一条消息"""
while True:
# 检查缓冲区
pos = self.buffer.find(DELIMITER)
if pos >= 0:
message = self.buffer[:pos]
self.buffer = self.buffer[pos + len(DELIMITER):]
return message.replace(b'\\r\\n', DELIMITER).replace(b'\\\\', b'\\')
# 读取更多数据
data = self.sock.recv(4096)
if not data:
raise ConnectionError("连接关闭")
self.buffer += data
方案三:长度前缀(推荐)
"""长度前缀消息协议"""
import socket
import struct
def send_message(sock, data: bytes):
"""发送长度前缀消息"""
# 4 字节头部 + 数据
header = struct.pack('!I', len(data))
sock.sendall(header + data)
def recv_message(sock) -> bytes:
"""接收长度前缀消息"""
# 先接收 4 字节头部
header = recv_exact(sock, 4)
length = struct.unpack('!I', header)[0]
if length > 10 * 1024 * 1024: # 最大 10MB
raise ValueError(f"消息过大: {length}")
# 接收数据
return recv_exact(sock, length)
def recv_exact(sock, size: int) -> bytes:
"""精确接收指定字节数"""
chunks = []
received = 0
while received < size:
chunk = sock.recv(size - received)
if not chunk:
raise ConnectionError("连接关闭")
chunks.append(chunk)
received += len(chunk)
return b''.join(chunks)
# 使用示例
# send_message(sock, b"Hello World")
# data = recv_message(sock)
11.4 消息协议设计
TLV 格式
"""TLV (Type-Length-Value) 消息协议"""
import struct
from enum import IntEnum
class MessageType(IntEnum):
HEARTBEAT = 0x01
LOGIN = 0x02
MESSAGE = 0x03
FILE = 0x04
# 消息格式:
# [类型 1B][长度 4B][数据 NB]
def encode_message(msg_type: int, data: bytes) -> bytes:
"""编码消息"""
header = struct.pack('!BI', msg_type, len(data))
return header + data
def decode_message(sock) -> tuple:
"""解码消息"""
# 读取头部(5 字节)
header = recv_exact(sock, 5)
msg_type, length = struct.unpack('!BI', header)
# 读取数据
data = recv_exact(sock, length)
return msg_type, data
JSON 协议
"""JSON 消息协议"""
import json
import struct
def send_json(sock, obj: dict):
"""发送 JSON 消息"""
data = json.dumps(obj, ensure_ascii=False).encode('utf-8')
header = struct.pack('!I', len(data))
sock.sendall(header + data)
def recv_json(sock) -> dict:
"""接收 JSON 消息"""
header = recv_exact(sock, 4)
length = struct.unpack('!I', header)[0]
data = recv_exact(sock, length)
return json.loads(data.decode('utf-8'))
# 使用
# send_json(sock, {"type": "login", "user": "admin", "pass": "123456"})
# msg = recv_json(sock)
11.5 心跳机制
"""心跳机制实现"""
import socket
import time
import threading
class HeartbeatConnection:
def __init__(self, sock, interval=30, timeout=90):
self.sock = sock
self.interval = interval
self.timeout = timeout
self.last_recv = time.time()
self.running = True
def start_heartbeat(self):
"""启动心跳线程"""
t = threading.Thread(target=self._heartbeat_loop, daemon=True)
t.start()
def _heartbeat_loop(self):
while self.running:
time.sleep(self.interval)
# 检查是否超时
if time.time() - self.last_recv > self.timeout:
print("心跳超时")
self.sock.close()
break
# 发送心跳
try:
self.sock.sendall(b'\x00\x00\x00\x01') # 心跳包
except:
self.sock.close()
break
def on_data_received(self, data):
"""收到数据时更新时间戳"""
self.last_recv = time.time()
11.6 连接池
"""TCP 连接池"""
import socket
import queue
import threading
class TCPConnectionPool:
def __init__(self, host, port, max_size=10):
self.host = host
self.port = port
self.max_size = max_size
self.pool = queue.Queue(maxsize=max_size)
self.size = 0
self.lock = threading.Lock()
def get_connection(self):
"""获取连接"""
try:
return self.pool.get_nowait()
except queue.Empty:
with self.lock:
if self.size < self.max_size:
self.size += 1
return self._create_connection()
# 等待归还
return self.pool.get()
def return_connection(self, sock):
"""归还连接"""
try:
self.pool.put_nowait(sock)
except queue.Full:
sock.close()
with self.lock:
self.size -= 1
def _create_connection(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((self.host, self.port))
return sock
def close_all(self):
"""关闭所有连接"""
while not self.pool.empty():
try:
sock = self.pool.get_nowait()
sock.close()
except:
pass
# 使用
pool = TCPConnectionPool('localhost', 8080, max_size=10)
sock = pool.get_connection()
try:
sock.sendall(b"request")
data = sock.recv(4096)
finally:
pool.return_connection(sock)
11.7 注意事项
⚠️ 必须处理粘包:TCP 没有消息边界,不处理粘包会导致数据解析错误
⚠️ recv 不保证读完:必须循环读取直到收到完整消息
⚠️ sendall 也不保证对端收到:只保证数据进入了发送缓冲区
⚠️ 连接断开检测:recv 返回空 bytes 或抛出 ConnectionResetError
⚠️ 心跳保活:长连接必须有心跳机制,防止 NAT/防火墙断开连接
11.8 扩展阅读
- Python asyncio 文档
- Protobuf - 高效序列化
- msgpack - 高效二进制 JSON
下一章:12 - UDP 编程实战 - UDP 收发、广播/多播