# P2P Network Communication - Relay Server """ 中转服务器模块 负责管理客户端连接、用户注册、消息转发和离线消息缓存 需求: 8.1, 8.2, 8.4, 8.5, 1.7, 2.2, 2.3, 2.5 """ import asyncio import logging import time import json from asyncio import StreamReader, StreamWriter from dataclasses import dataclass, field from datetime import datetime from typing import Dict, List, Optional, Tuple, Set from collections import deque from shared.models import ( Message, MessageType, UserInfo, UserStatus ) from shared.message_handler import ( MessageHandler, MessageValidationError, MessageSerializationError ) from config import ServerConfig # 设置日志 logger = logging.getLogger(__name__) class ServerError(Exception): """服务器错误基类""" pass class ConnectionError(ServerError): """连接错误""" pass class UserRegistrationError(ServerError): """用户注册错误""" pass @dataclass class ClientConnection: """客户端连接信息""" user_id: str reader: StreamReader writer: StreamWriter connected_at: datetime = field(default_factory=datetime.now) last_heartbeat: float = field(default_factory=time.time) ip_address: str = "" port: int = 0 @property def is_alive(self) -> bool: """检查连接是否存活(基于心跳)""" return time.time() - self.last_heartbeat < 60 # 60秒超时 class RelayServer: """ 中转服务器 负责: - 监听指定端口等待客户端连接 (需求 8.1) - 验证客户端身份并分配会话 (需求 8.2) - 维护所有活跃连接的状态 (需求 8.3) - 将消息路由到目标客户端 (需求 8.4) - 缓存离线消息并在客户端上线后投递 (需求 8.5) - 记录客户端信息并维护在线用户列表 (需求 1.7) - 用户注册和注销 (需求 2.2, 2.5) - 获取在线用户列表 (需求 2.3) """ def __init__(self, config: Optional[ServerConfig] = None): """ 初始化服务器 Args: config: 服务器配置,如果为None则使用默认配置 """ self.config = config or ServerConfig() self.host = self.config.host self.port = self.config.port self.max_connections = self.config.max_connections # 连接管理 self._connections: Dict[str, ClientConnection] = {} # user_id -> connection self._users: Dict[str, UserInfo] = {} # user_id -> user_info # 离线消息缓存 (user_id -> list of messages) self._offline_messages: Dict[str, deque] = {} self._max_offline_messages = 1000 # 每用户最大离线消息数 # 服务器状态 self._server: Optional[asyncio.AbstractServer] = None self._running = False self._lock = asyncio.Lock() # 消息处理器 self._message_handler = MessageHandler() self._setup_message_handlers() # 心跳检查任务 self._heartbeat_task: Optional[asyncio.Task] = None logger.info(f"RelayServer initialized: {self.host}:{self.port}") def _setup_message_handlers(self) -> None: """设置消息处理器""" self._message_handler.register_handler( MessageType.USER_REGISTER, self._handle_user_register ) self._message_handler.register_handler( MessageType.USER_UNREGISTER, self._handle_user_unregister ) self._message_handler.register_handler( MessageType.USER_LIST_REQUEST, self._handle_user_list_request ) self._message_handler.register_handler( MessageType.HEARTBEAT, self._handle_heartbeat ) # 设置默认处理器用于转发其他类型的消息 self._message_handler.set_default_handler(self._handle_relay_message) async def start(self) -> None: """ 启动服务器 监听指定端口等待客户端连接 (需求 8.1) Raises: ServerError: 服务器启动失败时抛出 """ if self._running: logger.warning("Server is already running") return try: self._server = await asyncio.start_server( self._handle_client, self.host, self.port ) self._running = True # 启动心跳检查任务 self._heartbeat_task = asyncio.create_task(self._heartbeat_checker()) addr = self._server.sockets[0].getsockname() logger.info(f"Server started on {addr[0]}:{addr[1]}") async with self._server: await self._server.serve_forever() except OSError as e: raise ServerError(f"Failed to start server: {e}") except Exception as e: logger.error(f"Server error: {e}") raise ServerError(f"Server error: {e}") async def stop(self) -> None: """ 停止服务器 关闭所有连接并释放资源 """ if not self._running: logger.warning("Server is not running") return self._running = False # 取消心跳检查任务 if self._heartbeat_task: self._heartbeat_task.cancel() try: await self._heartbeat_task except asyncio.CancelledError: pass # 关闭所有客户端连接 async with self._lock: for user_id, conn in list(self._connections.items()): try: conn.writer.close() await conn.writer.wait_closed() except Exception as e: logger.error(f"Error closing connection for {user_id}: {e}") self._connections.clear() # 关闭服务器 if self._server: self._server.close() await self._server.wait_closed() self._server = None logger.info("Server stopped") async def _handle_client(self, reader: StreamReader, writer: StreamWriter) -> None: """ 处理客户端连接 验证客户端身份并分配会话 (需求 8.2) Args: reader: 异步读取器 writer: 异步写入器 """ addr = writer.get_extra_info('peername') ip_address = addr[0] if addr else "unknown" port = addr[1] if addr else 0 logger.info(f"New connection from {ip_address}:{port}") # 检查连接数限制 if len(self._connections) >= self.max_connections: logger.warning(f"Max connections reached, rejecting {ip_address}:{port}") error_msg = self._create_error_message( "server", "unknown", "Server is at maximum capacity" ) try: await self._send_message(writer, error_msg) except Exception: pass writer.close() await writer.wait_closed() return user_id: Optional[str] = None try: while self._running: # 读取消息 message = await self._read_message(reader) if message is None: break # 如果是注册消息,记录user_id if message.msg_type == MessageType.USER_REGISTER: user_id = message.sender_id # 创建临时连接对象用于注册 temp_conn = ClientConnection( user_id=user_id, reader=reader, writer=writer, ip_address=ip_address, port=port ) await self._process_register(message, temp_conn) elif user_id: # 更新心跳时间 if user_id in self._connections: self._connections[user_id].last_heartbeat = time.time() # 处理消息 await self._process_message(message, user_id) else: # 未注册的连接发送非注册消息 logger.warning(f"Unregistered client sent message: {message.msg_type}") error_msg = self._create_error_message( "server", message.sender_id, "Please register first" ) await self._send_message(writer, error_msg) except asyncio.CancelledError: logger.info(f"Connection cancelled for {ip_address}:{port}") except Exception as e: logger.error(f"Error handling client {ip_address}:{port}: {e}") finally: # 清理连接 if user_id: await self._cleanup_connection(user_id) else: writer.close() try: await writer.wait_closed() except Exception: pass logger.info(f"Connection closed for {ip_address}:{port}") async def _read_message(self, reader: StreamReader) -> Optional[Message]: """ 从流中读取消息 Args: reader: 异步读取器 Returns: 解析后的消息对象,如果连接关闭则返回None """ try: # 读取消息头 header = await reader.readexactly(self._message_handler.HEADER_SIZE) if not header: return None # 解析消息长度 import struct payload_length, version = struct.unpack( self._message_handler.HEADER_FORMAT, header ) # 读取消息体 payload = await reader.readexactly(payload_length) # 反序列化消息 full_data = header + payload return self._message_handler.deserialize(full_data) except asyncio.IncompleteReadError: return None except MessageSerializationError as e: logger.error(f"Failed to deserialize message: {e}") return None except Exception as e: logger.error(f"Error reading message: {e}") return None async def _send_message(self, writer: StreamWriter, message: Message) -> bool: """ 发送消息到客户端 Args: writer: 异步写入器 message: 要发送的消息 Returns: 发送成功返回True,否则返回False """ try: data = self._message_handler.serialize(message) writer.write(data) await writer.drain() return True except Exception as e: logger.error(f"Failed to send message: {e}") return False async def _process_message(self, message: Message, sender_user_id: str) -> None: """ 处理接收到的消息 Args: message: 接收到的消息 sender_user_id: 发送者用户ID """ try: # 验证消息 self._message_handler.validate(message) # 根据消息类型处理 if message.msg_type == MessageType.USER_UNREGISTER: await self._handle_user_unregister_async(message) elif message.msg_type == MessageType.USER_LIST_REQUEST: await self._handle_user_list_request_async(message) elif message.msg_type == MessageType.HEARTBEAT: await self._handle_heartbeat_async(message) else: # 转发消息 await self._relay_message_async(message) except MessageValidationError as e: logger.error(f"Message validation failed: {e}") error_msg = self._create_error_message( "server", sender_user_id, f"Invalid message: {e}" ) if sender_user_id in self._connections: await self._send_message( self._connections[sender_user_id].writer, error_msg ) async def _process_register(self, message: Message, conn: ClientConnection) -> None: """ 处理用户注册 Args: message: 注册消息 conn: 客户端连接 """ user_id = message.sender_id try: # 解析用户信息 user_info = UserInfo.deserialize(message.payload) user_info.status = UserStatus.ONLINE user_info.ip_address = conn.ip_address user_info.port = conn.port user_info.last_seen = datetime.now() # 注册用户 success = await self._register_user_async(user_id, user_info, conn) if success: # 发送注册成功响应 ack_msg = self._create_ack_message("server", user_id, "Registration successful") await self._send_message(conn.writer, ack_msg) # 投递离线消息 await self._deliver_offline_messages(user_id) else: error_msg = self._create_error_message( "server", user_id, "Registration failed" ) await self._send_message(conn.writer, error_msg) except Exception as e: logger.error(f"Error processing registration for {user_id}: {e}") error_msg = self._create_error_message( "server", user_id, f"Registration error: {e}" ) await self._send_message(conn.writer, error_msg) async def _cleanup_connection(self, user_id: str) -> None: """ 清理用户连接 Args: user_id: 用户ID """ async with self._lock: if user_id in self._connections: conn = self._connections[user_id] try: conn.writer.close() await conn.writer.wait_closed() except Exception: pass del self._connections[user_id] # 更新用户状态为离线 if user_id in self._users: self._users[user_id].status = UserStatus.OFFLINE self._users[user_id].last_seen = datetime.now() logger.info(f"User {user_id} disconnected") async def _heartbeat_checker(self) -> None: """心跳检查任务,定期检查连接是否存活""" while self._running: try: await asyncio.sleep(self.config.heartbeat_interval) async with self._lock: dead_connections = [] for user_id, conn in self._connections.items(): if not conn.is_alive: dead_connections.append(user_id) for user_id in dead_connections: logger.info(f"Heartbeat timeout for user {user_id}") await self._cleanup_connection(user_id) except asyncio.CancelledError: break except Exception as e: logger.error(f"Error in heartbeat checker: {e}") # ==================== 用户管理功能 (需求 1.7, 2.2, 2.3, 2.5) ==================== def register_user(self, user_id: str, connection: ClientConnection) -> bool: """ 注册用户(同步版本,用于消息处理器回调) 记录客户端信息并维护在线用户列表 (需求 1.7) 允许用户注册 (需求 2.2) Args: user_id: 用户ID connection: 客户端连接 Returns: 注册成功返回True,否则返回False """ # 这是同步版本,实际使用异步版本 return False async def _register_user_async( self, user_id: str, user_info: UserInfo, connection: ClientConnection ) -> bool: """ 注册用户(异步版本) Args: user_id: 用户ID user_info: 用户信息 connection: 客户端连接 Returns: 注册成功返回True,否则返回False """ async with self._lock: # 检查是否已有连接(可能是重连) if user_id in self._connections: old_conn = self._connections[user_id] try: old_conn.writer.close() await old_conn.writer.wait_closed() except Exception: pass logger.info(f"User {user_id} reconnected, closing old connection") # 注册新连接 self._connections[user_id] = connection self._users[user_id] = user_info logger.info(f"User {user_id} registered from {connection.ip_address}:{connection.port}") return True def unregister_user(self, user_id: str) -> None: """ 注销用户(同步版本,用于消息处理器回调) 通知服务器更新在线状态 (需求 2.5) Args: user_id: 用户ID """ # 这是同步版本,实际使用异步版本 pass async def _unregister_user_async(self, user_id: str) -> None: """ 注销用户(异步版本) Args: user_id: 用户ID """ await self._cleanup_connection(user_id) def get_online_users(self) -> List[UserInfo]: """ 获取在线用户列表 显示当前所有在线用户 (需求 2.3) Returns: 在线用户信息列表 """ online_users = [] for user_id, user_info in self._users.items(): if user_id in self._connections and user_info.status == UserStatus.ONLINE: online_users.append(user_info) return online_users def get_user_info(self, user_id: str) -> Optional[UserInfo]: """ 获取指定用户信息 Args: user_id: 用户ID Returns: 用户信息,如果用户不存在则返回None """ return self._users.get(user_id) def is_user_online(self, user_id: str) -> bool: """ 检查用户是否在线 Args: user_id: 用户ID Returns: 用户在线返回True,否则返回False """ return ( user_id in self._connections and user_id in self._users and self._users[user_id].status == UserStatus.ONLINE ) # ==================== 消息处理器回调 ==================== def _handle_user_register(self, message: Message) -> None: """处理用户注册消息(同步回调)""" # 实际处理在 _process_register 中完成 pass def _handle_user_unregister(self, message: Message) -> None: """处理用户注销消息(同步回调)""" # 实际处理在异步版本中完成 pass async def _handle_user_unregister_async(self, message: Message) -> None: """处理用户注销消息(异步版本)""" user_id = message.sender_id await self._unregister_user_async(user_id) logger.info(f"User {user_id} unregistered") def _handle_user_list_request(self, message: Message) -> None: """处理用户列表请求(同步回调)""" # 实际处理在异步版本中完成 pass async def _handle_user_list_request_async(self, message: Message) -> None: """处理用户列表请求(异步版本)""" user_id = message.sender_id if user_id not in self._connections: return # 获取在线用户列表 online_users = self.get_online_users() # 构建响应 users_data = [user.to_dict() for user in online_users] payload = json.dumps(users_data).encode('utf-8') response = Message( msg_type=MessageType.USER_LIST_RESPONSE, sender_id="server", receiver_id=user_id, timestamp=time.time(), payload=payload ) await self._send_message(self._connections[user_id].writer, response) def _handle_heartbeat(self, message: Message) -> None: """处理心跳消息(同步回调)""" # 心跳时间更新在 _process_message 中完成 pass async def _handle_heartbeat_async(self, message: Message) -> None: """处理心跳消息(异步版本)""" user_id = message.sender_id if user_id in self._connections: self._connections[user_id].last_heartbeat = time.time() # 发送心跳响应 response = Message( msg_type=MessageType.HEARTBEAT, sender_id="server", receiver_id=user_id, timestamp=time.time(), payload=b"" ) await self._send_message(self._connections[user_id].writer, response) def _handle_relay_message(self, message: Message) -> None: """处理需要转发的消息(同步回调)""" # 实际处理在异步版本中完成 pass # ==================== 消息转发功能 (需求 8.4, 8.5) ==================== async def relay_message(self, message: Message) -> bool: """ 转发消息到目标客户端 将消息路由到目标客户端 (需求 8.4) Args: message: 要转发的消息 Returns: 转发成功返回True,否则返回False """ return await self._relay_message_async(message) async def _relay_message_async(self, message: Message) -> bool: """ 转发消息(异步版本) Args: message: 要转发的消息 Returns: 转发成功返回True,否则返回False """ receiver_id = message.receiver_id sender_id = message.sender_id # 记录语音通话相关消息的日志 voice_types = { MessageType.VOICE_CALL_REQUEST, MessageType.VOICE_CALL_ACCEPT, MessageType.VOICE_CALL_REJECT, MessageType.VOICE_CALL_END, MessageType.VOICE_DATA, } if message.msg_type in voice_types: logger.info(f"Voice message: {message.msg_type.value} from {sender_id} to {receiver_id}") # 检查目标用户是否在线 if self.is_user_online(receiver_id): # 在线,直接转发 conn = self._connections[receiver_id] success = await self._send_message(conn.writer, message) if success: logger.debug(f"Message relayed from {sender_id} to {receiver_id}") # 发送ACK给发送者 if sender_id in self._connections: ack = self._create_ack_message( "server", sender_id, f"Message delivered to {receiver_id}" ) await self._send_message(self._connections[sender_id].writer, ack) return True else: logger.error(f"Failed to relay message to {receiver_id}") return False else: # 离线,缓存消息 self.cache_offline_message(receiver_id, message) logger.info(f"User {receiver_id} offline, message cached") # 通知发送者消息已缓存 if sender_id in self._connections: ack = self._create_ack_message( "server", sender_id, f"User {receiver_id} is offline, message cached" ) await self._send_message(self._connections[sender_id].writer, ack) return True def cache_offline_message(self, user_id: str, message: Message) -> None: """ 缓存离线消息 缓存消息并在客户端上线后投递 (需求 8.5) Args: user_id: 目标用户ID message: 要缓存的消息 """ if user_id not in self._offline_messages: self._offline_messages[user_id] = deque(maxlen=self._max_offline_messages) self._offline_messages[user_id].append(message) logger.debug(f"Cached offline message for user {user_id}, " f"total: {len(self._offline_messages[user_id])}") async def _deliver_offline_messages(self, user_id: str) -> None: """ 投递离线消息 当用户上线时投递所有缓存的消息 (需求 8.5) Args: user_id: 用户ID """ if user_id not in self._offline_messages: return if user_id not in self._connections: return messages = self._offline_messages.pop(user_id) conn = self._connections[user_id] delivered_count = 0 for message in messages: success = await self._send_message(conn.writer, message) if success: delivered_count += 1 else: # 如果发送失败,将剩余消息放回缓存 remaining = list(messages)[delivered_count:] if remaining: self._offline_messages[user_id] = deque(remaining, maxlen=self._max_offline_messages) break if delivered_count > 0: logger.info(f"Delivered {delivered_count} offline messages to user {user_id}") def get_offline_message_count(self, user_id: str) -> int: """ 获取用户的离线消息数量 Args: user_id: 用户ID Returns: 离线消息数量 """ if user_id in self._offline_messages: return len(self._offline_messages[user_id]) return 0 def clear_offline_messages(self, user_id: str) -> None: """ 清除用户的离线消息 Args: user_id: 用户ID """ if user_id in self._offline_messages: del self._offline_messages[user_id] logger.debug(f"Cleared offline messages for user {user_id}") # ==================== 辅助方法 ==================== def _create_error_message(self, sender_id: str, receiver_id: str, error: str) -> Message: """ 创建错误消息 Args: sender_id: 发送者ID receiver_id: 接收者ID error: 错误信息 Returns: 错误消息对象 """ return Message( msg_type=MessageType.ERROR, sender_id=sender_id, receiver_id=receiver_id, timestamp=time.time(), payload=error.encode('utf-8') ) def _create_ack_message(self, sender_id: str, receiver_id: str, info: str) -> Message: """ 创建确认消息 Args: sender_id: 发送者ID receiver_id: 接收者ID info: 确认信息 Returns: 确认消息对象 """ return Message( msg_type=MessageType.ACK, sender_id=sender_id, receiver_id=receiver_id, timestamp=time.time(), payload=info.encode('utf-8') ) # ==================== 属性访问 ==================== @property def is_running(self) -> bool: """服务器是否正在运行""" return self._running @property def connection_count(self) -> int: """当前连接数""" return len(self._connections) @property def user_count(self) -> int: """注册用户数""" return len(self._users) @property def online_user_count(self) -> int: """在线用户数""" return len(self.get_online_users())