parent
d6b92b67dc
commit
842aa282cc
@ -0,0 +1,365 @@
|
||||
# P2P Network Communication - Relay Server Tests
|
||||
"""
|
||||
中转服务器单元测试
|
||||
测试服务器核心功能、用户管理和消息转发
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import time
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
from server.relay_server import (
|
||||
RelayServer, ClientConnection, ServerError
|
||||
)
|
||||
from shared.models import (
|
||||
Message, MessageType, UserInfo, UserStatus
|
||||
)
|
||||
from config import ServerConfig
|
||||
|
||||
|
||||
class TestRelayServerInit:
|
||||
"""测试服务器初始化"""
|
||||
|
||||
def test_init_with_default_config(self):
|
||||
"""测试使用默认配置初始化"""
|
||||
server = RelayServer()
|
||||
assert server.host == "0.0.0.0"
|
||||
assert server.port == 8888
|
||||
assert server.max_connections == 1000
|
||||
assert not server.is_running
|
||||
assert server.connection_count == 0
|
||||
|
||||
def test_init_with_custom_config(self):
|
||||
"""测试使用自定义配置初始化"""
|
||||
config = ServerConfig(
|
||||
host="127.0.0.1",
|
||||
port=9999,
|
||||
max_connections=500
|
||||
)
|
||||
server = RelayServer(config)
|
||||
assert server.host == "127.0.0.1"
|
||||
assert server.port == 9999
|
||||
assert server.max_connections == 500
|
||||
|
||||
|
||||
class TestUserManagement:
|
||||
"""测试用户管理功能"""
|
||||
|
||||
@pytest.fixture
|
||||
def server(self):
|
||||
"""创建服务器实例"""
|
||||
return RelayServer()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connection(self):
|
||||
"""创建模拟连接"""
|
||||
reader = AsyncMock()
|
||||
writer = AsyncMock()
|
||||
writer.close = MagicMock()
|
||||
writer.wait_closed = AsyncMock()
|
||||
return ClientConnection(
|
||||
user_id="test_user",
|
||||
reader=reader,
|
||||
writer=writer,
|
||||
ip_address="192.168.1.100",
|
||||
port=12345
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def user_info(self):
|
||||
"""创建用户信息"""
|
||||
return UserInfo(
|
||||
user_id="test_user",
|
||||
username="testuser",
|
||||
display_name="Test User",
|
||||
status=UserStatus.ONLINE
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_user(self, server, mock_connection, user_info):
|
||||
"""测试用户注册"""
|
||||
success = await server._register_user_async(
|
||||
"test_user", user_info, mock_connection
|
||||
)
|
||||
assert success
|
||||
assert server.is_user_online("test_user")
|
||||
assert server.connection_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_user_reconnect(self, server, mock_connection, user_info):
|
||||
"""测试用户重连(替换旧连接)"""
|
||||
# 第一次注册
|
||||
await server._register_user_async("test_user", user_info, mock_connection)
|
||||
|
||||
# 创建新连接
|
||||
new_reader = AsyncMock()
|
||||
new_writer = AsyncMock()
|
||||
new_writer.close = MagicMock()
|
||||
new_writer.wait_closed = AsyncMock()
|
||||
new_connection = ClientConnection(
|
||||
user_id="test_user",
|
||||
reader=new_reader,
|
||||
writer=new_writer,
|
||||
ip_address="192.168.1.101",
|
||||
port=12346
|
||||
)
|
||||
|
||||
# 重新注册
|
||||
success = await server._register_user_async("test_user", user_info, new_connection)
|
||||
assert success
|
||||
assert server.connection_count == 1 # 仍然只有一个连接
|
||||
|
||||
# 旧连接应该被关闭
|
||||
mock_connection.writer.close.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_user(self, server, mock_connection, user_info):
|
||||
"""测试用户注销"""
|
||||
await server._register_user_async("test_user", user_info, mock_connection)
|
||||
assert server.is_user_online("test_user")
|
||||
|
||||
await server._unregister_user_async("test_user")
|
||||
assert not server.is_user_online("test_user")
|
||||
assert server.connection_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_online_users(self, server, mock_connection, user_info):
|
||||
"""测试获取在线用户列表"""
|
||||
# 初始为空
|
||||
assert len(server.get_online_users()) == 0
|
||||
|
||||
# 注册用户
|
||||
await server._register_user_async("test_user", user_info, mock_connection)
|
||||
|
||||
online_users = server.get_online_users()
|
||||
assert len(online_users) == 1
|
||||
assert online_users[0].user_id == "test_user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_info(self, server, mock_connection, user_info):
|
||||
"""测试获取用户信息"""
|
||||
# 用户不存在
|
||||
assert server.get_user_info("test_user") is None
|
||||
|
||||
# 注册用户
|
||||
await server._register_user_async("test_user", user_info, mock_connection)
|
||||
|
||||
info = server.get_user_info("test_user")
|
||||
assert info is not None
|
||||
assert info.user_id == "test_user"
|
||||
assert info.username == "testuser"
|
||||
|
||||
|
||||
class TestMessageRelay:
|
||||
"""测试消息转发功能"""
|
||||
|
||||
@pytest.fixture
|
||||
def server(self):
|
||||
"""创建服务器实例"""
|
||||
return RelayServer()
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_connection(self):
|
||||
"""创建模拟连接的工厂函数"""
|
||||
def _create(user_id):
|
||||
reader = AsyncMock()
|
||||
writer = AsyncMock()
|
||||
writer.close = MagicMock()
|
||||
writer.wait_closed = AsyncMock()
|
||||
writer.write = MagicMock()
|
||||
writer.drain = AsyncMock()
|
||||
return ClientConnection(
|
||||
user_id=user_id,
|
||||
reader=reader,
|
||||
writer=writer,
|
||||
ip_address="192.168.1.100",
|
||||
port=12345
|
||||
)
|
||||
return _create
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relay_message_to_online_user(self, server, create_mock_connection):
|
||||
"""测试转发消息给在线用户"""
|
||||
# 注册发送者和接收者
|
||||
sender_conn = create_mock_connection("sender")
|
||||
receiver_conn = create_mock_connection("receiver")
|
||||
|
||||
sender_info = UserInfo(
|
||||
user_id="sender", username="sender",
|
||||
display_name="Sender", status=UserStatus.ONLINE
|
||||
)
|
||||
receiver_info = UserInfo(
|
||||
user_id="receiver", username="receiver",
|
||||
display_name="Receiver", status=UserStatus.ONLINE
|
||||
)
|
||||
|
||||
await server._register_user_async("sender", sender_info, sender_conn)
|
||||
await server._register_user_async("receiver", receiver_info, receiver_conn)
|
||||
|
||||
# 创建消息
|
||||
message = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="sender",
|
||||
receiver_id="receiver",
|
||||
timestamp=time.time(),
|
||||
payload=b"Hello, receiver!"
|
||||
)
|
||||
|
||||
# 转发消息
|
||||
success = await server.relay_message(message)
|
||||
assert success
|
||||
|
||||
# 验证消息被发送到接收者
|
||||
receiver_conn.writer.write.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relay_message_to_offline_user(self, server, create_mock_connection):
|
||||
"""测试转发消息给离线用户(缓存)"""
|
||||
# 只注册发送者
|
||||
sender_conn = create_mock_connection("sender")
|
||||
sender_info = UserInfo(
|
||||
user_id="sender", username="sender",
|
||||
display_name="Sender", status=UserStatus.ONLINE
|
||||
)
|
||||
await server._register_user_async("sender", sender_info, sender_conn)
|
||||
|
||||
# 创建消息
|
||||
message = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="sender",
|
||||
receiver_id="offline_user",
|
||||
timestamp=time.time(),
|
||||
payload=b"Hello, offline user!"
|
||||
)
|
||||
|
||||
# 转发消息(应该被缓存)
|
||||
success = await server.relay_message(message)
|
||||
assert success
|
||||
|
||||
# 验证消息被缓存
|
||||
assert server.get_offline_message_count("offline_user") == 1
|
||||
|
||||
def test_cache_offline_message(self, server):
|
||||
"""测试离线消息缓存"""
|
||||
message = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="sender",
|
||||
receiver_id="receiver",
|
||||
timestamp=time.time(),
|
||||
payload=b"Test message"
|
||||
)
|
||||
|
||||
server.cache_offline_message("receiver", message)
|
||||
assert server.get_offline_message_count("receiver") == 1
|
||||
|
||||
# 缓存多条消息
|
||||
server.cache_offline_message("receiver", message)
|
||||
server.cache_offline_message("receiver", message)
|
||||
assert server.get_offline_message_count("receiver") == 3
|
||||
|
||||
def test_clear_offline_messages(self, server):
|
||||
"""测试清除离线消息"""
|
||||
message = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="sender",
|
||||
receiver_id="receiver",
|
||||
timestamp=time.time(),
|
||||
payload=b"Test message"
|
||||
)
|
||||
|
||||
server.cache_offline_message("receiver", message)
|
||||
assert server.get_offline_message_count("receiver") == 1
|
||||
|
||||
server.clear_offline_messages("receiver")
|
||||
assert server.get_offline_message_count("receiver") == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deliver_offline_messages(self, server, create_mock_connection):
|
||||
"""测试投递离线消息"""
|
||||
# 缓存离线消息
|
||||
message1 = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="sender",
|
||||
receiver_id="receiver",
|
||||
timestamp=time.time(),
|
||||
payload=b"Message 1"
|
||||
)
|
||||
message2 = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="sender",
|
||||
receiver_id="receiver",
|
||||
timestamp=time.time() + 1,
|
||||
payload=b"Message 2"
|
||||
)
|
||||
|
||||
server.cache_offline_message("receiver", message1)
|
||||
server.cache_offline_message("receiver", message2)
|
||||
assert server.get_offline_message_count("receiver") == 2
|
||||
|
||||
# 用户上线
|
||||
receiver_conn = create_mock_connection("receiver")
|
||||
receiver_info = UserInfo(
|
||||
user_id="receiver", username="receiver",
|
||||
display_name="Receiver", status=UserStatus.ONLINE
|
||||
)
|
||||
await server._register_user_async("receiver", receiver_info, receiver_conn)
|
||||
|
||||
# 投递离线消息
|
||||
await server._deliver_offline_messages("receiver")
|
||||
|
||||
# 验证消息被投递
|
||||
assert receiver_conn.writer.write.call_count >= 2
|
||||
assert server.get_offline_message_count("receiver") == 0
|
||||
|
||||
|
||||
class TestHelperMethods:
|
||||
"""测试辅助方法"""
|
||||
|
||||
@pytest.fixture
|
||||
def server(self):
|
||||
"""创建服务器实例"""
|
||||
return RelayServer()
|
||||
|
||||
def test_create_error_message(self, server):
|
||||
"""测试创建错误消息"""
|
||||
error_msg = server._create_error_message(
|
||||
"server", "user1", "Test error"
|
||||
)
|
||||
assert error_msg.msg_type == MessageType.ERROR
|
||||
assert error_msg.sender_id == "server"
|
||||
assert error_msg.receiver_id == "user1"
|
||||
assert error_msg.payload == b"Test error"
|
||||
|
||||
def test_create_ack_message(self, server):
|
||||
"""测试创建确认消息"""
|
||||
ack_msg = server._create_ack_message(
|
||||
"server", "user1", "Test ack"
|
||||
)
|
||||
assert ack_msg.msg_type == MessageType.ACK
|
||||
assert ack_msg.sender_id == "server"
|
||||
assert ack_msg.receiver_id == "user1"
|
||||
assert ack_msg.payload == b"Test ack"
|
||||
|
||||
|
||||
class TestClientConnection:
|
||||
"""测试客户端连接类"""
|
||||
|
||||
def test_connection_is_alive(self):
|
||||
"""测试连接存活检查"""
|
||||
reader = AsyncMock()
|
||||
writer = AsyncMock()
|
||||
|
||||
conn = ClientConnection(
|
||||
user_id="test",
|
||||
reader=reader,
|
||||
writer=writer,
|
||||
last_heartbeat=time.time()
|
||||
)
|
||||
assert conn.is_alive
|
||||
|
||||
# 模拟超时
|
||||
conn.last_heartbeat = time.time() - 120 # 2分钟前
|
||||
assert not conn.is_alive
|
||||
Loading…
Reference in new issue