# P2P Network Communication - Message Handler Tests """ 测试消息处理器的基本功能 """ import pytest from shared.models import Message, MessageType from shared.message_handler import ( MessageHandler, MessageValidationError, MessageSerializationError, MessageRoutingError ) class TestMessageHandlerSerialization: """测试消息处理器的序列化功能""" def setup_method(self): """每个测试方法前初始化""" self.handler = MessageHandler() def test_serialize_text_message(self): """测试文本消息序列化""" msg = Message( msg_type=MessageType.TEXT, sender_id="user1", receiver_id="user2", timestamp=1234567890.0, payload=b"Hello, World!" ) serialized = self.handler.serialize(msg) assert isinstance(serialized, bytes) assert len(serialized) > MessageHandler.HEADER_SIZE def test_deserialize_text_message(self): """测试文本消息反序列化""" original = Message( msg_type=MessageType.TEXT, sender_id="sender", receiver_id="receiver", timestamp=9876543210.0, payload=b"Test payload" ) serialized = self.handler.serialize(original) restored = self.handler.deserialize(serialized) assert restored.msg_type == original.msg_type assert restored.sender_id == original.sender_id assert restored.receiver_id == original.receiver_id assert restored.timestamp == original.timestamp assert restored.payload == original.payload def test_deserialize_short_data_raises_error(self): """测试数据过短时抛出错误""" with pytest.raises(MessageSerializationError): self.handler.deserialize(b"short") def test_deserialize_invalid_json_raises_error(self): """测试无效JSON时抛出错误""" import struct # 创建有效头部但无效JSON的数据 invalid_json = b"not valid json" header = struct.pack("!II", len(invalid_json), 1) with pytest.raises(MessageSerializationError): self.handler.deserialize(header + invalid_json) class TestMessageHandlerValidation: """测试消息处理器的验证功能""" def setup_method(self): """每个测试方法前初始化""" self.handler = MessageHandler() def test_validate_valid_message(self): """测试验证有效消息""" msg = Message( msg_type=MessageType.TEXT, sender_id="user1", receiver_id="user2", timestamp=1234567890.0, payload=b"Valid message" ) assert self.handler.validate(msg) is True def test_validate_empty_sender_raises_error(self): """测试空发送者ID时抛出错误""" msg = Message( msg_type=MessageType.TEXT, sender_id="", receiver_id="user2", timestamp=1234567890.0, payload=b"Test" ) with pytest.raises(MessageValidationError): self.handler.validate(msg) def test_validate_empty_receiver_raises_error(self): """测试空接收者ID时抛出错误""" msg = Message( msg_type=MessageType.TEXT, sender_id="user1", receiver_id="", timestamp=1234567890.0, payload=b"Test" ) with pytest.raises(MessageValidationError): self.handler.validate(msg) def test_validate_invalid_timestamp_raises_error(self): """测试无效时间戳时抛出错误""" msg = Message( msg_type=MessageType.TEXT, sender_id="user1", receiver_id="user2", timestamp=-1.0, payload=b"Test" ) with pytest.raises(MessageValidationError): self.handler.validate(msg) class TestMessageHandlerRouting: """测试消息处理器的路由功能""" def setup_method(self): """每个测试方法前初始化""" self.handler = MessageHandler() self.received_messages = [] def test_register_and_route_handler(self): """测试注册和路由处理器""" def text_handler(msg): self.received_messages.append(msg) self.handler.register_handler(MessageType.TEXT, text_handler) msg = Message( msg_type=MessageType.TEXT, sender_id="user1", receiver_id="user2", timestamp=1234567890.0, payload=b"Routed message" ) self.handler.route(msg) assert len(self.received_messages) == 1 assert self.received_messages[0].payload == b"Routed message" def test_route_without_handler_raises_error(self): """测试没有处理器时路由抛出错误""" msg = Message( msg_type=MessageType.FILE_REQUEST, sender_id="user1", receiver_id="user2", timestamp=1234567890.0, payload=b"File request" ) with pytest.raises(MessageRoutingError): self.handler.route(msg) def test_default_handler(self): """测试默认处理器""" def default_handler(msg): self.received_messages.append(msg) self.handler.set_default_handler(default_handler) msg = Message( msg_type=MessageType.HEARTBEAT, sender_id="user1", receiver_id="user2", timestamp=1234567890.0, payload=b"" ) self.handler.route(msg) assert len(self.received_messages) == 1 def test_unregister_handler(self): """测试注销处理器""" def handler(msg): self.received_messages.append(msg) self.handler.register_handler(MessageType.TEXT, handler) self.handler.unregister_handler(MessageType.TEXT) assert not self.handler.has_handler(MessageType.TEXT) def test_get_registered_types(self): """测试获取已注册的消息类型""" self.handler.register_handler(MessageType.TEXT, lambda m: None) self.handler.register_handler(MessageType.FILE_REQUEST, lambda m: None) registered = self.handler.get_registered_types() assert MessageType.TEXT in registered assert MessageType.FILE_REQUEST in registered assert len(registered) == 2