|
|
# P2P Network Communication - Connection Manager
|
|
|
"""
|
|
|
客户端连接管理器模块
|
|
|
负责管理网络连接、自动选择通信模式、心跳机制和重连逻辑
|
|
|
|
|
|
需求: 1.1, 1.2, 1.3, 1.4, 1.5, 1.6
|
|
|
"""
|
|
|
|
|
|
import asyncio
|
|
|
import logging
|
|
|
import socket
|
|
|
import struct
|
|
|
import time
|
|
|
import json
|
|
|
from dataclasses import dataclass, field
|
|
|
from datetime import datetime
|
|
|
from enum import Enum
|
|
|
from typing import Callable, Dict, List, Optional, Tuple, Any
|
|
|
|
|
|
from shared.models import (
|
|
|
Message, MessageType, UserInfo, UserStatus,
|
|
|
ConnectionMode, PeerInfo, NetworkQuality
|
|
|
)
|
|
|
from shared.message_handler import (
|
|
|
MessageHandler, MessageSerializationError, MessageValidationError
|
|
|
)
|
|
|
from config import ClientConfig
|
|
|
|
|
|
|
|
|
# 设置日志
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class ConnectionState(Enum):
|
|
|
"""连接状态枚举"""
|
|
|
DISCONNECTED = "disconnected"
|
|
|
CONNECTING = "connecting"
|
|
|
CONNECTED = "connected"
|
|
|
RECONNECTING = "reconnecting"
|
|
|
ERROR = "error"
|
|
|
|
|
|
|
|
|
class ConnectionError(Exception):
|
|
|
"""连接错误"""
|
|
|
pass
|
|
|
|
|
|
|
|
|
class ReconnectionError(Exception):
|
|
|
"""重连错误"""
|
|
|
pass
|
|
|
|
|
|
|
|
|
# 回调类型定义
|
|
|
StateChangeCallback = Callable[[ConnectionState, Optional[str]], None]
|
|
|
MessageReceivedCallback = Callable[[Message], None]
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class Connection:
|
|
|
"""连接信息"""
|
|
|
peer_id: str
|
|
|
reader: asyncio.StreamReader
|
|
|
writer: asyncio.StreamWriter
|
|
|
mode: ConnectionMode
|
|
|
connected_at: datetime = field(default_factory=datetime.now)
|
|
|
last_activity: float = field(default_factory=time.time)
|
|
|
ip_address: str = ""
|
|
|
port: int = 0
|
|
|
|
|
|
|
|
|
class ConnectionManager:
|
|
|
"""
|
|
|
连接管理器
|
|
|
|
|
|
负责:
|
|
|
- 连接到中转服务器 (需求 1.1, 1.4, 1.5)
|
|
|
- 发现局域网内的其他客户端 (需求 1.2)
|
|
|
- 自动选择通信模式 (需求 1.1, 1.2, 1.3)
|
|
|
- 心跳机制保持连接 (需求 1.4, 1.5)
|
|
|
- 网络重连机制 (需求 1.6)
|
|
|
"""
|
|
|
|
|
|
# LAN发现协议魔数
|
|
|
LAN_DISCOVERY_MAGIC = b"P2P_DISCOVER"
|
|
|
LAN_RESPONSE_MAGIC = b"P2P_RESPONSE"
|
|
|
|
|
|
def __init__(self, config: Optional[ClientConfig] = None):
|
|
|
"""
|
|
|
初始化连接管理器
|
|
|
|
|
|
Args:
|
|
|
config: 客户端配置,如果为None则使用默认配置
|
|
|
"""
|
|
|
self.config = config or ClientConfig()
|
|
|
|
|
|
# 服务器连接
|
|
|
self._server_reader: Optional[asyncio.StreamReader] = None
|
|
|
self._server_writer: Optional[asyncio.StreamWriter] = None
|
|
|
self._server_connected = False
|
|
|
|
|
|
# 用户信息
|
|
|
self._user_info: Optional[UserInfo] = None
|
|
|
self._user_id: str = ""
|
|
|
|
|
|
# P2P连接管理
|
|
|
self._peer_connections: Dict[str, Connection] = {}
|
|
|
self._discovered_peers: Dict[str, PeerInfo] = {}
|
|
|
|
|
|
# 连接状态
|
|
|
self._state = ConnectionState.DISCONNECTED
|
|
|
self._state_callbacks: List[StateChangeCallback] = []
|
|
|
self._message_callbacks: List[MessageReceivedCallback] = []
|
|
|
|
|
|
# 消息处理器
|
|
|
self._message_handler = MessageHandler()
|
|
|
|
|
|
# 异步任务
|
|
|
self._heartbeat_task: Optional[asyncio.Task] = None
|
|
|
self._receive_task: Optional[asyncio.Task] = None
|
|
|
self._lan_discovery_task: Optional[asyncio.Task] = None
|
|
|
self._lan_listener_task: Optional[asyncio.Task] = None
|
|
|
|
|
|
# 重连控制
|
|
|
self._reconnect_attempts = 0
|
|
|
self._should_reconnect = True
|
|
|
self._reconnecting = False
|
|
|
|
|
|
# 锁
|
|
|
self._lock = asyncio.Lock()
|
|
|
|
|
|
# LAN监听
|
|
|
self._lan_socket: Optional[socket.socket] = None
|
|
|
self._lan_listen_port: int = 0
|
|
|
|
|
|
# P2P TCP监听
|
|
|
self._p2p_server: Optional[asyncio.AbstractServer] = None
|
|
|
self._p2p_listen_port: int = 0
|
|
|
|
|
|
logger.info("ConnectionManager initialized")
|
|
|
|
|
|
# ==================== 状态管理 ====================
|
|
|
|
|
|
@property
|
|
|
def state(self) -> ConnectionState:
|
|
|
"""获取当前连接状态"""
|
|
|
return self._state
|
|
|
|
|
|
@property
|
|
|
def is_connected(self) -> bool:
|
|
|
"""是否已连接到服务器"""
|
|
|
return self._server_connected and self._state == ConnectionState.CONNECTED
|
|
|
|
|
|
@property
|
|
|
def user_id(self) -> str:
|
|
|
"""获取当前用户ID"""
|
|
|
return self._user_id
|
|
|
|
|
|
def _set_state(self, state: ConnectionState, reason: Optional[str] = None) -> None:
|
|
|
"""
|
|
|
设置连接状态并通知回调
|
|
|
|
|
|
Args:
|
|
|
state: 新状态
|
|
|
reason: 状态变更原因
|
|
|
"""
|
|
|
old_state = self._state
|
|
|
self._state = state
|
|
|
|
|
|
if old_state != state:
|
|
|
logger.info(f"Connection state changed: {old_state.value} -> {state.value}"
|
|
|
+ (f" ({reason})" if reason else ""))
|
|
|
|
|
|
# 通知所有回调
|
|
|
for callback in self._state_callbacks:
|
|
|
try:
|
|
|
callback(state, reason)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in state callback: {e}")
|
|
|
|
|
|
def add_state_callback(self, callback: StateChangeCallback) -> None:
|
|
|
"""
|
|
|
添加状态变更回调
|
|
|
|
|
|
Args:
|
|
|
callback: 回调函数
|
|
|
"""
|
|
|
self._state_callbacks.append(callback)
|
|
|
|
|
|
def remove_state_callback(self, callback: StateChangeCallback) -> None:
|
|
|
"""
|
|
|
移除状态变更回调
|
|
|
|
|
|
Args:
|
|
|
callback: 回调函数
|
|
|
"""
|
|
|
if callback in self._state_callbacks:
|
|
|
self._state_callbacks.remove(callback)
|
|
|
|
|
|
def add_message_callback(self, callback: MessageReceivedCallback) -> None:
|
|
|
"""
|
|
|
添加消息接收回调
|
|
|
|
|
|
Args:
|
|
|
callback: 回调函数
|
|
|
"""
|
|
|
self._message_callbacks.append(callback)
|
|
|
|
|
|
def remove_message_callback(self, callback: MessageReceivedCallback) -> None:
|
|
|
"""
|
|
|
移除消息接收回调
|
|
|
|
|
|
Args:
|
|
|
callback: 回调函数
|
|
|
"""
|
|
|
if callback in self._message_callbacks:
|
|
|
self._message_callbacks.remove(callback)
|
|
|
|
|
|
# ==================== 服务器连接 (需求 1.1, 1.4, 1.5) ====================
|
|
|
|
|
|
async def connect_to_server(self, user_info: UserInfo) -> bool:
|
|
|
"""
|
|
|
连接到中转服务器
|
|
|
|
|
|
自动检测当前网络环境并选择合适的通信模式 (需求 1.1)
|
|
|
使用TCP套接字建立可靠连接 (需求 1.4)
|
|
|
使用UDP套接字建立快速连接 (需求 1.5)
|
|
|
|
|
|
Args:
|
|
|
user_info: 用户信息
|
|
|
|
|
|
Returns:
|
|
|
连接成功返回True,否则返回False
|
|
|
|
|
|
Raises:
|
|
|
ConnectionError: 连接失败时抛出
|
|
|
"""
|
|
|
if self._server_connected:
|
|
|
logger.warning("Already connected to server")
|
|
|
return True
|
|
|
|
|
|
self._user_info = user_info
|
|
|
self._user_id = user_info.user_id
|
|
|
self._set_state(ConnectionState.CONNECTING, "Connecting to server")
|
|
|
|
|
|
try:
|
|
|
# 建立TCP连接
|
|
|
self._server_reader, self._server_writer = await asyncio.wait_for(
|
|
|
asyncio.open_connection(
|
|
|
self.config.server_host,
|
|
|
self.config.server_port
|
|
|
),
|
|
|
timeout=self.config.connection_timeout
|
|
|
)
|
|
|
|
|
|
# 发送注册消息
|
|
|
register_msg = Message(
|
|
|
msg_type=MessageType.USER_REGISTER,
|
|
|
sender_id=self._user_id,
|
|
|
receiver_id="server",
|
|
|
timestamp=time.time(),
|
|
|
payload=user_info.serialize()
|
|
|
)
|
|
|
|
|
|
await self._send_to_server(register_msg)
|
|
|
|
|
|
# 等待注册响应
|
|
|
response = await self._receive_from_server()
|
|
|
|
|
|
if response is None:
|
|
|
raise ConnectionError("No response from server")
|
|
|
|
|
|
if response.msg_type == MessageType.ERROR:
|
|
|
error_msg = response.payload.decode('utf-8')
|
|
|
raise ConnectionError(f"Registration failed: {error_msg}")
|
|
|
|
|
|
if response.msg_type != MessageType.ACK:
|
|
|
raise ConnectionError(f"Unexpected response: {response.msg_type}")
|
|
|
|
|
|
# 连接成功
|
|
|
self._server_connected = True
|
|
|
self._reconnect_attempts = 0
|
|
|
self._set_state(ConnectionState.CONNECTED, "Connected to server")
|
|
|
|
|
|
# 启动心跳任务
|
|
|
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
|
|
|
|
|
# 启动消息接收任务
|
|
|
self._receive_task = asyncio.create_task(self._receive_loop())
|
|
|
|
|
|
# 启动LAN监听
|
|
|
await self._start_lan_listener()
|
|
|
|
|
|
# 启动P2P TCP监听器
|
|
|
await self._start_p2p_server()
|
|
|
|
|
|
logger.info(f"Connected to server {self.config.server_host}:{self.config.server_port}")
|
|
|
return True
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
self._set_state(ConnectionState.ERROR, "Connection timeout")
|
|
|
raise ConnectionError("Connection timeout")
|
|
|
except OSError as e:
|
|
|
self._set_state(ConnectionState.ERROR, f"Network error: {e}")
|
|
|
raise ConnectionError(f"Failed to connect: {e}")
|
|
|
except Exception as e:
|
|
|
self._set_state(ConnectionState.ERROR, str(e))
|
|
|
raise ConnectionError(f"Connection error: {e}")
|
|
|
|
|
|
async def disconnect(self) -> None:
|
|
|
"""
|
|
|
断开所有连接
|
|
|
|
|
|
关闭服务器连接和所有P2P连接
|
|
|
"""
|
|
|
self._should_reconnect = False
|
|
|
|
|
|
# 取消所有任务
|
|
|
tasks_to_cancel = [
|
|
|
self._heartbeat_task,
|
|
|
self._receive_task,
|
|
|
self._lan_discovery_task,
|
|
|
self._lan_listener_task
|
|
|
]
|
|
|
|
|
|
for task in tasks_to_cancel:
|
|
|
if task and not task.done():
|
|
|
task.cancel()
|
|
|
try:
|
|
|
await task
|
|
|
except asyncio.CancelledError:
|
|
|
pass
|
|
|
|
|
|
# 发送注销消息
|
|
|
if self._server_connected and self._server_writer:
|
|
|
try:
|
|
|
unregister_msg = Message(
|
|
|
msg_type=MessageType.USER_UNREGISTER,
|
|
|
sender_id=self._user_id,
|
|
|
receiver_id="server",
|
|
|
timestamp=time.time(),
|
|
|
payload=b""
|
|
|
)
|
|
|
await self._send_to_server(unregister_msg)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error sending unregister message: {e}")
|
|
|
|
|
|
# 关闭服务器连接
|
|
|
if self._server_writer:
|
|
|
try:
|
|
|
self._server_writer.close()
|
|
|
await self._server_writer.wait_closed()
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error closing server connection: {e}")
|
|
|
|
|
|
self._server_reader = None
|
|
|
self._server_writer = None
|
|
|
self._server_connected = False
|
|
|
|
|
|
# 关闭所有P2P连接
|
|
|
async with self._lock:
|
|
|
for peer_id, conn in list(self._peer_connections.items()):
|
|
|
try:
|
|
|
conn.writer.close()
|
|
|
await conn.writer.wait_closed()
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error closing peer connection {peer_id}: {e}")
|
|
|
self._peer_connections.clear()
|
|
|
|
|
|
# 关闭LAN监听socket
|
|
|
if self._lan_socket:
|
|
|
try:
|
|
|
self._lan_socket.close()
|
|
|
except Exception:
|
|
|
pass
|
|
|
self._lan_socket = None
|
|
|
|
|
|
# 关闭P2P TCP服务器
|
|
|
if self._p2p_server:
|
|
|
try:
|
|
|
self._p2p_server.close()
|
|
|
await self._p2p_server.wait_closed()
|
|
|
except Exception:
|
|
|
pass
|
|
|
self._p2p_server = None
|
|
|
self._p2p_listen_port = 0
|
|
|
|
|
|
self._set_state(ConnectionState.DISCONNECTED, "Disconnected")
|
|
|
logger.info("Disconnected from all connections")
|
|
|
|
|
|
# ==================== 心跳机制 (需求 1.4, 1.5) ====================
|
|
|
|
|
|
async def _heartbeat_loop(self) -> None:
|
|
|
"""
|
|
|
心跳循环,定期发送心跳消息保持连接
|
|
|
|
|
|
使用TCP套接字建立可靠连接 (需求 1.4)
|
|
|
"""
|
|
|
while self._server_connected:
|
|
|
try:
|
|
|
await asyncio.sleep(self.config.heartbeat_interval)
|
|
|
|
|
|
if not self._server_connected:
|
|
|
break
|
|
|
|
|
|
# 发送心跳消息
|
|
|
heartbeat_msg = Message(
|
|
|
msg_type=MessageType.HEARTBEAT,
|
|
|
sender_id=self._user_id,
|
|
|
receiver_id="server",
|
|
|
timestamp=time.time(),
|
|
|
payload=b""
|
|
|
)
|
|
|
|
|
|
await self._send_to_server(heartbeat_msg)
|
|
|
logger.debug("Heartbeat sent")
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
break
|
|
|
except Exception as e:
|
|
|
logger.error(f"Heartbeat error: {e}")
|
|
|
# 心跳失败,触发重连
|
|
|
if self._should_reconnect:
|
|
|
asyncio.create_task(self._handle_connection_lost("Heartbeat failed"))
|
|
|
break
|
|
|
|
|
|
async def _receive_loop(self) -> None:
|
|
|
"""
|
|
|
消息接收循环,持续接收服务器消息
|
|
|
"""
|
|
|
while self._server_connected:
|
|
|
try:
|
|
|
message = await self._receive_from_server()
|
|
|
|
|
|
if message is None:
|
|
|
# 连接断开
|
|
|
if self._should_reconnect:
|
|
|
asyncio.create_task(self._handle_connection_lost("Connection closed by server"))
|
|
|
break
|
|
|
|
|
|
# 处理消息
|
|
|
await self._handle_server_message(message)
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
break
|
|
|
except Exception as e:
|
|
|
logger.error(f"Receive error: {e}")
|
|
|
if self._should_reconnect:
|
|
|
asyncio.create_task(self._handle_connection_lost(f"Receive error: {e}"))
|
|
|
break
|
|
|
|
|
|
async def _handle_server_message(self, message: Message) -> None:
|
|
|
"""
|
|
|
处理从服务器接收的消息
|
|
|
|
|
|
Args:
|
|
|
message: 接收到的消息
|
|
|
"""
|
|
|
logger.debug(f"Received message: {message.msg_type.value} from {message.sender_id}")
|
|
|
|
|
|
# 心跳响应不需要特殊处理
|
|
|
if message.msg_type == MessageType.HEARTBEAT:
|
|
|
return
|
|
|
|
|
|
# 通知所有消息回调
|
|
|
for callback in self._message_callbacks:
|
|
|
try:
|
|
|
callback(message)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in message callback: {e}")
|
|
|
|
|
|
# ==================== 消息发送和接收 ====================
|
|
|
|
|
|
async def _send_to_server(self, message: Message) -> bool:
|
|
|
"""
|
|
|
发送消息到服务器
|
|
|
|
|
|
Args:
|
|
|
message: 要发送的消息
|
|
|
|
|
|
Returns:
|
|
|
发送成功返回True,否则返回False
|
|
|
"""
|
|
|
if not self._server_writer:
|
|
|
logger.error("Not connected to server")
|
|
|
return False
|
|
|
|
|
|
try:
|
|
|
data = self._message_handler.serialize(message)
|
|
|
self._server_writer.write(data)
|
|
|
await self._server_writer.drain()
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to send message: {e}")
|
|
|
return False
|
|
|
|
|
|
async def _receive_from_server(self) -> Optional[Message]:
|
|
|
"""
|
|
|
从服务器接收消息
|
|
|
|
|
|
Returns:
|
|
|
接收到的消息,如果连接断开则返回None
|
|
|
"""
|
|
|
if not self._server_reader:
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
# 读取消息头
|
|
|
header = await self._server_reader.readexactly(
|
|
|
self._message_handler.HEADER_SIZE
|
|
|
)
|
|
|
|
|
|
if not header:
|
|
|
return None
|
|
|
|
|
|
# 解析消息长度
|
|
|
payload_length, version = struct.unpack(
|
|
|
self._message_handler.HEADER_FORMAT, header
|
|
|
)
|
|
|
|
|
|
# 读取消息体
|
|
|
payload = await self._server_reader.readexactly(payload_length)
|
|
|
|
|
|
# 反序列化消息
|
|
|
full_data = header + payload
|
|
|
return self._message_handler.deserialize(full_data)
|
|
|
|
|
|
except asyncio.IncompleteReadError:
|
|
|
return None
|
|
|
except MessageSerializationError as e:
|
|
|
logger.error(f"Failed to deserialize message: {e}")
|
|
|
return None
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error receiving message: {e}")
|
|
|
return None
|
|
|
|
|
|
async def send_message(self, peer_id: str, message: Message) -> bool:
|
|
|
"""
|
|
|
发送消息到指定对等端
|
|
|
|
|
|
根据连接模式选择通过服务器中转或直接P2P发送
|
|
|
优先使用P2P直连,如果没有则尝试建立,失败则使用服务器中转
|
|
|
|
|
|
Args:
|
|
|
peer_id: 目标对等端ID
|
|
|
message: 要发送的消息
|
|
|
|
|
|
Returns:
|
|
|
发送成功返回True,否则返回False
|
|
|
"""
|
|
|
# 检查是否有直接P2P连接
|
|
|
if peer_id in self._peer_connections:
|
|
|
conn = self._peer_connections[peer_id]
|
|
|
try:
|
|
|
data = self._message_handler.serialize(message)
|
|
|
conn.writer.write(data)
|
|
|
await conn.writer.drain()
|
|
|
conn.last_activity = time.time()
|
|
|
logger.debug(f"Message sent to {peer_id} via P2P")
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to send P2P message to {peer_id}: {e}")
|
|
|
# P2P连接可能已断开,移除它
|
|
|
await self._close_p2p_connection(peer_id)
|
|
|
|
|
|
# 尝试建立P2P连接(如果对方在局域网内)
|
|
|
if peer_id in self._discovered_peers and peer_id not in self._peer_connections:
|
|
|
logger.info(f"Attempting to establish P2P connection with {peer_id}")
|
|
|
if await self.establish_p2p_connection(peer_id):
|
|
|
# P2P连接建立成功,重试发送
|
|
|
if peer_id in self._peer_connections:
|
|
|
conn = self._peer_connections[peer_id]
|
|
|
try:
|
|
|
data = self._message_handler.serialize(message)
|
|
|
conn.writer.write(data)
|
|
|
await conn.writer.drain()
|
|
|
conn.last_activity = time.time()
|
|
|
logger.debug(f"Message sent to {peer_id} via newly established P2P")
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to send via new P2P connection: {e}")
|
|
|
|
|
|
# 通过服务器中转
|
|
|
if self._server_connected:
|
|
|
message.receiver_id = peer_id
|
|
|
success = await self._send_to_server(message)
|
|
|
if success:
|
|
|
logger.debug(f"Message sent to {peer_id} via relay server")
|
|
|
return success
|
|
|
|
|
|
logger.error(f"Cannot send message to {peer_id}: no connection available")
|
|
|
return False
|
|
|
|
|
|
# ==================== 局域网发现 (需求 1.2) ====================
|
|
|
|
|
|
async def discover_lan_peers(self) -> List[PeerInfo]:
|
|
|
"""
|
|
|
发现局域网内的其他客户端
|
|
|
|
|
|
两个客户端在同一局域网内时建立直接的点对点连接 (需求 1.2)
|
|
|
|
|
|
Returns:
|
|
|
发现的对等端列表
|
|
|
"""
|
|
|
discovered = []
|
|
|
|
|
|
try:
|
|
|
# 创建UDP广播socket
|
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
|
sock.setblocking(False)
|
|
|
|
|
|
# 构建发现请求 - 包含P2P TCP端口
|
|
|
discovery_data = {
|
|
|
"magic": self.LAN_DISCOVERY_MAGIC.decode('utf-8'),
|
|
|
"user_id": self._user_id,
|
|
|
"username": self._user_info.username if self._user_info else "",
|
|
|
"port": self._p2p_listen_port # 使用P2P TCP端口
|
|
|
}
|
|
|
request_data = json.dumps(discovery_data).encode('utf-8')
|
|
|
|
|
|
# 发送广播
|
|
|
broadcast_addr = ('<broadcast>', self.config.lan_broadcast_port)
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
await loop.sock_sendto(sock, request_data, broadcast_addr)
|
|
|
logger.debug(f"LAN discovery broadcast sent to port {self.config.lan_broadcast_port}")
|
|
|
|
|
|
# 等待响应
|
|
|
start_time = time.time()
|
|
|
while time.time() - start_time < self.config.lan_discovery_timeout:
|
|
|
try:
|
|
|
# 设置短超时
|
|
|
sock.settimeout(0.1)
|
|
|
data, addr = sock.recvfrom(1024)
|
|
|
|
|
|
# 解析响应
|
|
|
response = json.loads(data.decode('utf-8'))
|
|
|
|
|
|
if response.get("magic") == self.LAN_RESPONSE_MAGIC.decode('utf-8'):
|
|
|
peer_id = response.get("user_id")
|
|
|
|
|
|
# 忽略自己
|
|
|
if peer_id == self._user_id:
|
|
|
continue
|
|
|
|
|
|
peer_info = PeerInfo(
|
|
|
peer_id=peer_id,
|
|
|
username=response.get("username", ""),
|
|
|
ip_address=addr[0],
|
|
|
port=response.get("port", 0)
|
|
|
)
|
|
|
|
|
|
# 避免重复
|
|
|
if peer_id not in self._discovered_peers:
|
|
|
self._discovered_peers[peer_id] = peer_info
|
|
|
discovered.append(peer_info)
|
|
|
logger.info(f"Discovered LAN peer: {peer_id} at {addr[0]}:{peer_info.port}")
|
|
|
|
|
|
except socket.timeout:
|
|
|
continue
|
|
|
except Exception as e:
|
|
|
logger.debug(f"Error receiving discovery response: {e}")
|
|
|
continue
|
|
|
|
|
|
sock.close()
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"LAN discovery error: {e}")
|
|
|
|
|
|
return discovered
|
|
|
|
|
|
async def _start_lan_listener(self) -> None:
|
|
|
"""
|
|
|
启动LAN发现监听器
|
|
|
|
|
|
监听其他客户端的发现请求并响应
|
|
|
"""
|
|
|
try:
|
|
|
# 创建UDP监听socket
|
|
|
self._lan_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
|
self._lan_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
|
self._lan_socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
|
|
self._lan_socket.bind(('', self.config.lan_broadcast_port))
|
|
|
self._lan_socket.setblocking(False)
|
|
|
|
|
|
# 获取监听端口(用于P2P连接)
|
|
|
self._lan_listen_port = self._lan_socket.getsockname()[1]
|
|
|
|
|
|
# 启动监听任务
|
|
|
self._lan_listener_task = asyncio.create_task(self._lan_listen_loop())
|
|
|
|
|
|
logger.info(f"LAN listener started on port {self.config.lan_broadcast_port}")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to start LAN listener: {e}")
|
|
|
|
|
|
async def _lan_listen_loop(self) -> None:
|
|
|
"""
|
|
|
LAN发现监听循环
|
|
|
"""
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
while self._server_connected:
|
|
|
try:
|
|
|
# 非阻塞接收
|
|
|
try:
|
|
|
data, addr = await asyncio.wait_for(
|
|
|
loop.sock_recvfrom(self._lan_socket, 1024),
|
|
|
timeout=1.0
|
|
|
)
|
|
|
except asyncio.TimeoutError:
|
|
|
continue
|
|
|
|
|
|
# 解析请求
|
|
|
request = json.loads(data.decode('utf-8'))
|
|
|
|
|
|
if request.get("magic") == self.LAN_DISCOVERY_MAGIC.decode('utf-8'):
|
|
|
peer_id = request.get("user_id")
|
|
|
|
|
|
# 忽略自己的广播
|
|
|
if peer_id == self._user_id:
|
|
|
continue
|
|
|
|
|
|
# 发送响应 - 返回P2P TCP端口
|
|
|
response_data = {
|
|
|
"magic": self.LAN_RESPONSE_MAGIC.decode('utf-8'),
|
|
|
"user_id": self._user_id,
|
|
|
"username": self._user_info.username if self._user_info else "",
|
|
|
"port": self._p2p_listen_port # 使用P2P TCP端口
|
|
|
}
|
|
|
response = json.dumps(response_data).encode('utf-8')
|
|
|
|
|
|
await loop.sock_sendto(self._lan_socket, response, addr)
|
|
|
logger.debug(f"Responded to LAN discovery from {addr}, P2P port: {self._p2p_listen_port}")
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
break
|
|
|
except Exception as e:
|
|
|
logger.debug(f"LAN listen error: {e}")
|
|
|
continue
|
|
|
|
|
|
async def _start_p2p_server(self) -> None:
|
|
|
"""
|
|
|
启动P2P TCP服务器
|
|
|
|
|
|
监听来自局域网其他客户端的直接连接
|
|
|
"""
|
|
|
try:
|
|
|
# 尝试在随机端口启动TCP服务器
|
|
|
self._p2p_server = await asyncio.start_server(
|
|
|
self._handle_p2p_connection,
|
|
|
'0.0.0.0',
|
|
|
0 # 让系统分配端口
|
|
|
)
|
|
|
|
|
|
# 获取实际监听的端口
|
|
|
addr = self._p2p_server.sockets[0].getsockname()
|
|
|
self._p2p_listen_port = addr[1]
|
|
|
|
|
|
logger.info(f"P2P TCP server started on port {self._p2p_listen_port}")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to start P2P server: {e}")
|
|
|
self._p2p_listen_port = 0
|
|
|
|
|
|
# ==================== 通信模式选择 (需求 1.1, 1.2, 1.3) ====================
|
|
|
|
|
|
def get_connection_mode(self, peer_id: str) -> ConnectionMode:
|
|
|
"""
|
|
|
获取与指定对等端的连接模式
|
|
|
|
|
|
自动检测当前网络环境并选择合适的通信模式 (需求 1.1)
|
|
|
|
|
|
Args:
|
|
|
peer_id: 对等端ID
|
|
|
|
|
|
Returns:
|
|
|
连接模式(P2P或中转)
|
|
|
"""
|
|
|
# 检查是否有直接P2P连接
|
|
|
if peer_id in self._peer_connections:
|
|
|
return ConnectionMode.P2P
|
|
|
|
|
|
# 检查是否在已发现的LAN对等端中
|
|
|
if peer_id in self._discovered_peers:
|
|
|
return ConnectionMode.P2P
|
|
|
|
|
|
# 默认使用服务器中转
|
|
|
return ConnectionMode.RELAY
|
|
|
|
|
|
async def connect_to_peer(self, peer_id: str) -> Optional[Connection]:
|
|
|
"""
|
|
|
连接到指定的对等端
|
|
|
|
|
|
根据网络环境选择P2P直连或服务器中转 (需求 1.1, 1.2, 1.3)
|
|
|
|
|
|
Args:
|
|
|
peer_id: 对等端ID
|
|
|
|
|
|
Returns:
|
|
|
连接对象,如果连接失败则返回None
|
|
|
"""
|
|
|
# 检查是否已有连接
|
|
|
if peer_id in self._peer_connections:
|
|
|
return self._peer_connections[peer_id]
|
|
|
|
|
|
# 检查是否可以P2P直连
|
|
|
if peer_id in self._discovered_peers:
|
|
|
peer_info = self._discovered_peers[peer_id]
|
|
|
|
|
|
try:
|
|
|
# 尝试建立TCP直连
|
|
|
reader, writer = await asyncio.wait_for(
|
|
|
asyncio.open_connection(
|
|
|
peer_info.ip_address,
|
|
|
peer_info.port
|
|
|
),
|
|
|
timeout=self.config.connection_timeout
|
|
|
)
|
|
|
|
|
|
conn = Connection(
|
|
|
peer_id=peer_id,
|
|
|
reader=reader,
|
|
|
writer=writer,
|
|
|
mode=ConnectionMode.P2P,
|
|
|
ip_address=peer_info.ip_address,
|
|
|
port=peer_info.port
|
|
|
)
|
|
|
|
|
|
async with self._lock:
|
|
|
self._peer_connections[peer_id] = conn
|
|
|
|
|
|
logger.info(f"P2P connection established with {peer_id}")
|
|
|
return conn
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Failed to establish P2P connection with {peer_id}: {e}")
|
|
|
# P2P连接失败,将使用服务器中转
|
|
|
|
|
|
# 使用服务器中转(不需要建立额外连接)
|
|
|
if self._server_connected:
|
|
|
# 创建一个虚拟连接对象表示中转模式
|
|
|
conn = Connection(
|
|
|
peer_id=peer_id,
|
|
|
reader=self._server_reader,
|
|
|
writer=self._server_writer,
|
|
|
mode=ConnectionMode.RELAY,
|
|
|
ip_address=self.config.server_host,
|
|
|
port=self.config.server_port
|
|
|
)
|
|
|
return conn
|
|
|
|
|
|
return None
|
|
|
|
|
|
# ==================== 网络重连机制 (需求 1.6) ====================
|
|
|
|
|
|
async def _handle_connection_lost(self, reason: str) -> None:
|
|
|
"""
|
|
|
处理连接丢失
|
|
|
|
|
|
网络连接断开时自动尝试重新连接并通知用户当前状态 (需求 1.6)
|
|
|
|
|
|
Args:
|
|
|
reason: 断开原因
|
|
|
"""
|
|
|
if self._reconnecting or not self._should_reconnect:
|
|
|
return
|
|
|
|
|
|
self._reconnecting = True
|
|
|
self._server_connected = False
|
|
|
self._set_state(ConnectionState.RECONNECTING, reason)
|
|
|
|
|
|
logger.info(f"Connection lost: {reason}. Starting reconnection...")
|
|
|
|
|
|
# 清理旧连接
|
|
|
if self._server_writer:
|
|
|
try:
|
|
|
self._server_writer.close()
|
|
|
await self._server_writer.wait_closed()
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
self._server_reader = None
|
|
|
self._server_writer = None
|
|
|
|
|
|
# 尝试重连
|
|
|
success = await self._reconnect()
|
|
|
|
|
|
self._reconnecting = False
|
|
|
|
|
|
if not success:
|
|
|
self._set_state(ConnectionState.ERROR, "Reconnection failed")
|
|
|
|
|
|
async def _reconnect(self) -> bool:
|
|
|
"""
|
|
|
执行重连逻辑
|
|
|
|
|
|
自动尝试重新连接并通知用户当前状态 (需求 1.6)
|
|
|
|
|
|
Returns:
|
|
|
重连成功返回True,否则返回False
|
|
|
"""
|
|
|
if not self._user_info:
|
|
|
logger.error("Cannot reconnect: no user info")
|
|
|
return False
|
|
|
|
|
|
max_attempts = self.config.reconnect_attempts
|
|
|
base_delay = self.config.reconnect_delay
|
|
|
|
|
|
for attempt in range(max_attempts):
|
|
|
if not self._should_reconnect:
|
|
|
return False
|
|
|
|
|
|
self._reconnect_attempts = attempt + 1
|
|
|
delay = base_delay * (2 ** attempt) # 指数退避: 1s, 2s, 4s
|
|
|
|
|
|
logger.info(f"Reconnection attempt {attempt + 1}/{max_attempts} "
|
|
|
f"(delay: {delay}s)")
|
|
|
|
|
|
# 等待延迟
|
|
|
await asyncio.sleep(delay)
|
|
|
|
|
|
if not self._should_reconnect:
|
|
|
return False
|
|
|
|
|
|
try:
|
|
|
# 尝试建立连接
|
|
|
self._server_reader, self._server_writer = await asyncio.wait_for(
|
|
|
asyncio.open_connection(
|
|
|
self.config.server_host,
|
|
|
self.config.server_port
|
|
|
),
|
|
|
timeout=self.config.connection_timeout
|
|
|
)
|
|
|
|
|
|
# 发送注册消息
|
|
|
register_msg = Message(
|
|
|
msg_type=MessageType.USER_REGISTER,
|
|
|
sender_id=self._user_id,
|
|
|
receiver_id="server",
|
|
|
timestamp=time.time(),
|
|
|
payload=self._user_info.serialize()
|
|
|
)
|
|
|
|
|
|
await self._send_to_server(register_msg)
|
|
|
|
|
|
# 等待响应
|
|
|
response = await asyncio.wait_for(
|
|
|
self._receive_from_server(),
|
|
|
timeout=self.config.connection_timeout
|
|
|
)
|
|
|
|
|
|
if response and response.msg_type == MessageType.ACK:
|
|
|
# 重连成功
|
|
|
self._server_connected = True
|
|
|
self._reconnect_attempts = 0
|
|
|
self._set_state(ConnectionState.CONNECTED, "Reconnected")
|
|
|
|
|
|
# 重启心跳任务
|
|
|
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
|
|
|
|
|
# 重启消息接收任务
|
|
|
self._receive_task = asyncio.create_task(self._receive_loop())
|
|
|
|
|
|
logger.info("Reconnection successful")
|
|
|
return True
|
|
|
else:
|
|
|
raise ConnectionError("Invalid response from server")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Reconnection attempt {attempt + 1} failed: {e}")
|
|
|
|
|
|
# 清理失败的连接
|
|
|
if self._server_writer:
|
|
|
try:
|
|
|
self._server_writer.close()
|
|
|
await self._server_writer.wait_closed()
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
self._server_reader = None
|
|
|
self._server_writer = None
|
|
|
|
|
|
logger.error(f"Reconnection failed after {max_attempts} attempts")
|
|
|
return False
|
|
|
|
|
|
def enable_reconnect(self, enabled: bool = True) -> None:
|
|
|
"""
|
|
|
启用或禁用自动重连
|
|
|
|
|
|
Args:
|
|
|
enabled: 是否启用
|
|
|
"""
|
|
|
self._should_reconnect = enabled
|
|
|
logger.info(f"Auto-reconnect {'enabled' if enabled else 'disabled'}")
|
|
|
|
|
|
# ==================== 用户列表获取 ====================
|
|
|
|
|
|
async def get_online_users(self) -> List[UserInfo]:
|
|
|
"""
|
|
|
获取在线用户列表
|
|
|
|
|
|
Returns:
|
|
|
在线用户列表
|
|
|
"""
|
|
|
if not self._server_connected:
|
|
|
return []
|
|
|
|
|
|
# 发送用户列表请求
|
|
|
request = Message(
|
|
|
msg_type=MessageType.USER_LIST_REQUEST,
|
|
|
sender_id=self._user_id,
|
|
|
receiver_id="server",
|
|
|
timestamp=time.time(),
|
|
|
payload=b""
|
|
|
)
|
|
|
|
|
|
await self._send_to_server(request)
|
|
|
|
|
|
# 等待响应(通过消息回调处理)
|
|
|
# 这里简化处理,实际应该使用Future等待特定响应
|
|
|
return []
|
|
|
|
|
|
def get_discovered_peers(self) -> List[PeerInfo]:
|
|
|
"""
|
|
|
获取已发现的LAN对等端列表
|
|
|
|
|
|
Returns:
|
|
|
对等端列表
|
|
|
"""
|
|
|
return list(self._discovered_peers.values())
|
|
|
|
|
|
# ==================== 网络质量检测 ====================
|
|
|
|
|
|
async def get_network_quality(self, peer_id: str) -> NetworkQuality:
|
|
|
"""
|
|
|
获取与指定对等端的网络质量
|
|
|
|
|
|
Args:
|
|
|
peer_id: 对等端ID
|
|
|
|
|
|
Returns:
|
|
|
网络质量等级
|
|
|
"""
|
|
|
# 发送ping并测量延迟
|
|
|
start_time = time.time()
|
|
|
|
|
|
ping_msg = Message(
|
|
|
msg_type=MessageType.HEARTBEAT,
|
|
|
sender_id=self._user_id,
|
|
|
receiver_id=peer_id,
|
|
|
timestamp=start_time,
|
|
|
payload=b"ping"
|
|
|
)
|
|
|
|
|
|
success = await self.send_message(peer_id, ping_msg)
|
|
|
|
|
|
if not success:
|
|
|
return NetworkQuality.BAD
|
|
|
|
|
|
# 简化处理:基于发送时间估算
|
|
|
latency = (time.time() - start_time) * 1000 # 转换为毫秒
|
|
|
|
|
|
if latency < 50:
|
|
|
return NetworkQuality.EXCELLENT
|
|
|
elif latency < 100:
|
|
|
return NetworkQuality.GOOD
|
|
|
elif latency < 200:
|
|
|
return NetworkQuality.FAIR
|
|
|
elif latency < 300:
|
|
|
return NetworkQuality.POOR
|
|
|
else:
|
|
|
return NetworkQuality.BAD
|
|
|
|
|
|
# ==================== P2P直连监听 (需求 1.2) ====================
|
|
|
|
|
|
async def start_p2p_listener(self, port: int = 0) -> int:
|
|
|
"""
|
|
|
启动P2P连接监听器
|
|
|
|
|
|
允许其他客户端直接连接到本客户端 (需求 1.2)
|
|
|
|
|
|
Args:
|
|
|
port: 监听端口,0表示自动分配
|
|
|
|
|
|
Returns:
|
|
|
实际监听的端口号
|
|
|
"""
|
|
|
try:
|
|
|
server = await asyncio.start_server(
|
|
|
self._handle_p2p_connection,
|
|
|
'0.0.0.0',
|
|
|
port
|
|
|
)
|
|
|
|
|
|
addr = server.sockets[0].getsockname()
|
|
|
self._lan_listen_port = addr[1]
|
|
|
|
|
|
logger.info(f"P2P listener started on port {self._lan_listen_port}")
|
|
|
return self._lan_listen_port
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to start P2P listener: {e}")
|
|
|
return 0
|
|
|
|
|
|
async def _handle_p2p_connection(
|
|
|
self,
|
|
|
reader: asyncio.StreamReader,
|
|
|
writer: asyncio.StreamWriter
|
|
|
) -> None:
|
|
|
"""
|
|
|
处理传入的P2P连接
|
|
|
|
|
|
Args:
|
|
|
reader: 异步读取器
|
|
|
writer: 异步写入器
|
|
|
"""
|
|
|
addr = writer.get_extra_info('peername')
|
|
|
ip_address = addr[0] if addr else "unknown"
|
|
|
port = addr[1] if addr else 0
|
|
|
|
|
|
logger.info(f"Incoming P2P connection from {ip_address}:{port}")
|
|
|
|
|
|
peer_id: Optional[str] = None
|
|
|
|
|
|
try:
|
|
|
# 等待对方发送身份信息
|
|
|
message = await self._receive_p2p_message(reader)
|
|
|
|
|
|
if message and message.msg_type == MessageType.USER_REGISTER:
|
|
|
peer_id = message.sender_id
|
|
|
|
|
|
# 创建连接对象
|
|
|
conn = Connection(
|
|
|
peer_id=peer_id,
|
|
|
reader=reader,
|
|
|
writer=writer,
|
|
|
mode=ConnectionMode.P2P,
|
|
|
ip_address=ip_address,
|
|
|
port=port
|
|
|
)
|
|
|
|
|
|
async with self._lock:
|
|
|
self._peer_connections[peer_id] = conn
|
|
|
|
|
|
# 发送确认
|
|
|
ack_msg = Message(
|
|
|
msg_type=MessageType.ACK,
|
|
|
sender_id=self._user_id,
|
|
|
receiver_id=peer_id,
|
|
|
timestamp=time.time(),
|
|
|
payload=b"P2P connection accepted"
|
|
|
)
|
|
|
await self._send_p2p_message(writer, ack_msg)
|
|
|
|
|
|
logger.info(f"P2P connection established with {peer_id}")
|
|
|
|
|
|
# 开始接收消息
|
|
|
await self._p2p_receive_loop(peer_id, reader)
|
|
|
else:
|
|
|
logger.warning(f"Invalid P2P handshake from {ip_address}:{port}")
|
|
|
writer.close()
|
|
|
await writer.wait_closed()
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error handling P2P connection: {e}")
|
|
|
finally:
|
|
|
if peer_id:
|
|
|
await self._close_p2p_connection(peer_id)
|
|
|
|
|
|
async def _p2p_receive_loop(self, peer_id: str, reader: asyncio.StreamReader) -> None:
|
|
|
"""
|
|
|
P2P消息接收循环
|
|
|
|
|
|
Args:
|
|
|
peer_id: 对等端ID
|
|
|
reader: 异步读取器
|
|
|
"""
|
|
|
while peer_id in self._peer_connections:
|
|
|
try:
|
|
|
message = await self._receive_p2p_message(reader)
|
|
|
|
|
|
if message is None:
|
|
|
break
|
|
|
|
|
|
# 更新活动时间
|
|
|
if peer_id in self._peer_connections:
|
|
|
self._peer_connections[peer_id].last_activity = time.time()
|
|
|
|
|
|
# 通知消息回调
|
|
|
for callback in self._message_callbacks:
|
|
|
try:
|
|
|
callback(message)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in message callback: {e}")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"P2P receive error from {peer_id}: {e}")
|
|
|
break
|
|
|
|
|
|
async def _receive_p2p_message(self, reader: asyncio.StreamReader) -> Optional[Message]:
|
|
|
"""
|
|
|
从P2P连接接收消息
|
|
|
|
|
|
Args:
|
|
|
reader: 异步读取器
|
|
|
|
|
|
Returns:
|
|
|
接收到的消息,如果连接断开则返回None
|
|
|
"""
|
|
|
try:
|
|
|
header = await reader.readexactly(self._message_handler.HEADER_SIZE)
|
|
|
|
|
|
if not header:
|
|
|
return None
|
|
|
|
|
|
payload_length, version = struct.unpack(
|
|
|
self._message_handler.HEADER_FORMAT, header
|
|
|
)
|
|
|
|
|
|
payload = await reader.readexactly(payload_length)
|
|
|
full_data = header + payload
|
|
|
|
|
|
return self._message_handler.deserialize(full_data)
|
|
|
|
|
|
except asyncio.IncompleteReadError:
|
|
|
return None
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error receiving P2P message: {e}")
|
|
|
return None
|
|
|
|
|
|
async def _send_p2p_message(self, writer: asyncio.StreamWriter, message: Message) -> bool:
|
|
|
"""
|
|
|
发送P2P消息
|
|
|
|
|
|
Args:
|
|
|
writer: 异步写入器
|
|
|
message: 要发送的消息
|
|
|
|
|
|
Returns:
|
|
|
发送成功返回True,否则返回False
|
|
|
"""
|
|
|
try:
|
|
|
data = self._message_handler.serialize(message)
|
|
|
writer.write(data)
|
|
|
await writer.drain()
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to send P2P message: {e}")
|
|
|
return False
|
|
|
|
|
|
async def _close_p2p_connection(self, peer_id: str) -> None:
|
|
|
"""
|
|
|
关闭P2P连接
|
|
|
|
|
|
Args:
|
|
|
peer_id: 对等端ID
|
|
|
"""
|
|
|
async with self._lock:
|
|
|
if peer_id in self._peer_connections:
|
|
|
conn = self._peer_connections[peer_id]
|
|
|
try:
|
|
|
conn.writer.close()
|
|
|
await conn.writer.wait_closed()
|
|
|
except Exception:
|
|
|
pass
|
|
|
del self._peer_connections[peer_id]
|
|
|
logger.info(f"P2P connection closed with {peer_id}")
|
|
|
|
|
|
async def establish_p2p_connection(self, peer_id: str) -> bool:
|
|
|
"""
|
|
|
主动建立P2P连接
|
|
|
|
|
|
尝试与已发现的LAN对等端建立直接连接 (需求 1.2)
|
|
|
|
|
|
Args:
|
|
|
peer_id: 对等端ID
|
|
|
|
|
|
Returns:
|
|
|
连接成功返回True,否则返回False
|
|
|
"""
|
|
|
if peer_id in self._peer_connections:
|
|
|
return True
|
|
|
|
|
|
if peer_id not in self._discovered_peers:
|
|
|
logger.warning(f"Peer {peer_id} not discovered in LAN")
|
|
|
return False
|
|
|
|
|
|
peer_info = self._discovered_peers[peer_id]
|
|
|
|
|
|
try:
|
|
|
reader, writer = await asyncio.wait_for(
|
|
|
asyncio.open_connection(
|
|
|
peer_info.ip_address,
|
|
|
peer_info.port
|
|
|
),
|
|
|
timeout=self.config.connection_timeout
|
|
|
)
|
|
|
|
|
|
# 发送身份信息
|
|
|
register_msg = Message(
|
|
|
msg_type=MessageType.USER_REGISTER,
|
|
|
sender_id=self._user_id,
|
|
|
receiver_id=peer_id,
|
|
|
timestamp=time.time(),
|
|
|
payload=self._user_info.serialize() if self._user_info else b""
|
|
|
)
|
|
|
|
|
|
await self._send_p2p_message(writer, register_msg)
|
|
|
|
|
|
# 等待确认
|
|
|
response = await asyncio.wait_for(
|
|
|
self._receive_p2p_message(reader),
|
|
|
timeout=self.config.connection_timeout
|
|
|
)
|
|
|
|
|
|
if response and response.msg_type == MessageType.ACK:
|
|
|
conn = Connection(
|
|
|
peer_id=peer_id,
|
|
|
reader=reader,
|
|
|
writer=writer,
|
|
|
mode=ConnectionMode.P2P,
|
|
|
ip_address=peer_info.ip_address,
|
|
|
port=peer_info.port
|
|
|
)
|
|
|
|
|
|
async with self._lock:
|
|
|
self._peer_connections[peer_id] = conn
|
|
|
|
|
|
# 启动接收循环
|
|
|
asyncio.create_task(self._p2p_receive_loop(peer_id, reader))
|
|
|
|
|
|
logger.info(f"P2P connection established with {peer_id}")
|
|
|
return True
|
|
|
else:
|
|
|
writer.close()
|
|
|
await writer.wait_closed()
|
|
|
return False
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to establish P2P connection with {peer_id}: {e}")
|
|
|
return False
|
|
|
|
|
|
# ==================== 自动模式切换 (需求 1.1, 1.2, 1.3) ====================
|
|
|
|
|
|
async def auto_select_connection_mode(self, peer_id: str) -> ConnectionMode:
|
|
|
"""
|
|
|
自动选择最佳连接模式
|
|
|
|
|
|
自动检测当前网络环境并选择合适的通信模式 (需求 1.1)
|
|
|
两个客户端在同一局域网内时建立直接的点对点连接 (需求 1.2)
|
|
|
两个客户端不在同一局域网内时通过服务器进行消息中转 (需求 1.3)
|
|
|
|
|
|
Args:
|
|
|
peer_id: 对等端ID
|
|
|
|
|
|
Returns:
|
|
|
选择的连接模式
|
|
|
"""
|
|
|
# 1. 检查是否已有P2P连接
|
|
|
if peer_id in self._peer_connections:
|
|
|
conn = self._peer_connections[peer_id]
|
|
|
if conn.mode == ConnectionMode.P2P:
|
|
|
logger.debug(f"Using existing P2P connection for {peer_id}")
|
|
|
return ConnectionMode.P2P
|
|
|
|
|
|
# 2. 检查是否在已发现的LAN对等端中
|
|
|
if peer_id in self._discovered_peers:
|
|
|
# 尝试建立P2P连接
|
|
|
success = await self.establish_p2p_connection(peer_id)
|
|
|
if success:
|
|
|
logger.info(f"Established P2P connection with {peer_id}")
|
|
|
return ConnectionMode.P2P
|
|
|
|
|
|
# 3. 尝试发现LAN对等端
|
|
|
await self.discover_lan_peers()
|
|
|
|
|
|
if peer_id in self._discovered_peers:
|
|
|
success = await self.establish_p2p_connection(peer_id)
|
|
|
if success:
|
|
|
logger.info(f"Discovered and connected to {peer_id} via P2P")
|
|
|
return ConnectionMode.P2P
|
|
|
|
|
|
# 4. 使用服务器中转
|
|
|
if self._server_connected:
|
|
|
logger.info(f"Using relay server for {peer_id}")
|
|
|
return ConnectionMode.RELAY
|
|
|
|
|
|
logger.warning(f"No connection available for {peer_id}")
|
|
|
return ConnectionMode.UNKNOWN
|
|
|
|
|
|
async def switch_to_p2p(self, peer_id: str) -> bool:
|
|
|
"""
|
|
|
尝试将连接切换到P2P模式
|
|
|
|
|
|
Args:
|
|
|
peer_id: 对等端ID
|
|
|
|
|
|
Returns:
|
|
|
切换成功返回True,否则返回False
|
|
|
"""
|
|
|
current_mode = self.get_connection_mode(peer_id)
|
|
|
|
|
|
if current_mode == ConnectionMode.P2P:
|
|
|
return True
|
|
|
|
|
|
# 尝试发现并连接
|
|
|
await self.discover_lan_peers()
|
|
|
|
|
|
if peer_id in self._discovered_peers:
|
|
|
return await self.establish_p2p_connection(peer_id)
|
|
|
|
|
|
return False
|
|
|
|
|
|
async def switch_to_relay(self, peer_id: str) -> bool:
|
|
|
"""
|
|
|
将连接切换到中转模式
|
|
|
|
|
|
Args:
|
|
|
peer_id: 对等端ID
|
|
|
|
|
|
Returns:
|
|
|
切换成功返回True,否则返回False
|
|
|
"""
|
|
|
# 关闭P2P连接(如果存在)
|
|
|
if peer_id in self._peer_connections:
|
|
|
await self._close_p2p_connection(peer_id)
|
|
|
|
|
|
# 确保服务器连接可用
|
|
|
return self._server_connected
|
|
|
|
|
|
def get_all_connection_modes(self) -> Dict[str, ConnectionMode]:
|
|
|
"""
|
|
|
获取所有对等端的连接模式
|
|
|
|
|
|
Returns:
|
|
|
对等端ID到连接模式的映射
|
|
|
"""
|
|
|
modes = {}
|
|
|
|
|
|
# P2P连接
|
|
|
for peer_id, conn in self._peer_connections.items():
|
|
|
modes[peer_id] = conn.mode
|
|
|
|
|
|
# 已发现但未连接的LAN对等端
|
|
|
for peer_id in self._discovered_peers:
|
|
|
if peer_id not in modes:
|
|
|
modes[peer_id] = ConnectionMode.P2P # 可以建立P2P
|
|
|
|
|
|
return modes
|
|
|
|
|
|
async def optimize_connections(self) -> None:
|
|
|
"""
|
|
|
优化所有连接
|
|
|
|
|
|
尝试将中转连接升级为P2P连接以提高性能
|
|
|
"""
|
|
|
# 发现LAN对等端
|
|
|
await self.discover_lan_peers()
|
|
|
|
|
|
# 对于每个已发现的对等端,尝试建立P2P连接
|
|
|
for peer_id in list(self._discovered_peers.keys()):
|
|
|
if peer_id not in self._peer_connections:
|
|
|
try:
|
|
|
await self.establish_p2p_connection(peer_id)
|
|
|
except Exception as e:
|
|
|
logger.debug(f"Failed to optimize connection for {peer_id}: {e}")
|
|
|
|
|
|
# ==================== 连接信息查询 ====================
|
|
|
|
|
|
def get_peer_connection_info(self, peer_id: str) -> Optional[Dict[str, Any]]:
|
|
|
"""
|
|
|
获取对等端连接信息
|
|
|
|
|
|
Args:
|
|
|
peer_id: 对等端ID
|
|
|
|
|
|
Returns:
|
|
|
连接信息字典,如果没有连接则返回None
|
|
|
"""
|
|
|
if peer_id in self._peer_connections:
|
|
|
conn = self._peer_connections[peer_id]
|
|
|
return {
|
|
|
"peer_id": peer_id,
|
|
|
"mode": conn.mode.value,
|
|
|
"ip_address": conn.ip_address,
|
|
|
"port": conn.port,
|
|
|
"connected_at": conn.connected_at.isoformat(),
|
|
|
"last_activity": conn.last_activity
|
|
|
}
|
|
|
|
|
|
if peer_id in self._discovered_peers:
|
|
|
peer = self._discovered_peers[peer_id]
|
|
|
return {
|
|
|
"peer_id": peer_id,
|
|
|
"mode": "discovered",
|
|
|
"ip_address": peer.ip_address,
|
|
|
"port": peer.port,
|
|
|
"discovered_at": peer.discovered_at.isoformat()
|
|
|
}
|
|
|
|
|
|
return None
|
|
|
|
|
|
def get_connection_stats(self) -> Dict[str, Any]:
|
|
|
"""
|
|
|
获取连接统计信息
|
|
|
|
|
|
Returns:
|
|
|
统计信息字典
|
|
|
"""
|
|
|
p2p_count = sum(
|
|
|
1 for conn in self._peer_connections.values()
|
|
|
if conn.mode == ConnectionMode.P2P
|
|
|
)
|
|
|
|
|
|
return {
|
|
|
"server_connected": self._server_connected,
|
|
|
"state": self._state.value,
|
|
|
"total_peer_connections": len(self._peer_connections),
|
|
|
"p2p_connections": p2p_count,
|
|
|
"discovered_peers": len(self._discovered_peers),
|
|
|
"reconnect_attempts": self._reconnect_attempts
|
|
|
}
|
|
|
|
|
|
# ==================== 重连状态通知增强 (需求 1.6) ====================
|
|
|
|
|
|
def get_reconnect_status(self) -> Dict[str, Any]:
|
|
|
"""
|
|
|
获取重连状态信息
|
|
|
|
|
|
网络连接断开时通知用户当前状态 (需求 1.6)
|
|
|
|
|
|
Returns:
|
|
|
重连状态信息字典
|
|
|
"""
|
|
|
return {
|
|
|
"is_reconnecting": self._reconnecting,
|
|
|
"reconnect_attempts": self._reconnect_attempts,
|
|
|
"max_attempts": self.config.reconnect_attempts,
|
|
|
"auto_reconnect_enabled": self._should_reconnect,
|
|
|
"current_state": self._state.value
|
|
|
}
|
|
|
|
|
|
async def manual_reconnect(self) -> bool:
|
|
|
"""
|
|
|
手动触发重连
|
|
|
|
|
|
当自动重连失败后,允许用户手动触发重连
|
|
|
|
|
|
Returns:
|
|
|
重连成功返回True,否则返回False
|
|
|
"""
|
|
|
if self._server_connected:
|
|
|
logger.info("Already connected, no need to reconnect")
|
|
|
return True
|
|
|
|
|
|
if self._reconnecting:
|
|
|
logger.warning("Reconnection already in progress")
|
|
|
return False
|
|
|
|
|
|
if not self._user_info:
|
|
|
logger.error("Cannot reconnect: no user info available")
|
|
|
return False
|
|
|
|
|
|
self._should_reconnect = True
|
|
|
self._reconnect_attempts = 0
|
|
|
|
|
|
return await self._reconnect()
|
|
|
|
|
|
def set_reconnect_config(
|
|
|
self,
|
|
|
max_attempts: Optional[int] = None,
|
|
|
base_delay: Optional[float] = None
|
|
|
) -> None:
|
|
|
"""
|
|
|
配置重连参数
|
|
|
|
|
|
Args:
|
|
|
max_attempts: 最大重连次数
|
|
|
base_delay: 基础重连延迟(秒)
|
|
|
"""
|
|
|
if max_attempts is not None:
|
|
|
self.config.reconnect_attempts = max_attempts
|
|
|
if base_delay is not None:
|
|
|
self.config.reconnect_delay = base_delay
|
|
|
|
|
|
logger.info(f"Reconnect config updated: max_attempts={self.config.reconnect_attempts}, "
|
|
|
f"base_delay={self.config.reconnect_delay}s")
|
|
|
|
|
|
# ==================== P2P连接重连 ====================
|
|
|
|
|
|
async def reconnect_p2p(self, peer_id: str) -> bool:
|
|
|
"""
|
|
|
重连P2P连接
|
|
|
|
|
|
Args:
|
|
|
peer_id: 对等端ID
|
|
|
|
|
|
Returns:
|
|
|
重连成功返回True,否则返回False
|
|
|
"""
|
|
|
# 先关闭旧连接
|
|
|
if peer_id in self._peer_connections:
|
|
|
await self._close_p2p_connection(peer_id)
|
|
|
|
|
|
# 重新发现
|
|
|
await self.discover_lan_peers()
|
|
|
|
|
|
# 尝试重连
|
|
|
if peer_id in self._discovered_peers:
|
|
|
return await self.establish_p2p_connection(peer_id)
|
|
|
|
|
|
return False
|
|
|
|
|
|
# ==================== 连接健康检查 ====================
|
|
|
|
|
|
async def check_connection_health(self) -> Dict[str, Any]:
|
|
|
"""
|
|
|
检查所有连接的健康状态
|
|
|
|
|
|
Returns:
|
|
|
健康状态报告
|
|
|
"""
|
|
|
report = {
|
|
|
"server": {
|
|
|
"connected": self._server_connected,
|
|
|
"state": self._state.value,
|
|
|
"healthy": False
|
|
|
},
|
|
|
"p2p_connections": {},
|
|
|
"overall_healthy": False
|
|
|
}
|
|
|
|
|
|
# 检查服务器连接
|
|
|
if self._server_connected:
|
|
|
try:
|
|
|
# 发送心跳测试
|
|
|
heartbeat = Message(
|
|
|
msg_type=MessageType.HEARTBEAT,
|
|
|
sender_id=self._user_id,
|
|
|
receiver_id="server",
|
|
|
timestamp=time.time(),
|
|
|
payload=b""
|
|
|
)
|
|
|
success = await self._send_to_server(heartbeat)
|
|
|
report["server"]["healthy"] = success
|
|
|
except Exception:
|
|
|
report["server"]["healthy"] = False
|
|
|
|
|
|
# 检查P2P连接
|
|
|
for peer_id, conn in list(self._peer_connections.items()):
|
|
|
idle_time = time.time() - conn.last_activity
|
|
|
is_healthy = idle_time < 60 # 60秒内有活动认为健康
|
|
|
|
|
|
report["p2p_connections"][peer_id] = {
|
|
|
"mode": conn.mode.value,
|
|
|
"idle_seconds": idle_time,
|
|
|
"healthy": is_healthy
|
|
|
}
|
|
|
|
|
|
# 总体健康状态
|
|
|
report["overall_healthy"] = (
|
|
|
report["server"]["healthy"] or
|
|
|
any(c["healthy"] for c in report["p2p_connections"].values())
|
|
|
)
|
|
|
|
|
|
return report
|
|
|
|
|
|
async def cleanup_stale_connections(self, max_idle_seconds: float = 300) -> int:
|
|
|
"""
|
|
|
清理过期的P2P连接
|
|
|
|
|
|
Args:
|
|
|
max_idle_seconds: 最大空闲时间(秒)
|
|
|
|
|
|
Returns:
|
|
|
清理的连接数
|
|
|
"""
|
|
|
cleaned = 0
|
|
|
current_time = time.time()
|
|
|
|
|
|
for peer_id in list(self._peer_connections.keys()):
|
|
|
conn = self._peer_connections.get(peer_id)
|
|
|
if conn and (current_time - conn.last_activity) > max_idle_seconds:
|
|
|
await self._close_p2p_connection(peer_id)
|
|
|
cleaned += 1
|
|
|
logger.info(f"Cleaned stale P2P connection: {peer_id}")
|
|
|
|
|
|
return cleaned
|