You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
212 lines
6.6 KiB
212 lines
6.6 KiB
# P2P Network Communication - Message Handler Tests
|
|
"""
|
|
测试消息处理器的基本功能
|
|
"""
|
|
|
|
import pytest
|
|
from shared.models import Message, MessageType
|
|
from shared.message_handler import (
|
|
MessageHandler,
|
|
MessageValidationError,
|
|
MessageSerializationError,
|
|
MessageRoutingError
|
|
)
|
|
|
|
|
|
class TestMessageHandlerSerialization:
|
|
"""测试消息处理器的序列化功能"""
|
|
|
|
def setup_method(self):
|
|
"""每个测试方法前初始化"""
|
|
self.handler = MessageHandler()
|
|
|
|
def test_serialize_text_message(self):
|
|
"""测试文本消息序列化"""
|
|
msg = Message(
|
|
msg_type=MessageType.TEXT,
|
|
sender_id="user1",
|
|
receiver_id="user2",
|
|
timestamp=1234567890.0,
|
|
payload=b"Hello, World!"
|
|
)
|
|
|
|
serialized = self.handler.serialize(msg)
|
|
|
|
assert isinstance(serialized, bytes)
|
|
assert len(serialized) > MessageHandler.HEADER_SIZE
|
|
|
|
def test_deserialize_text_message(self):
|
|
"""测试文本消息反序列化"""
|
|
original = Message(
|
|
msg_type=MessageType.TEXT,
|
|
sender_id="sender",
|
|
receiver_id="receiver",
|
|
timestamp=9876543210.0,
|
|
payload=b"Test payload"
|
|
)
|
|
|
|
serialized = self.handler.serialize(original)
|
|
restored = self.handler.deserialize(serialized)
|
|
|
|
assert restored.msg_type == original.msg_type
|
|
assert restored.sender_id == original.sender_id
|
|
assert restored.receiver_id == original.receiver_id
|
|
assert restored.timestamp == original.timestamp
|
|
assert restored.payload == original.payload
|
|
|
|
def test_deserialize_short_data_raises_error(self):
|
|
"""测试数据过短时抛出错误"""
|
|
with pytest.raises(MessageSerializationError):
|
|
self.handler.deserialize(b"short")
|
|
|
|
def test_deserialize_invalid_json_raises_error(self):
|
|
"""测试无效JSON时抛出错误"""
|
|
import struct
|
|
# 创建有效头部但无效JSON的数据
|
|
invalid_json = b"not valid json"
|
|
header = struct.pack("!II", len(invalid_json), 1)
|
|
|
|
with pytest.raises(MessageSerializationError):
|
|
self.handler.deserialize(header + invalid_json)
|
|
|
|
|
|
class TestMessageHandlerValidation:
|
|
"""测试消息处理器的验证功能"""
|
|
|
|
def setup_method(self):
|
|
"""每个测试方法前初始化"""
|
|
self.handler = MessageHandler()
|
|
|
|
def test_validate_valid_message(self):
|
|
"""测试验证有效消息"""
|
|
msg = Message(
|
|
msg_type=MessageType.TEXT,
|
|
sender_id="user1",
|
|
receiver_id="user2",
|
|
timestamp=1234567890.0,
|
|
payload=b"Valid message"
|
|
)
|
|
|
|
assert self.handler.validate(msg) is True
|
|
|
|
def test_validate_empty_sender_raises_error(self):
|
|
"""测试空发送者ID时抛出错误"""
|
|
msg = Message(
|
|
msg_type=MessageType.TEXT,
|
|
sender_id="",
|
|
receiver_id="user2",
|
|
timestamp=1234567890.0,
|
|
payload=b"Test"
|
|
)
|
|
|
|
with pytest.raises(MessageValidationError):
|
|
self.handler.validate(msg)
|
|
|
|
def test_validate_empty_receiver_raises_error(self):
|
|
"""测试空接收者ID时抛出错误"""
|
|
msg = Message(
|
|
msg_type=MessageType.TEXT,
|
|
sender_id="user1",
|
|
receiver_id="",
|
|
timestamp=1234567890.0,
|
|
payload=b"Test"
|
|
)
|
|
|
|
with pytest.raises(MessageValidationError):
|
|
self.handler.validate(msg)
|
|
|
|
def test_validate_invalid_timestamp_raises_error(self):
|
|
"""测试无效时间戳时抛出错误"""
|
|
msg = Message(
|
|
msg_type=MessageType.TEXT,
|
|
sender_id="user1",
|
|
receiver_id="user2",
|
|
timestamp=-1.0,
|
|
payload=b"Test"
|
|
)
|
|
|
|
with pytest.raises(MessageValidationError):
|
|
self.handler.validate(msg)
|
|
|
|
|
|
class TestMessageHandlerRouting:
|
|
"""测试消息处理器的路由功能"""
|
|
|
|
def setup_method(self):
|
|
"""每个测试方法前初始化"""
|
|
self.handler = MessageHandler()
|
|
self.received_messages = []
|
|
|
|
def test_register_and_route_handler(self):
|
|
"""测试注册和路由处理器"""
|
|
def text_handler(msg):
|
|
self.received_messages.append(msg)
|
|
|
|
self.handler.register_handler(MessageType.TEXT, text_handler)
|
|
|
|
msg = Message(
|
|
msg_type=MessageType.TEXT,
|
|
sender_id="user1",
|
|
receiver_id="user2",
|
|
timestamp=1234567890.0,
|
|
payload=b"Routed message"
|
|
)
|
|
|
|
self.handler.route(msg)
|
|
|
|
assert len(self.received_messages) == 1
|
|
assert self.received_messages[0].payload == b"Routed message"
|
|
|
|
def test_route_without_handler_raises_error(self):
|
|
"""测试没有处理器时路由抛出错误"""
|
|
msg = Message(
|
|
msg_type=MessageType.FILE_REQUEST,
|
|
sender_id="user1",
|
|
receiver_id="user2",
|
|
timestamp=1234567890.0,
|
|
payload=b"File request"
|
|
)
|
|
|
|
with pytest.raises(MessageRoutingError):
|
|
self.handler.route(msg)
|
|
|
|
def test_default_handler(self):
|
|
"""测试默认处理器"""
|
|
def default_handler(msg):
|
|
self.received_messages.append(msg)
|
|
|
|
self.handler.set_default_handler(default_handler)
|
|
|
|
msg = Message(
|
|
msg_type=MessageType.HEARTBEAT,
|
|
sender_id="user1",
|
|
receiver_id="user2",
|
|
timestamp=1234567890.0,
|
|
payload=b""
|
|
)
|
|
|
|
self.handler.route(msg)
|
|
|
|
assert len(self.received_messages) == 1
|
|
|
|
def test_unregister_handler(self):
|
|
"""测试注销处理器"""
|
|
def handler(msg):
|
|
self.received_messages.append(msg)
|
|
|
|
self.handler.register_handler(MessageType.TEXT, handler)
|
|
self.handler.unregister_handler(MessageType.TEXT)
|
|
|
|
assert not self.handler.has_handler(MessageType.TEXT)
|
|
|
|
def test_get_registered_types(self):
|
|
"""测试获取已注册的消息类型"""
|
|
self.handler.register_handler(MessageType.TEXT, lambda m: None)
|
|
self.handler.register_handler(MessageType.FILE_REQUEST, lambda m: None)
|
|
|
|
registered = self.handler.get_registered_types()
|
|
|
|
assert MessageType.TEXT in registered
|
|
assert MessageType.FILE_REQUEST in registered
|
|
assert len(registered) == 2
|