diff --git a/server/__init__.py b/server/__init__.py index 3ff8bde..45919ca 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -5,3 +5,11 @@ """ __version__ = "0.1.0" + +from server.relay_server import ( + RelayServer, + ClientConnection, + ServerError, + ConnectionError, + UserRegistrationError, +) diff --git a/server/database.py b/server/database.py new file mode 100644 index 0000000..b840083 --- /dev/null +++ b/server/database.py @@ -0,0 +1,285 @@ +# P2P Network Communication - Database Module +""" +数据库模块,提供MySQL数据库连接池管理和表初始化功能 +使用aiomysql异步驱动实现高性能数据库操作 +""" + +import asyncio +import logging +from contextlib import asynccontextmanager +from typing import Optional, AsyncGenerator + +import aiomysql + +from config import ServerConfig, load_server_config + +logger = logging.getLogger(__name__) + + +class DatabaseManager: + """数据库管理器,负责连接池管理和数据库初始化""" + + def __init__(self, config: Optional[ServerConfig] = None): + """ + 初始化数据库管理器 + + Args: + config: 服务器配置,如果为None则从环境变量加载 + """ + self.config = config or load_server_config() + self._pool: Optional[aiomysql.Pool] = None + self._initialized = False + + async def initialize(self) -> None: + """初始化数据库连接池和表结构""" + if self._initialized: + return + + await self._create_pool() + await self._create_tables() + self._initialized = True + logger.info("Database initialized successfully") + + async def _create_pool(self) -> None: + """创建数据库连接池""" + try: + self._pool = await aiomysql.create_pool( + host=self.config.db_host, + port=self.config.db_port, + user=self.config.db_user, + password=self.config.db_password, + db=self.config.db_name, + minsize=1, + maxsize=self.config.db_pool_size, + charset='utf8mb4', + autocommit=True, + echo=False, + ) + logger.info(f"Database connection pool created: {self.config.db_host}:{self.config.db_port}") + except Exception as e: + logger.error(f"Failed to create database pool: {e}") + raise + + async def _create_tables(self) -> None: + """创建数据库表结构""" + create_users_table = """ + CREATE TABLE IF NOT EXISTS users ( + user_id VARCHAR(64) PRIMARY KEY, + username VARCHAR(100) UNIQUE NOT NULL, + display_name VARCHAR(200) NOT NULL, + password_hash VARCHAR(255) NOT NULL DEFAULT '', + public_key BLOB, + status VARCHAR(20) DEFAULT 'offline', + ip_address VARCHAR(45) DEFAULT '', + port INT DEFAULT 0, + last_seen TIMESTAMP NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + INDEX idx_username (username), + INDEX idx_status (status) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + """ + + create_messages_table = """ + CREATE TABLE IF NOT EXISTS messages ( + message_id VARCHAR(64) PRIMARY KEY, + sender_id VARCHAR(64) NOT NULL, + receiver_id VARCHAR(64) NOT NULL, + content_type VARCHAR(50) NOT NULL, + content TEXT, + timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + is_read BOOLEAN DEFAULT FALSE, + is_delivered BOOLEAN DEFAULT FALSE, + INDEX idx_sender (sender_id), + INDEX idx_receiver (receiver_id), + INDEX idx_timestamp (timestamp), + INDEX idx_conversation (sender_id, receiver_id, timestamp) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + """ + + create_file_transfers_table = """ + CREATE TABLE IF NOT EXISTS file_transfers ( + transfer_id VARCHAR(64) PRIMARY KEY, + file_name VARCHAR(500) NOT NULL, + file_size BIGINT NOT NULL, + file_hash VARCHAR(128) NOT NULL, + sender_id VARCHAR(64) NOT NULL, + receiver_id VARCHAR(64) NOT NULL, + status VARCHAR(20) NOT NULL DEFAULT 'pending', + progress DECIMAL(5,2) DEFAULT 0, + start_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + end_time TIMESTAMP NULL, + INDEX idx_sender (sender_id), + INDEX idx_receiver (receiver_id), + INDEX idx_status (status) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + """ + + create_offline_messages_table = """ + CREATE TABLE IF NOT EXISTS offline_messages ( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + user_id VARCHAR(64) NOT NULL, + message_data BLOB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + INDEX idx_user_id (user_id), + INDEX idx_created_at (created_at) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + """ + + tables = [ + ("users", create_users_table), + ("messages", create_messages_table), + ("file_transfers", create_file_transfers_table), + ("offline_messages", create_offline_messages_table), + ] + + async with self.acquire() as conn: + async with conn.cursor() as cursor: + for table_name, create_sql in tables: + try: + await cursor.execute(create_sql) + logger.info(f"Table '{table_name}' created or already exists") + except Exception as e: + logger.error(f"Failed to create table '{table_name}': {e}") + raise + + @asynccontextmanager + async def acquire(self) -> AsyncGenerator[aiomysql.Connection, None]: + """ + 获取数据库连接的上下文管理器 + + Yields: + 数据库连接对象 + """ + if not self._pool: + raise RuntimeError("Database pool not initialized. Call initialize() first.") + + conn = await self._pool.acquire() + try: + yield conn + finally: + self._pool.release(conn) + + async def execute(self, query: str, args: tuple = ()) -> int: + """ + 执行SQL语句(INSERT, UPDATE, DELETE) + + Args: + query: SQL语句 + args: 参数元组 + + Returns: + 受影响的行数 + """ + async with self.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute(query, args) + return cursor.rowcount + + async def fetch_one(self, query: str, args: tuple = ()) -> Optional[tuple]: + """ + 查询单条记录 + + Args: + query: SQL语句 + args: 参数元组 + + Returns: + 查询结果元组,如果没有结果返回None + """ + async with self.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute(query, args) + return await cursor.fetchone() + + async def fetch_all(self, query: str, args: tuple = ()) -> list: + """ + 查询多条记录 + + Args: + query: SQL语句 + args: 参数元组 + + Returns: + 查询结果列表 + """ + async with self.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute(query, args) + return await cursor.fetchall() + + async def fetch_dict(self, query: str, args: tuple = ()) -> Optional[dict]: + """ + 查询单条记录并返回字典 + + Args: + query: SQL语句 + args: 参数元组 + + Returns: + 查询结果字典,如果没有结果返回None + """ + async with self.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cursor: + await cursor.execute(query, args) + return await cursor.fetchone() + + async def fetch_all_dict(self, query: str, args: tuple = ()) -> list: + """ + 查询多条记录并返回字典列表 + + Args: + query: SQL语句 + args: 参数元组 + + Returns: + 查询结果字典列表 + """ + async with self.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cursor: + await cursor.execute(query, args) + return await cursor.fetchall() + + async def close(self) -> None: + """关闭数据库连接池""" + if self._pool: + self._pool.close() + await self._pool.wait_closed() + self._pool = None + self._initialized = False + logger.info("Database connection pool closed") + + @property + def is_initialized(self) -> bool: + """检查数据库是否已初始化""" + return self._initialized + + @property + def pool(self) -> Optional[aiomysql.Pool]: + """获取连接池对象""" + return self._pool + + +# 全局数据库管理器实例 +_db_manager: Optional[DatabaseManager] = None + + +async def get_database() -> DatabaseManager: + """ + 获取全局数据库管理器实例 + + Returns: + DatabaseManager实例 + """ + global _db_manager + if _db_manager is None: + _db_manager = DatabaseManager() + await _db_manager.initialize() + return _db_manager + + +async def close_database() -> None: + """关闭全局数据库连接""" + global _db_manager + if _db_manager: + await _db_manager.close() + _db_manager = None diff --git a/server/relay_server.py b/server/relay_server.py new file mode 100644 index 0000000..a4a281d --- /dev/null +++ b/server/relay_server.py @@ -0,0 +1,851 @@ +# P2P Network Communication - Relay Server +""" +中转服务器模块 +负责管理客户端连接、用户注册、消息转发和离线消息缓存 + +需求: 8.1, 8.2, 8.4, 8.5, 1.7, 2.2, 2.3, 2.5 +""" + +import asyncio +import logging +import time +import json +from asyncio import StreamReader, StreamWriter +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, List, Optional, Tuple, Set +from collections import deque + +from shared.models import ( + Message, MessageType, UserInfo, UserStatus +) +from shared.message_handler import ( + MessageHandler, MessageValidationError, MessageSerializationError +) +from config import ServerConfig + + +# 设置日志 +logger = logging.getLogger(__name__) + + +class ServerError(Exception): + """服务器错误基类""" + pass + + +class ConnectionError(ServerError): + """连接错误""" + pass + + +class UserRegistrationError(ServerError): + """用户注册错误""" + pass + + +@dataclass +class ClientConnection: + """客户端连接信息""" + user_id: str + reader: StreamReader + writer: StreamWriter + connected_at: datetime = field(default_factory=datetime.now) + last_heartbeat: float = field(default_factory=time.time) + ip_address: str = "" + port: int = 0 + + @property + def is_alive(self) -> bool: + """检查连接是否存活(基于心跳)""" + return time.time() - self.last_heartbeat < 60 # 60秒超时 + + +class RelayServer: + """ + 中转服务器 + + 负责: + - 监听指定端口等待客户端连接 (需求 8.1) + - 验证客户端身份并分配会话 (需求 8.2) + - 维护所有活跃连接的状态 (需求 8.3) + - 将消息路由到目标客户端 (需求 8.4) + - 缓存离线消息并在客户端上线后投递 (需求 8.5) + - 记录客户端信息并维护在线用户列表 (需求 1.7) + - 用户注册和注销 (需求 2.2, 2.5) + - 获取在线用户列表 (需求 2.3) + """ + + def __init__(self, config: Optional[ServerConfig] = None): + """ + 初始化服务器 + + Args: + config: 服务器配置,如果为None则使用默认配置 + """ + self.config = config or ServerConfig() + self.host = self.config.host + self.port = self.config.port + self.max_connections = self.config.max_connections + + # 连接管理 + self._connections: Dict[str, ClientConnection] = {} # user_id -> connection + self._users: Dict[str, UserInfo] = {} # user_id -> user_info + + # 离线消息缓存 (user_id -> list of messages) + self._offline_messages: Dict[str, deque] = {} + self._max_offline_messages = 1000 # 每用户最大离线消息数 + + # 服务器状态 + self._server: Optional[asyncio.AbstractServer] = None + self._running = False + self._lock = asyncio.Lock() + + # 消息处理器 + self._message_handler = MessageHandler() + self._setup_message_handlers() + + # 心跳检查任务 + self._heartbeat_task: Optional[asyncio.Task] = None + + logger.info(f"RelayServer initialized: {self.host}:{self.port}") + + def _setup_message_handlers(self) -> None: + """设置消息处理器""" + self._message_handler.register_handler( + MessageType.USER_REGISTER, self._handle_user_register + ) + self._message_handler.register_handler( + MessageType.USER_UNREGISTER, self._handle_user_unregister + ) + self._message_handler.register_handler( + MessageType.USER_LIST_REQUEST, self._handle_user_list_request + ) + self._message_handler.register_handler( + MessageType.HEARTBEAT, self._handle_heartbeat + ) + # 设置默认处理器用于转发其他类型的消息 + self._message_handler.set_default_handler(self._handle_relay_message) + + async def start(self) -> None: + """ + 启动服务器 + + 监听指定端口等待客户端连接 (需求 8.1) + + Raises: + ServerError: 服务器启动失败时抛出 + """ + if self._running: + logger.warning("Server is already running") + return + + try: + self._server = await asyncio.start_server( + self._handle_client, + self.host, + self.port + ) + self._running = True + + # 启动心跳检查任务 + self._heartbeat_task = asyncio.create_task(self._heartbeat_checker()) + + addr = self._server.sockets[0].getsockname() + logger.info(f"Server started on {addr[0]}:{addr[1]}") + + async with self._server: + await self._server.serve_forever() + + except OSError as e: + raise ServerError(f"Failed to start server: {e}") + except Exception as e: + logger.error(f"Server error: {e}") + raise ServerError(f"Server error: {e}") + + async def stop(self) -> None: + """ + 停止服务器 + + 关闭所有连接并释放资源 + """ + if not self._running: + logger.warning("Server is not running") + return + + self._running = False + + # 取消心跳检查任务 + if self._heartbeat_task: + self._heartbeat_task.cancel() + try: + await self._heartbeat_task + except asyncio.CancelledError: + pass + + # 关闭所有客户端连接 + async with self._lock: + for user_id, conn in list(self._connections.items()): + try: + conn.writer.close() + await conn.writer.wait_closed() + except Exception as e: + logger.error(f"Error closing connection for {user_id}: {e}") + self._connections.clear() + + # 关闭服务器 + if self._server: + self._server.close() + await self._server.wait_closed() + self._server = None + + logger.info("Server stopped") + + async def _handle_client(self, reader: StreamReader, writer: StreamWriter) -> None: + """ + 处理客户端连接 + + 验证客户端身份并分配会话 (需求 8.2) + + Args: + reader: 异步读取器 + writer: 异步写入器 + """ + addr = writer.get_extra_info('peername') + ip_address = addr[0] if addr else "unknown" + port = addr[1] if addr else 0 + + logger.info(f"New connection from {ip_address}:{port}") + + # 检查连接数限制 + if len(self._connections) >= self.max_connections: + logger.warning(f"Max connections reached, rejecting {ip_address}:{port}") + error_msg = self._create_error_message( + "server", "unknown", "Server is at maximum capacity" + ) + try: + await self._send_message(writer, error_msg) + except Exception: + pass + writer.close() + await writer.wait_closed() + return + + user_id: Optional[str] = None + + try: + while self._running: + # 读取消息 + message = await self._read_message(reader) + if message is None: + break + + # 如果是注册消息,记录user_id + if message.msg_type == MessageType.USER_REGISTER: + user_id = message.sender_id + # 创建临时连接对象用于注册 + temp_conn = ClientConnection( + user_id=user_id, + reader=reader, + writer=writer, + ip_address=ip_address, + port=port + ) + await self._process_register(message, temp_conn) + elif user_id: + # 更新心跳时间 + if user_id in self._connections: + self._connections[user_id].last_heartbeat = time.time() + + # 处理消息 + await self._process_message(message, user_id) + else: + # 未注册的连接发送非注册消息 + logger.warning(f"Unregistered client sent message: {message.msg_type}") + error_msg = self._create_error_message( + "server", message.sender_id, "Please register first" + ) + await self._send_message(writer, error_msg) + + except asyncio.CancelledError: + logger.info(f"Connection cancelled for {ip_address}:{port}") + except Exception as e: + logger.error(f"Error handling client {ip_address}:{port}: {e}") + finally: + # 清理连接 + if user_id: + await self._cleanup_connection(user_id) + else: + writer.close() + try: + await writer.wait_closed() + except Exception: + pass + logger.info(f"Connection closed for {ip_address}:{port}") + + async def _read_message(self, reader: StreamReader) -> Optional[Message]: + """ + 从流中读取消息 + + Args: + reader: 异步读取器 + + Returns: + 解析后的消息对象,如果连接关闭则返回None + """ + try: + # 读取消息头 + header = await reader.readexactly(self._message_handler.HEADER_SIZE) + if not header: + return None + + # 解析消息长度 + import struct + payload_length, version = struct.unpack( + self._message_handler.HEADER_FORMAT, header + ) + + # 读取消息体 + payload = await reader.readexactly(payload_length) + + # 反序列化消息 + full_data = header + payload + return self._message_handler.deserialize(full_data) + + except asyncio.IncompleteReadError: + return None + except MessageSerializationError as e: + logger.error(f"Failed to deserialize message: {e}") + return None + except Exception as e: + logger.error(f"Error reading message: {e}") + return None + + async def _send_message(self, writer: StreamWriter, message: Message) -> bool: + """ + 发送消息到客户端 + + Args: + writer: 异步写入器 + message: 要发送的消息 + + Returns: + 发送成功返回True,否则返回False + """ + try: + data = self._message_handler.serialize(message) + writer.write(data) + await writer.drain() + return True + except Exception as e: + logger.error(f"Failed to send message: {e}") + return False + + async def _process_message(self, message: Message, sender_user_id: str) -> None: + """ + 处理接收到的消息 + + Args: + message: 接收到的消息 + sender_user_id: 发送者用户ID + """ + try: + # 验证消息 + self._message_handler.validate(message) + + # 根据消息类型处理 + if message.msg_type == MessageType.USER_UNREGISTER: + await self._handle_user_unregister_async(message) + elif message.msg_type == MessageType.USER_LIST_REQUEST: + await self._handle_user_list_request_async(message) + elif message.msg_type == MessageType.HEARTBEAT: + await self._handle_heartbeat_async(message) + else: + # 转发消息 + await self._relay_message_async(message) + + except MessageValidationError as e: + logger.error(f"Message validation failed: {e}") + error_msg = self._create_error_message( + "server", sender_user_id, f"Invalid message: {e}" + ) + if sender_user_id in self._connections: + await self._send_message( + self._connections[sender_user_id].writer, error_msg + ) + + async def _process_register(self, message: Message, conn: ClientConnection) -> None: + """ + 处理用户注册 + + Args: + message: 注册消息 + conn: 客户端连接 + """ + user_id = message.sender_id + + try: + # 解析用户信息 + user_info = UserInfo.deserialize(message.payload) + user_info.status = UserStatus.ONLINE + user_info.ip_address = conn.ip_address + user_info.port = conn.port + user_info.last_seen = datetime.now() + + # 注册用户 + success = await self._register_user_async(user_id, user_info, conn) + + if success: + # 发送注册成功响应 + ack_msg = self._create_ack_message("server", user_id, "Registration successful") + await self._send_message(conn.writer, ack_msg) + + # 投递离线消息 + await self._deliver_offline_messages(user_id) + else: + error_msg = self._create_error_message( + "server", user_id, "Registration failed" + ) + await self._send_message(conn.writer, error_msg) + + except Exception as e: + logger.error(f"Error processing registration for {user_id}: {e}") + error_msg = self._create_error_message( + "server", user_id, f"Registration error: {e}" + ) + await self._send_message(conn.writer, error_msg) + + async def _cleanup_connection(self, user_id: str) -> None: + """ + 清理用户连接 + + Args: + user_id: 用户ID + """ + async with self._lock: + if user_id in self._connections: + conn = self._connections[user_id] + try: + conn.writer.close() + await conn.writer.wait_closed() + except Exception: + pass + del self._connections[user_id] + + # 更新用户状态为离线 + if user_id in self._users: + self._users[user_id].status = UserStatus.OFFLINE + self._users[user_id].last_seen = datetime.now() + + logger.info(f"User {user_id} disconnected") + + async def _heartbeat_checker(self) -> None: + """心跳检查任务,定期检查连接是否存活""" + while self._running: + try: + await asyncio.sleep(self.config.heartbeat_interval) + + async with self._lock: + dead_connections = [] + for user_id, conn in self._connections.items(): + if not conn.is_alive: + dead_connections.append(user_id) + + for user_id in dead_connections: + logger.info(f"Heartbeat timeout for user {user_id}") + await self._cleanup_connection(user_id) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in heartbeat checker: {e}") + + # ==================== 用户管理功能 (需求 1.7, 2.2, 2.3, 2.5) ==================== + + def register_user(self, user_id: str, connection: ClientConnection) -> bool: + """ + 注册用户(同步版本,用于消息处理器回调) + + 记录客户端信息并维护在线用户列表 (需求 1.7) + 允许用户注册 (需求 2.2) + + Args: + user_id: 用户ID + connection: 客户端连接 + + Returns: + 注册成功返回True,否则返回False + """ + # 这是同步版本,实际使用异步版本 + return False + + async def _register_user_async( + self, user_id: str, user_info: UserInfo, connection: ClientConnection + ) -> bool: + """ + 注册用户(异步版本) + + Args: + user_id: 用户ID + user_info: 用户信息 + connection: 客户端连接 + + Returns: + 注册成功返回True,否则返回False + """ + async with self._lock: + # 检查是否已有连接(可能是重连) + if user_id in self._connections: + old_conn = self._connections[user_id] + try: + old_conn.writer.close() + await old_conn.writer.wait_closed() + except Exception: + pass + logger.info(f"User {user_id} reconnected, closing old connection") + + # 注册新连接 + self._connections[user_id] = connection + self._users[user_id] = user_info + + logger.info(f"User {user_id} registered from {connection.ip_address}:{connection.port}") + return True + + def unregister_user(self, user_id: str) -> None: + """ + 注销用户(同步版本,用于消息处理器回调) + + 通知服务器更新在线状态 (需求 2.5) + + Args: + user_id: 用户ID + """ + # 这是同步版本,实际使用异步版本 + pass + + async def _unregister_user_async(self, user_id: str) -> None: + """ + 注销用户(异步版本) + + Args: + user_id: 用户ID + """ + await self._cleanup_connection(user_id) + + def get_online_users(self) -> List[UserInfo]: + """ + 获取在线用户列表 + + 显示当前所有在线用户 (需求 2.3) + + Returns: + 在线用户信息列表 + """ + online_users = [] + for user_id, user_info in self._users.items(): + if user_id in self._connections and user_info.status == UserStatus.ONLINE: + online_users.append(user_info) + return online_users + + def get_user_info(self, user_id: str) -> Optional[UserInfo]: + """ + 获取指定用户信息 + + Args: + user_id: 用户ID + + Returns: + 用户信息,如果用户不存在则返回None + """ + return self._users.get(user_id) + + def is_user_online(self, user_id: str) -> bool: + """ + 检查用户是否在线 + + Args: + user_id: 用户ID + + Returns: + 用户在线返回True,否则返回False + """ + return ( + user_id in self._connections and + user_id in self._users and + self._users[user_id].status == UserStatus.ONLINE + ) + + # ==================== 消息处理器回调 ==================== + + def _handle_user_register(self, message: Message) -> None: + """处理用户注册消息(同步回调)""" + # 实际处理在 _process_register 中完成 + pass + + def _handle_user_unregister(self, message: Message) -> None: + """处理用户注销消息(同步回调)""" + # 实际处理在异步版本中完成 + pass + + async def _handle_user_unregister_async(self, message: Message) -> None: + """处理用户注销消息(异步版本)""" + user_id = message.sender_id + await self._unregister_user_async(user_id) + logger.info(f"User {user_id} unregistered") + + def _handle_user_list_request(self, message: Message) -> None: + """处理用户列表请求(同步回调)""" + # 实际处理在异步版本中完成 + pass + + async def _handle_user_list_request_async(self, message: Message) -> None: + """处理用户列表请求(异步版本)""" + user_id = message.sender_id + + if user_id not in self._connections: + return + + # 获取在线用户列表 + online_users = self.get_online_users() + + # 构建响应 + users_data = [user.to_dict() for user in online_users] + payload = json.dumps(users_data).encode('utf-8') + + response = Message( + msg_type=MessageType.USER_LIST_RESPONSE, + sender_id="server", + receiver_id=user_id, + timestamp=time.time(), + payload=payload + ) + + await self._send_message(self._connections[user_id].writer, response) + + def _handle_heartbeat(self, message: Message) -> None: + """处理心跳消息(同步回调)""" + # 心跳时间更新在 _process_message 中完成 + pass + + async def _handle_heartbeat_async(self, message: Message) -> None: + """处理心跳消息(异步版本)""" + user_id = message.sender_id + + if user_id in self._connections: + self._connections[user_id].last_heartbeat = time.time() + + # 发送心跳响应 + response = Message( + msg_type=MessageType.HEARTBEAT, + sender_id="server", + receiver_id=user_id, + timestamp=time.time(), + payload=b"" + ) + await self._send_message(self._connections[user_id].writer, response) + + def _handle_relay_message(self, message: Message) -> None: + """处理需要转发的消息(同步回调)""" + # 实际处理在异步版本中完成 + pass + + # ==================== 消息转发功能 (需求 8.4, 8.5) ==================== + + async def relay_message(self, message: Message) -> bool: + """ + 转发消息到目标客户端 + + 将消息路由到目标客户端 (需求 8.4) + + Args: + message: 要转发的消息 + + Returns: + 转发成功返回True,否则返回False + """ + return await self._relay_message_async(message) + + async def _relay_message_async(self, message: Message) -> bool: + """ + 转发消息(异步版本) + + Args: + message: 要转发的消息 + + Returns: + 转发成功返回True,否则返回False + """ + receiver_id = message.receiver_id + sender_id = message.sender_id + + # 检查目标用户是否在线 + if self.is_user_online(receiver_id): + # 在线,直接转发 + conn = self._connections[receiver_id] + success = await self._send_message(conn.writer, message) + + if success: + logger.debug(f"Message relayed from {sender_id} to {receiver_id}") + # 发送ACK给发送者 + if sender_id in self._connections: + ack = self._create_ack_message( + "server", sender_id, + f"Message delivered to {receiver_id}" + ) + await self._send_message(self._connections[sender_id].writer, ack) + return True + else: + logger.error(f"Failed to relay message to {receiver_id}") + return False + else: + # 离线,缓存消息 + self.cache_offline_message(receiver_id, message) + logger.info(f"User {receiver_id} offline, message cached") + + # 通知发送者消息已缓存 + if sender_id in self._connections: + ack = self._create_ack_message( + "server", sender_id, + f"User {receiver_id} is offline, message cached" + ) + await self._send_message(self._connections[sender_id].writer, ack) + return True + + def cache_offline_message(self, user_id: str, message: Message) -> None: + """ + 缓存离线消息 + + 缓存消息并在客户端上线后投递 (需求 8.5) + + Args: + user_id: 目标用户ID + message: 要缓存的消息 + """ + if user_id not in self._offline_messages: + self._offline_messages[user_id] = deque(maxlen=self._max_offline_messages) + + self._offline_messages[user_id].append(message) + logger.debug(f"Cached offline message for user {user_id}, " + f"total: {len(self._offline_messages[user_id])}") + + async def _deliver_offline_messages(self, user_id: str) -> None: + """ + 投递离线消息 + + 当用户上线时投递所有缓存的消息 (需求 8.5) + + Args: + user_id: 用户ID + """ + if user_id not in self._offline_messages: + return + + if user_id not in self._connections: + return + + messages = self._offline_messages.pop(user_id) + conn = self._connections[user_id] + + delivered_count = 0 + for message in messages: + success = await self._send_message(conn.writer, message) + if success: + delivered_count += 1 + else: + # 如果发送失败,将剩余消息放回缓存 + remaining = list(messages)[delivered_count:] + if remaining: + self._offline_messages[user_id] = deque(remaining, maxlen=self._max_offline_messages) + break + + if delivered_count > 0: + logger.info(f"Delivered {delivered_count} offline messages to user {user_id}") + + def get_offline_message_count(self, user_id: str) -> int: + """ + 获取用户的离线消息数量 + + Args: + user_id: 用户ID + + Returns: + 离线消息数量 + """ + if user_id in self._offline_messages: + return len(self._offline_messages[user_id]) + return 0 + + def clear_offline_messages(self, user_id: str) -> None: + """ + 清除用户的离线消息 + + Args: + user_id: 用户ID + """ + if user_id in self._offline_messages: + del self._offline_messages[user_id] + logger.debug(f"Cleared offline messages for user {user_id}") + + # ==================== 辅助方法 ==================== + + def _create_error_message(self, sender_id: str, receiver_id: str, error: str) -> Message: + """ + 创建错误消息 + + Args: + sender_id: 发送者ID + receiver_id: 接收者ID + error: 错误信息 + + Returns: + 错误消息对象 + """ + return Message( + msg_type=MessageType.ERROR, + sender_id=sender_id, + receiver_id=receiver_id, + timestamp=time.time(), + payload=error.encode('utf-8') + ) + + def _create_ack_message(self, sender_id: str, receiver_id: str, info: str) -> Message: + """ + 创建确认消息 + + Args: + sender_id: 发送者ID + receiver_id: 接收者ID + info: 确认信息 + + Returns: + 确认消息对象 + """ + return Message( + msg_type=MessageType.ACK, + sender_id=sender_id, + receiver_id=receiver_id, + timestamp=time.time(), + payload=info.encode('utf-8') + ) + + # ==================== 属性访问 ==================== + + @property + def is_running(self) -> bool: + """服务器是否正在运行""" + return self._running + + @property + def connection_count(self) -> int: + """当前连接数""" + return len(self._connections) + + @property + def user_count(self) -> int: + """注册用户数""" + return len(self._users) + + @property + def online_user_count(self) -> int: + """在线用户数""" + return len(self.get_online_users()) diff --git a/server/repositories.py b/server/repositories.py new file mode 100644 index 0000000..29b23c8 --- /dev/null +++ b/server/repositories.py @@ -0,0 +1,905 @@ +# P2P Network Communication - Data Access Layer +""" +数据访问层,提供用户、消息、文件传输记录的CRUD操作 +""" + +import json +import logging +from datetime import datetime +from typing import Optional, List + +from shared.models import ( + UserInfo, UserStatus, ChatMessage, MessageType, + FileTransferRecord, TransferStatus, Message +) +from server.database import DatabaseManager + +logger = logging.getLogger(__name__) + + +class UserRepository: + """用户数据访问层""" + + def __init__(self, db: DatabaseManager): + """ + 初始化用户仓库 + + Args: + db: 数据库管理器实例 + """ + self.db = db + + async def create(self, user: UserInfo, password_hash: str = "") -> bool: + """ + 创建新用户 + + Args: + user: 用户信息对象 + password_hash: 密码哈希值 + + Returns: + 创建是否成功 + """ + query = """ + INSERT INTO users (user_id, username, display_name, password_hash, + public_key, status, ip_address, port, last_seen) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) + """ + try: + await self.db.execute(query, ( + user.user_id, + user.username, + user.display_name, + password_hash, + user.public_key if user.public_key else None, + user.status.value, + user.ip_address, + user.port, + user.last_seen, + )) + logger.info(f"User created: {user.user_id}") + return True + except Exception as e: + logger.error(f"Failed to create user {user.user_id}: {e}") + return False + + async def get_by_id(self, user_id: str) -> Optional[UserInfo]: + """ + 根据用户ID获取用户信息 + + Args: + user_id: 用户ID + + Returns: + 用户信息对象,如果不存在返回None + """ + query = """ + SELECT user_id, username, display_name, public_key, status, + ip_address, port, last_seen + FROM users WHERE user_id = %s + """ + row = await self.db.fetch_dict(query, (user_id,)) + if row: + return self._row_to_user(row) + return None + + async def get_by_username(self, username: str) -> Optional[UserInfo]: + """ + 根据用户名获取用户信息 + + Args: + username: 用户名 + + Returns: + 用户信息对象,如果不存在返回None + """ + query = """ + SELECT user_id, username, display_name, public_key, status, + ip_address, port, last_seen + FROM users WHERE username = %s + """ + row = await self.db.fetch_dict(query, (username,)) + if row: + return self._row_to_user(row) + return None + + async def update(self, user: UserInfo) -> bool: + """ + 更新用户信息 + + Args: + user: 用户信息对象 + + Returns: + 更新是否成功 + """ + query = """ + UPDATE users SET display_name = %s, public_key = %s, status = %s, + ip_address = %s, port = %s, last_seen = %s + WHERE user_id = %s + """ + try: + rows = await self.db.execute(query, ( + user.display_name, + user.public_key if user.public_key else None, + user.status.value, + user.ip_address, + user.port, + user.last_seen, + user.user_id, + )) + return rows > 0 + except Exception as e: + logger.error(f"Failed to update user {user.user_id}: {e}") + return False + + async def update_status(self, user_id: str, status: UserStatus, + ip_address: str = "", port: int = 0) -> bool: + """ + 更新用户在线状态 + + Args: + user_id: 用户ID + status: 用户状态 + ip_address: IP地址 + port: 端口号 + + Returns: + 更新是否成功 + """ + query = """ + UPDATE users SET status = %s, ip_address = %s, port = %s, last_seen = %s + WHERE user_id = %s + """ + try: + rows = await self.db.execute(query, ( + status.value, + ip_address, + port, + datetime.now(), + user_id, + )) + return rows > 0 + except Exception as e: + logger.error(f"Failed to update user status {user_id}: {e}") + return False + + async def delete(self, user_id: str) -> bool: + """ + 删除用户 + + Args: + user_id: 用户ID + + Returns: + 删除是否成功 + """ + query = "DELETE FROM users WHERE user_id = %s" + try: + rows = await self.db.execute(query, (user_id,)) + return rows > 0 + except Exception as e: + logger.error(f"Failed to delete user {user_id}: {e}") + return False + + async def get_online_users(self) -> List[UserInfo]: + """ + 获取所有在线用户 + + Returns: + 在线用户列表 + """ + query = """ + SELECT user_id, username, display_name, public_key, status, + ip_address, port, last_seen + FROM users WHERE status = 'online' + """ + rows = await self.db.fetch_all_dict(query) + return [self._row_to_user(row) for row in rows] + + async def get_all_users(self) -> List[UserInfo]: + """ + 获取所有用户 + + Returns: + 用户列表 + """ + query = """ + SELECT user_id, username, display_name, public_key, status, + ip_address, port, last_seen + FROM users ORDER BY username + """ + rows = await self.db.fetch_all_dict(query) + return [self._row_to_user(row) for row in rows] + + async def exists(self, user_id: str) -> bool: + """ + 检查用户是否存在 + + Args: + user_id: 用户ID + + Returns: + 用户是否存在 + """ + query = "SELECT 1 FROM users WHERE user_id = %s" + row = await self.db.fetch_one(query, (user_id,)) + return row is not None + + async def username_exists(self, username: str) -> bool: + """ + 检查用户名是否已存在 + + Args: + username: 用户名 + + Returns: + 用户名是否已存在 + """ + query = "SELECT 1 FROM users WHERE username = %s" + row = await self.db.fetch_one(query, (username,)) + return row is not None + + def _row_to_user(self, row: dict) -> UserInfo: + """将数据库行转换为UserInfo对象""" + return UserInfo( + user_id=row['user_id'], + username=row['username'], + display_name=row['display_name'], + status=UserStatus(row['status']) if row['status'] else UserStatus.OFFLINE, + ip_address=row['ip_address'] or "", + port=row['port'] or 0, + last_seen=row['last_seen'], + public_key=row['public_key'] or bytes(), + ) + + +class MessageRepository: + """消息数据访问层""" + + def __init__(self, db: DatabaseManager): + """ + 初始化消息仓库 + + Args: + db: 数据库管理器实例 + """ + self.db = db + + async def create(self, message: ChatMessage) -> bool: + """ + 保存聊天消息 + + Args: + message: 聊天消息对象 + + Returns: + 保存是否成功 + """ + query = """ + INSERT INTO messages (message_id, sender_id, receiver_id, content_type, + content, timestamp, is_read, is_delivered) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) + """ + try: + await self.db.execute(query, ( + message.message_id, + message.sender_id, + message.receiver_id, + message.content_type.value, + message.content, + message.timestamp, + message.is_read, + message.is_sent, + )) + logger.debug(f"Message saved: {message.message_id}") + return True + except Exception as e: + logger.error(f"Failed to save message {message.message_id}: {e}") + return False + + async def get_by_id(self, message_id: str) -> Optional[ChatMessage]: + """ + 根据消息ID获取消息 + + Args: + message_id: 消息ID + + Returns: + 聊天消息对象,如果不存在返回None + """ + query = """ + SELECT message_id, sender_id, receiver_id, content_type, content, + timestamp, is_read, is_delivered + FROM messages WHERE message_id = %s + """ + row = await self.db.fetch_dict(query, (message_id,)) + if row: + return self._row_to_message(row) + return None + + async def get_conversation(self, user1_id: str, user2_id: str, + limit: int = 100, offset: int = 0) -> List[ChatMessage]: + """ + 获取两个用户之间的聊天历史 + + Args: + user1_id: 用户1的ID + user2_id: 用户2的ID + limit: 返回消息数量限制 + offset: 偏移量 + + Returns: + 聊天消息列表,按时间升序排列 + """ + query = """ + SELECT message_id, sender_id, receiver_id, content_type, content, + timestamp, is_read, is_delivered + FROM messages + WHERE (sender_id = %s AND receiver_id = %s) + OR (sender_id = %s AND receiver_id = %s) + ORDER BY timestamp ASC + LIMIT %s OFFSET %s + """ + rows = await self.db.fetch_all_dict(query, ( + user1_id, user2_id, user2_id, user1_id, limit, offset + )) + return [self._row_to_message(row) for row in rows] + + async def get_user_messages(self, user_id: str, limit: int = 100, + offset: int = 0) -> List[ChatMessage]: + """ + 获取用户的所有消息(发送和接收) + + Args: + user_id: 用户ID + limit: 返回消息数量限制 + offset: 偏移量 + + Returns: + 聊天消息列表,按时间升序排列 + """ + query = """ + SELECT message_id, sender_id, receiver_id, content_type, content, + timestamp, is_read, is_delivered + FROM messages + WHERE sender_id = %s OR receiver_id = %s + ORDER BY timestamp ASC + LIMIT %s OFFSET %s + """ + rows = await self.db.fetch_all_dict(query, (user_id, user_id, limit, offset)) + return [self._row_to_message(row) for row in rows] + + async def get_unread_messages(self, user_id: str) -> List[ChatMessage]: + """ + 获取用户的未读消息 + + Args: + user_id: 用户ID + + Returns: + 未读消息列表 + """ + query = """ + SELECT message_id, sender_id, receiver_id, content_type, content, + timestamp, is_read, is_delivered + FROM messages + WHERE receiver_id = %s AND is_read = FALSE + ORDER BY timestamp ASC + """ + rows = await self.db.fetch_all_dict(query, (user_id,)) + return [self._row_to_message(row) for row in rows] + + async def mark_as_read(self, message_id: str) -> bool: + """ + 标记消息为已读 + + Args: + message_id: 消息ID + + Returns: + 更新是否成功 + """ + query = "UPDATE messages SET is_read = TRUE WHERE message_id = %s" + try: + rows = await self.db.execute(query, (message_id,)) + return rows > 0 + except Exception as e: + logger.error(f"Failed to mark message as read {message_id}: {e}") + return False + + async def mark_conversation_as_read(self, user_id: str, sender_id: str) -> int: + """ + 标记与某用户的所有消息为已读 + + Args: + user_id: 接收者用户ID + sender_id: 发送者用户ID + + Returns: + 更新的消息数量 + """ + query = """ + UPDATE messages SET is_read = TRUE + WHERE receiver_id = %s AND sender_id = %s AND is_read = FALSE + """ + try: + return await self.db.execute(query, (user_id, sender_id)) + except Exception as e: + logger.error(f"Failed to mark conversation as read: {e}") + return 0 + + async def mark_as_delivered(self, message_id: str) -> bool: + """ + 标记消息为已投递 + + Args: + message_id: 消息ID + + Returns: + 更新是否成功 + """ + query = "UPDATE messages SET is_delivered = TRUE WHERE message_id = %s" + try: + rows = await self.db.execute(query, (message_id,)) + return rows > 0 + except Exception as e: + logger.error(f"Failed to mark message as delivered {message_id}: {e}") + return False + + async def delete(self, message_id: str) -> bool: + """ + 删除消息 + + Args: + message_id: 消息ID + + Returns: + 删除是否成功 + """ + query = "DELETE FROM messages WHERE message_id = %s" + try: + rows = await self.db.execute(query, (message_id,)) + return rows > 0 + except Exception as e: + logger.error(f"Failed to delete message {message_id}: {e}") + return False + + async def delete_conversation(self, user1_id: str, user2_id: str) -> int: + """ + 删除两个用户之间的所有消息 + + Args: + user1_id: 用户1的ID + user2_id: 用户2的ID + + Returns: + 删除的消息数量 + """ + query = """ + DELETE FROM messages + WHERE (sender_id = %s AND receiver_id = %s) + OR (sender_id = %s AND receiver_id = %s) + """ + try: + return await self.db.execute(query, (user1_id, user2_id, user2_id, user1_id)) + except Exception as e: + logger.error(f"Failed to delete conversation: {e}") + return 0 + + def _row_to_message(self, row: dict) -> ChatMessage: + """将数据库行转换为ChatMessage对象""" + return ChatMessage( + message_id=row['message_id'], + sender_id=row['sender_id'], + receiver_id=row['receiver_id'], + content_type=MessageType(row['content_type']), + content=row['content'] or "", + timestamp=row['timestamp'], + is_read=bool(row['is_read']), + is_sent=bool(row['is_delivered']), + ) + + +class FileTransferRepository: + """文件传输记录数据访问层""" + + def __init__(self, db: DatabaseManager): + """ + 初始化文件传输仓库 + + Args: + db: 数据库管理器实例 + """ + self.db = db + + async def create(self, record: FileTransferRecord) -> bool: + """ + 创建文件传输记录 + + Args: + record: 文件传输记录对象 + + Returns: + 创建是否成功 + """ + query = """ + INSERT INTO file_transfers (transfer_id, file_name, file_size, file_hash, + sender_id, receiver_id, status, progress, + start_time, end_time) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """ + try: + await self.db.execute(query, ( + record.transfer_id, + record.file_name, + record.file_size, + record.file_hash, + record.sender_id, + record.receiver_id, + record.status.value, + record.progress, + record.start_time, + record.end_time, + )) + logger.info(f"File transfer record created: {record.transfer_id}") + return True + except Exception as e: + logger.error(f"Failed to create file transfer record: {e}") + return False + + async def get_by_id(self, transfer_id: str) -> Optional[FileTransferRecord]: + """ + 根据传输ID获取记录 + + Args: + transfer_id: 传输ID + + Returns: + 文件传输记录对象,如果不存在返回None + """ + query = """ + SELECT transfer_id, file_name, file_size, file_hash, sender_id, + receiver_id, status, progress, start_time, end_time + FROM file_transfers WHERE transfer_id = %s + """ + row = await self.db.fetch_dict(query, (transfer_id,)) + if row: + return self._row_to_record(row) + return None + + async def update_status(self, transfer_id: str, status: TransferStatus, + progress: float = None) -> bool: + """ + 更新传输状态 + + Args: + transfer_id: 传输ID + status: 新状态 + progress: 进度(可选) + + Returns: + 更新是否成功 + """ + if progress is not None: + query = """ + UPDATE file_transfers SET status = %s, progress = %s + WHERE transfer_id = %s + """ + args = (status.value, progress, transfer_id) + else: + query = "UPDATE file_transfers SET status = %s WHERE transfer_id = %s" + args = (status.value, transfer_id) + + try: + rows = await self.db.execute(query, args) + return rows > 0 + except Exception as e: + logger.error(f"Failed to update transfer status: {e}") + return False + + async def update_progress(self, transfer_id: str, progress: float) -> bool: + """ + 更新传输进度 + + Args: + transfer_id: 传输ID + progress: 进度百分比(0-100) + + Returns: + 更新是否成功 + """ + query = "UPDATE file_transfers SET progress = %s WHERE transfer_id = %s" + try: + rows = await self.db.execute(query, (progress, transfer_id)) + return rows > 0 + except Exception as e: + logger.error(f"Failed to update transfer progress: {e}") + return False + + async def complete(self, transfer_id: str) -> bool: + """ + 标记传输完成 + + Args: + transfer_id: 传输ID + + Returns: + 更新是否成功 + """ + query = """ + UPDATE file_transfers SET status = %s, progress = 100, end_time = %s + WHERE transfer_id = %s + """ + try: + rows = await self.db.execute(query, ( + TransferStatus.COMPLETED.value, datetime.now(), transfer_id + )) + return rows > 0 + except Exception as e: + logger.error(f"Failed to complete transfer: {e}") + return False + + async def fail(self, transfer_id: str) -> bool: + """ + 标记传输失败 + + Args: + transfer_id: 传输ID + + Returns: + 更新是否成功 + """ + query = """ + UPDATE file_transfers SET status = %s, end_time = %s + WHERE transfer_id = %s + """ + try: + rows = await self.db.execute(query, ( + TransferStatus.FAILED.value, datetime.now(), transfer_id + )) + return rows > 0 + except Exception as e: + logger.error(f"Failed to mark transfer as failed: {e}") + return False + + async def get_user_transfers(self, user_id: str, + status: TransferStatus = None) -> List[FileTransferRecord]: + """ + 获取用户的文件传输记录 + + Args: + user_id: 用户ID + status: 过滤状态(可选) + + Returns: + 文件传输记录列表 + """ + if status: + query = """ + SELECT transfer_id, file_name, file_size, file_hash, sender_id, + receiver_id, status, progress, start_time, end_time + FROM file_transfers + WHERE (sender_id = %s OR receiver_id = %s) AND status = %s + ORDER BY start_time DESC + """ + rows = await self.db.fetch_all_dict(query, (user_id, user_id, status.value)) + else: + query = """ + SELECT transfer_id, file_name, file_size, file_hash, sender_id, + receiver_id, status, progress, start_time, end_time + FROM file_transfers + WHERE sender_id = %s OR receiver_id = %s + ORDER BY start_time DESC + """ + rows = await self.db.fetch_all_dict(query, (user_id, user_id)) + + return [self._row_to_record(row) for row in rows] + + async def get_pending_transfers(self, user_id: str) -> List[FileTransferRecord]: + """ + 获取用户待处理的传输 + + Args: + user_id: 用户ID + + Returns: + 待处理的文件传输记录列表 + """ + query = """ + SELECT transfer_id, file_name, file_size, file_hash, sender_id, + receiver_id, status, progress, start_time, end_time + FROM file_transfers + WHERE (sender_id = %s OR receiver_id = %s) + AND status IN ('pending', 'in_progress', 'paused') + ORDER BY start_time DESC + """ + rows = await self.db.fetch_all_dict(query, (user_id, user_id)) + return [self._row_to_record(row) for row in rows] + + async def delete(self, transfer_id: str) -> bool: + """ + 删除传输记录 + + Args: + transfer_id: 传输ID + + Returns: + 删除是否成功 + """ + query = "DELETE FROM file_transfers WHERE transfer_id = %s" + try: + rows = await self.db.execute(query, (transfer_id,)) + return rows > 0 + except Exception as e: + logger.error(f"Failed to delete transfer record: {e}") + return False + + def _row_to_record(self, row: dict) -> FileTransferRecord: + """将数据库行转换为FileTransferRecord对象""" + return FileTransferRecord( + transfer_id=row['transfer_id'], + file_name=row['file_name'], + file_size=row['file_size'], + file_hash=row['file_hash'], + sender_id=row['sender_id'], + receiver_id=row['receiver_id'], + status=TransferStatus(row['status']), + progress=float(row['progress']), + start_time=row['start_time'], + end_time=row['end_time'], + ) + + +class OfflineMessageRepository: + """离线消息数据访问层""" + + def __init__(self, db: DatabaseManager): + """ + 初始化离线消息仓库 + + Args: + db: 数据库管理器实例 + """ + self.db = db + + async def cache(self, user_id: str, message: Message) -> bool: + """ + 缓存离线消息 + + Args: + user_id: 目标用户ID + message: 消息对象 + + Returns: + 缓存是否成功 + """ + query = """ + INSERT INTO offline_messages (user_id, message_data) + VALUES (%s, %s) + """ + try: + message_data = message.serialize() + await self.db.execute(query, (user_id, message_data)) + logger.debug(f"Offline message cached for user: {user_id}") + return True + except Exception as e: + logger.error(f"Failed to cache offline message: {e}") + return False + + async def get_messages(self, user_id: str) -> List[Message]: + """ + 获取用户的所有离线消息 + + Args: + user_id: 用户ID + + Returns: + 离线消息列表,按创建时间升序排列 + """ + query = """ + SELECT id, message_data FROM offline_messages + WHERE user_id = %s ORDER BY created_at ASC + """ + rows = await self.db.fetch_all_dict(query, (user_id,)) + messages = [] + for row in rows: + try: + message = Message.deserialize(row['message_data']) + messages.append(message) + except Exception as e: + logger.error(f"Failed to deserialize offline message: {e}") + return messages + + async def get_messages_with_ids(self, user_id: str) -> List[tuple]: + """ + 获取用户的所有离线消息及其ID + + Args: + user_id: 用户ID + + Returns: + (消息ID, 消息对象)元组列表 + """ + query = """ + SELECT id, message_data FROM offline_messages + WHERE user_id = %s ORDER BY created_at ASC + """ + rows = await self.db.fetch_all_dict(query, (user_id,)) + results = [] + for row in rows: + try: + message = Message.deserialize(row['message_data']) + results.append((row['id'], message)) + except Exception as e: + logger.error(f"Failed to deserialize offline message: {e}") + return results + + async def delete(self, message_id: int) -> bool: + """ + 删除单条离线消息 + + Args: + message_id: 离线消息记录ID + + Returns: + 删除是否成功 + """ + query = "DELETE FROM offline_messages WHERE id = %s" + try: + rows = await self.db.execute(query, (message_id,)) + return rows > 0 + except Exception as e: + logger.error(f"Failed to delete offline message: {e}") + return False + + async def delete_all(self, user_id: str) -> int: + """ + 删除用户的所有离线消息 + + Args: + user_id: 用户ID + + Returns: + 删除的消息数量 + """ + query = "DELETE FROM offline_messages WHERE user_id = %s" + try: + return await self.db.execute(query, (user_id,)) + except Exception as e: + logger.error(f"Failed to delete all offline messages: {e}") + return 0 + + async def count(self, user_id: str) -> int: + """ + 获取用户的离线消息数量 + + Args: + user_id: 用户ID + + Returns: + 离线消息数量 + """ + query = "SELECT COUNT(*) as count FROM offline_messages WHERE user_id = %s" + row = await self.db.fetch_dict(query, (user_id,)) + return row['count'] if row else 0 + + async def cleanup_old_messages(self, max_age_seconds: int) -> int: + """ + 清理过期的离线消息 + + Args: + max_age_seconds: 最大保留时间(秒) + + Returns: + 删除的消息数量 + """ + query = """ + DELETE FROM offline_messages + WHERE created_at < DATE_SUB(NOW(), INTERVAL %s SECOND) + """ + try: + return await self.db.execute(query, (max_age_seconds,)) + except Exception as e: + logger.error(f"Failed to cleanup old offline messages: {e}") + return 0 diff --git a/tests/test_relay_server.py b/tests/test_relay_server.py new file mode 100644 index 0000000..069d8db --- /dev/null +++ b/tests/test_relay_server.py @@ -0,0 +1,365 @@ +# P2P Network Communication - Relay Server Tests +""" +中转服务器单元测试 +测试服务器核心功能、用户管理和消息转发 +""" + +import asyncio +import pytest +import time +import json +from unittest.mock import AsyncMock, MagicMock, patch +from datetime import datetime + +from server.relay_server import ( + RelayServer, ClientConnection, ServerError +) +from shared.models import ( + Message, MessageType, UserInfo, UserStatus +) +from config import ServerConfig + + +class TestRelayServerInit: + """测试服务器初始化""" + + def test_init_with_default_config(self): + """测试使用默认配置初始化""" + server = RelayServer() + assert server.host == "0.0.0.0" + assert server.port == 8888 + assert server.max_connections == 1000 + assert not server.is_running + assert server.connection_count == 0 + + def test_init_with_custom_config(self): + """测试使用自定义配置初始化""" + config = ServerConfig( + host="127.0.0.1", + port=9999, + max_connections=500 + ) + server = RelayServer(config) + assert server.host == "127.0.0.1" + assert server.port == 9999 + assert server.max_connections == 500 + + +class TestUserManagement: + """测试用户管理功能""" + + @pytest.fixture + def server(self): + """创建服务器实例""" + return RelayServer() + + @pytest.fixture + def mock_connection(self): + """创建模拟连接""" + reader = AsyncMock() + writer = AsyncMock() + writer.close = MagicMock() + writer.wait_closed = AsyncMock() + return ClientConnection( + user_id="test_user", + reader=reader, + writer=writer, + ip_address="192.168.1.100", + port=12345 + ) + + @pytest.fixture + def user_info(self): + """创建用户信息""" + return UserInfo( + user_id="test_user", + username="testuser", + display_name="Test User", + status=UserStatus.ONLINE + ) + + @pytest.mark.asyncio + async def test_register_user(self, server, mock_connection, user_info): + """测试用户注册""" + success = await server._register_user_async( + "test_user", user_info, mock_connection + ) + assert success + assert server.is_user_online("test_user") + assert server.connection_count == 1 + + @pytest.mark.asyncio + async def test_register_user_reconnect(self, server, mock_connection, user_info): + """测试用户重连(替换旧连接)""" + # 第一次注册 + await server._register_user_async("test_user", user_info, mock_connection) + + # 创建新连接 + new_reader = AsyncMock() + new_writer = AsyncMock() + new_writer.close = MagicMock() + new_writer.wait_closed = AsyncMock() + new_connection = ClientConnection( + user_id="test_user", + reader=new_reader, + writer=new_writer, + ip_address="192.168.1.101", + port=12346 + ) + + # 重新注册 + success = await server._register_user_async("test_user", user_info, new_connection) + assert success + assert server.connection_count == 1 # 仍然只有一个连接 + + # 旧连接应该被关闭 + mock_connection.writer.close.assert_called() + + @pytest.mark.asyncio + async def test_unregister_user(self, server, mock_connection, user_info): + """测试用户注销""" + await server._register_user_async("test_user", user_info, mock_connection) + assert server.is_user_online("test_user") + + await server._unregister_user_async("test_user") + assert not server.is_user_online("test_user") + assert server.connection_count == 0 + + @pytest.mark.asyncio + async def test_get_online_users(self, server, mock_connection, user_info): + """测试获取在线用户列表""" + # 初始为空 + assert len(server.get_online_users()) == 0 + + # 注册用户 + await server._register_user_async("test_user", user_info, mock_connection) + + online_users = server.get_online_users() + assert len(online_users) == 1 + assert online_users[0].user_id == "test_user" + + @pytest.mark.asyncio + async def test_get_user_info(self, server, mock_connection, user_info): + """测试获取用户信息""" + # 用户不存在 + assert server.get_user_info("test_user") is None + + # 注册用户 + await server._register_user_async("test_user", user_info, mock_connection) + + info = server.get_user_info("test_user") + assert info is not None + assert info.user_id == "test_user" + assert info.username == "testuser" + + +class TestMessageRelay: + """测试消息转发功能""" + + @pytest.fixture + def server(self): + """创建服务器实例""" + return RelayServer() + + @pytest.fixture + def create_mock_connection(self): + """创建模拟连接的工厂函数""" + def _create(user_id): + reader = AsyncMock() + writer = AsyncMock() + writer.close = MagicMock() + writer.wait_closed = AsyncMock() + writer.write = MagicMock() + writer.drain = AsyncMock() + return ClientConnection( + user_id=user_id, + reader=reader, + writer=writer, + ip_address="192.168.1.100", + port=12345 + ) + return _create + + @pytest.mark.asyncio + async def test_relay_message_to_online_user(self, server, create_mock_connection): + """测试转发消息给在线用户""" + # 注册发送者和接收者 + sender_conn = create_mock_connection("sender") + receiver_conn = create_mock_connection("receiver") + + sender_info = UserInfo( + user_id="sender", username="sender", + display_name="Sender", status=UserStatus.ONLINE + ) + receiver_info = UserInfo( + user_id="receiver", username="receiver", + display_name="Receiver", status=UserStatus.ONLINE + ) + + await server._register_user_async("sender", sender_info, sender_conn) + await server._register_user_async("receiver", receiver_info, receiver_conn) + + # 创建消息 + message = Message( + msg_type=MessageType.TEXT, + sender_id="sender", + receiver_id="receiver", + timestamp=time.time(), + payload=b"Hello, receiver!" + ) + + # 转发消息 + success = await server.relay_message(message) + assert success + + # 验证消息被发送到接收者 + receiver_conn.writer.write.assert_called() + + @pytest.mark.asyncio + async def test_relay_message_to_offline_user(self, server, create_mock_connection): + """测试转发消息给离线用户(缓存)""" + # 只注册发送者 + sender_conn = create_mock_connection("sender") + sender_info = UserInfo( + user_id="sender", username="sender", + display_name="Sender", status=UserStatus.ONLINE + ) + await server._register_user_async("sender", sender_info, sender_conn) + + # 创建消息 + message = Message( + msg_type=MessageType.TEXT, + sender_id="sender", + receiver_id="offline_user", + timestamp=time.time(), + payload=b"Hello, offline user!" + ) + + # 转发消息(应该被缓存) + success = await server.relay_message(message) + assert success + + # 验证消息被缓存 + assert server.get_offline_message_count("offline_user") == 1 + + def test_cache_offline_message(self, server): + """测试离线消息缓存""" + message = Message( + msg_type=MessageType.TEXT, + sender_id="sender", + receiver_id="receiver", + timestamp=time.time(), + payload=b"Test message" + ) + + server.cache_offline_message("receiver", message) + assert server.get_offline_message_count("receiver") == 1 + + # 缓存多条消息 + server.cache_offline_message("receiver", message) + server.cache_offline_message("receiver", message) + assert server.get_offline_message_count("receiver") == 3 + + def test_clear_offline_messages(self, server): + """测试清除离线消息""" + message = Message( + msg_type=MessageType.TEXT, + sender_id="sender", + receiver_id="receiver", + timestamp=time.time(), + payload=b"Test message" + ) + + server.cache_offline_message("receiver", message) + assert server.get_offline_message_count("receiver") == 1 + + server.clear_offline_messages("receiver") + assert server.get_offline_message_count("receiver") == 0 + + @pytest.mark.asyncio + async def test_deliver_offline_messages(self, server, create_mock_connection): + """测试投递离线消息""" + # 缓存离线消息 + message1 = Message( + msg_type=MessageType.TEXT, + sender_id="sender", + receiver_id="receiver", + timestamp=time.time(), + payload=b"Message 1" + ) + message2 = Message( + msg_type=MessageType.TEXT, + sender_id="sender", + receiver_id="receiver", + timestamp=time.time() + 1, + payload=b"Message 2" + ) + + server.cache_offline_message("receiver", message1) + server.cache_offline_message("receiver", message2) + assert server.get_offline_message_count("receiver") == 2 + + # 用户上线 + receiver_conn = create_mock_connection("receiver") + receiver_info = UserInfo( + user_id="receiver", username="receiver", + display_name="Receiver", status=UserStatus.ONLINE + ) + await server._register_user_async("receiver", receiver_info, receiver_conn) + + # 投递离线消息 + await server._deliver_offline_messages("receiver") + + # 验证消息被投递 + assert receiver_conn.writer.write.call_count >= 2 + assert server.get_offline_message_count("receiver") == 0 + + +class TestHelperMethods: + """测试辅助方法""" + + @pytest.fixture + def server(self): + """创建服务器实例""" + return RelayServer() + + def test_create_error_message(self, server): + """测试创建错误消息""" + error_msg = server._create_error_message( + "server", "user1", "Test error" + ) + assert error_msg.msg_type == MessageType.ERROR + assert error_msg.sender_id == "server" + assert error_msg.receiver_id == "user1" + assert error_msg.payload == b"Test error" + + def test_create_ack_message(self, server): + """测试创建确认消息""" + ack_msg = server._create_ack_message( + "server", "user1", "Test ack" + ) + assert ack_msg.msg_type == MessageType.ACK + assert ack_msg.sender_id == "server" + assert ack_msg.receiver_id == "user1" + assert ack_msg.payload == b"Test ack" + + +class TestClientConnection: + """测试客户端连接类""" + + def test_connection_is_alive(self): + """测试连接存活检查""" + reader = AsyncMock() + writer = AsyncMock() + + conn = ClientConnection( + user_id="test", + reader=reader, + writer=writer, + last_heartbeat=time.time() + ) + assert conn.is_alive + + # 模拟超时 + conn.last_heartbeat = time.time() - 120 # 2分钟前 + assert not conn.is_alive