|
|
# P2P Network Communication - Relay Server
|
|
|
"""
|
|
|
中转服务器模块
|
|
|
负责管理客户端连接、用户注册、消息转发和离线消息缓存
|
|
|
|
|
|
需求: 8.1, 8.2, 8.4, 8.5, 1.7, 2.2, 2.3, 2.5
|
|
|
"""
|
|
|
|
|
|
import asyncio
|
|
|
import logging
|
|
|
import time
|
|
|
import json
|
|
|
from asyncio import StreamReader, StreamWriter
|
|
|
from dataclasses import dataclass, field
|
|
|
from datetime import datetime
|
|
|
from typing import Dict, List, Optional, Tuple, Set
|
|
|
from collections import deque
|
|
|
|
|
|
from shared.models import (
|
|
|
Message, MessageType, UserInfo, UserStatus
|
|
|
)
|
|
|
from shared.message_handler import (
|
|
|
MessageHandler, MessageValidationError, MessageSerializationError
|
|
|
)
|
|
|
from config import ServerConfig
|
|
|
|
|
|
|
|
|
# 设置日志
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class ServerError(Exception):
|
|
|
"""服务器错误基类"""
|
|
|
pass
|
|
|
|
|
|
|
|
|
class ConnectionError(ServerError):
|
|
|
"""连接错误"""
|
|
|
pass
|
|
|
|
|
|
|
|
|
class UserRegistrationError(ServerError):
|
|
|
"""用户注册错误"""
|
|
|
pass
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class ClientConnection:
|
|
|
"""客户端连接信息"""
|
|
|
user_id: str
|
|
|
reader: StreamReader
|
|
|
writer: StreamWriter
|
|
|
connected_at: datetime = field(default_factory=datetime.now)
|
|
|
last_heartbeat: float = field(default_factory=time.time)
|
|
|
ip_address: str = ""
|
|
|
port: int = 0
|
|
|
|
|
|
@property
|
|
|
def is_alive(self) -> bool:
|
|
|
"""检查连接是否存活(基于心跳)"""
|
|
|
return time.time() - self.last_heartbeat < 60 # 60秒超时
|
|
|
|
|
|
|
|
|
class RelayServer:
|
|
|
"""
|
|
|
中转服务器
|
|
|
|
|
|
负责:
|
|
|
- 监听指定端口等待客户端连接 (需求 8.1)
|
|
|
- 验证客户端身份并分配会话 (需求 8.2)
|
|
|
- 维护所有活跃连接的状态 (需求 8.3)
|
|
|
- 将消息路由到目标客户端 (需求 8.4)
|
|
|
- 缓存离线消息并在客户端上线后投递 (需求 8.5)
|
|
|
- 记录客户端信息并维护在线用户列表 (需求 1.7)
|
|
|
- 用户注册和注销 (需求 2.2, 2.5)
|
|
|
- 获取在线用户列表 (需求 2.3)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: Optional[ServerConfig] = None):
|
|
|
"""
|
|
|
初始化服务器
|
|
|
|
|
|
Args:
|
|
|
config: 服务器配置,如果为None则使用默认配置
|
|
|
"""
|
|
|
self.config = config or ServerConfig()
|
|
|
self.host = self.config.host
|
|
|
self.port = self.config.port
|
|
|
self.max_connections = self.config.max_connections
|
|
|
|
|
|
# 连接管理
|
|
|
self._connections: Dict[str, ClientConnection] = {} # user_id -> connection
|
|
|
self._users: Dict[str, UserInfo] = {} # user_id -> user_info
|
|
|
|
|
|
# 离线消息缓存 (user_id -> list of messages)
|
|
|
self._offline_messages: Dict[str, deque] = {}
|
|
|
self._max_offline_messages = 1000 # 每用户最大离线消息数
|
|
|
|
|
|
# 服务器状态
|
|
|
self._server: Optional[asyncio.AbstractServer] = None
|
|
|
self._running = False
|
|
|
self._lock = asyncio.Lock()
|
|
|
|
|
|
# 消息处理器
|
|
|
self._message_handler = MessageHandler()
|
|
|
self._setup_message_handlers()
|
|
|
|
|
|
# 心跳检查任务
|
|
|
self._heartbeat_task: Optional[asyncio.Task] = None
|
|
|
|
|
|
logger.info(f"RelayServer initialized: {self.host}:{self.port}")
|
|
|
|
|
|
def _setup_message_handlers(self) -> None:
|
|
|
"""设置消息处理器"""
|
|
|
self._message_handler.register_handler(
|
|
|
MessageType.USER_REGISTER, self._handle_user_register
|
|
|
)
|
|
|
self._message_handler.register_handler(
|
|
|
MessageType.USER_UNREGISTER, self._handle_user_unregister
|
|
|
)
|
|
|
self._message_handler.register_handler(
|
|
|
MessageType.USER_LIST_REQUEST, self._handle_user_list_request
|
|
|
)
|
|
|
self._message_handler.register_handler(
|
|
|
MessageType.HEARTBEAT, self._handle_heartbeat
|
|
|
)
|
|
|
# 设置默认处理器用于转发其他类型的消息
|
|
|
self._message_handler.set_default_handler(self._handle_relay_message)
|
|
|
|
|
|
async def start(self) -> None:
|
|
|
"""
|
|
|
启动服务器
|
|
|
|
|
|
监听指定端口等待客户端连接 (需求 8.1)
|
|
|
|
|
|
Raises:
|
|
|
ServerError: 服务器启动失败时抛出
|
|
|
"""
|
|
|
if self._running:
|
|
|
logger.warning("Server is already running")
|
|
|
return
|
|
|
|
|
|
try:
|
|
|
self._server = await asyncio.start_server(
|
|
|
self._handle_client,
|
|
|
self.host,
|
|
|
self.port
|
|
|
)
|
|
|
self._running = True
|
|
|
|
|
|
# 启动心跳检查任务
|
|
|
self._heartbeat_task = asyncio.create_task(self._heartbeat_checker())
|
|
|
|
|
|
addr = self._server.sockets[0].getsockname()
|
|
|
logger.info(f"Server started on {addr[0]}:{addr[1]}")
|
|
|
|
|
|
async with self._server:
|
|
|
await self._server.serve_forever()
|
|
|
|
|
|
except OSError as e:
|
|
|
raise ServerError(f"Failed to start server: {e}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Server error: {e}")
|
|
|
raise ServerError(f"Server error: {e}")
|
|
|
|
|
|
async def stop(self) -> None:
|
|
|
"""
|
|
|
停止服务器
|
|
|
|
|
|
关闭所有连接并释放资源
|
|
|
"""
|
|
|
if not self._running:
|
|
|
logger.warning("Server is not running")
|
|
|
return
|
|
|
|
|
|
self._running = False
|
|
|
|
|
|
# 取消心跳检查任务
|
|
|
if self._heartbeat_task:
|
|
|
self._heartbeat_task.cancel()
|
|
|
try:
|
|
|
await self._heartbeat_task
|
|
|
except asyncio.CancelledError:
|
|
|
pass
|
|
|
|
|
|
# 关闭所有客户端连接
|
|
|
async with self._lock:
|
|
|
for user_id, conn in list(self._connections.items()):
|
|
|
try:
|
|
|
conn.writer.close()
|
|
|
await conn.writer.wait_closed()
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error closing connection for {user_id}: {e}")
|
|
|
self._connections.clear()
|
|
|
|
|
|
# 关闭服务器
|
|
|
if self._server:
|
|
|
self._server.close()
|
|
|
await self._server.wait_closed()
|
|
|
self._server = None
|
|
|
|
|
|
logger.info("Server stopped")
|
|
|
|
|
|
async def _handle_client(self, reader: StreamReader, writer: StreamWriter) -> None:
|
|
|
"""
|
|
|
处理客户端连接
|
|
|
|
|
|
验证客户端身份并分配会话 (需求 8.2)
|
|
|
|
|
|
Args:
|
|
|
reader: 异步读取器
|
|
|
writer: 异步写入器
|
|
|
"""
|
|
|
addr = writer.get_extra_info('peername')
|
|
|
ip_address = addr[0] if addr else "unknown"
|
|
|
port = addr[1] if addr else 0
|
|
|
|
|
|
logger.info(f"New connection from {ip_address}:{port}")
|
|
|
|
|
|
# 检查连接数限制
|
|
|
if len(self._connections) >= self.max_connections:
|
|
|
logger.warning(f"Max connections reached, rejecting {ip_address}:{port}")
|
|
|
error_msg = self._create_error_message(
|
|
|
"server", "unknown", "Server is at maximum capacity"
|
|
|
)
|
|
|
try:
|
|
|
await self._send_message(writer, error_msg)
|
|
|
except Exception:
|
|
|
pass
|
|
|
writer.close()
|
|
|
await writer.wait_closed()
|
|
|
return
|
|
|
|
|
|
user_id: Optional[str] = None
|
|
|
|
|
|
try:
|
|
|
while self._running:
|
|
|
# 读取消息
|
|
|
message = await self._read_message(reader)
|
|
|
if message is None:
|
|
|
break
|
|
|
|
|
|
# 如果是注册消息,记录user_id
|
|
|
if message.msg_type == MessageType.USER_REGISTER:
|
|
|
user_id = message.sender_id
|
|
|
# 创建临时连接对象用于注册
|
|
|
temp_conn = ClientConnection(
|
|
|
user_id=user_id,
|
|
|
reader=reader,
|
|
|
writer=writer,
|
|
|
ip_address=ip_address,
|
|
|
port=port
|
|
|
)
|
|
|
await self._process_register(message, temp_conn)
|
|
|
elif user_id:
|
|
|
# 更新心跳时间
|
|
|
if user_id in self._connections:
|
|
|
self._connections[user_id].last_heartbeat = time.time()
|
|
|
|
|
|
# 处理消息
|
|
|
await self._process_message(message, user_id)
|
|
|
else:
|
|
|
# 未注册的连接发送非注册消息
|
|
|
logger.warning(f"Unregistered client sent message: {message.msg_type}")
|
|
|
error_msg = self._create_error_message(
|
|
|
"server", message.sender_id, "Please register first"
|
|
|
)
|
|
|
await self._send_message(writer, error_msg)
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
logger.info(f"Connection cancelled for {ip_address}:{port}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error handling client {ip_address}:{port}: {e}")
|
|
|
finally:
|
|
|
# 清理连接
|
|
|
if user_id:
|
|
|
await self._cleanup_connection(user_id)
|
|
|
else:
|
|
|
writer.close()
|
|
|
try:
|
|
|
await writer.wait_closed()
|
|
|
except Exception:
|
|
|
pass
|
|
|
logger.info(f"Connection closed for {ip_address}:{port}")
|
|
|
|
|
|
async def _read_message(self, reader: StreamReader) -> Optional[Message]:
|
|
|
"""
|
|
|
从流中读取消息
|
|
|
|
|
|
Args:
|
|
|
reader: 异步读取器
|
|
|
|
|
|
Returns:
|
|
|
解析后的消息对象,如果连接关闭则返回None
|
|
|
"""
|
|
|
try:
|
|
|
# 读取消息头
|
|
|
header = await reader.readexactly(self._message_handler.HEADER_SIZE)
|
|
|
if not header:
|
|
|
return None
|
|
|
|
|
|
# 解析消息长度
|
|
|
import struct
|
|
|
payload_length, version = struct.unpack(
|
|
|
self._message_handler.HEADER_FORMAT, header
|
|
|
)
|
|
|
|
|
|
# 读取消息体
|
|
|
payload = await reader.readexactly(payload_length)
|
|
|
|
|
|
# 反序列化消息
|
|
|
full_data = header + payload
|
|
|
return self._message_handler.deserialize(full_data)
|
|
|
|
|
|
except asyncio.IncompleteReadError:
|
|
|
return None
|
|
|
except MessageSerializationError as e:
|
|
|
logger.error(f"Failed to deserialize message: {e}")
|
|
|
return None
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error reading message: {e}")
|
|
|
return None
|
|
|
|
|
|
async def _send_message(self, writer: StreamWriter, message: Message) -> bool:
|
|
|
"""
|
|
|
发送消息到客户端
|
|
|
|
|
|
Args:
|
|
|
writer: 异步写入器
|
|
|
message: 要发送的消息
|
|
|
|
|
|
Returns:
|
|
|
发送成功返回True,否则返回False
|
|
|
"""
|
|
|
try:
|
|
|
data = self._message_handler.serialize(message)
|
|
|
writer.write(data)
|
|
|
await writer.drain()
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to send message: {e}")
|
|
|
return False
|
|
|
|
|
|
async def _process_message(self, message: Message, sender_user_id: str) -> None:
|
|
|
"""
|
|
|
处理接收到的消息
|
|
|
|
|
|
Args:
|
|
|
message: 接收到的消息
|
|
|
sender_user_id: 发送者用户ID
|
|
|
"""
|
|
|
try:
|
|
|
# 验证消息
|
|
|
self._message_handler.validate(message)
|
|
|
|
|
|
# 根据消息类型处理
|
|
|
if message.msg_type == MessageType.USER_UNREGISTER:
|
|
|
await self._handle_user_unregister_async(message)
|
|
|
elif message.msg_type == MessageType.USER_LIST_REQUEST:
|
|
|
await self._handle_user_list_request_async(message)
|
|
|
elif message.msg_type == MessageType.HEARTBEAT:
|
|
|
await self._handle_heartbeat_async(message)
|
|
|
else:
|
|
|
# 转发消息
|
|
|
await self._relay_message_async(message)
|
|
|
|
|
|
except MessageValidationError as e:
|
|
|
logger.error(f"Message validation failed: {e}")
|
|
|
error_msg = self._create_error_message(
|
|
|
"server", sender_user_id, f"Invalid message: {e}"
|
|
|
)
|
|
|
if sender_user_id in self._connections:
|
|
|
await self._send_message(
|
|
|
self._connections[sender_user_id].writer, error_msg
|
|
|
)
|
|
|
|
|
|
async def _process_register(self, message: Message, conn: ClientConnection) -> None:
|
|
|
"""
|
|
|
处理用户注册
|
|
|
|
|
|
Args:
|
|
|
message: 注册消息
|
|
|
conn: 客户端连接
|
|
|
"""
|
|
|
user_id = message.sender_id
|
|
|
|
|
|
try:
|
|
|
# 解析用户信息
|
|
|
user_info = UserInfo.deserialize(message.payload)
|
|
|
user_info.status = UserStatus.ONLINE
|
|
|
user_info.ip_address = conn.ip_address
|
|
|
user_info.port = conn.port
|
|
|
user_info.last_seen = datetime.now()
|
|
|
|
|
|
# 注册用户
|
|
|
success = await self._register_user_async(user_id, user_info, conn)
|
|
|
|
|
|
if success:
|
|
|
# 发送注册成功响应
|
|
|
ack_msg = self._create_ack_message("server", user_id, "Registration successful")
|
|
|
await self._send_message(conn.writer, ack_msg)
|
|
|
|
|
|
# 投递离线消息
|
|
|
await self._deliver_offline_messages(user_id)
|
|
|
else:
|
|
|
error_msg = self._create_error_message(
|
|
|
"server", user_id, "Registration failed"
|
|
|
)
|
|
|
await self._send_message(conn.writer, error_msg)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error processing registration for {user_id}: {e}")
|
|
|
error_msg = self._create_error_message(
|
|
|
"server", user_id, f"Registration error: {e}"
|
|
|
)
|
|
|
await self._send_message(conn.writer, error_msg)
|
|
|
|
|
|
async def _cleanup_connection(self, user_id: str) -> None:
|
|
|
"""
|
|
|
清理用户连接
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
"""
|
|
|
async with self._lock:
|
|
|
if user_id in self._connections:
|
|
|
conn = self._connections[user_id]
|
|
|
try:
|
|
|
conn.writer.close()
|
|
|
await conn.writer.wait_closed()
|
|
|
except Exception:
|
|
|
pass
|
|
|
del self._connections[user_id]
|
|
|
|
|
|
# 更新用户状态为离线
|
|
|
if user_id in self._users:
|
|
|
self._users[user_id].status = UserStatus.OFFLINE
|
|
|
self._users[user_id].last_seen = datetime.now()
|
|
|
|
|
|
logger.info(f"User {user_id} disconnected")
|
|
|
|
|
|
async def _heartbeat_checker(self) -> None:
|
|
|
"""心跳检查任务,定期检查连接是否存活"""
|
|
|
while self._running:
|
|
|
try:
|
|
|
await asyncio.sleep(self.config.heartbeat_interval)
|
|
|
|
|
|
async with self._lock:
|
|
|
dead_connections = []
|
|
|
for user_id, conn in self._connections.items():
|
|
|
if not conn.is_alive:
|
|
|
dead_connections.append(user_id)
|
|
|
|
|
|
for user_id in dead_connections:
|
|
|
logger.info(f"Heartbeat timeout for user {user_id}")
|
|
|
await self._cleanup_connection(user_id)
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
break
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in heartbeat checker: {e}")
|
|
|
|
|
|
# ==================== 用户管理功能 (需求 1.7, 2.2, 2.3, 2.5) ====================
|
|
|
|
|
|
def register_user(self, user_id: str, connection: ClientConnection) -> bool:
|
|
|
"""
|
|
|
注册用户(同步版本,用于消息处理器回调)
|
|
|
|
|
|
记录客户端信息并维护在线用户列表 (需求 1.7)
|
|
|
允许用户注册 (需求 2.2)
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
connection: 客户端连接
|
|
|
|
|
|
Returns:
|
|
|
注册成功返回True,否则返回False
|
|
|
"""
|
|
|
# 这是同步版本,实际使用异步版本
|
|
|
return False
|
|
|
|
|
|
async def _register_user_async(
|
|
|
self, user_id: str, user_info: UserInfo, connection: ClientConnection
|
|
|
) -> bool:
|
|
|
"""
|
|
|
注册用户(异步版本)
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
user_info: 用户信息
|
|
|
connection: 客户端连接
|
|
|
|
|
|
Returns:
|
|
|
注册成功返回True,否则返回False
|
|
|
"""
|
|
|
async with self._lock:
|
|
|
# 检查是否已有连接(可能是重连)
|
|
|
if user_id in self._connections:
|
|
|
old_conn = self._connections[user_id]
|
|
|
try:
|
|
|
old_conn.writer.close()
|
|
|
await old_conn.writer.wait_closed()
|
|
|
except Exception:
|
|
|
pass
|
|
|
logger.info(f"User {user_id} reconnected, closing old connection")
|
|
|
|
|
|
# 注册新连接
|
|
|
self._connections[user_id] = connection
|
|
|
self._users[user_id] = user_info
|
|
|
|
|
|
logger.info(f"User {user_id} registered from {connection.ip_address}:{connection.port}")
|
|
|
return True
|
|
|
|
|
|
def unregister_user(self, user_id: str) -> None:
|
|
|
"""
|
|
|
注销用户(同步版本,用于消息处理器回调)
|
|
|
|
|
|
通知服务器更新在线状态 (需求 2.5)
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
"""
|
|
|
# 这是同步版本,实际使用异步版本
|
|
|
pass
|
|
|
|
|
|
async def _unregister_user_async(self, user_id: str) -> None:
|
|
|
"""
|
|
|
注销用户(异步版本)
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
"""
|
|
|
await self._cleanup_connection(user_id)
|
|
|
|
|
|
def get_online_users(self) -> List[UserInfo]:
|
|
|
"""
|
|
|
获取在线用户列表
|
|
|
|
|
|
显示当前所有在线用户 (需求 2.3)
|
|
|
|
|
|
Returns:
|
|
|
在线用户信息列表
|
|
|
"""
|
|
|
online_users = []
|
|
|
for user_id, user_info in self._users.items():
|
|
|
if user_id in self._connections and user_info.status == UserStatus.ONLINE:
|
|
|
online_users.append(user_info)
|
|
|
return online_users
|
|
|
|
|
|
def get_user_info(self, user_id: str) -> Optional[UserInfo]:
|
|
|
"""
|
|
|
获取指定用户信息
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
|
|
|
Returns:
|
|
|
用户信息,如果用户不存在则返回None
|
|
|
"""
|
|
|
return self._users.get(user_id)
|
|
|
|
|
|
def is_user_online(self, user_id: str) -> bool:
|
|
|
"""
|
|
|
检查用户是否在线
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
|
|
|
Returns:
|
|
|
用户在线返回True,否则返回False
|
|
|
"""
|
|
|
return (
|
|
|
user_id in self._connections and
|
|
|
user_id in self._users and
|
|
|
self._users[user_id].status == UserStatus.ONLINE
|
|
|
)
|
|
|
|
|
|
# ==================== 消息处理器回调 ====================
|
|
|
|
|
|
def _handle_user_register(self, message: Message) -> None:
|
|
|
"""处理用户注册消息(同步回调)"""
|
|
|
# 实际处理在 _process_register 中完成
|
|
|
pass
|
|
|
|
|
|
def _handle_user_unregister(self, message: Message) -> None:
|
|
|
"""处理用户注销消息(同步回调)"""
|
|
|
# 实际处理在异步版本中完成
|
|
|
pass
|
|
|
|
|
|
async def _handle_user_unregister_async(self, message: Message) -> None:
|
|
|
"""处理用户注销消息(异步版本)"""
|
|
|
user_id = message.sender_id
|
|
|
await self._unregister_user_async(user_id)
|
|
|
logger.info(f"User {user_id} unregistered")
|
|
|
|
|
|
def _handle_user_list_request(self, message: Message) -> None:
|
|
|
"""处理用户列表请求(同步回调)"""
|
|
|
# 实际处理在异步版本中完成
|
|
|
pass
|
|
|
|
|
|
async def _handle_user_list_request_async(self, message: Message) -> None:
|
|
|
"""处理用户列表请求(异步版本)"""
|
|
|
user_id = message.sender_id
|
|
|
|
|
|
if user_id not in self._connections:
|
|
|
return
|
|
|
|
|
|
# 获取在线用户列表
|
|
|
online_users = self.get_online_users()
|
|
|
|
|
|
# 构建响应
|
|
|
users_data = [user.to_dict() for user in online_users]
|
|
|
payload = json.dumps(users_data).encode('utf-8')
|
|
|
|
|
|
response = Message(
|
|
|
msg_type=MessageType.USER_LIST_RESPONSE,
|
|
|
sender_id="server",
|
|
|
receiver_id=user_id,
|
|
|
timestamp=time.time(),
|
|
|
payload=payload
|
|
|
)
|
|
|
|
|
|
await self._send_message(self._connections[user_id].writer, response)
|
|
|
|
|
|
def _handle_heartbeat(self, message: Message) -> None:
|
|
|
"""处理心跳消息(同步回调)"""
|
|
|
# 心跳时间更新在 _process_message 中完成
|
|
|
pass
|
|
|
|
|
|
async def _handle_heartbeat_async(self, message: Message) -> None:
|
|
|
"""处理心跳消息(异步版本)"""
|
|
|
user_id = message.sender_id
|
|
|
|
|
|
if user_id in self._connections:
|
|
|
self._connections[user_id].last_heartbeat = time.time()
|
|
|
|
|
|
# 发送心跳响应
|
|
|
response = Message(
|
|
|
msg_type=MessageType.HEARTBEAT,
|
|
|
sender_id="server",
|
|
|
receiver_id=user_id,
|
|
|
timestamp=time.time(),
|
|
|
payload=b""
|
|
|
)
|
|
|
await self._send_message(self._connections[user_id].writer, response)
|
|
|
|
|
|
def _handle_relay_message(self, message: Message) -> None:
|
|
|
"""处理需要转发的消息(同步回调)"""
|
|
|
# 实际处理在异步版本中完成
|
|
|
pass
|
|
|
|
|
|
# ==================== 消息转发功能 (需求 8.4, 8.5) ====================
|
|
|
|
|
|
async def relay_message(self, message: Message) -> bool:
|
|
|
"""
|
|
|
转发消息到目标客户端
|
|
|
|
|
|
将消息路由到目标客户端 (需求 8.4)
|
|
|
|
|
|
Args:
|
|
|
message: 要转发的消息
|
|
|
|
|
|
Returns:
|
|
|
转发成功返回True,否则返回False
|
|
|
"""
|
|
|
return await self._relay_message_async(message)
|
|
|
|
|
|
async def _relay_message_async(self, message: Message) -> bool:
|
|
|
"""
|
|
|
转发消息(异步版本)
|
|
|
|
|
|
Args:
|
|
|
message: 要转发的消息
|
|
|
|
|
|
Returns:
|
|
|
转发成功返回True,否则返回False
|
|
|
"""
|
|
|
receiver_id = message.receiver_id
|
|
|
sender_id = message.sender_id
|
|
|
|
|
|
# 记录语音通话相关消息的日志
|
|
|
voice_types = {
|
|
|
MessageType.VOICE_CALL_REQUEST,
|
|
|
MessageType.VOICE_CALL_ACCEPT,
|
|
|
MessageType.VOICE_CALL_REJECT,
|
|
|
MessageType.VOICE_CALL_END,
|
|
|
MessageType.VOICE_DATA,
|
|
|
}
|
|
|
if message.msg_type in voice_types:
|
|
|
logger.info(f"Voice message: {message.msg_type.value} from {sender_id} to {receiver_id}")
|
|
|
|
|
|
# 检查目标用户是否在线
|
|
|
if self.is_user_online(receiver_id):
|
|
|
# 在线,直接转发
|
|
|
conn = self._connections[receiver_id]
|
|
|
success = await self._send_message(conn.writer, message)
|
|
|
|
|
|
if success:
|
|
|
logger.debug(f"Message relayed from {sender_id} to {receiver_id}")
|
|
|
# 发送ACK给发送者
|
|
|
if sender_id in self._connections:
|
|
|
ack = self._create_ack_message(
|
|
|
"server", sender_id,
|
|
|
f"Message delivered to {receiver_id}"
|
|
|
)
|
|
|
await self._send_message(self._connections[sender_id].writer, ack)
|
|
|
return True
|
|
|
else:
|
|
|
logger.error(f"Failed to relay message to {receiver_id}")
|
|
|
return False
|
|
|
else:
|
|
|
# 离线,缓存消息
|
|
|
self.cache_offline_message(receiver_id, message)
|
|
|
logger.info(f"User {receiver_id} offline, message cached")
|
|
|
|
|
|
# 通知发送者消息已缓存
|
|
|
if sender_id in self._connections:
|
|
|
ack = self._create_ack_message(
|
|
|
"server", sender_id,
|
|
|
f"User {receiver_id} is offline, message cached"
|
|
|
)
|
|
|
await self._send_message(self._connections[sender_id].writer, ack)
|
|
|
return True
|
|
|
|
|
|
def cache_offline_message(self, user_id: str, message: Message) -> None:
|
|
|
"""
|
|
|
缓存离线消息
|
|
|
|
|
|
缓存消息并在客户端上线后投递 (需求 8.5)
|
|
|
|
|
|
Args:
|
|
|
user_id: 目标用户ID
|
|
|
message: 要缓存的消息
|
|
|
"""
|
|
|
if user_id not in self._offline_messages:
|
|
|
self._offline_messages[user_id] = deque(maxlen=self._max_offline_messages)
|
|
|
|
|
|
self._offline_messages[user_id].append(message)
|
|
|
logger.debug(f"Cached offline message for user {user_id}, "
|
|
|
f"total: {len(self._offline_messages[user_id])}")
|
|
|
|
|
|
async def _deliver_offline_messages(self, user_id: str) -> None:
|
|
|
"""
|
|
|
投递离线消息
|
|
|
|
|
|
当用户上线时投递所有缓存的消息 (需求 8.5)
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
"""
|
|
|
if user_id not in self._offline_messages:
|
|
|
return
|
|
|
|
|
|
if user_id not in self._connections:
|
|
|
return
|
|
|
|
|
|
messages = self._offline_messages.pop(user_id)
|
|
|
conn = self._connections[user_id]
|
|
|
|
|
|
delivered_count = 0
|
|
|
for message in messages:
|
|
|
success = await self._send_message(conn.writer, message)
|
|
|
if success:
|
|
|
delivered_count += 1
|
|
|
else:
|
|
|
# 如果发送失败,将剩余消息放回缓存
|
|
|
remaining = list(messages)[delivered_count:]
|
|
|
if remaining:
|
|
|
self._offline_messages[user_id] = deque(remaining, maxlen=self._max_offline_messages)
|
|
|
break
|
|
|
|
|
|
if delivered_count > 0:
|
|
|
logger.info(f"Delivered {delivered_count} offline messages to user {user_id}")
|
|
|
|
|
|
def get_offline_message_count(self, user_id: str) -> int:
|
|
|
"""
|
|
|
获取用户的离线消息数量
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
|
|
|
Returns:
|
|
|
离线消息数量
|
|
|
"""
|
|
|
if user_id in self._offline_messages:
|
|
|
return len(self._offline_messages[user_id])
|
|
|
return 0
|
|
|
|
|
|
def clear_offline_messages(self, user_id: str) -> None:
|
|
|
"""
|
|
|
清除用户的离线消息
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
"""
|
|
|
if user_id in self._offline_messages:
|
|
|
del self._offline_messages[user_id]
|
|
|
logger.debug(f"Cleared offline messages for user {user_id}")
|
|
|
|
|
|
# ==================== 辅助方法 ====================
|
|
|
|
|
|
def _create_error_message(self, sender_id: str, receiver_id: str, error: str) -> Message:
|
|
|
"""
|
|
|
创建错误消息
|
|
|
|
|
|
Args:
|
|
|
sender_id: 发送者ID
|
|
|
receiver_id: 接收者ID
|
|
|
error: 错误信息
|
|
|
|
|
|
Returns:
|
|
|
错误消息对象
|
|
|
"""
|
|
|
return Message(
|
|
|
msg_type=MessageType.ERROR,
|
|
|
sender_id=sender_id,
|
|
|
receiver_id=receiver_id,
|
|
|
timestamp=time.time(),
|
|
|
payload=error.encode('utf-8')
|
|
|
)
|
|
|
|
|
|
def _create_ack_message(self, sender_id: str, receiver_id: str, info: str) -> Message:
|
|
|
"""
|
|
|
创建确认消息
|
|
|
|
|
|
Args:
|
|
|
sender_id: 发送者ID
|
|
|
receiver_id: 接收者ID
|
|
|
info: 确认信息
|
|
|
|
|
|
Returns:
|
|
|
确认消息对象
|
|
|
"""
|
|
|
return Message(
|
|
|
msg_type=MessageType.ACK,
|
|
|
sender_id=sender_id,
|
|
|
receiver_id=receiver_id,
|
|
|
timestamp=time.time(),
|
|
|
payload=info.encode('utf-8')
|
|
|
)
|
|
|
|
|
|
# ==================== 属性访问 ====================
|
|
|
|
|
|
@property
|
|
|
def is_running(self) -> bool:
|
|
|
"""服务器是否正在运行"""
|
|
|
return self._running
|
|
|
|
|
|
@property
|
|
|
def connection_count(self) -> int:
|
|
|
"""当前连接数"""
|
|
|
return len(self._connections)
|
|
|
|
|
|
@property
|
|
|
def user_count(self) -> int:
|
|
|
"""注册用户数"""
|
|
|
return len(self._users)
|
|
|
|
|
|
@property
|
|
|
def online_user_count(self) -> int:
|
|
|
"""在线用户数"""
|
|
|
return len(self.get_online_users())
|