You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
332 lines
9.7 KiB
332 lines
9.7 KiB
# P2P Network Communication - Connection Manager Tests
|
|
"""
|
|
测试客户端连接管理器的基本功能
|
|
|
|
需求: 1.1, 1.2, 1.3, 1.4, 1.5, 1.6
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from datetime import datetime
|
|
|
|
from client.connection_manager import (
|
|
ConnectionManager,
|
|
ConnectionState,
|
|
ConnectionError,
|
|
Connection,
|
|
)
|
|
from shared.models import (
|
|
Message, MessageType, UserInfo, UserStatus,
|
|
ConnectionMode, PeerInfo
|
|
)
|
|
from config import ClientConfig
|
|
|
|
|
|
class TestConnectionManagerInit:
|
|
"""测试连接管理器初始化"""
|
|
|
|
def test_init_with_default_config(self):
|
|
"""测试使用默认配置初始化"""
|
|
cm = ConnectionManager()
|
|
|
|
assert cm.config is not None
|
|
assert cm.state == ConnectionState.DISCONNECTED
|
|
assert cm.is_connected is False
|
|
assert cm.user_id == ""
|
|
|
|
def test_init_with_custom_config(self):
|
|
"""测试使用自定义配置初始化"""
|
|
config = ClientConfig(
|
|
server_host="192.168.1.100",
|
|
server_port=9999,
|
|
heartbeat_interval=60
|
|
)
|
|
cm = ConnectionManager(config)
|
|
|
|
assert cm.config.server_host == "192.168.1.100"
|
|
assert cm.config.server_port == 9999
|
|
assert cm.config.heartbeat_interval == 60
|
|
|
|
|
|
class TestConnectionState:
|
|
"""测试连接状态管理"""
|
|
|
|
def test_initial_state_is_disconnected(self):
|
|
"""测试初始状态为断开"""
|
|
cm = ConnectionManager()
|
|
assert cm.state == ConnectionState.DISCONNECTED
|
|
|
|
def test_state_callback_registration(self):
|
|
"""测试状态回调注册"""
|
|
cm = ConnectionManager()
|
|
callback_called = []
|
|
|
|
def state_callback(state, reason):
|
|
callback_called.append((state, reason))
|
|
|
|
cm.add_state_callback(state_callback)
|
|
cm._set_state(ConnectionState.CONNECTING, "Test")
|
|
|
|
assert len(callback_called) == 1
|
|
assert callback_called[0][0] == ConnectionState.CONNECTING
|
|
assert callback_called[0][1] == "Test"
|
|
|
|
def test_state_callback_removal(self):
|
|
"""测试状态回调移除"""
|
|
cm = ConnectionManager()
|
|
callback_called = []
|
|
|
|
def state_callback(state, reason):
|
|
callback_called.append((state, reason))
|
|
|
|
cm.add_state_callback(state_callback)
|
|
cm.remove_state_callback(state_callback)
|
|
cm._set_state(ConnectionState.CONNECTING, "Test")
|
|
|
|
assert len(callback_called) == 0
|
|
|
|
|
|
class TestConnectionMode:
|
|
"""测试连接模式选择 (需求 1.1, 1.2, 1.3)"""
|
|
|
|
def test_default_mode_is_relay(self):
|
|
"""测试默认模式为中转"""
|
|
cm = ConnectionManager()
|
|
mode = cm.get_connection_mode("unknown_peer")
|
|
|
|
assert mode == ConnectionMode.RELAY
|
|
|
|
def test_mode_is_p2p_for_discovered_peer(self):
|
|
"""测试已发现的LAN对等端使用P2P模式"""
|
|
cm = ConnectionManager()
|
|
|
|
# 模拟发现对等端
|
|
peer_info = PeerInfo(
|
|
peer_id="peer1",
|
|
username="test_peer",
|
|
ip_address="192.168.1.100",
|
|
port=8889
|
|
)
|
|
cm._discovered_peers["peer1"] = peer_info
|
|
|
|
mode = cm.get_connection_mode("peer1")
|
|
assert mode == ConnectionMode.P2P
|
|
|
|
def test_mode_is_p2p_for_connected_peer(self):
|
|
"""测试已连接的P2P对等端返回P2P模式"""
|
|
cm = ConnectionManager()
|
|
|
|
# 模拟P2P连接
|
|
conn = Connection(
|
|
peer_id="peer1",
|
|
reader=MagicMock(),
|
|
writer=MagicMock(),
|
|
mode=ConnectionMode.P2P,
|
|
ip_address="192.168.1.100",
|
|
port=8889
|
|
)
|
|
cm._peer_connections["peer1"] = conn
|
|
|
|
mode = cm.get_connection_mode("peer1")
|
|
assert mode == ConnectionMode.P2P
|
|
|
|
|
|
class TestReconnectMechanism:
|
|
"""测试网络重连机制 (需求 1.6)"""
|
|
|
|
def test_reconnect_status_initial(self):
|
|
"""测试初始重连状态"""
|
|
cm = ConnectionManager()
|
|
status = cm.get_reconnect_status()
|
|
|
|
assert status["is_reconnecting"] is False
|
|
assert status["reconnect_attempts"] == 0
|
|
assert status["auto_reconnect_enabled"] is True
|
|
|
|
def test_enable_disable_reconnect(self):
|
|
"""测试启用/禁用自动重连"""
|
|
cm = ConnectionManager()
|
|
|
|
cm.enable_reconnect(False)
|
|
assert cm._should_reconnect is False
|
|
|
|
cm.enable_reconnect(True)
|
|
assert cm._should_reconnect is True
|
|
|
|
def test_set_reconnect_config(self):
|
|
"""测试配置重连参数"""
|
|
cm = ConnectionManager()
|
|
|
|
cm.set_reconnect_config(max_attempts=5, base_delay=2.0)
|
|
|
|
assert cm.config.reconnect_attempts == 5
|
|
assert cm.config.reconnect_delay == 2.0
|
|
|
|
|
|
class TestMessageCallbacks:
|
|
"""测试消息回调"""
|
|
|
|
def test_message_callback_registration(self):
|
|
"""测试消息回调注册"""
|
|
cm = ConnectionManager()
|
|
messages_received = []
|
|
|
|
def msg_callback(msg):
|
|
messages_received.append(msg)
|
|
|
|
cm.add_message_callback(msg_callback)
|
|
|
|
# 模拟接收消息
|
|
test_msg = Message(
|
|
msg_type=MessageType.TEXT,
|
|
sender_id="user1",
|
|
receiver_id="user2",
|
|
timestamp=1234567890.0,
|
|
payload=b"Test"
|
|
)
|
|
|
|
# 直接调用回调
|
|
for callback in cm._message_callbacks:
|
|
callback(test_msg)
|
|
|
|
assert len(messages_received) == 1
|
|
assert messages_received[0].payload == b"Test"
|
|
|
|
def test_message_callback_removal(self):
|
|
"""测试消息回调移除"""
|
|
cm = ConnectionManager()
|
|
messages_received = []
|
|
|
|
def msg_callback(msg):
|
|
messages_received.append(msg)
|
|
|
|
cm.add_message_callback(msg_callback)
|
|
cm.remove_message_callback(msg_callback)
|
|
|
|
assert len(cm._message_callbacks) == 0
|
|
|
|
|
|
class TestConnectionStats:
|
|
"""测试连接统计"""
|
|
|
|
def test_connection_stats_initial(self):
|
|
"""测试初始连接统计"""
|
|
cm = ConnectionManager()
|
|
stats = cm.get_connection_stats()
|
|
|
|
assert stats["server_connected"] is False
|
|
assert stats["state"] == "disconnected"
|
|
assert stats["total_peer_connections"] == 0
|
|
assert stats["p2p_connections"] == 0
|
|
assert stats["discovered_peers"] == 0
|
|
|
|
def test_connection_stats_with_peers(self):
|
|
"""测试有对等端时的连接统计"""
|
|
cm = ConnectionManager()
|
|
|
|
# 添加发现的对等端
|
|
cm._discovered_peers["peer1"] = PeerInfo(
|
|
peer_id="peer1",
|
|
username="test",
|
|
ip_address="192.168.1.100",
|
|
port=8889
|
|
)
|
|
|
|
# 添加P2P连接
|
|
cm._peer_connections["peer2"] = Connection(
|
|
peer_id="peer2",
|
|
reader=MagicMock(),
|
|
writer=MagicMock(),
|
|
mode=ConnectionMode.P2P,
|
|
ip_address="192.168.1.101",
|
|
port=8889
|
|
)
|
|
|
|
stats = cm.get_connection_stats()
|
|
|
|
assert stats["total_peer_connections"] == 1
|
|
assert stats["p2p_connections"] == 1
|
|
assert stats["discovered_peers"] == 1
|
|
|
|
|
|
class TestPeerConnectionInfo:
|
|
"""测试对等端连接信息"""
|
|
|
|
def test_get_peer_info_for_connected_peer(self):
|
|
"""测试获取已连接对等端信息"""
|
|
cm = ConnectionManager()
|
|
|
|
conn = Connection(
|
|
peer_id="peer1",
|
|
reader=MagicMock(),
|
|
writer=MagicMock(),
|
|
mode=ConnectionMode.P2P,
|
|
ip_address="192.168.1.100",
|
|
port=8889
|
|
)
|
|
cm._peer_connections["peer1"] = conn
|
|
|
|
info = cm.get_peer_connection_info("peer1")
|
|
|
|
assert info is not None
|
|
assert info["peer_id"] == "peer1"
|
|
assert info["mode"] == "p2p"
|
|
assert info["ip_address"] == "192.168.1.100"
|
|
|
|
def test_get_peer_info_for_discovered_peer(self):
|
|
"""测试获取已发现对等端信息"""
|
|
cm = ConnectionManager()
|
|
|
|
cm._discovered_peers["peer1"] = PeerInfo(
|
|
peer_id="peer1",
|
|
username="test",
|
|
ip_address="192.168.1.100",
|
|
port=8889
|
|
)
|
|
|
|
info = cm.get_peer_connection_info("peer1")
|
|
|
|
assert info is not None
|
|
assert info["peer_id"] == "peer1"
|
|
assert info["mode"] == "discovered"
|
|
|
|
def test_get_peer_info_for_unknown_peer(self):
|
|
"""测试获取未知对等端信息"""
|
|
cm = ConnectionManager()
|
|
|
|
info = cm.get_peer_connection_info("unknown")
|
|
|
|
assert info is None
|
|
|
|
|
|
class TestAllConnectionModes:
|
|
"""测试获取所有连接模式"""
|
|
|
|
def test_get_all_connection_modes(self):
|
|
"""测试获取所有对等端的连接模式"""
|
|
cm = ConnectionManager()
|
|
|
|
# 添加P2P连接
|
|
cm._peer_connections["peer1"] = Connection(
|
|
peer_id="peer1",
|
|
reader=MagicMock(),
|
|
writer=MagicMock(),
|
|
mode=ConnectionMode.P2P,
|
|
ip_address="192.168.1.100",
|
|
port=8889
|
|
)
|
|
|
|
# 添加发现的对等端
|
|
cm._discovered_peers["peer2"] = PeerInfo(
|
|
peer_id="peer2",
|
|
username="test",
|
|
ip_address="192.168.1.101",
|
|
port=8889
|
|
)
|
|
|
|
modes = cm.get_all_connection_modes()
|
|
|
|
assert modes["peer1"] == ConnectionMode.P2P
|
|
assert modes["peer2"] == ConnectionMode.P2P
|