You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

366 lines
12 KiB

# P2P Network Communication - Relay Server Tests
"""
中转服务器单元测试
测试服务器核心功能、用户管理和消息转发
"""
import asyncio
import pytest
import time
import json
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime
from server.relay_server import (
RelayServer, ClientConnection, ServerError
)
from shared.models import (
Message, MessageType, UserInfo, UserStatus
)
from config import ServerConfig
class TestRelayServerInit:
"""测试服务器初始化"""
def test_init_with_default_config(self):
"""测试使用默认配置初始化"""
server = RelayServer()
assert server.host == "0.0.0.0"
assert server.port == 8888
assert server.max_connections == 1000
assert not server.is_running
assert server.connection_count == 0
def test_init_with_custom_config(self):
"""测试使用自定义配置初始化"""
config = ServerConfig(
host="127.0.0.1",
port=9999,
max_connections=500
)
server = RelayServer(config)
assert server.host == "127.0.0.1"
assert server.port == 9999
assert server.max_connections == 500
class TestUserManagement:
"""测试用户管理功能"""
@pytest.fixture
def server(self):
"""创建服务器实例"""
return RelayServer()
@pytest.fixture
def mock_connection(self):
"""创建模拟连接"""
reader = AsyncMock()
writer = AsyncMock()
writer.close = MagicMock()
writer.wait_closed = AsyncMock()
return ClientConnection(
user_id="test_user",
reader=reader,
writer=writer,
ip_address="192.168.1.100",
port=12345
)
@pytest.fixture
def user_info(self):
"""创建用户信息"""
return UserInfo(
user_id="test_user",
username="testuser",
display_name="Test User",
status=UserStatus.ONLINE
)
@pytest.mark.asyncio
async def test_register_user(self, server, mock_connection, user_info):
"""测试用户注册"""
success = await server._register_user_async(
"test_user", user_info, mock_connection
)
assert success
assert server.is_user_online("test_user")
assert server.connection_count == 1
@pytest.mark.asyncio
async def test_register_user_reconnect(self, server, mock_connection, user_info):
"""测试用户重连(替换旧连接)"""
# 第一次注册
await server._register_user_async("test_user", user_info, mock_connection)
# 创建新连接
new_reader = AsyncMock()
new_writer = AsyncMock()
new_writer.close = MagicMock()
new_writer.wait_closed = AsyncMock()
new_connection = ClientConnection(
user_id="test_user",
reader=new_reader,
writer=new_writer,
ip_address="192.168.1.101",
port=12346
)
# 重新注册
success = await server._register_user_async("test_user", user_info, new_connection)
assert success
assert server.connection_count == 1 # 仍然只有一个连接
# 旧连接应该被关闭
mock_connection.writer.close.assert_called()
@pytest.mark.asyncio
async def test_unregister_user(self, server, mock_connection, user_info):
"""测试用户注销"""
await server._register_user_async("test_user", user_info, mock_connection)
assert server.is_user_online("test_user")
await server._unregister_user_async("test_user")
assert not server.is_user_online("test_user")
assert server.connection_count == 0
@pytest.mark.asyncio
async def test_get_online_users(self, server, mock_connection, user_info):
"""测试获取在线用户列表"""
# 初始为空
assert len(server.get_online_users()) == 0
# 注册用户
await server._register_user_async("test_user", user_info, mock_connection)
online_users = server.get_online_users()
assert len(online_users) == 1
assert online_users[0].user_id == "test_user"
@pytest.mark.asyncio
async def test_get_user_info(self, server, mock_connection, user_info):
"""测试获取用户信息"""
# 用户不存在
assert server.get_user_info("test_user") is None
# 注册用户
await server._register_user_async("test_user", user_info, mock_connection)
info = server.get_user_info("test_user")
assert info is not None
assert info.user_id == "test_user"
assert info.username == "testuser"
class TestMessageRelay:
"""测试消息转发功能"""
@pytest.fixture
def server(self):
"""创建服务器实例"""
return RelayServer()
@pytest.fixture
def create_mock_connection(self):
"""创建模拟连接的工厂函数"""
def _create(user_id):
reader = AsyncMock()
writer = AsyncMock()
writer.close = MagicMock()
writer.wait_closed = AsyncMock()
writer.write = MagicMock()
writer.drain = AsyncMock()
return ClientConnection(
user_id=user_id,
reader=reader,
writer=writer,
ip_address="192.168.1.100",
port=12345
)
return _create
@pytest.mark.asyncio
async def test_relay_message_to_online_user(self, server, create_mock_connection):
"""测试转发消息给在线用户"""
# 注册发送者和接收者
sender_conn = create_mock_connection("sender")
receiver_conn = create_mock_connection("receiver")
sender_info = UserInfo(
user_id="sender", username="sender",
display_name="Sender", status=UserStatus.ONLINE
)
receiver_info = UserInfo(
user_id="receiver", username="receiver",
display_name="Receiver", status=UserStatus.ONLINE
)
await server._register_user_async("sender", sender_info, sender_conn)
await server._register_user_async("receiver", receiver_info, receiver_conn)
# 创建消息
message = Message(
msg_type=MessageType.TEXT,
sender_id="sender",
receiver_id="receiver",
timestamp=time.time(),
payload=b"Hello, receiver!"
)
# 转发消息
success = await server.relay_message(message)
assert success
# 验证消息被发送到接收者
receiver_conn.writer.write.assert_called()
@pytest.mark.asyncio
async def test_relay_message_to_offline_user(self, server, create_mock_connection):
"""测试转发消息给离线用户(缓存)"""
# 只注册发送者
sender_conn = create_mock_connection("sender")
sender_info = UserInfo(
user_id="sender", username="sender",
display_name="Sender", status=UserStatus.ONLINE
)
await server._register_user_async("sender", sender_info, sender_conn)
# 创建消息
message = Message(
msg_type=MessageType.TEXT,
sender_id="sender",
receiver_id="offline_user",
timestamp=time.time(),
payload=b"Hello, offline user!"
)
# 转发消息(应该被缓存)
success = await server.relay_message(message)
assert success
# 验证消息被缓存
assert server.get_offline_message_count("offline_user") == 1
def test_cache_offline_message(self, server):
"""测试离线消息缓存"""
message = Message(
msg_type=MessageType.TEXT,
sender_id="sender",
receiver_id="receiver",
timestamp=time.time(),
payload=b"Test message"
)
server.cache_offline_message("receiver", message)
assert server.get_offline_message_count("receiver") == 1
# 缓存多条消息
server.cache_offline_message("receiver", message)
server.cache_offline_message("receiver", message)
assert server.get_offline_message_count("receiver") == 3
def test_clear_offline_messages(self, server):
"""测试清除离线消息"""
message = Message(
msg_type=MessageType.TEXT,
sender_id="sender",
receiver_id="receiver",
timestamp=time.time(),
payload=b"Test message"
)
server.cache_offline_message("receiver", message)
assert server.get_offline_message_count("receiver") == 1
server.clear_offline_messages("receiver")
assert server.get_offline_message_count("receiver") == 0
@pytest.mark.asyncio
async def test_deliver_offline_messages(self, server, create_mock_connection):
"""测试投递离线消息"""
# 缓存离线消息
message1 = Message(
msg_type=MessageType.TEXT,
sender_id="sender",
receiver_id="receiver",
timestamp=time.time(),
payload=b"Message 1"
)
message2 = Message(
msg_type=MessageType.TEXT,
sender_id="sender",
receiver_id="receiver",
timestamp=time.time() + 1,
payload=b"Message 2"
)
server.cache_offline_message("receiver", message1)
server.cache_offline_message("receiver", message2)
assert server.get_offline_message_count("receiver") == 2
# 用户上线
receiver_conn = create_mock_connection("receiver")
receiver_info = UserInfo(
user_id="receiver", username="receiver",
display_name="Receiver", status=UserStatus.ONLINE
)
await server._register_user_async("receiver", receiver_info, receiver_conn)
# 投递离线消息
await server._deliver_offline_messages("receiver")
# 验证消息被投递
assert receiver_conn.writer.write.call_count >= 2
assert server.get_offline_message_count("receiver") == 0
class TestHelperMethods:
"""测试辅助方法"""
@pytest.fixture
def server(self):
"""创建服务器实例"""
return RelayServer()
def test_create_error_message(self, server):
"""测试创建错误消息"""
error_msg = server._create_error_message(
"server", "user1", "Test error"
)
assert error_msg.msg_type == MessageType.ERROR
assert error_msg.sender_id == "server"
assert error_msg.receiver_id == "user1"
assert error_msg.payload == b"Test error"
def test_create_ack_message(self, server):
"""测试创建确认消息"""
ack_msg = server._create_ack_message(
"server", "user1", "Test ack"
)
assert ack_msg.msg_type == MessageType.ACK
assert ack_msg.sender_id == "server"
assert ack_msg.receiver_id == "user1"
assert ack_msg.payload == b"Test ack"
class TestClientConnection:
"""测试客户端连接类"""
def test_connection_is_alive(self):
"""测试连接存活检查"""
reader = AsyncMock()
writer = AsyncMock()
conn = ClientConnection(
user_id="test",
reader=reader,
writer=writer,
last_heartbeat=time.time()
)
assert conn.is_alive
# 模拟超时
conn.last_heartbeat = time.time() - 120 # 2分钟前
assert not conn.is_alive