# P2P Network Communication - Message Handler """ 消息处理器模块 负责消息的序列化、反序列化、校验和路由 """ import json import struct import hashlib import logging from typing import Callable, Dict, Optional, Any from dataclasses import dataclass from shared.models import Message, MessageType # 设置日志 logger = logging.getLogger(__name__) class MessageValidationError(Exception): """消息验证错误""" pass class MessageSerializationError(Exception): """消息序列化错误""" pass class MessageRoutingError(Exception): """消息路由错误""" pass # 消息处理回调类型 MessageCallback = Callable[[Message], None] class MessageHandler: """ 消息处理器 负责处理和路由各类消息,包括: - 消息序列化/反序列化 - 消息校验 - 消息路由到对应处理器 需求: 3.1, 3.2 """ # 消息头格式: 4字节长度 + 4字节版本 HEADER_FORMAT = "!II" # Network byte order, 2 unsigned ints HEADER_SIZE = struct.calcsize(HEADER_FORMAT) PROTOCOL_VERSION = 1 def __init__(self): """初始化消息处理器""" self._handlers: Dict[MessageType, MessageCallback] = {} self._default_handler: Optional[MessageCallback] = None def serialize(self, message: Message) -> bytes: """ 序列化消息为字节流 格式: [4字节长度][4字节版本][JSON数据] Args: message: 要序列化的消息对象 Returns: 序列化后的字节流 Raises: MessageSerializationError: 序列化失败时抛出 """ try: # 将消息转换为JSON字节流 json_data = json.dumps(message.to_dict(), ensure_ascii=False).encode('utf-8') # 构建消息头 header = struct.pack( self.HEADER_FORMAT, len(json_data), self.PROTOCOL_VERSION ) return header + json_data except (TypeError, ValueError, json.JSONEncodeError) as e: raise MessageSerializationError(f"Failed to serialize message: {e}") def deserialize(self, data: bytes) -> Message: """ 反序列化字节流为消息 Args: data: 包含消息头和消息体的字节流 Returns: 反序列化后的消息对象 Raises: MessageSerializationError: 反序列化失败时抛出 """ try: if len(data) < self.HEADER_SIZE: raise MessageSerializationError( f"Data too short: expected at least {self.HEADER_SIZE} bytes, got {len(data)}" ) # 解析消息头 header = data[:self.HEADER_SIZE] payload_length, version = struct.unpack(self.HEADER_FORMAT, header) # 版本检查 if version != self.PROTOCOL_VERSION: raise MessageSerializationError( f"Unsupported protocol version: {version}, expected {self.PROTOCOL_VERSION}" ) # 检查数据长度 expected_length = self.HEADER_SIZE + payload_length if len(data) < expected_length: raise MessageSerializationError( f"Incomplete message: expected {expected_length} bytes, got {len(data)}" ) # 解析JSON数据 json_data = data[self.HEADER_SIZE:self.HEADER_SIZE + payload_length] message_dict = json.loads(json_data.decode('utf-8')) return Message.from_dict(message_dict) except json.JSONDecodeError as e: raise MessageSerializationError(f"Failed to parse JSON: {e}") except (KeyError, ValueError) as e: raise MessageSerializationError(f"Invalid message format: {e}") except struct.error as e: raise MessageSerializationError(f"Failed to unpack header: {e}") def validate(self, message: Message) -> bool: """ 验证消息完整性 检查项: - 校验和是否正确 - 必要字段是否存在 - 消息类型是否有效 Args: message: 要验证的消息对象 Returns: 验证通过返回True,否则返回False Raises: MessageValidationError: 验证失败时抛出(包含详细错误信息) """ errors = [] # 检查必要字段 if not message.sender_id: errors.append("sender_id is required") if not message.receiver_id: errors.append("receiver_id is required") if message.timestamp <= 0: errors.append("timestamp must be positive") # 检查消息类型 if not isinstance(message.msg_type, MessageType): errors.append(f"Invalid message type: {message.msg_type}") # 对于文件传输和语音通话相关消息,跳过校验和验证(因为payload在序列化过程中会变化) skip_checksum_types = { MessageType.FILE_REQUEST, MessageType.FILE_CHUNK, MessageType.FILE_COMPLETE, MessageType.IMAGE, MessageType.VOICE_CALL_REQUEST, MessageType.VOICE_CALL_ACCEPT, MessageType.VOICE_CALL_REJECT, MessageType.VOICE_CALL_END, MessageType.VOICE_DATA, MessageType.AUDIO_STREAM, MessageType.VIDEO_STREAM, } # 验证校验和(跳过文件传输消息) if message.msg_type not in skip_checksum_types: if not message.verify_checksum(): errors.append("Checksum verification failed") if errors: raise MessageValidationError("; ".join(errors)) return True def register_handler(self, msg_type: MessageType, handler: MessageCallback) -> None: """ 注册消息处理器 Args: msg_type: 消息类型 handler: 处理回调函数 """ self._handlers[msg_type] = handler logger.debug(f"Registered handler for message type: {msg_type.value}") def unregister_handler(self, msg_type: MessageType) -> None: """ 注销消息处理器 Args: msg_type: 消息类型 """ if msg_type in self._handlers: del self._handlers[msg_type] logger.debug(f"Unregistered handler for message type: {msg_type.value}") def set_default_handler(self, handler: Optional[MessageCallback]) -> None: """ 设置默认处理器(用于未注册类型的消息) Args: handler: 默认处理回调函数,None表示清除 """ self._default_handler = handler def route(self, message: Message) -> None: """ 路由消息到对应处理器 根据消息类型将消息分发到已注册的处理器。 如果没有找到对应的处理器,则使用默认处理器。 如果没有默认处理器,则抛出异常。 Args: message: 要路由的消息 Raises: MessageRoutingError: 找不到处理器时抛出 MessageValidationError: 消息验证失败时抛出 """ # 先验证消息 self.validate(message) # 查找处理器 handler = self._handlers.get(message.msg_type) if handler is None: handler = self._default_handler if handler is None: raise MessageRoutingError( f"No handler registered for message type: {message.msg_type.value}" ) # 调用处理器 try: handler(message) logger.debug(f"Routed message {message.message_id} to handler for {message.msg_type.value}") except Exception as e: logger.error(f"Handler error for message {message.message_id}: {e}") raise def has_handler(self, msg_type: MessageType) -> bool: """ 检查是否有指定类型的处理器 Args: msg_type: 消息类型 Returns: 如果有处理器返回True,否则返回False """ return msg_type in self._handlers or self._default_handler is not None def get_registered_types(self) -> list: """ 获取所有已注册处理器的消息类型 Returns: 消息类型列表 """ return list(self._handlers.keys())