强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

TCP/UDP 网络协议教程 / 11-TCP 编程实战

11 - TCP 编程实战

11.1 TCP 服务器基础架构

"""完整的 TCP 回显服务器"""
import socket
import threading
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class TCPServer:
    def __init__(self, host='0.0.0.0', port=8080):
        self.host = host
        self.port = port
        self.server_sock = None
        self.running = False
    
    def start(self):
        """启动服务器"""
        self.server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.server_sock.bind((self.host, self.port))
        self.server_sock.listen(128)
        self.running = True
        
        logger.info(f"服务器启动: {self.host}:{self.port}")
        
        while self.running:
            try:
                client_sock, addr = self.server_sock.accept()
                logger.info(f"新连接: {addr}")
                
                t = threading.Thread(
                    target=self.handle_client,
                    args=(client_sock, addr),
                    daemon=True
                )
                t.start()
            except OSError:
                break
    
    def handle_client(self, client_sock, addr):
        """处理客户端连接"""
        try:
            while True:
                data = client_sock.recv(4096)
                if not data:
                    break
                client_sock.sendall(data)
        except ConnectionResetError:
            logger.warning(f"连接重置: {addr}")
        except Exception as e:
            logger.error(f"错误: {e}")
        finally:
            client_sock.close()
            logger.info(f"连接关闭: {addr}")
    
    def stop(self):
        """停止服务器"""
        self.running = False
        if self.server_sock:
            self.server_sock.close()

11.2 多线程 vs 异步模型

多线程模型

"""线程池模型"""
from concurrent.futures import ThreadPoolExecutor
import socket

class ThreadPoolTCPServer:
    def __init__(self, host, port, max_workers=100):
        self.host = host
        self.port = port
        self.pool = ThreadPoolExecutor(max_workers=max_workers)
    
    def start(self):
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
            server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            server.bind((self.host, self.port))
            server.listen(128)
            
            while True:
                client, addr = server.accept()
                self.pool.submit(self.handle_client, client, addr)
    
    def handle_client(self, client_sock, addr):
        with client_sock:
            try:
                while True:
                    data = client_sock.recv(4096)
                    if not data:
                        break
                    client_sock.sendall(data)
            except ConnectionResetError:
                pass

asyncio 异步模型

"""asyncio TCP 服务器"""
import asyncio

async def handle_client(reader, writer):
    """处理客户端连接"""
    addr = writer.get_extra_info('peername')
    print(f"新连接: {addr}")
    
    try:
        while True:
            data = await reader.read(4096)
            if not data:
                break
            writer.write(data)
            await writer.drain()
    except ConnectionResetError:
        pass
    finally:
        writer.close()
        print(f"连接关闭: {addr}")

async def main():
    server = await asyncio.start_server(handle_client, '0.0.0.0', 8080)
    async with server:
        await server.serve_forever()

# asyncio.run(main())

11.3 粘包问题

TCP 是字节流,没有消息边界

发送方:
send(b"Hello")
send(b"World")

接收方可能:
recv() → b"HelloWorld"  (粘包)
recv() → b"Hello"       (正常)
recv() → b"World"       (正常)

或者:
recv() → b"Hel"         (拆包)
recv() → b"loWorl"      (拆包)
recv() → b"d"           (拆包)

方案一:固定长度消息

"""固定长度消息协议"""
import socket

FIXED_SIZE = 1024

def send_fixed(sock, data: bytes):
    """发送固定长度消息"""
    if len(data) > FIXED_SIZE:
        raise ValueError("数据过大")
    # 填充到固定长度
    padded = data.ljust(FIXED_SIZE, b'\x00')
    sock.sendall(padded)

def recv_fixed(sock) -> bytes:
    """接收固定长度消息"""
    chunks = []
    received = 0
    while received < FIXED_SIZE:
        chunk = sock.recv(FIXED_SIZE - received)
        if not chunk:
            raise ConnectionError("连接关闭")
        chunks.append(chunk)
        received += len(chunk)
    return b''.join(chunks).rstrip(b'\x00')

# 使用
# send_fixed(sock, b"Hello")
# data = recv_fixed(sock)

方案二:分隔符

"""分隔符消息协议"""
import socket

DELIMITER = b'\r\n'

def send_delimited(sock, data: bytes):
    """发送分隔符消息"""
    # 转义数据中的分隔符
    escaped = data.replace(b'\\', b'\\\\').replace(DELIMITER, b'\\r\\n')
    sock.sendall(escaped + DELIMITER)

class DelimitedReceiver:
    def __init__(self, sock):
        self.sock = sock
        self.buffer = b''
    
    def recv_message(self) -> bytes:
        """接收一条消息"""
        while True:
            # 检查缓冲区
            pos = self.buffer.find(DELIMITER)
            if pos >= 0:
                message = self.buffer[:pos]
                self.buffer = self.buffer[pos + len(DELIMITER):]
                return message.replace(b'\\r\\n', DELIMITER).replace(b'\\\\', b'\\')
            
            # 读取更多数据
            data = self.sock.recv(4096)
            if not data:
                raise ConnectionError("连接关闭")
            self.buffer += data

方案三:长度前缀(推荐)

"""长度前缀消息协议"""
import socket
import struct

def send_message(sock, data: bytes):
    """发送长度前缀消息"""
    # 4 字节头部 + 数据
    header = struct.pack('!I', len(data))
    sock.sendall(header + data)

def recv_message(sock) -> bytes:
    """接收长度前缀消息"""
    # 先接收 4 字节头部
    header = recv_exact(sock, 4)
    length = struct.unpack('!I', header)[0]
    
    if length > 10 * 1024 * 1024:  # 最大 10MB
        raise ValueError(f"消息过大: {length}")
    
    # 接收数据
    return recv_exact(sock, length)

def recv_exact(sock, size: int) -> bytes:
    """精确接收指定字节数"""
    chunks = []
    received = 0
    while received < size:
        chunk = sock.recv(size - received)
        if not chunk:
            raise ConnectionError("连接关闭")
        chunks.append(chunk)
        received += len(chunk)
    return b''.join(chunks)

# 使用示例
# send_message(sock, b"Hello World")
# data = recv_message(sock)

11.4 消息协议设计

TLV 格式

"""TLV (Type-Length-Value) 消息协议"""
import struct
from enum import IntEnum

class MessageType(IntEnum):
    HEARTBEAT = 0x01
    LOGIN = 0x02
    MESSAGE = 0x03
    FILE = 0x04

# 消息格式:
# [类型 1B][长度 4B][数据 NB]

def encode_message(msg_type: int, data: bytes) -> bytes:
    """编码消息"""
    header = struct.pack('!BI', msg_type, len(data))
    return header + data

def decode_message(sock) -> tuple:
    """解码消息"""
    # 读取头部(5 字节)
    header = recv_exact(sock, 5)
    msg_type, length = struct.unpack('!BI', header)
    
    # 读取数据
    data = recv_exact(sock, length)
    return msg_type, data

JSON 协议

"""JSON 消息协议"""
import json
import struct

def send_json(sock, obj: dict):
    """发送 JSON 消息"""
    data = json.dumps(obj, ensure_ascii=False).encode('utf-8')
    header = struct.pack('!I', len(data))
    sock.sendall(header + data)

def recv_json(sock) -> dict:
    """接收 JSON 消息"""
    header = recv_exact(sock, 4)
    length = struct.unpack('!I', header)[0]
    data = recv_exact(sock, length)
    return json.loads(data.decode('utf-8'))

# 使用
# send_json(sock, {"type": "login", "user": "admin", "pass": "123456"})
# msg = recv_json(sock)

11.5 心跳机制

"""心跳机制实现"""
import socket
import time
import threading

class HeartbeatConnection:
    def __init__(self, sock, interval=30, timeout=90):
        self.sock = sock
        self.interval = interval
        self.timeout = timeout
        self.last_recv = time.time()
        self.running = True
    
    def start_heartbeat(self):
        """启动心跳线程"""
        t = threading.Thread(target=self._heartbeat_loop, daemon=True)
        t.start()
    
    def _heartbeat_loop(self):
        while self.running:
            time.sleep(self.interval)
            
            # 检查是否超时
            if time.time() - self.last_recv > self.timeout:
                print("心跳超时")
                self.sock.close()
                break
            
            # 发送心跳
            try:
                self.sock.sendall(b'\x00\x00\x00\x01')  # 心跳包
            except:
                self.sock.close()
                break
    
    def on_data_received(self, data):
        """收到数据时更新时间戳"""
        self.last_recv = time.time()

11.6 连接池

"""TCP 连接池"""
import socket
import queue
import threading

class TCPConnectionPool:
    def __init__(self, host, port, max_size=10):
        self.host = host
        self.port = port
        self.max_size = max_size
        self.pool = queue.Queue(maxsize=max_size)
        self.size = 0
        self.lock = threading.Lock()
    
    def get_connection(self):
        """获取连接"""
        try:
            return self.pool.get_nowait()
        except queue.Empty:
            with self.lock:
                if self.size < self.max_size:
                    self.size += 1
                    return self._create_connection()
            # 等待归还
            return self.pool.get()
    
    def return_connection(self, sock):
        """归还连接"""
        try:
            self.pool.put_nowait(sock)
        except queue.Full:
            sock.close()
            with self.lock:
                self.size -= 1
    
    def _create_connection(self):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.connect((self.host, self.port))
        return sock
    
    def close_all(self):
        """关闭所有连接"""
        while not self.pool.empty():
            try:
                sock = self.pool.get_nowait()
                sock.close()
            except:
                pass

# 使用
pool = TCPConnectionPool('localhost', 8080, max_size=10)
sock = pool.get_connection()
try:
    sock.sendall(b"request")
    data = sock.recv(4096)
finally:
    pool.return_connection(sock)

11.7 注意事项

⚠️ 必须处理粘包:TCP 没有消息边界,不处理粘包会导致数据解析错误

⚠️ recv 不保证读完:必须循环读取直到收到完整消息

⚠️ sendall 也不保证对端收到:只保证数据进入了发送缓冲区

⚠️ 连接断开检测:recv 返回空 bytes 或抛出 ConnectionResetError

⚠️ 心跳保活:长连接必须有心跳机制,防止 NAT/防火墙断开连接

11.8 扩展阅读


下一章12 - UDP 编程实战 - UDP 收发、广播/多播