You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

332 lines
9.7 KiB

# P2P Network Communication - Connection Manager Tests
"""
测试客户端连接管理器的基本功能
需求: 1.1, 1.2, 1.3, 1.4, 1.5, 1.6
"""
import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime
from client.connection_manager import (
ConnectionManager,
ConnectionState,
ConnectionError,
Connection,
)
from shared.models import (
Message, MessageType, UserInfo, UserStatus,
ConnectionMode, PeerInfo
)
from config import ClientConfig
class TestConnectionManagerInit:
"""测试连接管理器初始化"""
def test_init_with_default_config(self):
"""测试使用默认配置初始化"""
cm = ConnectionManager()
assert cm.config is not None
assert cm.state == ConnectionState.DISCONNECTED
assert cm.is_connected is False
assert cm.user_id == ""
def test_init_with_custom_config(self):
"""测试使用自定义配置初始化"""
config = ClientConfig(
server_host="192.168.1.100",
server_port=9999,
heartbeat_interval=60
)
cm = ConnectionManager(config)
assert cm.config.server_host == "192.168.1.100"
assert cm.config.server_port == 9999
assert cm.config.heartbeat_interval == 60
class TestConnectionState:
"""测试连接状态管理"""
def test_initial_state_is_disconnected(self):
"""测试初始状态为断开"""
cm = ConnectionManager()
assert cm.state == ConnectionState.DISCONNECTED
def test_state_callback_registration(self):
"""测试状态回调注册"""
cm = ConnectionManager()
callback_called = []
def state_callback(state, reason):
callback_called.append((state, reason))
cm.add_state_callback(state_callback)
cm._set_state(ConnectionState.CONNECTING, "Test")
assert len(callback_called) == 1
assert callback_called[0][0] == ConnectionState.CONNECTING
assert callback_called[0][1] == "Test"
def test_state_callback_removal(self):
"""测试状态回调移除"""
cm = ConnectionManager()
callback_called = []
def state_callback(state, reason):
callback_called.append((state, reason))
cm.add_state_callback(state_callback)
cm.remove_state_callback(state_callback)
cm._set_state(ConnectionState.CONNECTING, "Test")
assert len(callback_called) == 0
class TestConnectionMode:
"""测试连接模式选择 (需求 1.1, 1.2, 1.3)"""
def test_default_mode_is_relay(self):
"""测试默认模式为中转"""
cm = ConnectionManager()
mode = cm.get_connection_mode("unknown_peer")
assert mode == ConnectionMode.RELAY
def test_mode_is_p2p_for_discovered_peer(self):
"""测试已发现的LAN对等端使用P2P模式"""
cm = ConnectionManager()
# 模拟发现对等端
peer_info = PeerInfo(
peer_id="peer1",
username="test_peer",
ip_address="192.168.1.100",
port=8889
)
cm._discovered_peers["peer1"] = peer_info
mode = cm.get_connection_mode("peer1")
assert mode == ConnectionMode.P2P
def test_mode_is_p2p_for_connected_peer(self):
"""测试已连接的P2P对等端返回P2P模式"""
cm = ConnectionManager()
# 模拟P2P连接
conn = Connection(
peer_id="peer1",
reader=MagicMock(),
writer=MagicMock(),
mode=ConnectionMode.P2P,
ip_address="192.168.1.100",
port=8889
)
cm._peer_connections["peer1"] = conn
mode = cm.get_connection_mode("peer1")
assert mode == ConnectionMode.P2P
class TestReconnectMechanism:
"""测试网络重连机制 (需求 1.6)"""
def test_reconnect_status_initial(self):
"""测试初始重连状态"""
cm = ConnectionManager()
status = cm.get_reconnect_status()
assert status["is_reconnecting"] is False
assert status["reconnect_attempts"] == 0
assert status["auto_reconnect_enabled"] is True
def test_enable_disable_reconnect(self):
"""测试启用/禁用自动重连"""
cm = ConnectionManager()
cm.enable_reconnect(False)
assert cm._should_reconnect is False
cm.enable_reconnect(True)
assert cm._should_reconnect is True
def test_set_reconnect_config(self):
"""测试配置重连参数"""
cm = ConnectionManager()
cm.set_reconnect_config(max_attempts=5, base_delay=2.0)
assert cm.config.reconnect_attempts == 5
assert cm.config.reconnect_delay == 2.0
class TestMessageCallbacks:
"""测试消息回调"""
def test_message_callback_registration(self):
"""测试消息回调注册"""
cm = ConnectionManager()
messages_received = []
def msg_callback(msg):
messages_received.append(msg)
cm.add_message_callback(msg_callback)
# 模拟接收消息
test_msg = Message(
msg_type=MessageType.TEXT,
sender_id="user1",
receiver_id="user2",
timestamp=1234567890.0,
payload=b"Test"
)
# 直接调用回调
for callback in cm._message_callbacks:
callback(test_msg)
assert len(messages_received) == 1
assert messages_received[0].payload == b"Test"
def test_message_callback_removal(self):
"""测试消息回调移除"""
cm = ConnectionManager()
messages_received = []
def msg_callback(msg):
messages_received.append(msg)
cm.add_message_callback(msg_callback)
cm.remove_message_callback(msg_callback)
assert len(cm._message_callbacks) == 0
class TestConnectionStats:
"""测试连接统计"""
def test_connection_stats_initial(self):
"""测试初始连接统计"""
cm = ConnectionManager()
stats = cm.get_connection_stats()
assert stats["server_connected"] is False
assert stats["state"] == "disconnected"
assert stats["total_peer_connections"] == 0
assert stats["p2p_connections"] == 0
assert stats["discovered_peers"] == 0
def test_connection_stats_with_peers(self):
"""测试有对等端时的连接统计"""
cm = ConnectionManager()
# 添加发现的对等端
cm._discovered_peers["peer1"] = PeerInfo(
peer_id="peer1",
username="test",
ip_address="192.168.1.100",
port=8889
)
# 添加P2P连接
cm._peer_connections["peer2"] = Connection(
peer_id="peer2",
reader=MagicMock(),
writer=MagicMock(),
mode=ConnectionMode.P2P,
ip_address="192.168.1.101",
port=8889
)
stats = cm.get_connection_stats()
assert stats["total_peer_connections"] == 1
assert stats["p2p_connections"] == 1
assert stats["discovered_peers"] == 1
class TestPeerConnectionInfo:
"""测试对等端连接信息"""
def test_get_peer_info_for_connected_peer(self):
"""测试获取已连接对等端信息"""
cm = ConnectionManager()
conn = Connection(
peer_id="peer1",
reader=MagicMock(),
writer=MagicMock(),
mode=ConnectionMode.P2P,
ip_address="192.168.1.100",
port=8889
)
cm._peer_connections["peer1"] = conn
info = cm.get_peer_connection_info("peer1")
assert info is not None
assert info["peer_id"] == "peer1"
assert info["mode"] == "p2p"
assert info["ip_address"] == "192.168.1.100"
def test_get_peer_info_for_discovered_peer(self):
"""测试获取已发现对等端信息"""
cm = ConnectionManager()
cm._discovered_peers["peer1"] = PeerInfo(
peer_id="peer1",
username="test",
ip_address="192.168.1.100",
port=8889
)
info = cm.get_peer_connection_info("peer1")
assert info is not None
assert info["peer_id"] == "peer1"
assert info["mode"] == "discovered"
def test_get_peer_info_for_unknown_peer(self):
"""测试获取未知对等端信息"""
cm = ConnectionManager()
info = cm.get_peer_connection_info("unknown")
assert info is None
class TestAllConnectionModes:
"""测试获取所有连接模式"""
def test_get_all_connection_modes(self):
"""测试获取所有对等端的连接模式"""
cm = ConnectionManager()
# 添加P2P连接
cm._peer_connections["peer1"] = Connection(
peer_id="peer1",
reader=MagicMock(),
writer=MagicMock(),
mode=ConnectionMode.P2P,
ip_address="192.168.1.100",
port=8889
)
# 添加发现的对等端
cm._discovered_peers["peer2"] = PeerInfo(
peer_id="peer2",
username="test",
ip_address="192.168.1.101",
port=8889
)
modes = cm.get_all_connection_modes()
assert modes["peer1"] == ConnectionMode.P2P
assert modes["peer2"] == ConnectionMode.P2P