|
|
# P2P Network Communication - Message Handler
|
|
|
"""
|
|
|
消息处理器模块
|
|
|
负责消息的序列化、反序列化、校验和路由
|
|
|
"""
|
|
|
|
|
|
import json
|
|
|
import struct
|
|
|
import hashlib
|
|
|
import logging
|
|
|
from typing import Callable, Dict, Optional, Any
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
from shared.models import Message, MessageType
|
|
|
|
|
|
|
|
|
# 设置日志
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class MessageValidationError(Exception):
|
|
|
"""消息验证错误"""
|
|
|
pass
|
|
|
|
|
|
|
|
|
class MessageSerializationError(Exception):
|
|
|
"""消息序列化错误"""
|
|
|
pass
|
|
|
|
|
|
|
|
|
class MessageRoutingError(Exception):
|
|
|
"""消息路由错误"""
|
|
|
pass
|
|
|
|
|
|
|
|
|
# 消息处理回调类型
|
|
|
MessageCallback = Callable[[Message], None]
|
|
|
|
|
|
|
|
|
class MessageHandler:
|
|
|
"""
|
|
|
消息处理器
|
|
|
|
|
|
负责处理和路由各类消息,包括:
|
|
|
- 消息序列化/反序列化
|
|
|
- 消息校验
|
|
|
- 消息路由到对应处理器
|
|
|
|
|
|
需求: 3.1, 3.2
|
|
|
"""
|
|
|
|
|
|
# 消息头格式: 4字节长度 + 4字节版本
|
|
|
HEADER_FORMAT = "!II" # Network byte order, 2 unsigned ints
|
|
|
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
|
|
|
PROTOCOL_VERSION = 1
|
|
|
|
|
|
def __init__(self):
|
|
|
"""初始化消息处理器"""
|
|
|
self._handlers: Dict[MessageType, MessageCallback] = {}
|
|
|
self._default_handler: Optional[MessageCallback] = None
|
|
|
|
|
|
def serialize(self, message: Message) -> bytes:
|
|
|
"""
|
|
|
序列化消息为字节流
|
|
|
|
|
|
格式: [4字节长度][4字节版本][JSON数据]
|
|
|
|
|
|
Args:
|
|
|
message: 要序列化的消息对象
|
|
|
|
|
|
Returns:
|
|
|
序列化后的字节流
|
|
|
|
|
|
Raises:
|
|
|
MessageSerializationError: 序列化失败时抛出
|
|
|
"""
|
|
|
try:
|
|
|
# 将消息转换为JSON字节流
|
|
|
json_data = json.dumps(message.to_dict(), ensure_ascii=False).encode('utf-8')
|
|
|
|
|
|
# 构建消息头
|
|
|
header = struct.pack(
|
|
|
self.HEADER_FORMAT,
|
|
|
len(json_data),
|
|
|
self.PROTOCOL_VERSION
|
|
|
)
|
|
|
|
|
|
return header + json_data
|
|
|
|
|
|
except (TypeError, ValueError, json.JSONEncodeError) as e:
|
|
|
raise MessageSerializationError(f"Failed to serialize message: {e}")
|
|
|
|
|
|
def deserialize(self, data: bytes) -> Message:
|
|
|
"""
|
|
|
反序列化字节流为消息
|
|
|
|
|
|
Args:
|
|
|
data: 包含消息头和消息体的字节流
|
|
|
|
|
|
Returns:
|
|
|
反序列化后的消息对象
|
|
|
|
|
|
Raises:
|
|
|
MessageSerializationError: 反序列化失败时抛出
|
|
|
"""
|
|
|
try:
|
|
|
if len(data) < self.HEADER_SIZE:
|
|
|
raise MessageSerializationError(
|
|
|
f"Data too short: expected at least {self.HEADER_SIZE} bytes, got {len(data)}"
|
|
|
)
|
|
|
|
|
|
# 解析消息头
|
|
|
header = data[:self.HEADER_SIZE]
|
|
|
payload_length, version = struct.unpack(self.HEADER_FORMAT, header)
|
|
|
|
|
|
# 版本检查
|
|
|
if version != self.PROTOCOL_VERSION:
|
|
|
raise MessageSerializationError(
|
|
|
f"Unsupported protocol version: {version}, expected {self.PROTOCOL_VERSION}"
|
|
|
)
|
|
|
|
|
|
# 检查数据长度
|
|
|
expected_length = self.HEADER_SIZE + payload_length
|
|
|
if len(data) < expected_length:
|
|
|
raise MessageSerializationError(
|
|
|
f"Incomplete message: expected {expected_length} bytes, got {len(data)}"
|
|
|
)
|
|
|
|
|
|
# 解析JSON数据
|
|
|
json_data = data[self.HEADER_SIZE:self.HEADER_SIZE + payload_length]
|
|
|
message_dict = json.loads(json_data.decode('utf-8'))
|
|
|
|
|
|
return Message.from_dict(message_dict)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
raise MessageSerializationError(f"Failed to parse JSON: {e}")
|
|
|
except (KeyError, ValueError) as e:
|
|
|
raise MessageSerializationError(f"Invalid message format: {e}")
|
|
|
except struct.error as e:
|
|
|
raise MessageSerializationError(f"Failed to unpack header: {e}")
|
|
|
|
|
|
def validate(self, message: Message) -> bool:
|
|
|
"""
|
|
|
验证消息完整性
|
|
|
|
|
|
检查项:
|
|
|
- 校验和是否正确
|
|
|
- 必要字段是否存在
|
|
|
- 消息类型是否有效
|
|
|
|
|
|
Args:
|
|
|
message: 要验证的消息对象
|
|
|
|
|
|
Returns:
|
|
|
验证通过返回True,否则返回False
|
|
|
|
|
|
Raises:
|
|
|
MessageValidationError: 验证失败时抛出(包含详细错误信息)
|
|
|
"""
|
|
|
errors = []
|
|
|
|
|
|
# 检查必要字段
|
|
|
if not message.sender_id:
|
|
|
errors.append("sender_id is required")
|
|
|
|
|
|
if not message.receiver_id:
|
|
|
errors.append("receiver_id is required")
|
|
|
|
|
|
if message.timestamp <= 0:
|
|
|
errors.append("timestamp must be positive")
|
|
|
|
|
|
# 检查消息类型
|
|
|
if not isinstance(message.msg_type, MessageType):
|
|
|
errors.append(f"Invalid message type: {message.msg_type}")
|
|
|
|
|
|
# 对于文件传输和语音通话相关消息,跳过校验和验证(因为payload在序列化过程中会变化)
|
|
|
skip_checksum_types = {
|
|
|
MessageType.FILE_REQUEST,
|
|
|
MessageType.FILE_CHUNK,
|
|
|
MessageType.FILE_COMPLETE,
|
|
|
MessageType.IMAGE,
|
|
|
MessageType.VOICE_CALL_REQUEST,
|
|
|
MessageType.VOICE_CALL_ACCEPT,
|
|
|
MessageType.VOICE_CALL_REJECT,
|
|
|
MessageType.VOICE_CALL_END,
|
|
|
MessageType.VOICE_DATA,
|
|
|
MessageType.AUDIO_STREAM,
|
|
|
MessageType.VIDEO_STREAM,
|
|
|
}
|
|
|
|
|
|
# 验证校验和(跳过文件传输消息)
|
|
|
if message.msg_type not in skip_checksum_types:
|
|
|
if not message.verify_checksum():
|
|
|
errors.append("Checksum verification failed")
|
|
|
|
|
|
if errors:
|
|
|
raise MessageValidationError("; ".join(errors))
|
|
|
|
|
|
return True
|
|
|
|
|
|
def register_handler(self, msg_type: MessageType, handler: MessageCallback) -> None:
|
|
|
"""
|
|
|
注册消息处理器
|
|
|
|
|
|
Args:
|
|
|
msg_type: 消息类型
|
|
|
handler: 处理回调函数
|
|
|
"""
|
|
|
self._handlers[msg_type] = handler
|
|
|
logger.debug(f"Registered handler for message type: {msg_type.value}")
|
|
|
|
|
|
def unregister_handler(self, msg_type: MessageType) -> None:
|
|
|
"""
|
|
|
注销消息处理器
|
|
|
|
|
|
Args:
|
|
|
msg_type: 消息类型
|
|
|
"""
|
|
|
if msg_type in self._handlers:
|
|
|
del self._handlers[msg_type]
|
|
|
logger.debug(f"Unregistered handler for message type: {msg_type.value}")
|
|
|
|
|
|
def set_default_handler(self, handler: Optional[MessageCallback]) -> None:
|
|
|
"""
|
|
|
设置默认处理器(用于未注册类型的消息)
|
|
|
|
|
|
Args:
|
|
|
handler: 默认处理回调函数,None表示清除
|
|
|
"""
|
|
|
self._default_handler = handler
|
|
|
|
|
|
def route(self, message: Message) -> None:
|
|
|
"""
|
|
|
路由消息到对应处理器
|
|
|
|
|
|
根据消息类型将消息分发到已注册的处理器。
|
|
|
如果没有找到对应的处理器,则使用默认处理器。
|
|
|
如果没有默认处理器,则抛出异常。
|
|
|
|
|
|
Args:
|
|
|
message: 要路由的消息
|
|
|
|
|
|
Raises:
|
|
|
MessageRoutingError: 找不到处理器时抛出
|
|
|
MessageValidationError: 消息验证失败时抛出
|
|
|
"""
|
|
|
# 先验证消息
|
|
|
self.validate(message)
|
|
|
|
|
|
# 查找处理器
|
|
|
handler = self._handlers.get(message.msg_type)
|
|
|
|
|
|
if handler is None:
|
|
|
handler = self._default_handler
|
|
|
|
|
|
if handler is None:
|
|
|
raise MessageRoutingError(
|
|
|
f"No handler registered for message type: {message.msg_type.value}"
|
|
|
)
|
|
|
|
|
|
# 调用处理器
|
|
|
try:
|
|
|
handler(message)
|
|
|
logger.debug(f"Routed message {message.message_id} to handler for {message.msg_type.value}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Handler error for message {message.message_id}: {e}")
|
|
|
raise
|
|
|
|
|
|
def has_handler(self, msg_type: MessageType) -> bool:
|
|
|
"""
|
|
|
检查是否有指定类型的处理器
|
|
|
|
|
|
Args:
|
|
|
msg_type: 消息类型
|
|
|
|
|
|
Returns:
|
|
|
如果有处理器返回True,否则返回False
|
|
|
"""
|
|
|
return msg_type in self._handlers or self._default_handler is not None
|
|
|
|
|
|
def get_registered_types(self) -> list:
|
|
|
"""
|
|
|
获取所有已注册处理器的消息类型
|
|
|
|
|
|
Returns:
|
|
|
消息类型列表
|
|
|
"""
|
|
|
return list(self._handlers.keys())
|