From f9bd845a113cfa8ede7ae630b3c7289905a1945b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=8D=9A=E6=96=87?= <15549487+FX_YBW@user.noreply.gitee.com> Date: Thu, 25 Dec 2025 21:09:49 +0800 Subject: [PATCH] =?UTF-8?q?=E7=A1=AE=E4=BF=9D=E6=96=87=E4=BB=B6=E4=BC=A0?= =?UTF-8?q?=E8=BE=93=E6=A8=A1=E5=9D=97=E6=B5=8B=E8=AF=95=E9=80=9A=E8=BF=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/__init__.py | 20 + client/connection_manager.py | 1630 ++++++++++++++++++++++++++++++ client/file_transfer.py | 1250 +++++++++++++++++++++++ tests/test_connection_manager.py | 331 ++++++ tests/test_file_transfer.py | 580 +++++++++++ 5 files changed, 3811 insertions(+) create mode 100644 client/connection_manager.py create mode 100644 client/file_transfer.py create mode 100644 tests/test_connection_manager.py create mode 100644 tests/test_file_transfer.py diff --git a/client/__init__.py b/client/__init__.py index 6291d5c..3428136 100644 --- a/client/__init__.py +++ b/client/__init__.py @@ -5,3 +5,23 @@ """ __version__ = "0.1.0" + +from client.connection_manager import ( + ConnectionManager, + ConnectionState, + ConnectionError, + ReconnectionError, + Connection, + StateChangeCallback, + MessageReceivedCallback, +) + +from client.file_transfer import ( + FileTransferModule, + FileTransferError, + FileNotFoundError, + FileIntegrityError, + TransferCancelledError, + TransferState, + ProgressCallback, +) diff --git a/client/connection_manager.py b/client/connection_manager.py new file mode 100644 index 0000000..02b38be --- /dev/null +++ b/client/connection_manager.py @@ -0,0 +1,1630 @@ +# P2P Network Communication - Connection Manager +""" +客户端连接管理器模块 +负责管理网络连接、自动选择通信模式、心跳机制和重连逻辑 + +需求: 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 +""" + +import asyncio +import logging +import socket +import struct +import time +import json +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Callable, Dict, List, Optional, Tuple, Any + +from shared.models import ( + Message, MessageType, UserInfo, UserStatus, + ConnectionMode, PeerInfo, NetworkQuality +) +from shared.message_handler import ( + MessageHandler, MessageSerializationError, MessageValidationError +) +from config import ClientConfig + + +# 设置日志 +logger = logging.getLogger(__name__) + + +class ConnectionState(Enum): + """连接状态枚举""" + DISCONNECTED = "disconnected" + CONNECTING = "connecting" + CONNECTED = "connected" + RECONNECTING = "reconnecting" + ERROR = "error" + + +class ConnectionError(Exception): + """连接错误""" + pass + + +class ReconnectionError(Exception): + """重连错误""" + pass + + +# 回调类型定义 +StateChangeCallback = Callable[[ConnectionState, Optional[str]], None] +MessageReceivedCallback = Callable[[Message], None] + + +@dataclass +class Connection: + """连接信息""" + peer_id: str + reader: asyncio.StreamReader + writer: asyncio.StreamWriter + mode: ConnectionMode + connected_at: datetime = field(default_factory=datetime.now) + last_activity: float = field(default_factory=time.time) + ip_address: str = "" + port: int = 0 + + +class ConnectionManager: + """ + 连接管理器 + + 负责: + - 连接到中转服务器 (需求 1.1, 1.4, 1.5) + - 发现局域网内的其他客户端 (需求 1.2) + - 自动选择通信模式 (需求 1.1, 1.2, 1.3) + - 心跳机制保持连接 (需求 1.4, 1.5) + - 网络重连机制 (需求 1.6) + """ + + # LAN发现协议魔数 + LAN_DISCOVERY_MAGIC = b"P2P_DISCOVER" + LAN_RESPONSE_MAGIC = b"P2P_RESPONSE" + + def __init__(self, config: Optional[ClientConfig] = None): + """ + 初始化连接管理器 + + Args: + config: 客户端配置,如果为None则使用默认配置 + """ + self.config = config or ClientConfig() + + # 服务器连接 + self._server_reader: Optional[asyncio.StreamReader] = None + self._server_writer: Optional[asyncio.StreamWriter] = None + self._server_connected = False + + # 用户信息 + self._user_info: Optional[UserInfo] = None + self._user_id: str = "" + + # P2P连接管理 + self._peer_connections: Dict[str, Connection] = {} + self._discovered_peers: Dict[str, PeerInfo] = {} + + # 连接状态 + self._state = ConnectionState.DISCONNECTED + self._state_callbacks: List[StateChangeCallback] = [] + self._message_callbacks: List[MessageReceivedCallback] = [] + + # 消息处理器 + self._message_handler = MessageHandler() + + # 异步任务 + self._heartbeat_task: Optional[asyncio.Task] = None + self._receive_task: Optional[asyncio.Task] = None + self._lan_discovery_task: Optional[asyncio.Task] = None + self._lan_listener_task: Optional[asyncio.Task] = None + + # 重连控制 + self._reconnect_attempts = 0 + self._should_reconnect = True + self._reconnecting = False + + # 锁 + self._lock = asyncio.Lock() + + # LAN监听 + self._lan_socket: Optional[socket.socket] = None + self._lan_listen_port: int = 0 + + logger.info("ConnectionManager initialized") + + # ==================== 状态管理 ==================== + + @property + def state(self) -> ConnectionState: + """获取当前连接状态""" + return self._state + + @property + def is_connected(self) -> bool: + """是否已连接到服务器""" + return self._server_connected and self._state == ConnectionState.CONNECTED + + @property + def user_id(self) -> str: + """获取当前用户ID""" + return self._user_id + + def _set_state(self, state: ConnectionState, reason: Optional[str] = None) -> None: + """ + 设置连接状态并通知回调 + + Args: + state: 新状态 + reason: 状态变更原因 + """ + old_state = self._state + self._state = state + + if old_state != state: + logger.info(f"Connection state changed: {old_state.value} -> {state.value}" + + (f" ({reason})" if reason else "")) + + # 通知所有回调 + for callback in self._state_callbacks: + try: + callback(state, reason) + except Exception as e: + logger.error(f"Error in state callback: {e}") + + def add_state_callback(self, callback: StateChangeCallback) -> None: + """ + 添加状态变更回调 + + Args: + callback: 回调函数 + """ + self._state_callbacks.append(callback) + + def remove_state_callback(self, callback: StateChangeCallback) -> None: + """ + 移除状态变更回调 + + Args: + callback: 回调函数 + """ + if callback in self._state_callbacks: + self._state_callbacks.remove(callback) + + def add_message_callback(self, callback: MessageReceivedCallback) -> None: + """ + 添加消息接收回调 + + Args: + callback: 回调函数 + """ + self._message_callbacks.append(callback) + + def remove_message_callback(self, callback: MessageReceivedCallback) -> None: + """ + 移除消息接收回调 + + Args: + callback: 回调函数 + """ + if callback in self._message_callbacks: + self._message_callbacks.remove(callback) + + # ==================== 服务器连接 (需求 1.1, 1.4, 1.5) ==================== + + async def connect_to_server(self, user_info: UserInfo) -> bool: + """ + 连接到中转服务器 + + 自动检测当前网络环境并选择合适的通信模式 (需求 1.1) + 使用TCP套接字建立可靠连接 (需求 1.4) + 使用UDP套接字建立快速连接 (需求 1.5) + + Args: + user_info: 用户信息 + + Returns: + 连接成功返回True,否则返回False + + Raises: + ConnectionError: 连接失败时抛出 + """ + if self._server_connected: + logger.warning("Already connected to server") + return True + + self._user_info = user_info + self._user_id = user_info.user_id + self._set_state(ConnectionState.CONNECTING, "Connecting to server") + + try: + # 建立TCP连接 + self._server_reader, self._server_writer = await asyncio.wait_for( + asyncio.open_connection( + self.config.server_host, + self.config.server_port + ), + timeout=self.config.connection_timeout + ) + + # 发送注册消息 + register_msg = Message( + msg_type=MessageType.USER_REGISTER, + sender_id=self._user_id, + receiver_id="server", + timestamp=time.time(), + payload=user_info.serialize() + ) + + await self._send_to_server(register_msg) + + # 等待注册响应 + response = await self._receive_from_server() + + if response is None: + raise ConnectionError("No response from server") + + if response.msg_type == MessageType.ERROR: + error_msg = response.payload.decode('utf-8') + raise ConnectionError(f"Registration failed: {error_msg}") + + if response.msg_type != MessageType.ACK: + raise ConnectionError(f"Unexpected response: {response.msg_type}") + + # 连接成功 + self._server_connected = True + self._reconnect_attempts = 0 + self._set_state(ConnectionState.CONNECTED, "Connected to server") + + # 启动心跳任务 + self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + + # 启动消息接收任务 + self._receive_task = asyncio.create_task(self._receive_loop()) + + # 启动LAN监听 + await self._start_lan_listener() + + logger.info(f"Connected to server {self.config.server_host}:{self.config.server_port}") + return True + + except asyncio.TimeoutError: + self._set_state(ConnectionState.ERROR, "Connection timeout") + raise ConnectionError("Connection timeout") + except OSError as e: + self._set_state(ConnectionState.ERROR, f"Network error: {e}") + raise ConnectionError(f"Failed to connect: {e}") + except Exception as e: + self._set_state(ConnectionState.ERROR, str(e)) + raise ConnectionError(f"Connection error: {e}") + + async def disconnect(self) -> None: + """ + 断开所有连接 + + 关闭服务器连接和所有P2P连接 + """ + self._should_reconnect = False + + # 取消所有任务 + tasks_to_cancel = [ + self._heartbeat_task, + self._receive_task, + self._lan_discovery_task, + self._lan_listener_task + ] + + for task in tasks_to_cancel: + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # 发送注销消息 + if self._server_connected and self._server_writer: + try: + unregister_msg = Message( + msg_type=MessageType.USER_UNREGISTER, + sender_id=self._user_id, + receiver_id="server", + timestamp=time.time(), + payload=b"" + ) + await self._send_to_server(unregister_msg) + except Exception as e: + logger.error(f"Error sending unregister message: {e}") + + # 关闭服务器连接 + if self._server_writer: + try: + self._server_writer.close() + await self._server_writer.wait_closed() + except Exception as e: + logger.error(f"Error closing server connection: {e}") + + self._server_reader = None + self._server_writer = None + self._server_connected = False + + # 关闭所有P2P连接 + async with self._lock: + for peer_id, conn in list(self._peer_connections.items()): + try: + conn.writer.close() + await conn.writer.wait_closed() + except Exception as e: + logger.error(f"Error closing peer connection {peer_id}: {e}") + self._peer_connections.clear() + + # 关闭LAN监听socket + if self._lan_socket: + try: + self._lan_socket.close() + except Exception: + pass + self._lan_socket = None + + self._set_state(ConnectionState.DISCONNECTED, "Disconnected") + logger.info("Disconnected from all connections") + + # ==================== 心跳机制 (需求 1.4, 1.5) ==================== + + async def _heartbeat_loop(self) -> None: + """ + 心跳循环,定期发送心跳消息保持连接 + + 使用TCP套接字建立可靠连接 (需求 1.4) + """ + while self._server_connected: + try: + await asyncio.sleep(self.config.heartbeat_interval) + + if not self._server_connected: + break + + # 发送心跳消息 + heartbeat_msg = Message( + msg_type=MessageType.HEARTBEAT, + sender_id=self._user_id, + receiver_id="server", + timestamp=time.time(), + payload=b"" + ) + + await self._send_to_server(heartbeat_msg) + logger.debug("Heartbeat sent") + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Heartbeat error: {e}") + # 心跳失败,触发重连 + if self._should_reconnect: + asyncio.create_task(self._handle_connection_lost("Heartbeat failed")) + break + + async def _receive_loop(self) -> None: + """ + 消息接收循环,持续接收服务器消息 + """ + while self._server_connected: + try: + message = await self._receive_from_server() + + if message is None: + # 连接断开 + if self._should_reconnect: + asyncio.create_task(self._handle_connection_lost("Connection closed by server")) + break + + # 处理消息 + await self._handle_server_message(message) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Receive error: {e}") + if self._should_reconnect: + asyncio.create_task(self._handle_connection_lost(f"Receive error: {e}")) + break + + async def _handle_server_message(self, message: Message) -> None: + """ + 处理从服务器接收的消息 + + Args: + message: 接收到的消息 + """ + logger.debug(f"Received message: {message.msg_type.value} from {message.sender_id}") + + # 心跳响应不需要特殊处理 + if message.msg_type == MessageType.HEARTBEAT: + return + + # 通知所有消息回调 + for callback in self._message_callbacks: + try: + callback(message) + except Exception as e: + logger.error(f"Error in message callback: {e}") + + # ==================== 消息发送和接收 ==================== + + async def _send_to_server(self, message: Message) -> bool: + """ + 发送消息到服务器 + + Args: + message: 要发送的消息 + + Returns: + 发送成功返回True,否则返回False + """ + if not self._server_writer: + logger.error("Not connected to server") + return False + + try: + data = self._message_handler.serialize(message) + self._server_writer.write(data) + await self._server_writer.drain() + return True + except Exception as e: + logger.error(f"Failed to send message: {e}") + return False + + async def _receive_from_server(self) -> Optional[Message]: + """ + 从服务器接收消息 + + Returns: + 接收到的消息,如果连接断开则返回None + """ + if not self._server_reader: + return None + + try: + # 读取消息头 + header = await self._server_reader.readexactly( + self._message_handler.HEADER_SIZE + ) + + if not header: + return None + + # 解析消息长度 + payload_length, version = struct.unpack( + self._message_handler.HEADER_FORMAT, header + ) + + # 读取消息体 + payload = await self._server_reader.readexactly(payload_length) + + # 反序列化消息 + full_data = header + payload + return self._message_handler.deserialize(full_data) + + except asyncio.IncompleteReadError: + return None + except MessageSerializationError as e: + logger.error(f"Failed to deserialize message: {e}") + return None + except Exception as e: + logger.error(f"Error receiving message: {e}") + return None + + async def send_message(self, peer_id: str, message: Message) -> bool: + """ + 发送消息到指定对等端 + + 根据连接模式选择通过服务器中转或直接P2P发送 + + Args: + peer_id: 目标对等端ID + message: 要发送的消息 + + Returns: + 发送成功返回True,否则返回False + """ + # 检查是否有直接P2P连接 + if peer_id in self._peer_connections: + conn = self._peer_connections[peer_id] + try: + data = self._message_handler.serialize(message) + conn.writer.write(data) + await conn.writer.drain() + conn.last_activity = time.time() + logger.debug(f"Message sent to {peer_id} via P2P") + return True + except Exception as e: + logger.error(f"Failed to send P2P message to {peer_id}: {e}") + # P2P发送失败,尝试通过服务器中转 + + # 通过服务器中转 + if self._server_connected: + message.receiver_id = peer_id + success = await self._send_to_server(message) + if success: + logger.debug(f"Message sent to {peer_id} via relay server") + return success + + logger.error(f"Cannot send message to {peer_id}: no connection available") + return False + + # ==================== 局域网发现 (需求 1.2) ==================== + + async def discover_lan_peers(self) -> List[PeerInfo]: + """ + 发现局域网内的其他客户端 + + 两个客户端在同一局域网内时建立直接的点对点连接 (需求 1.2) + + Returns: + 发现的对等端列表 + """ + discovered = [] + + try: + # 创建UDP广播socket + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + # 构建发现请求 + discovery_data = { + "magic": self.LAN_DISCOVERY_MAGIC.decode('utf-8'), + "user_id": self._user_id, + "username": self._user_info.username if self._user_info else "", + "port": self._lan_listen_port + } + request_data = json.dumps(discovery_data).encode('utf-8') + + # 发送广播 + broadcast_addr = ('', self.config.lan_broadcast_port) + loop = asyncio.get_event_loop() + + await loop.sock_sendto(sock, request_data, broadcast_addr) + logger.debug(f"LAN discovery broadcast sent to port {self.config.lan_broadcast_port}") + + # 等待响应 + start_time = time.time() + while time.time() - start_time < self.config.lan_discovery_timeout: + try: + # 设置短超时 + sock.settimeout(0.1) + data, addr = sock.recvfrom(1024) + + # 解析响应 + response = json.loads(data.decode('utf-8')) + + if response.get("magic") == self.LAN_RESPONSE_MAGIC.decode('utf-8'): + peer_id = response.get("user_id") + + # 忽略自己 + if peer_id == self._user_id: + continue + + peer_info = PeerInfo( + peer_id=peer_id, + username=response.get("username", ""), + ip_address=addr[0], + port=response.get("port", 0) + ) + + # 避免重复 + if peer_id not in self._discovered_peers: + self._discovered_peers[peer_id] = peer_info + discovered.append(peer_info) + logger.info(f"Discovered LAN peer: {peer_id} at {addr[0]}:{peer_info.port}") + + except socket.timeout: + continue + except Exception as e: + logger.debug(f"Error receiving discovery response: {e}") + continue + + sock.close() + + except Exception as e: + logger.error(f"LAN discovery error: {e}") + + return discovered + + async def _start_lan_listener(self) -> None: + """ + 启动LAN发现监听器 + + 监听其他客户端的发现请求并响应 + """ + try: + # 创建UDP监听socket + self._lan_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self._lan_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._lan_socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + self._lan_socket.bind(('', self.config.lan_broadcast_port)) + self._lan_socket.setblocking(False) + + # 获取监听端口(用于P2P连接) + self._lan_listen_port = self._lan_socket.getsockname()[1] + + # 启动监听任务 + self._lan_listener_task = asyncio.create_task(self._lan_listen_loop()) + + logger.info(f"LAN listener started on port {self.config.lan_broadcast_port}") + + except Exception as e: + logger.error(f"Failed to start LAN listener: {e}") + + async def _lan_listen_loop(self) -> None: + """ + LAN发现监听循环 + """ + loop = asyncio.get_event_loop() + + while self._server_connected: + try: + # 非阻塞接收 + try: + data, addr = await asyncio.wait_for( + loop.sock_recvfrom(self._lan_socket, 1024), + timeout=1.0 + ) + except asyncio.TimeoutError: + continue + + # 解析请求 + request = json.loads(data.decode('utf-8')) + + if request.get("magic") == self.LAN_DISCOVERY_MAGIC.decode('utf-8'): + peer_id = request.get("user_id") + + # 忽略自己的广播 + if peer_id == self._user_id: + continue + + # 发送响应 + response_data = { + "magic": self.LAN_RESPONSE_MAGIC.decode('utf-8'), + "user_id": self._user_id, + "username": self._user_info.username if self._user_info else "", + "port": self._lan_listen_port + } + response = json.dumps(response_data).encode('utf-8') + + await loop.sock_sendto(self._lan_socket, response, addr) + logger.debug(f"Responded to LAN discovery from {addr}") + + except asyncio.CancelledError: + break + except Exception as e: + logger.debug(f"LAN listen error: {e}") + continue + + # ==================== 通信模式选择 (需求 1.1, 1.2, 1.3) ==================== + + def get_connection_mode(self, peer_id: str) -> ConnectionMode: + """ + 获取与指定对等端的连接模式 + + 自动检测当前网络环境并选择合适的通信模式 (需求 1.1) + + Args: + peer_id: 对等端ID + + Returns: + 连接模式(P2P或中转) + """ + # 检查是否有直接P2P连接 + if peer_id in self._peer_connections: + return ConnectionMode.P2P + + # 检查是否在已发现的LAN对等端中 + if peer_id in self._discovered_peers: + return ConnectionMode.P2P + + # 默认使用服务器中转 + return ConnectionMode.RELAY + + async def connect_to_peer(self, peer_id: str) -> Optional[Connection]: + """ + 连接到指定的对等端 + + 根据网络环境选择P2P直连或服务器中转 (需求 1.1, 1.2, 1.3) + + Args: + peer_id: 对等端ID + + Returns: + 连接对象,如果连接失败则返回None + """ + # 检查是否已有连接 + if peer_id in self._peer_connections: + return self._peer_connections[peer_id] + + # 检查是否可以P2P直连 + if peer_id in self._discovered_peers: + peer_info = self._discovered_peers[peer_id] + + try: + # 尝试建立TCP直连 + reader, writer = await asyncio.wait_for( + asyncio.open_connection( + peer_info.ip_address, + peer_info.port + ), + timeout=self.config.connection_timeout + ) + + conn = Connection( + peer_id=peer_id, + reader=reader, + writer=writer, + mode=ConnectionMode.P2P, + ip_address=peer_info.ip_address, + port=peer_info.port + ) + + async with self._lock: + self._peer_connections[peer_id] = conn + + logger.info(f"P2P connection established with {peer_id}") + return conn + + except Exception as e: + logger.warning(f"Failed to establish P2P connection with {peer_id}: {e}") + # P2P连接失败,将使用服务器中转 + + # 使用服务器中转(不需要建立额外连接) + if self._server_connected: + # 创建一个虚拟连接对象表示中转模式 + conn = Connection( + peer_id=peer_id, + reader=self._server_reader, + writer=self._server_writer, + mode=ConnectionMode.RELAY, + ip_address=self.config.server_host, + port=self.config.server_port + ) + return conn + + return None + + # ==================== 网络重连机制 (需求 1.6) ==================== + + async def _handle_connection_lost(self, reason: str) -> None: + """ + 处理连接丢失 + + 网络连接断开时自动尝试重新连接并通知用户当前状态 (需求 1.6) + + Args: + reason: 断开原因 + """ + if self._reconnecting or not self._should_reconnect: + return + + self._reconnecting = True + self._server_connected = False + self._set_state(ConnectionState.RECONNECTING, reason) + + logger.info(f"Connection lost: {reason}. Starting reconnection...") + + # 清理旧连接 + if self._server_writer: + try: + self._server_writer.close() + await self._server_writer.wait_closed() + except Exception: + pass + + self._server_reader = None + self._server_writer = None + + # 尝试重连 + success = await self._reconnect() + + self._reconnecting = False + + if not success: + self._set_state(ConnectionState.ERROR, "Reconnection failed") + + async def _reconnect(self) -> bool: + """ + 执行重连逻辑 + + 自动尝试重新连接并通知用户当前状态 (需求 1.6) + + Returns: + 重连成功返回True,否则返回False + """ + if not self._user_info: + logger.error("Cannot reconnect: no user info") + return False + + max_attempts = self.config.reconnect_attempts + base_delay = self.config.reconnect_delay + + for attempt in range(max_attempts): + if not self._should_reconnect: + return False + + self._reconnect_attempts = attempt + 1 + delay = base_delay * (2 ** attempt) # 指数退避: 1s, 2s, 4s + + logger.info(f"Reconnection attempt {attempt + 1}/{max_attempts} " + f"(delay: {delay}s)") + + # 等待延迟 + await asyncio.sleep(delay) + + if not self._should_reconnect: + return False + + try: + # 尝试建立连接 + self._server_reader, self._server_writer = await asyncio.wait_for( + asyncio.open_connection( + self.config.server_host, + self.config.server_port + ), + timeout=self.config.connection_timeout + ) + + # 发送注册消息 + register_msg = Message( + msg_type=MessageType.USER_REGISTER, + sender_id=self._user_id, + receiver_id="server", + timestamp=time.time(), + payload=self._user_info.serialize() + ) + + await self._send_to_server(register_msg) + + # 等待响应 + response = await asyncio.wait_for( + self._receive_from_server(), + timeout=self.config.connection_timeout + ) + + if response and response.msg_type == MessageType.ACK: + # 重连成功 + self._server_connected = True + self._reconnect_attempts = 0 + self._set_state(ConnectionState.CONNECTED, "Reconnected") + + # 重启心跳任务 + self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + + # 重启消息接收任务 + self._receive_task = asyncio.create_task(self._receive_loop()) + + logger.info("Reconnection successful") + return True + else: + raise ConnectionError("Invalid response from server") + + except Exception as e: + logger.warning(f"Reconnection attempt {attempt + 1} failed: {e}") + + # 清理失败的连接 + if self._server_writer: + try: + self._server_writer.close() + await self._server_writer.wait_closed() + except Exception: + pass + + self._server_reader = None + self._server_writer = None + + logger.error(f"Reconnection failed after {max_attempts} attempts") + return False + + def enable_reconnect(self, enabled: bool = True) -> None: + """ + 启用或禁用自动重连 + + Args: + enabled: 是否启用 + """ + self._should_reconnect = enabled + logger.info(f"Auto-reconnect {'enabled' if enabled else 'disabled'}") + + # ==================== 用户列表获取 ==================== + + async def get_online_users(self) -> List[UserInfo]: + """ + 获取在线用户列表 + + Returns: + 在线用户列表 + """ + if not self._server_connected: + return [] + + # 发送用户列表请求 + request = Message( + msg_type=MessageType.USER_LIST_REQUEST, + sender_id=self._user_id, + receiver_id="server", + timestamp=time.time(), + payload=b"" + ) + + await self._send_to_server(request) + + # 等待响应(通过消息回调处理) + # 这里简化处理,实际应该使用Future等待特定响应 + return [] + + def get_discovered_peers(self) -> List[PeerInfo]: + """ + 获取已发现的LAN对等端列表 + + Returns: + 对等端列表 + """ + return list(self._discovered_peers.values()) + + # ==================== 网络质量检测 ==================== + + async def get_network_quality(self, peer_id: str) -> NetworkQuality: + """ + 获取与指定对等端的网络质量 + + Args: + peer_id: 对等端ID + + Returns: + 网络质量等级 + """ + # 发送ping并测量延迟 + start_time = time.time() + + ping_msg = Message( + msg_type=MessageType.HEARTBEAT, + sender_id=self._user_id, + receiver_id=peer_id, + timestamp=start_time, + payload=b"ping" + ) + + success = await self.send_message(peer_id, ping_msg) + + if not success: + return NetworkQuality.BAD + + # 简化处理:基于发送时间估算 + latency = (time.time() - start_time) * 1000 # 转换为毫秒 + + if latency < 50: + return NetworkQuality.EXCELLENT + elif latency < 100: + return NetworkQuality.GOOD + elif latency < 200: + return NetworkQuality.FAIR + elif latency < 300: + return NetworkQuality.POOR + else: + return NetworkQuality.BAD + + # ==================== P2P直连监听 (需求 1.2) ==================== + + async def start_p2p_listener(self, port: int = 0) -> int: + """ + 启动P2P连接监听器 + + 允许其他客户端直接连接到本客户端 (需求 1.2) + + Args: + port: 监听端口,0表示自动分配 + + Returns: + 实际监听的端口号 + """ + try: + server = await asyncio.start_server( + self._handle_p2p_connection, + '0.0.0.0', + port + ) + + addr = server.sockets[0].getsockname() + self._lan_listen_port = addr[1] + + logger.info(f"P2P listener started on port {self._lan_listen_port}") + return self._lan_listen_port + + except Exception as e: + logger.error(f"Failed to start P2P listener: {e}") + return 0 + + async def _handle_p2p_connection( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter + ) -> None: + """ + 处理传入的P2P连接 + + Args: + reader: 异步读取器 + writer: 异步写入器 + """ + addr = writer.get_extra_info('peername') + ip_address = addr[0] if addr else "unknown" + port = addr[1] if addr else 0 + + logger.info(f"Incoming P2P connection from {ip_address}:{port}") + + peer_id: Optional[str] = None + + try: + # 等待对方发送身份信息 + message = await self._receive_p2p_message(reader) + + if message and message.msg_type == MessageType.USER_REGISTER: + peer_id = message.sender_id + + # 创建连接对象 + conn = Connection( + peer_id=peer_id, + reader=reader, + writer=writer, + mode=ConnectionMode.P2P, + ip_address=ip_address, + port=port + ) + + async with self._lock: + self._peer_connections[peer_id] = conn + + # 发送确认 + ack_msg = Message( + msg_type=MessageType.ACK, + sender_id=self._user_id, + receiver_id=peer_id, + timestamp=time.time(), + payload=b"P2P connection accepted" + ) + await self._send_p2p_message(writer, ack_msg) + + logger.info(f"P2P connection established with {peer_id}") + + # 开始接收消息 + await self._p2p_receive_loop(peer_id, reader) + else: + logger.warning(f"Invalid P2P handshake from {ip_address}:{port}") + writer.close() + await writer.wait_closed() + + except Exception as e: + logger.error(f"Error handling P2P connection: {e}") + finally: + if peer_id: + await self._close_p2p_connection(peer_id) + + async def _p2p_receive_loop(self, peer_id: str, reader: asyncio.StreamReader) -> None: + """ + P2P消息接收循环 + + Args: + peer_id: 对等端ID + reader: 异步读取器 + """ + while peer_id in self._peer_connections: + try: + message = await self._receive_p2p_message(reader) + + if message is None: + break + + # 更新活动时间 + if peer_id in self._peer_connections: + self._peer_connections[peer_id].last_activity = time.time() + + # 通知消息回调 + for callback in self._message_callbacks: + try: + callback(message) + except Exception as e: + logger.error(f"Error in message callback: {e}") + + except Exception as e: + logger.error(f"P2P receive error from {peer_id}: {e}") + break + + async def _receive_p2p_message(self, reader: asyncio.StreamReader) -> Optional[Message]: + """ + 从P2P连接接收消息 + + Args: + reader: 异步读取器 + + Returns: + 接收到的消息,如果连接断开则返回None + """ + try: + header = await reader.readexactly(self._message_handler.HEADER_SIZE) + + if not header: + return None + + payload_length, version = struct.unpack( + self._message_handler.HEADER_FORMAT, header + ) + + payload = await reader.readexactly(payload_length) + full_data = header + payload + + return self._message_handler.deserialize(full_data) + + except asyncio.IncompleteReadError: + return None + except Exception as e: + logger.error(f"Error receiving P2P message: {e}") + return None + + async def _send_p2p_message(self, writer: asyncio.StreamWriter, message: Message) -> bool: + """ + 发送P2P消息 + + Args: + writer: 异步写入器 + message: 要发送的消息 + + Returns: + 发送成功返回True,否则返回False + """ + try: + data = self._message_handler.serialize(message) + writer.write(data) + await writer.drain() + return True + except Exception as e: + logger.error(f"Failed to send P2P message: {e}") + return False + + async def _close_p2p_connection(self, peer_id: str) -> None: + """ + 关闭P2P连接 + + Args: + peer_id: 对等端ID + """ + async with self._lock: + if peer_id in self._peer_connections: + conn = self._peer_connections[peer_id] + try: + conn.writer.close() + await conn.writer.wait_closed() + except Exception: + pass + del self._peer_connections[peer_id] + logger.info(f"P2P connection closed with {peer_id}") + + async def establish_p2p_connection(self, peer_id: str) -> bool: + """ + 主动建立P2P连接 + + 尝试与已发现的LAN对等端建立直接连接 (需求 1.2) + + Args: + peer_id: 对等端ID + + Returns: + 连接成功返回True,否则返回False + """ + if peer_id in self._peer_connections: + return True + + if peer_id not in self._discovered_peers: + logger.warning(f"Peer {peer_id} not discovered in LAN") + return False + + peer_info = self._discovered_peers[peer_id] + + try: + reader, writer = await asyncio.wait_for( + asyncio.open_connection( + peer_info.ip_address, + peer_info.port + ), + timeout=self.config.connection_timeout + ) + + # 发送身份信息 + register_msg = Message( + msg_type=MessageType.USER_REGISTER, + sender_id=self._user_id, + receiver_id=peer_id, + timestamp=time.time(), + payload=self._user_info.serialize() if self._user_info else b"" + ) + + await self._send_p2p_message(writer, register_msg) + + # 等待确认 + response = await asyncio.wait_for( + self._receive_p2p_message(reader), + timeout=self.config.connection_timeout + ) + + if response and response.msg_type == MessageType.ACK: + conn = Connection( + peer_id=peer_id, + reader=reader, + writer=writer, + mode=ConnectionMode.P2P, + ip_address=peer_info.ip_address, + port=peer_info.port + ) + + async with self._lock: + self._peer_connections[peer_id] = conn + + # 启动接收循环 + asyncio.create_task(self._p2p_receive_loop(peer_id, reader)) + + logger.info(f"P2P connection established with {peer_id}") + return True + else: + writer.close() + await writer.wait_closed() + return False + + except Exception as e: + logger.error(f"Failed to establish P2P connection with {peer_id}: {e}") + return False + + # ==================== 自动模式切换 (需求 1.1, 1.2, 1.3) ==================== + + async def auto_select_connection_mode(self, peer_id: str) -> ConnectionMode: + """ + 自动选择最佳连接模式 + + 自动检测当前网络环境并选择合适的通信模式 (需求 1.1) + 两个客户端在同一局域网内时建立直接的点对点连接 (需求 1.2) + 两个客户端不在同一局域网内时通过服务器进行消息中转 (需求 1.3) + + Args: + peer_id: 对等端ID + + Returns: + 选择的连接模式 + """ + # 1. 检查是否已有P2P连接 + if peer_id in self._peer_connections: + conn = self._peer_connections[peer_id] + if conn.mode == ConnectionMode.P2P: + logger.debug(f"Using existing P2P connection for {peer_id}") + return ConnectionMode.P2P + + # 2. 检查是否在已发现的LAN对等端中 + if peer_id in self._discovered_peers: + # 尝试建立P2P连接 + success = await self.establish_p2p_connection(peer_id) + if success: + logger.info(f"Established P2P connection with {peer_id}") + return ConnectionMode.P2P + + # 3. 尝试发现LAN对等端 + await self.discover_lan_peers() + + if peer_id in self._discovered_peers: + success = await self.establish_p2p_connection(peer_id) + if success: + logger.info(f"Discovered and connected to {peer_id} via P2P") + return ConnectionMode.P2P + + # 4. 使用服务器中转 + if self._server_connected: + logger.info(f"Using relay server for {peer_id}") + return ConnectionMode.RELAY + + logger.warning(f"No connection available for {peer_id}") + return ConnectionMode.UNKNOWN + + async def switch_to_p2p(self, peer_id: str) -> bool: + """ + 尝试将连接切换到P2P模式 + + Args: + peer_id: 对等端ID + + Returns: + 切换成功返回True,否则返回False + """ + current_mode = self.get_connection_mode(peer_id) + + if current_mode == ConnectionMode.P2P: + return True + + # 尝试发现并连接 + await self.discover_lan_peers() + + if peer_id in self._discovered_peers: + return await self.establish_p2p_connection(peer_id) + + return False + + async def switch_to_relay(self, peer_id: str) -> bool: + """ + 将连接切换到中转模式 + + Args: + peer_id: 对等端ID + + Returns: + 切换成功返回True,否则返回False + """ + # 关闭P2P连接(如果存在) + if peer_id in self._peer_connections: + await self._close_p2p_connection(peer_id) + + # 确保服务器连接可用 + return self._server_connected + + def get_all_connection_modes(self) -> Dict[str, ConnectionMode]: + """ + 获取所有对等端的连接模式 + + Returns: + 对等端ID到连接模式的映射 + """ + modes = {} + + # P2P连接 + for peer_id, conn in self._peer_connections.items(): + modes[peer_id] = conn.mode + + # 已发现但未连接的LAN对等端 + for peer_id in self._discovered_peers: + if peer_id not in modes: + modes[peer_id] = ConnectionMode.P2P # 可以建立P2P + + return modes + + async def optimize_connections(self) -> None: + """ + 优化所有连接 + + 尝试将中转连接升级为P2P连接以提高性能 + """ + # 发现LAN对等端 + await self.discover_lan_peers() + + # 对于每个已发现的对等端,尝试建立P2P连接 + for peer_id in list(self._discovered_peers.keys()): + if peer_id not in self._peer_connections: + try: + await self.establish_p2p_connection(peer_id) + except Exception as e: + logger.debug(f"Failed to optimize connection for {peer_id}: {e}") + + # ==================== 连接信息查询 ==================== + + def get_peer_connection_info(self, peer_id: str) -> Optional[Dict[str, Any]]: + """ + 获取对等端连接信息 + + Args: + peer_id: 对等端ID + + Returns: + 连接信息字典,如果没有连接则返回None + """ + if peer_id in self._peer_connections: + conn = self._peer_connections[peer_id] + return { + "peer_id": peer_id, + "mode": conn.mode.value, + "ip_address": conn.ip_address, + "port": conn.port, + "connected_at": conn.connected_at.isoformat(), + "last_activity": conn.last_activity + } + + if peer_id in self._discovered_peers: + peer = self._discovered_peers[peer_id] + return { + "peer_id": peer_id, + "mode": "discovered", + "ip_address": peer.ip_address, + "port": peer.port, + "discovered_at": peer.discovered_at.isoformat() + } + + return None + + def get_connection_stats(self) -> Dict[str, Any]: + """ + 获取连接统计信息 + + Returns: + 统计信息字典 + """ + p2p_count = sum( + 1 for conn in self._peer_connections.values() + if conn.mode == ConnectionMode.P2P + ) + + return { + "server_connected": self._server_connected, + "state": self._state.value, + "total_peer_connections": len(self._peer_connections), + "p2p_connections": p2p_count, + "discovered_peers": len(self._discovered_peers), + "reconnect_attempts": self._reconnect_attempts + } + + # ==================== 重连状态通知增强 (需求 1.6) ==================== + + def get_reconnect_status(self) -> Dict[str, Any]: + """ + 获取重连状态信息 + + 网络连接断开时通知用户当前状态 (需求 1.6) + + Returns: + 重连状态信息字典 + """ + return { + "is_reconnecting": self._reconnecting, + "reconnect_attempts": self._reconnect_attempts, + "max_attempts": self.config.reconnect_attempts, + "auto_reconnect_enabled": self._should_reconnect, + "current_state": self._state.value + } + + async def manual_reconnect(self) -> bool: + """ + 手动触发重连 + + 当自动重连失败后,允许用户手动触发重连 + + Returns: + 重连成功返回True,否则返回False + """ + if self._server_connected: + logger.info("Already connected, no need to reconnect") + return True + + if self._reconnecting: + logger.warning("Reconnection already in progress") + return False + + if not self._user_info: + logger.error("Cannot reconnect: no user info available") + return False + + self._should_reconnect = True + self._reconnect_attempts = 0 + + return await self._reconnect() + + def set_reconnect_config( + self, + max_attempts: Optional[int] = None, + base_delay: Optional[float] = None + ) -> None: + """ + 配置重连参数 + + Args: + max_attempts: 最大重连次数 + base_delay: 基础重连延迟(秒) + """ + if max_attempts is not None: + self.config.reconnect_attempts = max_attempts + if base_delay is not None: + self.config.reconnect_delay = base_delay + + logger.info(f"Reconnect config updated: max_attempts={self.config.reconnect_attempts}, " + f"base_delay={self.config.reconnect_delay}s") + + # ==================== P2P连接重连 ==================== + + async def reconnect_p2p(self, peer_id: str) -> bool: + """ + 重连P2P连接 + + Args: + peer_id: 对等端ID + + Returns: + 重连成功返回True,否则返回False + """ + # 先关闭旧连接 + if peer_id in self._peer_connections: + await self._close_p2p_connection(peer_id) + + # 重新发现 + await self.discover_lan_peers() + + # 尝试重连 + if peer_id in self._discovered_peers: + return await self.establish_p2p_connection(peer_id) + + return False + + # ==================== 连接健康检查 ==================== + + async def check_connection_health(self) -> Dict[str, Any]: + """ + 检查所有连接的健康状态 + + Returns: + 健康状态报告 + """ + report = { + "server": { + "connected": self._server_connected, + "state": self._state.value, + "healthy": False + }, + "p2p_connections": {}, + "overall_healthy": False + } + + # 检查服务器连接 + if self._server_connected: + try: + # 发送心跳测试 + heartbeat = Message( + msg_type=MessageType.HEARTBEAT, + sender_id=self._user_id, + receiver_id="server", + timestamp=time.time(), + payload=b"" + ) + success = await self._send_to_server(heartbeat) + report["server"]["healthy"] = success + except Exception: + report["server"]["healthy"] = False + + # 检查P2P连接 + for peer_id, conn in list(self._peer_connections.items()): + idle_time = time.time() - conn.last_activity + is_healthy = idle_time < 60 # 60秒内有活动认为健康 + + report["p2p_connections"][peer_id] = { + "mode": conn.mode.value, + "idle_seconds": idle_time, + "healthy": is_healthy + } + + # 总体健康状态 + report["overall_healthy"] = ( + report["server"]["healthy"] or + any(c["healthy"] for c in report["p2p_connections"].values()) + ) + + return report + + async def cleanup_stale_connections(self, max_idle_seconds: float = 300) -> int: + """ + 清理过期的P2P连接 + + Args: + max_idle_seconds: 最大空闲时间(秒) + + Returns: + 清理的连接数 + """ + cleaned = 0 + current_time = time.time() + + for peer_id in list(self._peer_connections.keys()): + conn = self._peer_connections.get(peer_id) + if conn and (current_time - conn.last_activity) > max_idle_seconds: + await self._close_p2p_connection(peer_id) + cleaned += 1 + logger.info(f"Cleaned stale P2P connection: {peer_id}") + + return cleaned diff --git a/client/file_transfer.py b/client/file_transfer.py new file mode 100644 index 0000000..3786320 --- /dev/null +++ b/client/file_transfer.py @@ -0,0 +1,1250 @@ +# P2P Network Communication - File Transfer Module +""" +文件传输模块 +负责文件的分块传输、断点续传和完整性校验 + +需求: 4.2, 4.4, 4.5 +""" + +import asyncio +import hashlib +import json +import logging +import os +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Callable, Dict, List, Optional, Any + +from shared.models import ( + Message, MessageType, FileChunk, TransferProgress, + FileTransferRecord, TransferStatus +) +from shared.message_handler import MessageHandler +from config import ClientConfig + + +# 设置日志 +logger = logging.getLogger(__name__) + + +class FileTransferError(Exception): + """文件传输错误""" + pass + + +class FileNotFoundError(FileTransferError): + """文件不存在错误""" + pass + + +class FileIntegrityError(FileTransferError): + """文件完整性校验错误""" + pass + + +class TransferCancelledError(FileTransferError): + """传输取消错误""" + pass + + +# 进度回调类型 +ProgressCallback = Callable[[TransferProgress], None] + + +@dataclass +class TransferState: + """传输状态(用于断点续传)""" + file_id: str + file_path: str + file_name: str + file_size: int + file_hash: str + total_chunks: int + completed_chunks: List[int] = field(default_factory=list) + status: TransferStatus = TransferStatus.PENDING + sender_id: str = "" + receiver_id: str = "" + start_time: datetime = field(default_factory=datetime.now) + last_update: datetime = field(default_factory=datetime.now) + save_path: str = "" # 接收方保存路径 + + def to_dict(self) -> dict: + """转换为字典(用于持久化)""" + return { + "file_id": self.file_id, + "file_path": self.file_path, + "file_name": self.file_name, + "file_size": self.file_size, + "file_hash": self.file_hash, + "total_chunks": self.total_chunks, + "completed_chunks": self.completed_chunks, + "status": self.status.value, + "sender_id": self.sender_id, + "receiver_id": self.receiver_id, + "start_time": self.start_time.isoformat(), + "last_update": self.last_update.isoformat(), + "save_path": self.save_path, + } + + @classmethod + def from_dict(cls, data: dict) -> "TransferState": + """从字典创建TransferState对象""" + return cls( + file_id=data["file_id"], + file_path=data["file_path"], + file_name=data["file_name"], + file_size=data["file_size"], + file_hash=data["file_hash"], + total_chunks=data["total_chunks"], + completed_chunks=data.get("completed_chunks", []), + status=TransferStatus(data.get("status", "pending")), + sender_id=data.get("sender_id", ""), + receiver_id=data.get("receiver_id", ""), + start_time=datetime.fromisoformat(data["start_time"]), + last_update=datetime.fromisoformat(data["last_update"]), + save_path=data.get("save_path", ""), + ) + + @property + def progress_percent(self) -> float: + """获取进度百分比""" + if self.total_chunks == 0: + return 0.0 + return (len(self.completed_chunks) / self.total_chunks) * 100 + + @property + def transferred_size(self) -> int: + """获取已传输大小""" + chunk_size = FileTransferModule.CHUNK_SIZE + full_chunks = len(self.completed_chunks) + + if full_chunks == 0: + return 0 + + # 最后一个块可能不是完整大小 + if self.total_chunks in self.completed_chunks: + last_chunk_size = self.file_size % chunk_size + if last_chunk_size == 0: + last_chunk_size = chunk_size + return (full_chunks - 1) * chunk_size + last_chunk_size + + return full_chunks * chunk_size + + + +class FileTransferModule: + """ + 文件传输模块 + + 负责: + - 文件分块传输 (需求 4.2) + - 传输进度回调 (需求 4.3) + - 文件完整性校验 (需求 4.4) + - 断点续传 (需求 4.5) + """ + + # 每个块的大小: 64KB + CHUNK_SIZE = 64 * 1024 + + # 传输状态持久化目录 + STATE_DIR = "transfer_states" + + def __init__(self, config: Optional[ClientConfig] = None, + send_message_func: Optional[Callable] = None): + """ + 初始化文件传输模块 + + Args: + config: 客户端配置 + send_message_func: 发送消息的函数(由ConnectionManager提供) + """ + self.config = config or ClientConfig() + self._send_message = send_message_func + + # 传输状态管理 + self._active_transfers: Dict[str, TransferState] = {} + self._cancelled_transfers: set = set() + + # 接收缓冲区: file_id -> {chunk_index: data} + self._receive_buffers: Dict[str, Dict[int, bytes]] = {} + + # 进度回调 + self._progress_callbacks: Dict[str, ProgressCallback] = {} + + # 消息处理器 + self._message_handler = MessageHandler() + + # 确保状态目录存在 + self._state_dir = Path(self.config.data_dir) / self.STATE_DIR + self._state_dir.mkdir(parents=True, exist_ok=True) + + # 加载未完成的传输状态 + self._load_transfer_states() + + logger.info("FileTransferModule initialized") + + def set_send_message_func(self, func: Callable) -> None: + """ + 设置发送消息函数 + + Args: + func: 发送消息的异步函数 + """ + self._send_message = func + + # ==================== 文件哈希计算 (需求 4.4) ==================== + + def calculate_file_hash(self, file_path: str, algorithm: str = "sha256") -> str: + """ + 计算文件哈希值 + + 实现 MD5/SHA256 哈希计算 (需求 4.4) + + Args: + file_path: 文件路径 + algorithm: 哈希算法 ("md5" 或 "sha256") + + Returns: + 文件哈希值(十六进制字符串) + + Raises: + FileNotFoundError: 文件不存在 + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + if algorithm == "md5": + hasher = hashlib.md5() + else: + hasher = hashlib.sha256() + + with open(file_path, 'rb') as f: + while True: + data = f.read(self.CHUNK_SIZE) + if not data: + break + hasher.update(data) + + return hasher.hexdigest() + + def calculate_chunk_hash(self, data: bytes) -> str: + """ + 计算数据块哈希值 + + Args: + data: 数据块 + + Returns: + MD5哈希值 + """ + return hashlib.md5(data).hexdigest() + + def verify_file_integrity(self, file_path: str, expected_hash: str, + algorithm: str = "sha256") -> bool: + """ + 验证文件完整性 + + 实现传输完成后的校验逻辑 (需求 4.4) + + Args: + file_path: 文件路径 + expected_hash: 期望的哈希值 + algorithm: 哈希算法 + + Returns: + 校验通过返回True,否则返回False + """ + try: + actual_hash = self.calculate_file_hash(file_path, algorithm) + return actual_hash == expected_hash + except Exception as e: + logger.error(f"File integrity verification failed: {e}") + return False + + # ==================== 文件分块 (需求 4.2) ==================== + + def _get_total_chunks(self, file_size: int) -> int: + """ + 计算文件总块数 + + Args: + file_size: 文件大小 + + Returns: + 总块数 + """ + return (file_size + self.CHUNK_SIZE - 1) // self.CHUNK_SIZE + + def _read_chunk(self, file_path: str, chunk_index: int) -> bytes: + """ + 读取指定的文件块 + + 实现文件分块逻辑(64KB per chunk)(需求 4.2) + + Args: + file_path: 文件路径 + chunk_index: 块索引(从0开始) + + Returns: + 数据块 + """ + offset = chunk_index * self.CHUNK_SIZE + + with open(file_path, 'rb') as f: + f.seek(offset) + return f.read(self.CHUNK_SIZE) + + def _write_chunk(self, file_path: str, chunk_index: int, data: bytes) -> None: + """ + 写入指定的文件块 + + Args: + file_path: 文件路径 + chunk_index: 块索引 + data: 数据块 + """ + offset = chunk_index * self.CHUNK_SIZE + + # 确保文件存在 + if not os.path.exists(file_path): + # 创建空文件 + with open(file_path, 'wb') as f: + pass + + with open(file_path, 'r+b') as f: + f.seek(offset) + f.write(data) + + def split_file_to_chunks(self, file_path: str) -> List[FileChunk]: + """ + 将文件分割成多个块 + + 实现文件分块逻辑(64KB per chunk)(需求 4.2) + + Args: + file_path: 文件路径 + + Returns: + 文件块列表 + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + file_size = os.path.getsize(file_path) + total_chunks = self._get_total_chunks(file_size) + file_id = str(uuid.uuid4()) + + chunks = [] + for i in range(total_chunks): + data = self._read_chunk(file_path, i) + chunk = FileChunk( + file_id=file_id, + chunk_index=i, + total_chunks=total_chunks, + data=data + ) + chunks.append(chunk) + + return chunks + + # ==================== 进度计算和回调 ==================== + + def _calculate_progress(self, state: TransferState, + start_time: float) -> TransferProgress: + """ + 计算传输进度 + + 实现传输进度回调 (需求 4.3) + + Args: + state: 传输状态 + start_time: 开始时间 + + Returns: + 传输进度信息 + """ + elapsed = time.time() - start_time + transferred = state.transferred_size + + # 计算速度 + speed = transferred / elapsed if elapsed > 0 else 0 + + # 计算预计剩余时间 + remaining = state.file_size - transferred + eta = remaining / speed if speed > 0 else 0 + + return TransferProgress( + file_id=state.file_id, + file_name=state.file_name, + total_size=state.file_size, + transferred_size=transferred, + speed=speed, + eta=eta + ) + + def _notify_progress(self, file_id: str, progress: TransferProgress) -> None: + """ + 通知进度回调 + + Args: + file_id: 文件ID + progress: 进度信息 + """ + if file_id in self._progress_callbacks: + try: + self._progress_callbacks[file_id](progress) + except Exception as e: + logger.error(f"Progress callback error: {e}") + + # ==================== 发送文件 (需求 4.2) ==================== + + async def send_file(self, peer_id: str, file_path: str, + progress_callback: Optional[ProgressCallback] = None) -> bool: + """ + 发送文件到指定对等端 + + 实现 send_file() 发送文件 (需求 4.2) + 将文件分割成多个 Chunk 进行传输 (需求 4.2) + + Args: + peer_id: 目标对等端ID + file_path: 文件路径 + progress_callback: 进度回调函数 + + Returns: + 发送成功返回True,否则返回False + + Raises: + FileNotFoundError: 文件不存在 + FileTransferError: 传输错误 + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + if not self._send_message: + raise FileTransferError("Send message function not set") + + # 获取文件信息 + file_name = os.path.basename(file_path) + file_size = os.path.getsize(file_path) + file_hash = self.calculate_file_hash(file_path) + total_chunks = self._get_total_chunks(file_size) + file_id = str(uuid.uuid4()) + + logger.info(f"Starting file transfer: {file_name} ({file_size} bytes, " + f"{total_chunks} chunks) to {peer_id}") + + # 创建传输状态 + state = TransferState( + file_id=file_id, + file_path=file_path, + file_name=file_name, + file_size=file_size, + file_hash=file_hash, + total_chunks=total_chunks, + status=TransferStatus.IN_PROGRESS, + receiver_id=peer_id + ) + + self._active_transfers[file_id] = state + + if progress_callback: + self._progress_callbacks[file_id] = progress_callback + + # 保存传输状态(用于断点续传) + self._save_transfer_state(state) + + start_time = time.time() + + try: + # 发送文件请求消息 + file_request_payload = json.dumps({ + "file_id": file_id, + "file_name": file_name, + "file_size": file_size, + "file_hash": file_hash, + "total_chunks": total_chunks + }).encode('utf-8') + + request_msg = Message( + msg_type=MessageType.FILE_REQUEST, + sender_id="", # 由ConnectionManager填充 + receiver_id=peer_id, + timestamp=time.time(), + payload=file_request_payload + ) + + await self._send_message(peer_id, request_msg) + + # 发送所有数据块 + for chunk_index in range(total_chunks): + # 检查是否取消 + if file_id in self._cancelled_transfers: + raise TransferCancelledError("Transfer cancelled") + + # 读取数据块 + chunk_data = self._read_chunk(file_path, chunk_index) + chunk_hash = self.calculate_chunk_hash(chunk_data) + + # 创建文件块消息 + chunk_payload = json.dumps({ + "file_id": file_id, + "chunk_index": chunk_index, + "total_chunks": total_chunks, + "checksum": chunk_hash, + "data": chunk_data.hex() + }).encode('utf-8') + + chunk_msg = Message( + msg_type=MessageType.FILE_CHUNK, + sender_id="", + receiver_id=peer_id, + timestamp=time.time(), + payload=chunk_payload + ) + + await self._send_message(peer_id, chunk_msg) + + # 更新状态 + state.completed_chunks.append(chunk_index) + state.last_update = datetime.now() + + # 通知进度 + progress = self._calculate_progress(state, start_time) + self._notify_progress(file_id, progress) + + # 定期保存状态 + if chunk_index % 10 == 0: + self._save_transfer_state(state) + + # 小延迟避免网络拥塞 + await asyncio.sleep(0.001) + + # 发送完成消息 + complete_payload = json.dumps({ + "file_id": file_id, + "file_name": file_name, + "file_hash": file_hash, + "total_chunks": total_chunks + }).encode('utf-8') + + complete_msg = Message( + msg_type=MessageType.FILE_COMPLETE, + sender_id="", + receiver_id=peer_id, + timestamp=time.time(), + payload=complete_payload + ) + + await self._send_message(peer_id, complete_msg) + + # 更新状态为完成 + state.status = TransferStatus.COMPLETED + self._save_transfer_state(state) + + # 最终进度通知 + progress = self._calculate_progress(state, start_time) + self._notify_progress(file_id, progress) + + logger.info(f"File transfer completed: {file_name}") + return True + + except TransferCancelledError: + state.status = TransferStatus.CANCELLED + self._save_transfer_state(state) + logger.info(f"File transfer cancelled: {file_name}") + return False + + except Exception as e: + state.status = TransferStatus.FAILED + self._save_transfer_state(state) + logger.error(f"File transfer failed: {e}") + raise FileTransferError(f"Transfer failed: {e}") + + finally: + # 清理 + if file_id in self._progress_callbacks: + del self._progress_callbacks[file_id] + + # ==================== 接收文件 (需求 4.2) ==================== + + async def receive_file(self, file_id: str, save_path: str, + progress_callback: Optional[ProgressCallback] = None) -> bool: + """ + 接收文件并保存 + + 实现 receive_file() 接收文件 (需求 4.2) + + Args: + file_id: 文件ID + save_path: 保存路径 + progress_callback: 进度回调函数 + + Returns: + 接收成功返回True,否则返回False + """ + if file_id not in self._active_transfers: + logger.error(f"Unknown file transfer: {file_id}") + return False + + state = self._active_transfers[file_id] + state.save_path = save_path + state.status = TransferStatus.IN_PROGRESS + + if progress_callback: + self._progress_callbacks[file_id] = progress_callback + + # 等待所有块接收完成 + # 实际的块接收在 handle_file_chunk 中处理 + logger.info(f"Receiving file: {state.file_name} to {save_path}") + + return True + + def handle_file_request(self, message: Message) -> Optional[str]: + """ + 处理文件请求消息 + + Args: + message: 文件请求消息 + + Returns: + 文件ID,如果处理失败返回None + """ + try: + payload = json.loads(message.payload.decode('utf-8')) + + file_id = payload["file_id"] + file_name = payload["file_name"] + file_size = payload["file_size"] + file_hash = payload["file_hash"] + total_chunks = payload["total_chunks"] + + # 创建接收状态 + state = TransferState( + file_id=file_id, + file_path="", # 接收方不知道原始路径 + file_name=file_name, + file_size=file_size, + file_hash=file_hash, + total_chunks=total_chunks, + status=TransferStatus.PENDING, + sender_id=message.sender_id + ) + + self._active_transfers[file_id] = state + self._receive_buffers[file_id] = {} + + logger.info(f"File request received: {file_name} ({file_size} bytes)") + return file_id + + except Exception as e: + logger.error(f"Failed to handle file request: {e}") + return None + + def handle_file_chunk(self, message: Message) -> bool: + """ + 处理文件块消息 + + Args: + message: 文件块消息 + + Returns: + 处理成功返回True,否则返回False + """ + try: + payload = json.loads(message.payload.decode('utf-8')) + + file_id = payload["file_id"] + chunk_index = payload["chunk_index"] + total_chunks = payload["total_chunks"] + checksum = payload["checksum"] + data = bytes.fromhex(payload["data"]) + + # 验证块校验和 + if self.calculate_chunk_hash(data) != checksum: + logger.error(f"Chunk checksum mismatch: {file_id}[{chunk_index}]") + return False + + # 存储块数据 + if file_id not in self._receive_buffers: + self._receive_buffers[file_id] = {} + + self._receive_buffers[file_id][chunk_index] = data + + # 更新状态 + if file_id in self._active_transfers: + state = self._active_transfers[file_id] + if chunk_index not in state.completed_chunks: + state.completed_chunks.append(chunk_index) + state.last_update = datetime.now() + + # 通知进度 + if file_id in self._progress_callbacks: + progress = TransferProgress( + file_id=file_id, + file_name=state.file_name, + total_size=state.file_size, + transferred_size=len(state.completed_chunks) * self.CHUNK_SIZE, + speed=0, # 简化处理 + eta=0 + ) + self._notify_progress(file_id, progress) + + logger.debug(f"Received chunk {chunk_index + 1}/{total_chunks} for {file_id}") + return True + + except Exception as e: + logger.error(f"Failed to handle file chunk: {e}") + return False + + def handle_file_complete(self, message: Message) -> bool: + """ + 处理文件完成消息 + + 实现传输完成后的校验逻辑 (需求 4.4) + + Args: + message: 文件完成消息 + + Returns: + 处理成功返回True,否则返回False + """ + try: + payload = json.loads(message.payload.decode('utf-8')) + + file_id = payload["file_id"] + file_hash = payload["file_hash"] + + if file_id not in self._active_transfers: + logger.error(f"Unknown file transfer: {file_id}") + return False + + state = self._active_transfers[file_id] + + # 检查是否收到所有块 + if len(state.completed_chunks) != state.total_chunks: + logger.error(f"Missing chunks: received {len(state.completed_chunks)}, " + f"expected {state.total_chunks}") + state.status = TransferStatus.FAILED + return False + + # 组装文件 + if state.save_path: + success = self._assemble_file(file_id, state.save_path) + + if success: + # 验证文件完整性 + if self.verify_file_integrity(state.save_path, file_hash): + state.status = TransferStatus.COMPLETED + logger.info(f"File received and verified: {state.file_name}") + else: + state.status = TransferStatus.FAILED + logger.error(f"File integrity check failed: {state.file_name}") + return False + else: + state.status = TransferStatus.FAILED + return False + else: + # 保存路径未设置,保持数据在缓冲区 + state.status = TransferStatus.COMPLETED + + # 清理缓冲区 + if file_id in self._receive_buffers: + del self._receive_buffers[file_id] + + # 保存最终状态 + self._save_transfer_state(state) + + return True + + except Exception as e: + logger.error(f"Failed to handle file complete: {e}") + return False + + def _assemble_file(self, file_id: str, save_path: str) -> bool: + """ + 组装接收到的文件块 + + Args: + file_id: 文件ID + save_path: 保存路径 + + Returns: + 组装成功返回True,否则返回False + """ + if file_id not in self._receive_buffers: + return False + + if file_id not in self._active_transfers: + return False + + state = self._active_transfers[file_id] + buffer = self._receive_buffers[file_id] + + try: + # 确保目录存在 + os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) + + # 按顺序写入所有块 + with open(save_path, 'wb') as f: + for i in range(state.total_chunks): + if i not in buffer: + logger.error(f"Missing chunk {i} for file {file_id}") + return False + f.write(buffer[i]) + + logger.info(f"File assembled: {save_path}") + return True + + except Exception as e: + logger.error(f"Failed to assemble file: {e}") + return False + + # ==================== 断点续传 (需求 4.5) ==================== + + def _save_transfer_state(self, state: TransferState) -> None: + """ + 保存传输状态到文件 + + 实现传输状态持久化 (需求 4.5) + + Args: + state: 传输状态 + """ + try: + state_file = self._state_dir / f"{state.file_id}.json" + with open(state_file, 'w', encoding='utf-8') as f: + json.dump(state.to_dict(), f, ensure_ascii=False, indent=2) + except Exception as e: + logger.error(f"Failed to save transfer state: {e}") + + def _load_transfer_states(self) -> None: + """ + 加载所有未完成的传输状态 + + 实现传输状态持久化 (需求 4.5) + """ + try: + for state_file in self._state_dir.glob("*.json"): + try: + with open(state_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + state = TransferState.from_dict(data) + + # 只加载未完成的传输 + if state.status in [TransferStatus.PENDING, + TransferStatus.IN_PROGRESS, + TransferStatus.PAUSED]: + self._active_transfers[state.file_id] = state + logger.info(f"Loaded transfer state: {state.file_name} " + f"({state.progress_percent:.1f}%)") + + except Exception as e: + logger.error(f"Failed to load state file {state_file}: {e}") + + except Exception as e: + logger.error(f"Failed to load transfer states: {e}") + + def _delete_transfer_state(self, file_id: str) -> None: + """ + 删除传输状态文件 + + Args: + file_id: 文件ID + """ + try: + state_file = self._state_dir / f"{file_id}.json" + if state_file.exists(): + state_file.unlink() + except Exception as e: + logger.error(f"Failed to delete transfer state: {e}") + + async def resume_transfer(self, file_id: str) -> bool: + """ + 恢复中断的传输 + + 实现 resume_transfer() 恢复传输 (需求 4.5) + 支持断点续传功能 (需求 4.5) + + Args: + file_id: 文件ID + + Returns: + 恢复成功返回True,否则返回False + """ + if file_id not in self._active_transfers: + logger.error(f"Transfer not found: {file_id}") + return False + + state = self._active_transfers[file_id] + + if state.status == TransferStatus.COMPLETED: + logger.info(f"Transfer already completed: {file_id}") + return True + + if state.status == TransferStatus.CANCELLED: + logger.error(f"Transfer was cancelled: {file_id}") + return False + + # 检查文件是否仍然存在(发送方) + if state.file_path and not os.path.exists(state.file_path): + logger.error(f"Source file no longer exists: {state.file_path}") + state.status = TransferStatus.FAILED + self._save_transfer_state(state) + return False + + if not self._send_message: + raise FileTransferError("Send message function not set") + + logger.info(f"Resuming transfer: {state.file_name} from chunk " + f"{len(state.completed_chunks)}/{state.total_chunks}") + + # 更新状态 + state.status = TransferStatus.IN_PROGRESS + state.last_update = datetime.now() + + # 从取消列表中移除 + self._cancelled_transfers.discard(file_id) + + start_time = time.time() + + try: + # 发送恢复请求(包含已完成的块列表) + resume_payload = json.dumps({ + "file_id": file_id, + "file_name": state.file_name, + "file_size": state.file_size, + "file_hash": state.file_hash, + "total_chunks": state.total_chunks, + "completed_chunks": state.completed_chunks, + "resume": True + }).encode('utf-8') + + resume_msg = Message( + msg_type=MessageType.FILE_REQUEST, + sender_id="", + receiver_id=state.receiver_id, + timestamp=time.time(), + payload=resume_payload + ) + + await self._send_message(state.receiver_id, resume_msg) + + # 只发送未完成的块 + completed_set = set(state.completed_chunks) + + for chunk_index in range(state.total_chunks): + if chunk_index in completed_set: + continue + + # 检查是否取消 + if file_id in self._cancelled_transfers: + raise TransferCancelledError("Transfer cancelled") + + # 读取并发送数据块 + chunk_data = self._read_chunk(state.file_path, chunk_index) + chunk_hash = self.calculate_chunk_hash(chunk_data) + + chunk_payload = json.dumps({ + "file_id": file_id, + "chunk_index": chunk_index, + "total_chunks": state.total_chunks, + "checksum": chunk_hash, + "data": chunk_data.hex() + }).encode('utf-8') + + chunk_msg = Message( + msg_type=MessageType.FILE_CHUNK, + sender_id="", + receiver_id=state.receiver_id, + timestamp=time.time(), + payload=chunk_payload + ) + + await self._send_message(state.receiver_id, chunk_msg) + + # 更新状态 + state.completed_chunks.append(chunk_index) + state.last_update = datetime.now() + + # 通知进度 + if file_id in self._progress_callbacks: + progress = self._calculate_progress(state, start_time) + self._notify_progress(file_id, progress) + + # 定期保存状态 + if chunk_index % 10 == 0: + self._save_transfer_state(state) + + await asyncio.sleep(0.001) + + # 发送完成消息 + complete_payload = json.dumps({ + "file_id": file_id, + "file_name": state.file_name, + "file_hash": state.file_hash, + "total_chunks": state.total_chunks + }).encode('utf-8') + + complete_msg = Message( + msg_type=MessageType.FILE_COMPLETE, + sender_id="", + receiver_id=state.receiver_id, + timestamp=time.time(), + payload=complete_payload + ) + + await self._send_message(state.receiver_id, complete_msg) + + state.status = TransferStatus.COMPLETED + self._save_transfer_state(state) + + logger.info(f"Transfer resumed and completed: {state.file_name}") + return True + + except TransferCancelledError: + state.status = TransferStatus.CANCELLED + self._save_transfer_state(state) + return False + + except Exception as e: + state.status = TransferStatus.FAILED + self._save_transfer_state(state) + logger.error(f"Resume transfer failed: {e}") + return False + + def cancel_transfer(self, file_id: str) -> None: + """ + 取消传输 + + 实现 cancel_transfer() 取消传输 (需求 4.5) + + Args: + file_id: 文件ID + """ + self._cancelled_transfers.add(file_id) + + if file_id in self._active_transfers: + state = self._active_transfers[file_id] + state.status = TransferStatus.CANCELLED + self._save_transfer_state(state) + logger.info(f"Transfer cancelled: {state.file_name}") + + # 清理缓冲区 + if file_id in self._receive_buffers: + del self._receive_buffers[file_id] + + # 清理进度回调 + if file_id in self._progress_callbacks: + del self._progress_callbacks[file_id] + + def pause_transfer(self, file_id: str) -> bool: + """ + 暂停传输 + + Args: + file_id: 文件ID + + Returns: + 暂停成功返回True,否则返回False + """ + if file_id not in self._active_transfers: + return False + + state = self._active_transfers[file_id] + + if state.status != TransferStatus.IN_PROGRESS: + return False + + state.status = TransferStatus.PAUSED + self._save_transfer_state(state) + + logger.info(f"Transfer paused: {state.file_name}") + return True + + # ==================== 传输状态查询 ==================== + + def get_transfer_progress(self, file_id: str) -> Optional[TransferProgress]: + """ + 获取传输进度 + + Args: + file_id: 文件ID + + Returns: + 传输进度信息,如果传输不存在返回None + """ + if file_id not in self._active_transfers: + return None + + state = self._active_transfers[file_id] + + return TransferProgress( + file_id=file_id, + file_name=state.file_name, + total_size=state.file_size, + transferred_size=state.transferred_size, + speed=0, # 需要实时计算 + eta=0 + ) + + def get_transfer_state(self, file_id: str) -> Optional[TransferState]: + """ + 获取传输状态 + + Args: + file_id: 文件ID + + Returns: + 传输状态,如果不存在返回None + """ + return self._active_transfers.get(file_id) + + def get_all_transfers(self) -> List[TransferState]: + """ + 获取所有传输状态 + + Returns: + 传输状态列表 + """ + return list(self._active_transfers.values()) + + def get_pending_transfers(self) -> List[TransferState]: + """ + 获取所有待恢复的传输 + + Returns: + 待恢复的传输状态列表 + """ + return [ + state for state in self._active_transfers.values() + if state.status in [TransferStatus.PENDING, + TransferStatus.IN_PROGRESS, + TransferStatus.PAUSED] + ] + + def get_transfer_record(self, file_id: str) -> Optional[FileTransferRecord]: + """ + 获取传输记录 + + Args: + file_id: 文件ID + + Returns: + 传输记录,如果不存在返回None + """ + if file_id not in self._active_transfers: + return None + + state = self._active_transfers[file_id] + + return FileTransferRecord( + transfer_id=state.file_id, + file_name=state.file_name, + file_size=state.file_size, + file_hash=state.file_hash, + sender_id=state.sender_id, + receiver_id=state.receiver_id, + status=state.status, + progress=state.progress_percent, + start_time=state.start_time, + end_time=datetime.now() if state.status == TransferStatus.COMPLETED else None + ) + + # ==================== 清理 ==================== + + def cleanup_completed_transfers(self, max_age_hours: int = 24) -> int: + """ + 清理已完成的传输记录 + + Args: + max_age_hours: 最大保留时间(小时) + + Returns: + 清理的记录数 + """ + cleaned = 0 + cutoff = datetime.now() + + for file_id in list(self._active_transfers.keys()): + state = self._active_transfers[file_id] + + if state.status in [TransferStatus.COMPLETED, + TransferStatus.CANCELLED, + TransferStatus.FAILED]: + age_hours = (cutoff - state.last_update).total_seconds() / 3600 + + if age_hours > max_age_hours: + del self._active_transfers[file_id] + self._delete_transfer_state(file_id) + cleaned += 1 + + if cleaned > 0: + logger.info(f"Cleaned {cleaned} old transfer records") + + return cleaned + + def clear_all_transfers(self) -> None: + """ + 清除所有传输记录 + """ + for file_id in list(self._active_transfers.keys()): + self._delete_transfer_state(file_id) + + self._active_transfers.clear() + self._receive_buffers.clear() + self._progress_callbacks.clear() + self._cancelled_transfers.clear() + + logger.info("All transfer records cleared") + + # ==================== 保存接收的文件 ==================== + + def save_received_file(self, file_id: str, save_path: str) -> bool: + """ + 保存接收到的文件 + + Args: + file_id: 文件ID + save_path: 保存路径 + + Returns: + 保存成功返回True,否则返回False + """ + if file_id not in self._active_transfers: + logger.error(f"Transfer not found: {file_id}") + return False + + state = self._active_transfers[file_id] + + if state.status != TransferStatus.COMPLETED: + logger.error(f"Transfer not completed: {file_id}") + return False + + # 如果数据还在缓冲区,组装文件 + if file_id in self._receive_buffers: + success = self._assemble_file(file_id, save_path) + + if success: + # 验证完整性 + if self.verify_file_integrity(save_path, state.file_hash): + state.save_path = save_path + self._save_transfer_state(state) + + # 清理缓冲区 + del self._receive_buffers[file_id] + + logger.info(f"File saved: {save_path}") + return True + else: + logger.error(f"File integrity check failed after save") + return False + else: + return False + + logger.error(f"No data in buffer for file: {file_id}") + return False diff --git a/tests/test_connection_manager.py b/tests/test_connection_manager.py new file mode 100644 index 0000000..678ddd7 --- /dev/null +++ b/tests/test_connection_manager.py @@ -0,0 +1,331 @@ +# P2P Network Communication - Connection Manager Tests +""" +测试客户端连接管理器的基本功能 + +需求: 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 +""" + +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch +from datetime import datetime + +from client.connection_manager import ( + ConnectionManager, + ConnectionState, + ConnectionError, + Connection, +) +from shared.models import ( + Message, MessageType, UserInfo, UserStatus, + ConnectionMode, PeerInfo +) +from config import ClientConfig + + +class TestConnectionManagerInit: + """测试连接管理器初始化""" + + def test_init_with_default_config(self): + """测试使用默认配置初始化""" + cm = ConnectionManager() + + assert cm.config is not None + assert cm.state == ConnectionState.DISCONNECTED + assert cm.is_connected is False + assert cm.user_id == "" + + def test_init_with_custom_config(self): + """测试使用自定义配置初始化""" + config = ClientConfig( + server_host="192.168.1.100", + server_port=9999, + heartbeat_interval=60 + ) + cm = ConnectionManager(config) + + assert cm.config.server_host == "192.168.1.100" + assert cm.config.server_port == 9999 + assert cm.config.heartbeat_interval == 60 + + +class TestConnectionState: + """测试连接状态管理""" + + def test_initial_state_is_disconnected(self): + """测试初始状态为断开""" + cm = ConnectionManager() + assert cm.state == ConnectionState.DISCONNECTED + + def test_state_callback_registration(self): + """测试状态回调注册""" + cm = ConnectionManager() + callback_called = [] + + def state_callback(state, reason): + callback_called.append((state, reason)) + + cm.add_state_callback(state_callback) + cm._set_state(ConnectionState.CONNECTING, "Test") + + assert len(callback_called) == 1 + assert callback_called[0][0] == ConnectionState.CONNECTING + assert callback_called[0][1] == "Test" + + def test_state_callback_removal(self): + """测试状态回调移除""" + cm = ConnectionManager() + callback_called = [] + + def state_callback(state, reason): + callback_called.append((state, reason)) + + cm.add_state_callback(state_callback) + cm.remove_state_callback(state_callback) + cm._set_state(ConnectionState.CONNECTING, "Test") + + assert len(callback_called) == 0 + + +class TestConnectionMode: + """测试连接模式选择 (需求 1.1, 1.2, 1.3)""" + + def test_default_mode_is_relay(self): + """测试默认模式为中转""" + cm = ConnectionManager() + mode = cm.get_connection_mode("unknown_peer") + + assert mode == ConnectionMode.RELAY + + def test_mode_is_p2p_for_discovered_peer(self): + """测试已发现的LAN对等端使用P2P模式""" + cm = ConnectionManager() + + # 模拟发现对等端 + peer_info = PeerInfo( + peer_id="peer1", + username="test_peer", + ip_address="192.168.1.100", + port=8889 + ) + cm._discovered_peers["peer1"] = peer_info + + mode = cm.get_connection_mode("peer1") + assert mode == ConnectionMode.P2P + + def test_mode_is_p2p_for_connected_peer(self): + """测试已连接的P2P对等端返回P2P模式""" + cm = ConnectionManager() + + # 模拟P2P连接 + conn = Connection( + peer_id="peer1", + reader=MagicMock(), + writer=MagicMock(), + mode=ConnectionMode.P2P, + ip_address="192.168.1.100", + port=8889 + ) + cm._peer_connections["peer1"] = conn + + mode = cm.get_connection_mode("peer1") + assert mode == ConnectionMode.P2P + + +class TestReconnectMechanism: + """测试网络重连机制 (需求 1.6)""" + + def test_reconnect_status_initial(self): + """测试初始重连状态""" + cm = ConnectionManager() + status = cm.get_reconnect_status() + + assert status["is_reconnecting"] is False + assert status["reconnect_attempts"] == 0 + assert status["auto_reconnect_enabled"] is True + + def test_enable_disable_reconnect(self): + """测试启用/禁用自动重连""" + cm = ConnectionManager() + + cm.enable_reconnect(False) + assert cm._should_reconnect is False + + cm.enable_reconnect(True) + assert cm._should_reconnect is True + + def test_set_reconnect_config(self): + """测试配置重连参数""" + cm = ConnectionManager() + + cm.set_reconnect_config(max_attempts=5, base_delay=2.0) + + assert cm.config.reconnect_attempts == 5 + assert cm.config.reconnect_delay == 2.0 + + +class TestMessageCallbacks: + """测试消息回调""" + + def test_message_callback_registration(self): + """测试消息回调注册""" + cm = ConnectionManager() + messages_received = [] + + def msg_callback(msg): + messages_received.append(msg) + + cm.add_message_callback(msg_callback) + + # 模拟接收消息 + test_msg = Message( + msg_type=MessageType.TEXT, + sender_id="user1", + receiver_id="user2", + timestamp=1234567890.0, + payload=b"Test" + ) + + # 直接调用回调 + for callback in cm._message_callbacks: + callback(test_msg) + + assert len(messages_received) == 1 + assert messages_received[0].payload == b"Test" + + def test_message_callback_removal(self): + """测试消息回调移除""" + cm = ConnectionManager() + messages_received = [] + + def msg_callback(msg): + messages_received.append(msg) + + cm.add_message_callback(msg_callback) + cm.remove_message_callback(msg_callback) + + assert len(cm._message_callbacks) == 0 + + +class TestConnectionStats: + """测试连接统计""" + + def test_connection_stats_initial(self): + """测试初始连接统计""" + cm = ConnectionManager() + stats = cm.get_connection_stats() + + assert stats["server_connected"] is False + assert stats["state"] == "disconnected" + assert stats["total_peer_connections"] == 0 + assert stats["p2p_connections"] == 0 + assert stats["discovered_peers"] == 0 + + def test_connection_stats_with_peers(self): + """测试有对等端时的连接统计""" + cm = ConnectionManager() + + # 添加发现的对等端 + cm._discovered_peers["peer1"] = PeerInfo( + peer_id="peer1", + username="test", + ip_address="192.168.1.100", + port=8889 + ) + + # 添加P2P连接 + cm._peer_connections["peer2"] = Connection( + peer_id="peer2", + reader=MagicMock(), + writer=MagicMock(), + mode=ConnectionMode.P2P, + ip_address="192.168.1.101", + port=8889 + ) + + stats = cm.get_connection_stats() + + assert stats["total_peer_connections"] == 1 + assert stats["p2p_connections"] == 1 + assert stats["discovered_peers"] == 1 + + +class TestPeerConnectionInfo: + """测试对等端连接信息""" + + def test_get_peer_info_for_connected_peer(self): + """测试获取已连接对等端信息""" + cm = ConnectionManager() + + conn = Connection( + peer_id="peer1", + reader=MagicMock(), + writer=MagicMock(), + mode=ConnectionMode.P2P, + ip_address="192.168.1.100", + port=8889 + ) + cm._peer_connections["peer1"] = conn + + info = cm.get_peer_connection_info("peer1") + + assert info is not None + assert info["peer_id"] == "peer1" + assert info["mode"] == "p2p" + assert info["ip_address"] == "192.168.1.100" + + def test_get_peer_info_for_discovered_peer(self): + """测试获取已发现对等端信息""" + cm = ConnectionManager() + + cm._discovered_peers["peer1"] = PeerInfo( + peer_id="peer1", + username="test", + ip_address="192.168.1.100", + port=8889 + ) + + info = cm.get_peer_connection_info("peer1") + + assert info is not None + assert info["peer_id"] == "peer1" + assert info["mode"] == "discovered" + + def test_get_peer_info_for_unknown_peer(self): + """测试获取未知对等端信息""" + cm = ConnectionManager() + + info = cm.get_peer_connection_info("unknown") + + assert info is None + + +class TestAllConnectionModes: + """测试获取所有连接模式""" + + def test_get_all_connection_modes(self): + """测试获取所有对等端的连接模式""" + cm = ConnectionManager() + + # 添加P2P连接 + cm._peer_connections["peer1"] = Connection( + peer_id="peer1", + reader=MagicMock(), + writer=MagicMock(), + mode=ConnectionMode.P2P, + ip_address="192.168.1.100", + port=8889 + ) + + # 添加发现的对等端 + cm._discovered_peers["peer2"] = PeerInfo( + peer_id="peer2", + username="test", + ip_address="192.168.1.101", + port=8889 + ) + + modes = cm.get_all_connection_modes() + + assert modes["peer1"] == ConnectionMode.P2P + assert modes["peer2"] == ConnectionMode.P2P diff --git a/tests/test_file_transfer.py b/tests/test_file_transfer.py new file mode 100644 index 0000000..e6c6e1a --- /dev/null +++ b/tests/test_file_transfer.py @@ -0,0 +1,580 @@ +# P2P Network Communication - File Transfer Module Tests +""" +文件传输模块测试 +测试文件分块、传输、断点续传和完整性校验功能 +""" + +import asyncio +import os +import tempfile +import pytest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +from client.file_transfer import ( + FileTransferModule, + FileTransferError, + FileNotFoundError, + FileIntegrityError, + TransferCancelledError, + TransferState, +) +from shared.models import ( + Message, MessageType, FileChunk, TransferProgress, + TransferStatus +) +from config import ClientConfig + + +class TestFileTransferModuleInit: + """文件传输模块初始化测试""" + + def test_init_with_default_config(self): + """测试使用默认配置初始化""" + module = FileTransferModule() + + assert module.config is not None + assert module.CHUNK_SIZE == 64 * 1024 + assert len(module._active_transfers) == 0 + + def test_init_with_custom_config(self): + """测试使用自定义配置初始化""" + config = ClientConfig(chunk_size=32 * 1024) + module = FileTransferModule(config=config) + + assert module.config.chunk_size == 32 * 1024 + + def test_set_send_message_func(self): + """测试设置发送消息函数""" + module = FileTransferModule() + mock_func = AsyncMock() + + module.set_send_message_func(mock_func) + + assert module._send_message == mock_func + + +class TestFileHashCalculation: + """文件哈希计算测试""" + + def test_calculate_file_hash_sha256(self): + """测试SHA256哈希计算""" + module = FileTransferModule() + + # 创建临时文件 + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(b"Hello, World!") + temp_path = f.name + + try: + hash_value = module.calculate_file_hash(temp_path, "sha256") + + # SHA256 of "Hello, World!" + expected = "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f" + assert hash_value == expected + finally: + os.unlink(temp_path) + + def test_calculate_file_hash_md5(self): + """测试MD5哈希计算""" + module = FileTransferModule() + + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(b"Hello, World!") + temp_path = f.name + + try: + hash_value = module.calculate_file_hash(temp_path, "md5") + + # MD5 of "Hello, World!" + expected = "65a8e27d8879283831b664bd8b7f0ad4" + assert hash_value == expected + finally: + os.unlink(temp_path) + + def test_calculate_file_hash_nonexistent_file(self): + """测试计算不存在文件的哈希""" + module = FileTransferModule() + + with pytest.raises(FileNotFoundError): + module.calculate_file_hash("/nonexistent/file.txt") + + def test_calculate_chunk_hash(self): + """测试数据块哈希计算""" + module = FileTransferModule() + + data = b"Test chunk data" + hash_value = module.calculate_chunk_hash(data) + + import hashlib + expected = hashlib.md5(data).hexdigest() + assert hash_value == expected + + def test_verify_file_integrity_success(self): + """测试文件完整性验证成功""" + module = FileTransferModule() + + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(b"Test content") + temp_path = f.name + + try: + expected_hash = module.calculate_file_hash(temp_path) + result = module.verify_file_integrity(temp_path, expected_hash) + + assert result is True + finally: + os.unlink(temp_path) + + def test_verify_file_integrity_failure(self): + """测试文件完整性验证失败""" + module = FileTransferModule() + + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(b"Test content") + temp_path = f.name + + try: + result = module.verify_file_integrity(temp_path, "wrong_hash") + + assert result is False + finally: + os.unlink(temp_path) + + +class TestFileChunking: + """文件分块测试""" + + def test_get_total_chunks_small_file(self): + """测试小文件的块数计算""" + module = FileTransferModule() + + # 小于一个块 + assert module._get_total_chunks(1000) == 1 + + # 正好一个块 + assert module._get_total_chunks(64 * 1024) == 1 + + def test_get_total_chunks_large_file(self): + """测试大文件的块数计算""" + module = FileTransferModule() + + # 两个块 + assert module._get_total_chunks(64 * 1024 + 1) == 2 + + # 多个块 + assert module._get_total_chunks(256 * 1024) == 4 + + def test_read_chunk(self): + """测试读取文件块""" + module = FileTransferModule() + + # 创建测试文件 + test_data = b"A" * 100 + b"B" * 100 + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(test_data) + temp_path = f.name + + try: + chunk = module._read_chunk(temp_path, 0) + assert chunk == test_data + finally: + os.unlink(temp_path) + + def test_write_chunk(self): + """测试写入文件块""" + module = FileTransferModule() + + with tempfile.NamedTemporaryFile(delete=False) as f: + temp_path = f.name + + try: + # 写入第一个块 + module._write_chunk(temp_path, 0, b"First chunk") + + with open(temp_path, 'rb') as f: + content = f.read() + + assert content == b"First chunk" + finally: + os.unlink(temp_path) + + def test_split_file_to_chunks(self): + """测试文件分块""" + module = FileTransferModule() + + # 创建测试文件(大于一个块) + test_data = b"X" * (64 * 1024 + 100) + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(test_data) + temp_path = f.name + + try: + chunks = module.split_file_to_chunks(temp_path) + + assert len(chunks) == 2 + assert chunks[0].chunk_index == 0 + assert chunks[1].chunk_index == 1 + assert chunks[0].total_chunks == 2 + assert len(chunks[0].data) == 64 * 1024 + assert len(chunks[1].data) == 100 + finally: + os.unlink(temp_path) + + def test_split_nonexistent_file(self): + """测试分块不存在的文件""" + module = FileTransferModule() + + with pytest.raises(FileNotFoundError): + module.split_file_to_chunks("/nonexistent/file.txt") + + + +class TestTransferState: + """传输状态测试""" + + def test_transfer_state_creation(self): + """测试传输状态创建""" + state = TransferState( + file_id="test-id", + file_path="/path/to/file.txt", + file_name="file.txt", + file_size=1000, + file_hash="abc123", + total_chunks=2 + ) + + assert state.file_id == "test-id" + assert state.file_name == "file.txt" + assert state.status == TransferStatus.PENDING + assert len(state.completed_chunks) == 0 + + def test_transfer_state_progress_percent(self): + """测试进度百分比计算""" + state = TransferState( + file_id="test-id", + file_path="/path/to/file.txt", + file_name="file.txt", + file_size=1000, + file_hash="abc123", + total_chunks=4, + completed_chunks=[0, 1] + ) + + assert state.progress_percent == 50.0 + + def test_transfer_state_to_dict_and_from_dict(self): + """测试状态序列化和反序列化""" + state = TransferState( + file_id="test-id", + file_path="/path/to/file.txt", + file_name="file.txt", + file_size=1000, + file_hash="abc123", + total_chunks=2, + completed_chunks=[0], + status=TransferStatus.IN_PROGRESS + ) + + state_dict = state.to_dict() + restored = TransferState.from_dict(state_dict) + + assert restored.file_id == state.file_id + assert restored.file_name == state.file_name + assert restored.status == state.status + assert restored.completed_chunks == state.completed_chunks + + +class TestTransferManagement: + """传输管理测试""" + + def setup_method(self): + """每个测试前清理状态""" + # 清理可能存在的状态文件 + import shutil + state_dir = Path("data/transfer_states") + if state_dir.exists(): + shutil.rmtree(state_dir) + + def test_cancel_transfer(self): + """测试取消传输""" + module = FileTransferModule() + + # 创建一个传输状态 + state = TransferState( + file_id="test-id", + file_path="/path/to/file.txt", + file_name="file.txt", + file_size=1000, + file_hash="abc123", + total_chunks=2 + ) + module._active_transfers["test-id"] = state + + module.cancel_transfer("test-id") + + assert "test-id" in module._cancelled_transfers + assert state.status == TransferStatus.CANCELLED + + def test_pause_transfer(self): + """测试暂停传输""" + module = FileTransferModule() + + state = TransferState( + file_id="test-id", + file_path="/path/to/file.txt", + file_name="file.txt", + file_size=1000, + file_hash="abc123", + total_chunks=2, + status=TransferStatus.IN_PROGRESS + ) + module._active_transfers["test-id"] = state + + result = module.pause_transfer("test-id") + + assert result is True + assert state.status == TransferStatus.PAUSED + + def test_pause_transfer_not_in_progress(self): + """测试暂停非进行中的传输""" + module = FileTransferModule() + + state = TransferState( + file_id="test-id", + file_path="/path/to/file.txt", + file_name="file.txt", + file_size=1000, + file_hash="abc123", + total_chunks=2, + status=TransferStatus.COMPLETED + ) + module._active_transfers["test-id"] = state + + result = module.pause_transfer("test-id") + + assert result is False + + def test_get_transfer_progress(self): + """测试获取传输进度""" + module = FileTransferModule() + + state = TransferState( + file_id="test-id", + file_path="/path/to/file.txt", + file_name="file.txt", + file_size=128 * 1024, # 2 chunks + file_hash="abc123", + total_chunks=2, + completed_chunks=[0] + ) + module._active_transfers["test-id"] = state + + progress = module.get_transfer_progress("test-id") + + assert progress is not None + assert progress.file_id == "test-id" + assert progress.file_name == "file.txt" + assert progress.total_size == 128 * 1024 + + def test_get_transfer_progress_nonexistent(self): + """测试获取不存在的传输进度""" + module = FileTransferModule() + + progress = module.get_transfer_progress("nonexistent") + + assert progress is None + + def test_get_all_transfers(self): + """测试获取所有传输""" + module = FileTransferModule() + + state1 = TransferState( + file_id="id1", + file_path="/path/1.txt", + file_name="1.txt", + file_size=1000, + file_hash="hash1", + total_chunks=1 + ) + state2 = TransferState( + file_id="id2", + file_path="/path/2.txt", + file_name="2.txt", + file_size=2000, + file_hash="hash2", + total_chunks=1 + ) + + module._active_transfers["id1"] = state1 + module._active_transfers["id2"] = state2 + + transfers = module.get_all_transfers() + + assert len(transfers) == 2 + + def test_get_pending_transfers(self): + """测试获取待恢复的传输""" + module = FileTransferModule() + + state1 = TransferState( + file_id="id1", + file_path="/path/1.txt", + file_name="1.txt", + file_size=1000, + file_hash="hash1", + total_chunks=1, + status=TransferStatus.PAUSED + ) + state2 = TransferState( + file_id="id2", + file_path="/path/2.txt", + file_name="2.txt", + file_size=2000, + file_hash="hash2", + total_chunks=1, + status=TransferStatus.COMPLETED + ) + + module._active_transfers["id1"] = state1 + module._active_transfers["id2"] = state2 + + pending = module.get_pending_transfers() + + assert len(pending) == 1 + assert pending[0].file_id == "id1" + + +class TestMessageHandling: + """消息处理测试""" + + def test_handle_file_request(self): + """测试处理文件请求消息""" + module = FileTransferModule() + + import json + payload = json.dumps({ + "file_id": "test-file-id", + "file_name": "test.txt", + "file_size": 1000, + "file_hash": "abc123", + "total_chunks": 2 + }).encode('utf-8') + + message = Message( + msg_type=MessageType.FILE_REQUEST, + sender_id="sender", + receiver_id="receiver", + timestamp=1234567890.0, + payload=payload + ) + + file_id = module.handle_file_request(message) + + assert file_id == "test-file-id" + assert "test-file-id" in module._active_transfers + assert "test-file-id" in module._receive_buffers + + def test_handle_file_chunk(self): + """测试处理文件块消息""" + module = FileTransferModule() + + # 先创建传输状态 + state = TransferState( + file_id="test-file-id", + file_path="", + file_name="test.txt", + file_size=1000, + file_hash="abc123", + total_chunks=2 + ) + module._active_transfers["test-file-id"] = state + module._receive_buffers["test-file-id"] = {} + + import json + chunk_data = b"Test chunk data" + payload = json.dumps({ + "file_id": "test-file-id", + "chunk_index": 0, + "total_chunks": 2, + "checksum": module.calculate_chunk_hash(chunk_data), + "data": chunk_data.hex() + }).encode('utf-8') + + message = Message( + msg_type=MessageType.FILE_CHUNK, + sender_id="sender", + receiver_id="receiver", + timestamp=1234567890.0, + payload=payload + ) + + result = module.handle_file_chunk(message) + + assert result is True + assert 0 in module._receive_buffers["test-file-id"] + assert module._receive_buffers["test-file-id"][0] == chunk_data + assert 0 in state.completed_chunks + + def test_handle_file_chunk_checksum_mismatch(self): + """测试处理校验和不匹配的文件块""" + module = FileTransferModule() + + state = TransferState( + file_id="test-file-id", + file_path="", + file_name="test.txt", + file_size=1000, + file_hash="abc123", + total_chunks=2 + ) + module._active_transfers["test-file-id"] = state + module._receive_buffers["test-file-id"] = {} + + import json + chunk_data = b"Test chunk data" + payload = json.dumps({ + "file_id": "test-file-id", + "chunk_index": 0, + "total_chunks": 2, + "checksum": "wrong_checksum", + "data": chunk_data.hex() + }).encode('utf-8') + + message = Message( + msg_type=MessageType.FILE_CHUNK, + sender_id="sender", + receiver_id="receiver", + timestamp=1234567890.0, + payload=payload + ) + + result = module.handle_file_chunk(message) + + assert result is False + + +class TestCleanup: + """清理功能测试""" + + def test_clear_all_transfers(self): + """测试清除所有传输""" + module = FileTransferModule() + + state = TransferState( + file_id="test-id", + file_path="/path/to/file.txt", + file_name="file.txt", + file_size=1000, + file_hash="abc123", + total_chunks=2 + ) + module._active_transfers["test-id"] = state + module._receive_buffers["test-id"] = {} + module._cancelled_transfers.add("test-id") + + module.clear_all_transfers() + + assert len(module._active_transfers) == 0 + assert len(module._receive_buffers) == 0 + assert len(module._cancelled_transfers) == 0