You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

863 lines
29 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# P2P Network Communication - Relay Server
"""
中转服务器模块
负责管理客户端连接、用户注册、消息转发和离线消息缓存
需求: 8.1, 8.2, 8.4, 8.5, 1.7, 2.2, 2.3, 2.5
"""
import asyncio
import logging
import time
import json
from asyncio import StreamReader, StreamWriter
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Set
from collections import deque
from shared.models import (
Message, MessageType, UserInfo, UserStatus
)
from shared.message_handler import (
MessageHandler, MessageValidationError, MessageSerializationError
)
from config import ServerConfig
# 设置日志
logger = logging.getLogger(__name__)
class ServerError(Exception):
"""服务器错误基类"""
pass
class ConnectionError(ServerError):
"""连接错误"""
pass
class UserRegistrationError(ServerError):
"""用户注册错误"""
pass
@dataclass
class ClientConnection:
"""客户端连接信息"""
user_id: str
reader: StreamReader
writer: StreamWriter
connected_at: datetime = field(default_factory=datetime.now)
last_heartbeat: float = field(default_factory=time.time)
ip_address: str = ""
port: int = 0
@property
def is_alive(self) -> bool:
"""检查连接是否存活(基于心跳)"""
return time.time() - self.last_heartbeat < 60 # 60秒超时
class RelayServer:
"""
中转服务器
负责:
- 监听指定端口等待客户端连接 (需求 8.1)
- 验证客户端身份并分配会话 (需求 8.2)
- 维护所有活跃连接的状态 (需求 8.3)
- 将消息路由到目标客户端 (需求 8.4)
- 缓存离线消息并在客户端上线后投递 (需求 8.5)
- 记录客户端信息并维护在线用户列表 (需求 1.7)
- 用户注册和注销 (需求 2.2, 2.5)
- 获取在线用户列表 (需求 2.3)
"""
def __init__(self, config: Optional[ServerConfig] = None):
"""
初始化服务器
Args:
config: 服务器配置如果为None则使用默认配置
"""
self.config = config or ServerConfig()
self.host = self.config.host
self.port = self.config.port
self.max_connections = self.config.max_connections
# 连接管理
self._connections: Dict[str, ClientConnection] = {} # user_id -> connection
self._users: Dict[str, UserInfo] = {} # user_id -> user_info
# 离线消息缓存 (user_id -> list of messages)
self._offline_messages: Dict[str, deque] = {}
self._max_offline_messages = 1000 # 每用户最大离线消息数
# 服务器状态
self._server: Optional[asyncio.AbstractServer] = None
self._running = False
self._lock = asyncio.Lock()
# 消息处理器
self._message_handler = MessageHandler()
self._setup_message_handlers()
# 心跳检查任务
self._heartbeat_task: Optional[asyncio.Task] = None
logger.info(f"RelayServer initialized: {self.host}:{self.port}")
def _setup_message_handlers(self) -> None:
"""设置消息处理器"""
self._message_handler.register_handler(
MessageType.USER_REGISTER, self._handle_user_register
)
self._message_handler.register_handler(
MessageType.USER_UNREGISTER, self._handle_user_unregister
)
self._message_handler.register_handler(
MessageType.USER_LIST_REQUEST, self._handle_user_list_request
)
self._message_handler.register_handler(
MessageType.HEARTBEAT, self._handle_heartbeat
)
# 设置默认处理器用于转发其他类型的消息
self._message_handler.set_default_handler(self._handle_relay_message)
async def start(self) -> None:
"""
启动服务器
监听指定端口等待客户端连接 (需求 8.1)
Raises:
ServerError: 服务器启动失败时抛出
"""
if self._running:
logger.warning("Server is already running")
return
try:
self._server = await asyncio.start_server(
self._handle_client,
self.host,
self.port
)
self._running = True
# 启动心跳检查任务
self._heartbeat_task = asyncio.create_task(self._heartbeat_checker())
addr = self._server.sockets[0].getsockname()
logger.info(f"Server started on {addr[0]}:{addr[1]}")
async with self._server:
await self._server.serve_forever()
except OSError as e:
raise ServerError(f"Failed to start server: {e}")
except Exception as e:
logger.error(f"Server error: {e}")
raise ServerError(f"Server error: {e}")
async def stop(self) -> None:
"""
停止服务器
关闭所有连接并释放资源
"""
if not self._running:
logger.warning("Server is not running")
return
self._running = False
# 取消心跳检查任务
if self._heartbeat_task:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
# 关闭所有客户端连接
async with self._lock:
for user_id, conn in list(self._connections.items()):
try:
conn.writer.close()
await conn.writer.wait_closed()
except Exception as e:
logger.error(f"Error closing connection for {user_id}: {e}")
self._connections.clear()
# 关闭服务器
if self._server:
self._server.close()
await self._server.wait_closed()
self._server = None
logger.info("Server stopped")
async def _handle_client(self, reader: StreamReader, writer: StreamWriter) -> None:
"""
处理客户端连接
验证客户端身份并分配会话 (需求 8.2)
Args:
reader: 异步读取器
writer: 异步写入器
"""
addr = writer.get_extra_info('peername')
ip_address = addr[0] if addr else "unknown"
port = addr[1] if addr else 0
logger.info(f"New connection from {ip_address}:{port}")
# 检查连接数限制
if len(self._connections) >= self.max_connections:
logger.warning(f"Max connections reached, rejecting {ip_address}:{port}")
error_msg = self._create_error_message(
"server", "unknown", "Server is at maximum capacity"
)
try:
await self._send_message(writer, error_msg)
except Exception:
pass
writer.close()
await writer.wait_closed()
return
user_id: Optional[str] = None
try:
while self._running:
# 读取消息
message = await self._read_message(reader)
if message is None:
break
# 如果是注册消息记录user_id
if message.msg_type == MessageType.USER_REGISTER:
user_id = message.sender_id
# 创建临时连接对象用于注册
temp_conn = ClientConnection(
user_id=user_id,
reader=reader,
writer=writer,
ip_address=ip_address,
port=port
)
await self._process_register(message, temp_conn)
elif user_id:
# 更新心跳时间
if user_id in self._connections:
self._connections[user_id].last_heartbeat = time.time()
# 处理消息
await self._process_message(message, user_id)
else:
# 未注册的连接发送非注册消息
logger.warning(f"Unregistered client sent message: {message.msg_type}")
error_msg = self._create_error_message(
"server", message.sender_id, "Please register first"
)
await self._send_message(writer, error_msg)
except asyncio.CancelledError:
logger.info(f"Connection cancelled for {ip_address}:{port}")
except Exception as e:
logger.error(f"Error handling client {ip_address}:{port}: {e}")
finally:
# 清理连接
if user_id:
await self._cleanup_connection(user_id)
else:
writer.close()
try:
await writer.wait_closed()
except Exception:
pass
logger.info(f"Connection closed for {ip_address}:{port}")
async def _read_message(self, reader: StreamReader) -> Optional[Message]:
"""
从流中读取消息
Args:
reader: 异步读取器
Returns:
解析后的消息对象如果连接关闭则返回None
"""
try:
# 读取消息头
header = await reader.readexactly(self._message_handler.HEADER_SIZE)
if not header:
return None
# 解析消息长度
import struct
payload_length, version = struct.unpack(
self._message_handler.HEADER_FORMAT, header
)
# 读取消息体
payload = await reader.readexactly(payload_length)
# 反序列化消息
full_data = header + payload
return self._message_handler.deserialize(full_data)
except asyncio.IncompleteReadError:
return None
except MessageSerializationError as e:
logger.error(f"Failed to deserialize message: {e}")
return None
except Exception as e:
logger.error(f"Error reading message: {e}")
return None
async def _send_message(self, writer: StreamWriter, message: Message) -> bool:
"""
发送消息到客户端
Args:
writer: 异步写入器
message: 要发送的消息
Returns:
发送成功返回True否则返回False
"""
try:
data = self._message_handler.serialize(message)
writer.write(data)
await writer.drain()
return True
except Exception as e:
logger.error(f"Failed to send message: {e}")
return False
async def _process_message(self, message: Message, sender_user_id: str) -> None:
"""
处理接收到的消息
Args:
message: 接收到的消息
sender_user_id: 发送者用户ID
"""
try:
# 验证消息
self._message_handler.validate(message)
# 根据消息类型处理
if message.msg_type == MessageType.USER_UNREGISTER:
await self._handle_user_unregister_async(message)
elif message.msg_type == MessageType.USER_LIST_REQUEST:
await self._handle_user_list_request_async(message)
elif message.msg_type == MessageType.HEARTBEAT:
await self._handle_heartbeat_async(message)
else:
# 转发消息
await self._relay_message_async(message)
except MessageValidationError as e:
logger.error(f"Message validation failed: {e}")
error_msg = self._create_error_message(
"server", sender_user_id, f"Invalid message: {e}"
)
if sender_user_id in self._connections:
await self._send_message(
self._connections[sender_user_id].writer, error_msg
)
async def _process_register(self, message: Message, conn: ClientConnection) -> None:
"""
处理用户注册
Args:
message: 注册消息
conn: 客户端连接
"""
user_id = message.sender_id
try:
# 解析用户信息
user_info = UserInfo.deserialize(message.payload)
user_info.status = UserStatus.ONLINE
user_info.ip_address = conn.ip_address
user_info.port = conn.port
user_info.last_seen = datetime.now()
# 注册用户
success = await self._register_user_async(user_id, user_info, conn)
if success:
# 发送注册成功响应
ack_msg = self._create_ack_message("server", user_id, "Registration successful")
await self._send_message(conn.writer, ack_msg)
# 投递离线消息
await self._deliver_offline_messages(user_id)
else:
error_msg = self._create_error_message(
"server", user_id, "Registration failed"
)
await self._send_message(conn.writer, error_msg)
except Exception as e:
logger.error(f"Error processing registration for {user_id}: {e}")
error_msg = self._create_error_message(
"server", user_id, f"Registration error: {e}"
)
await self._send_message(conn.writer, error_msg)
async def _cleanup_connection(self, user_id: str) -> None:
"""
清理用户连接
Args:
user_id: 用户ID
"""
async with self._lock:
if user_id in self._connections:
conn = self._connections[user_id]
try:
conn.writer.close()
await conn.writer.wait_closed()
except Exception:
pass
del self._connections[user_id]
# 更新用户状态为离线
if user_id in self._users:
self._users[user_id].status = UserStatus.OFFLINE
self._users[user_id].last_seen = datetime.now()
logger.info(f"User {user_id} disconnected")
async def _heartbeat_checker(self) -> None:
"""心跳检查任务,定期检查连接是否存活"""
while self._running:
try:
await asyncio.sleep(self.config.heartbeat_interval)
async with self._lock:
dead_connections = []
for user_id, conn in self._connections.items():
if not conn.is_alive:
dead_connections.append(user_id)
for user_id in dead_connections:
logger.info(f"Heartbeat timeout for user {user_id}")
await self._cleanup_connection(user_id)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in heartbeat checker: {e}")
# ==================== 用户管理功能 (需求 1.7, 2.2, 2.3, 2.5) ====================
def register_user(self, user_id: str, connection: ClientConnection) -> bool:
"""
注册用户(同步版本,用于消息处理器回调)
记录客户端信息并维护在线用户列表 (需求 1.7)
允许用户注册 (需求 2.2)
Args:
user_id: 用户ID
connection: 客户端连接
Returns:
注册成功返回True否则返回False
"""
# 这是同步版本,实际使用异步版本
return False
async def _register_user_async(
self, user_id: str, user_info: UserInfo, connection: ClientConnection
) -> bool:
"""
注册用户(异步版本)
Args:
user_id: 用户ID
user_info: 用户信息
connection: 客户端连接
Returns:
注册成功返回True否则返回False
"""
async with self._lock:
# 检查是否已有连接(可能是重连)
if user_id in self._connections:
old_conn = self._connections[user_id]
try:
old_conn.writer.close()
await old_conn.writer.wait_closed()
except Exception:
pass
logger.info(f"User {user_id} reconnected, closing old connection")
# 注册新连接
self._connections[user_id] = connection
self._users[user_id] = user_info
logger.info(f"User {user_id} registered from {connection.ip_address}:{connection.port}")
return True
def unregister_user(self, user_id: str) -> None:
"""
注销用户(同步版本,用于消息处理器回调)
通知服务器更新在线状态 (需求 2.5)
Args:
user_id: 用户ID
"""
# 这是同步版本,实际使用异步版本
pass
async def _unregister_user_async(self, user_id: str) -> None:
"""
注销用户(异步版本)
Args:
user_id: 用户ID
"""
await self._cleanup_connection(user_id)
def get_online_users(self) -> List[UserInfo]:
"""
获取在线用户列表
显示当前所有在线用户 (需求 2.3)
Returns:
在线用户信息列表
"""
online_users = []
for user_id, user_info in self._users.items():
if user_id in self._connections and user_info.status == UserStatus.ONLINE:
online_users.append(user_info)
return online_users
def get_user_info(self, user_id: str) -> Optional[UserInfo]:
"""
获取指定用户信息
Args:
user_id: 用户ID
Returns:
用户信息如果用户不存在则返回None
"""
return self._users.get(user_id)
def is_user_online(self, user_id: str) -> bool:
"""
检查用户是否在线
Args:
user_id: 用户ID
Returns:
用户在线返回True否则返回False
"""
return (
user_id in self._connections and
user_id in self._users and
self._users[user_id].status == UserStatus.ONLINE
)
# ==================== 消息处理器回调 ====================
def _handle_user_register(self, message: Message) -> None:
"""处理用户注册消息(同步回调)"""
# 实际处理在 _process_register 中完成
pass
def _handle_user_unregister(self, message: Message) -> None:
"""处理用户注销消息(同步回调)"""
# 实际处理在异步版本中完成
pass
async def _handle_user_unregister_async(self, message: Message) -> None:
"""处理用户注销消息(异步版本)"""
user_id = message.sender_id
await self._unregister_user_async(user_id)
logger.info(f"User {user_id} unregistered")
def _handle_user_list_request(self, message: Message) -> None:
"""处理用户列表请求(同步回调)"""
# 实际处理在异步版本中完成
pass
async def _handle_user_list_request_async(self, message: Message) -> None:
"""处理用户列表请求(异步版本)"""
user_id = message.sender_id
if user_id not in self._connections:
return
# 获取在线用户列表
online_users = self.get_online_users()
# 构建响应
users_data = [user.to_dict() for user in online_users]
payload = json.dumps(users_data).encode('utf-8')
response = Message(
msg_type=MessageType.USER_LIST_RESPONSE,
sender_id="server",
receiver_id=user_id,
timestamp=time.time(),
payload=payload
)
await self._send_message(self._connections[user_id].writer, response)
def _handle_heartbeat(self, message: Message) -> None:
"""处理心跳消息(同步回调)"""
# 心跳时间更新在 _process_message 中完成
pass
async def _handle_heartbeat_async(self, message: Message) -> None:
"""处理心跳消息(异步版本)"""
user_id = message.sender_id
if user_id in self._connections:
self._connections[user_id].last_heartbeat = time.time()
# 发送心跳响应
response = Message(
msg_type=MessageType.HEARTBEAT,
sender_id="server",
receiver_id=user_id,
timestamp=time.time(),
payload=b""
)
await self._send_message(self._connections[user_id].writer, response)
def _handle_relay_message(self, message: Message) -> None:
"""处理需要转发的消息(同步回调)"""
# 实际处理在异步版本中完成
pass
# ==================== 消息转发功能 (需求 8.4, 8.5) ====================
async def relay_message(self, message: Message) -> bool:
"""
转发消息到目标客户端
将消息路由到目标客户端 (需求 8.4)
Args:
message: 要转发的消息
Returns:
转发成功返回True否则返回False
"""
return await self._relay_message_async(message)
async def _relay_message_async(self, message: Message) -> bool:
"""
转发消息(异步版本)
Args:
message: 要转发的消息
Returns:
转发成功返回True否则返回False
"""
receiver_id = message.receiver_id
sender_id = message.sender_id
# 记录语音通话相关消息的日志
voice_types = {
MessageType.VOICE_CALL_REQUEST,
MessageType.VOICE_CALL_ACCEPT,
MessageType.VOICE_CALL_REJECT,
MessageType.VOICE_CALL_END,
MessageType.VOICE_DATA,
}
if message.msg_type in voice_types:
logger.info(f"Voice message: {message.msg_type.value} from {sender_id} to {receiver_id}")
# 检查目标用户是否在线
if self.is_user_online(receiver_id):
# 在线,直接转发
conn = self._connections[receiver_id]
success = await self._send_message(conn.writer, message)
if success:
logger.debug(f"Message relayed from {sender_id} to {receiver_id}")
# 发送ACK给发送者
if sender_id in self._connections:
ack = self._create_ack_message(
"server", sender_id,
f"Message delivered to {receiver_id}"
)
await self._send_message(self._connections[sender_id].writer, ack)
return True
else:
logger.error(f"Failed to relay message to {receiver_id}")
return False
else:
# 离线,缓存消息
self.cache_offline_message(receiver_id, message)
logger.info(f"User {receiver_id} offline, message cached")
# 通知发送者消息已缓存
if sender_id in self._connections:
ack = self._create_ack_message(
"server", sender_id,
f"User {receiver_id} is offline, message cached"
)
await self._send_message(self._connections[sender_id].writer, ack)
return True
def cache_offline_message(self, user_id: str, message: Message) -> None:
"""
缓存离线消息
缓存消息并在客户端上线后投递 (需求 8.5)
Args:
user_id: 目标用户ID
message: 要缓存的消息
"""
if user_id not in self._offline_messages:
self._offline_messages[user_id] = deque(maxlen=self._max_offline_messages)
self._offline_messages[user_id].append(message)
logger.debug(f"Cached offline message for user {user_id}, "
f"total: {len(self._offline_messages[user_id])}")
async def _deliver_offline_messages(self, user_id: str) -> None:
"""
投递离线消息
当用户上线时投递所有缓存的消息 (需求 8.5)
Args:
user_id: 用户ID
"""
if user_id not in self._offline_messages:
return
if user_id not in self._connections:
return
messages = self._offline_messages.pop(user_id)
conn = self._connections[user_id]
delivered_count = 0
for message in messages:
success = await self._send_message(conn.writer, message)
if success:
delivered_count += 1
else:
# 如果发送失败,将剩余消息放回缓存
remaining = list(messages)[delivered_count:]
if remaining:
self._offline_messages[user_id] = deque(remaining, maxlen=self._max_offline_messages)
break
if delivered_count > 0:
logger.info(f"Delivered {delivered_count} offline messages to user {user_id}")
def get_offline_message_count(self, user_id: str) -> int:
"""
获取用户的离线消息数量
Args:
user_id: 用户ID
Returns:
离线消息数量
"""
if user_id in self._offline_messages:
return len(self._offline_messages[user_id])
return 0
def clear_offline_messages(self, user_id: str) -> None:
"""
清除用户的离线消息
Args:
user_id: 用户ID
"""
if user_id in self._offline_messages:
del self._offline_messages[user_id]
logger.debug(f"Cleared offline messages for user {user_id}")
# ==================== 辅助方法 ====================
def _create_error_message(self, sender_id: str, receiver_id: str, error: str) -> Message:
"""
创建错误消息
Args:
sender_id: 发送者ID
receiver_id: 接收者ID
error: 错误信息
Returns:
错误消息对象
"""
return Message(
msg_type=MessageType.ERROR,
sender_id=sender_id,
receiver_id=receiver_id,
timestamp=time.time(),
payload=error.encode('utf-8')
)
def _create_ack_message(self, sender_id: str, receiver_id: str, info: str) -> Message:
"""
创建确认消息
Args:
sender_id: 发送者ID
receiver_id: 接收者ID
info: 确认信息
Returns:
确认消息对象
"""
return Message(
msg_type=MessageType.ACK,
sender_id=sender_id,
receiver_id=receiver_id,
timestamp=time.time(),
payload=info.encode('utf-8')
)
# ==================== 属性访问 ====================
@property
def is_running(self) -> bool:
"""服务器是否正在运行"""
return self._running
@property
def connection_count(self) -> int:
"""当前连接数"""
return len(self._connections)
@property
def user_count(self) -> int:
"""注册用户数"""
return len(self._users)
@property
def online_user_count(self) -> int:
"""在线用户数"""
return len(self.get_online_users())