You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

282 lines
8.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# P2P Network Communication - Message Handler
"""
消息处理器模块
负责消息的序列化、反序列化、校验和路由
"""
import json
import struct
import hashlib
import logging
from typing import Callable, Dict, Optional, Any
from dataclasses import dataclass
from shared.models import Message, MessageType
# 设置日志
logger = logging.getLogger(__name__)
class MessageValidationError(Exception):
"""消息验证错误"""
pass
class MessageSerializationError(Exception):
"""消息序列化错误"""
pass
class MessageRoutingError(Exception):
"""消息路由错误"""
pass
# 消息处理回调类型
MessageCallback = Callable[[Message], None]
class MessageHandler:
"""
消息处理器
负责处理和路由各类消息,包括:
- 消息序列化/反序列化
- 消息校验
- 消息路由到对应处理器
需求: 3.1, 3.2
"""
# 消息头格式: 4字节长度 + 4字节版本
HEADER_FORMAT = "!II" # Network byte order, 2 unsigned ints
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
PROTOCOL_VERSION = 1
def __init__(self):
"""初始化消息处理器"""
self._handlers: Dict[MessageType, MessageCallback] = {}
self._default_handler: Optional[MessageCallback] = None
def serialize(self, message: Message) -> bytes:
"""
序列化消息为字节流
格式: [4字节长度][4字节版本][JSON数据]
Args:
message: 要序列化的消息对象
Returns:
序列化后的字节流
Raises:
MessageSerializationError: 序列化失败时抛出
"""
try:
# 将消息转换为JSON字节流
json_data = json.dumps(message.to_dict(), ensure_ascii=False).encode('utf-8')
# 构建消息头
header = struct.pack(
self.HEADER_FORMAT,
len(json_data),
self.PROTOCOL_VERSION
)
return header + json_data
except (TypeError, ValueError, json.JSONEncodeError) as e:
raise MessageSerializationError(f"Failed to serialize message: {e}")
def deserialize(self, data: bytes) -> Message:
"""
反序列化字节流为消息
Args:
data: 包含消息头和消息体的字节流
Returns:
反序列化后的消息对象
Raises:
MessageSerializationError: 反序列化失败时抛出
"""
try:
if len(data) < self.HEADER_SIZE:
raise MessageSerializationError(
f"Data too short: expected at least {self.HEADER_SIZE} bytes, got {len(data)}"
)
# 解析消息头
header = data[:self.HEADER_SIZE]
payload_length, version = struct.unpack(self.HEADER_FORMAT, header)
# 版本检查
if version != self.PROTOCOL_VERSION:
raise MessageSerializationError(
f"Unsupported protocol version: {version}, expected {self.PROTOCOL_VERSION}"
)
# 检查数据长度
expected_length = self.HEADER_SIZE + payload_length
if len(data) < expected_length:
raise MessageSerializationError(
f"Incomplete message: expected {expected_length} bytes, got {len(data)}"
)
# 解析JSON数据
json_data = data[self.HEADER_SIZE:self.HEADER_SIZE + payload_length]
message_dict = json.loads(json_data.decode('utf-8'))
return Message.from_dict(message_dict)
except json.JSONDecodeError as e:
raise MessageSerializationError(f"Failed to parse JSON: {e}")
except (KeyError, ValueError) as e:
raise MessageSerializationError(f"Invalid message format: {e}")
except struct.error as e:
raise MessageSerializationError(f"Failed to unpack header: {e}")
def validate(self, message: Message) -> bool:
"""
验证消息完整性
检查项:
- 校验和是否正确
- 必要字段是否存在
- 消息类型是否有效
Args:
message: 要验证的消息对象
Returns:
验证通过返回True否则返回False
Raises:
MessageValidationError: 验证失败时抛出(包含详细错误信息)
"""
errors = []
# 检查必要字段
if not message.sender_id:
errors.append("sender_id is required")
if not message.receiver_id:
errors.append("receiver_id is required")
if message.timestamp <= 0:
errors.append("timestamp must be positive")
# 检查消息类型
if not isinstance(message.msg_type, MessageType):
errors.append(f"Invalid message type: {message.msg_type}")
# 对于文件传输相关消息跳过校验和验证因为payload在序列化过程中会变化
skip_checksum_types = {
MessageType.FILE_REQUEST,
MessageType.FILE_CHUNK,
MessageType.FILE_COMPLETE,
MessageType.IMAGE,
}
# 验证校验和(跳过文件传输消息)
if message.msg_type not in skip_checksum_types:
if not message.verify_checksum():
errors.append("Checksum verification failed")
if errors:
raise MessageValidationError("; ".join(errors))
return True
def register_handler(self, msg_type: MessageType, handler: MessageCallback) -> None:
"""
注册消息处理器
Args:
msg_type: 消息类型
handler: 处理回调函数
"""
self._handlers[msg_type] = handler
logger.debug(f"Registered handler for message type: {msg_type.value}")
def unregister_handler(self, msg_type: MessageType) -> None:
"""
注销消息处理器
Args:
msg_type: 消息类型
"""
if msg_type in self._handlers:
del self._handlers[msg_type]
logger.debug(f"Unregistered handler for message type: {msg_type.value}")
def set_default_handler(self, handler: Optional[MessageCallback]) -> None:
"""
设置默认处理器(用于未注册类型的消息)
Args:
handler: 默认处理回调函数None表示清除
"""
self._default_handler = handler
def route(self, message: Message) -> None:
"""
路由消息到对应处理器
根据消息类型将消息分发到已注册的处理器。
如果没有找到对应的处理器,则使用默认处理器。
如果没有默认处理器,则抛出异常。
Args:
message: 要路由的消息
Raises:
MessageRoutingError: 找不到处理器时抛出
MessageValidationError: 消息验证失败时抛出
"""
# 先验证消息
self.validate(message)
# 查找处理器
handler = self._handlers.get(message.msg_type)
if handler is None:
handler = self._default_handler
if handler is None:
raise MessageRoutingError(
f"No handler registered for message type: {message.msg_type.value}"
)
# 调用处理器
try:
handler(message)
logger.debug(f"Routed message {message.message_id} to handler for {message.msg_type.value}")
except Exception as e:
logger.error(f"Handler error for message {message.message_id}: {e}")
raise
def has_handler(self, msg_type: MessageType) -> bool:
"""
检查是否有指定类型的处理器
Args:
msg_type: 消息类型
Returns:
如果有处理器返回True否则返回False
"""
return msg_type in self._handlers or self._default_handler is not None
def get_registered_types(self) -> list:
"""
获取所有已注册处理器的消息类型
Returns:
消息类型列表
"""
return list(self._handlers.keys())