You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

212 lines
6.6 KiB

# P2P Network Communication - Message Handler Tests
"""
测试消息处理器的基本功能
"""
import pytest
from shared.models import Message, MessageType
from shared.message_handler import (
MessageHandler,
MessageValidationError,
MessageSerializationError,
MessageRoutingError
)
class TestMessageHandlerSerialization:
"""测试消息处理器的序列化功能"""
def setup_method(self):
"""每个测试方法前初始化"""
self.handler = MessageHandler()
def test_serialize_text_message(self):
"""测试文本消息序列化"""
msg = Message(
msg_type=MessageType.TEXT,
sender_id="user1",
receiver_id="user2",
timestamp=1234567890.0,
payload=b"Hello, World!"
)
serialized = self.handler.serialize(msg)
assert isinstance(serialized, bytes)
assert len(serialized) > MessageHandler.HEADER_SIZE
def test_deserialize_text_message(self):
"""测试文本消息反序列化"""
original = Message(
msg_type=MessageType.TEXT,
sender_id="sender",
receiver_id="receiver",
timestamp=9876543210.0,
payload=b"Test payload"
)
serialized = self.handler.serialize(original)
restored = self.handler.deserialize(serialized)
assert restored.msg_type == original.msg_type
assert restored.sender_id == original.sender_id
assert restored.receiver_id == original.receiver_id
assert restored.timestamp == original.timestamp
assert restored.payload == original.payload
def test_deserialize_short_data_raises_error(self):
"""测试数据过短时抛出错误"""
with pytest.raises(MessageSerializationError):
self.handler.deserialize(b"short")
def test_deserialize_invalid_json_raises_error(self):
"""测试无效JSON时抛出错误"""
import struct
# 创建有效头部但无效JSON的数据
invalid_json = b"not valid json"
header = struct.pack("!II", len(invalid_json), 1)
with pytest.raises(MessageSerializationError):
self.handler.deserialize(header + invalid_json)
class TestMessageHandlerValidation:
"""测试消息处理器的验证功能"""
def setup_method(self):
"""每个测试方法前初始化"""
self.handler = MessageHandler()
def test_validate_valid_message(self):
"""测试验证有效消息"""
msg = Message(
msg_type=MessageType.TEXT,
sender_id="user1",
receiver_id="user2",
timestamp=1234567890.0,
payload=b"Valid message"
)
assert self.handler.validate(msg) is True
def test_validate_empty_sender_raises_error(self):
"""测试空发送者ID时抛出错误"""
msg = Message(
msg_type=MessageType.TEXT,
sender_id="",
receiver_id="user2",
timestamp=1234567890.0,
payload=b"Test"
)
with pytest.raises(MessageValidationError):
self.handler.validate(msg)
def test_validate_empty_receiver_raises_error(self):
"""测试空接收者ID时抛出错误"""
msg = Message(
msg_type=MessageType.TEXT,
sender_id="user1",
receiver_id="",
timestamp=1234567890.0,
payload=b"Test"
)
with pytest.raises(MessageValidationError):
self.handler.validate(msg)
def test_validate_invalid_timestamp_raises_error(self):
"""测试无效时间戳时抛出错误"""
msg = Message(
msg_type=MessageType.TEXT,
sender_id="user1",
receiver_id="user2",
timestamp=-1.0,
payload=b"Test"
)
with pytest.raises(MessageValidationError):
self.handler.validate(msg)
class TestMessageHandlerRouting:
"""测试消息处理器的路由功能"""
def setup_method(self):
"""每个测试方法前初始化"""
self.handler = MessageHandler()
self.received_messages = []
def test_register_and_route_handler(self):
"""测试注册和路由处理器"""
def text_handler(msg):
self.received_messages.append(msg)
self.handler.register_handler(MessageType.TEXT, text_handler)
msg = Message(
msg_type=MessageType.TEXT,
sender_id="user1",
receiver_id="user2",
timestamp=1234567890.0,
payload=b"Routed message"
)
self.handler.route(msg)
assert len(self.received_messages) == 1
assert self.received_messages[0].payload == b"Routed message"
def test_route_without_handler_raises_error(self):
"""测试没有处理器时路由抛出错误"""
msg = Message(
msg_type=MessageType.FILE_REQUEST,
sender_id="user1",
receiver_id="user2",
timestamp=1234567890.0,
payload=b"File request"
)
with pytest.raises(MessageRoutingError):
self.handler.route(msg)
def test_default_handler(self):
"""测试默认处理器"""
def default_handler(msg):
self.received_messages.append(msg)
self.handler.set_default_handler(default_handler)
msg = Message(
msg_type=MessageType.HEARTBEAT,
sender_id="user1",
receiver_id="user2",
timestamp=1234567890.0,
payload=b""
)
self.handler.route(msg)
assert len(self.received_messages) == 1
def test_unregister_handler(self):
"""测试注销处理器"""
def handler(msg):
self.received_messages.append(msg)
self.handler.register_handler(MessageType.TEXT, handler)
self.handler.unregister_handler(MessageType.TEXT)
assert not self.handler.has_handler(MessageType.TEXT)
def test_get_registered_types(self):
"""测试获取已注册的消息类型"""
self.handler.register_handler(MessageType.TEXT, lambda m: None)
self.handler.register_handler(MessageType.FILE_REQUEST, lambda m: None)
registered = self.handler.get_registered_types()
assert MessageType.TEXT in registered
assert MessageType.FILE_REQUEST in registered
assert len(registered) == 2