确保服务器模块测试通过

main
杨博文 3 weeks ago
parent d6b92b67dc
commit 842aa282cc

@ -5,3 +5,11 @@
"""
__version__ = "0.1.0"
from server.relay_server import (
RelayServer,
ClientConnection,
ServerError,
ConnectionError,
UserRegistrationError,
)

@ -0,0 +1,285 @@
# P2P Network Communication - Database Module
"""
数据库模块提供MySQL数据库连接池管理和表初始化功能
使用aiomysql异步驱动实现高性能数据库操作
"""
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import Optional, AsyncGenerator
import aiomysql
from config import ServerConfig, load_server_config
logger = logging.getLogger(__name__)
class DatabaseManager:
"""数据库管理器,负责连接池管理和数据库初始化"""
def __init__(self, config: Optional[ServerConfig] = None):
"""
初始化数据库管理器
Args:
config: 服务器配置如果为None则从环境变量加载
"""
self.config = config or load_server_config()
self._pool: Optional[aiomysql.Pool] = None
self._initialized = False
async def initialize(self) -> None:
"""初始化数据库连接池和表结构"""
if self._initialized:
return
await self._create_pool()
await self._create_tables()
self._initialized = True
logger.info("Database initialized successfully")
async def _create_pool(self) -> None:
"""创建数据库连接池"""
try:
self._pool = await aiomysql.create_pool(
host=self.config.db_host,
port=self.config.db_port,
user=self.config.db_user,
password=self.config.db_password,
db=self.config.db_name,
minsize=1,
maxsize=self.config.db_pool_size,
charset='utf8mb4',
autocommit=True,
echo=False,
)
logger.info(f"Database connection pool created: {self.config.db_host}:{self.config.db_port}")
except Exception as e:
logger.error(f"Failed to create database pool: {e}")
raise
async def _create_tables(self) -> None:
"""创建数据库表结构"""
create_users_table = """
CREATE TABLE IF NOT EXISTS users (
user_id VARCHAR(64) PRIMARY KEY,
username VARCHAR(100) UNIQUE NOT NULL,
display_name VARCHAR(200) NOT NULL,
password_hash VARCHAR(255) NOT NULL DEFAULT '',
public_key BLOB,
status VARCHAR(20) DEFAULT 'offline',
ip_address VARCHAR(45) DEFAULT '',
port INT DEFAULT 0,
last_seen TIMESTAMP NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_username (username),
INDEX idx_status (status)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
"""
create_messages_table = """
CREATE TABLE IF NOT EXISTS messages (
message_id VARCHAR(64) PRIMARY KEY,
sender_id VARCHAR(64) NOT NULL,
receiver_id VARCHAR(64) NOT NULL,
content_type VARCHAR(50) NOT NULL,
content TEXT,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_read BOOLEAN DEFAULT FALSE,
is_delivered BOOLEAN DEFAULT FALSE,
INDEX idx_sender (sender_id),
INDEX idx_receiver (receiver_id),
INDEX idx_timestamp (timestamp),
INDEX idx_conversation (sender_id, receiver_id, timestamp)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
"""
create_file_transfers_table = """
CREATE TABLE IF NOT EXISTS file_transfers (
transfer_id VARCHAR(64) PRIMARY KEY,
file_name VARCHAR(500) NOT NULL,
file_size BIGINT NOT NULL,
file_hash VARCHAR(128) NOT NULL,
sender_id VARCHAR(64) NOT NULL,
receiver_id VARCHAR(64) NOT NULL,
status VARCHAR(20) NOT NULL DEFAULT 'pending',
progress DECIMAL(5,2) DEFAULT 0,
start_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
end_time TIMESTAMP NULL,
INDEX idx_sender (sender_id),
INDEX idx_receiver (receiver_id),
INDEX idx_status (status)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
"""
create_offline_messages_table = """
CREATE TABLE IF NOT EXISTS offline_messages (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
message_data BLOB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_user_id (user_id),
INDEX idx_created_at (created_at)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
"""
tables = [
("users", create_users_table),
("messages", create_messages_table),
("file_transfers", create_file_transfers_table),
("offline_messages", create_offline_messages_table),
]
async with self.acquire() as conn:
async with conn.cursor() as cursor:
for table_name, create_sql in tables:
try:
await cursor.execute(create_sql)
logger.info(f"Table '{table_name}' created or already exists")
except Exception as e:
logger.error(f"Failed to create table '{table_name}': {e}")
raise
@asynccontextmanager
async def acquire(self) -> AsyncGenerator[aiomysql.Connection, None]:
"""
获取数据库连接的上下文管理器
Yields:
数据库连接对象
"""
if not self._pool:
raise RuntimeError("Database pool not initialized. Call initialize() first.")
conn = await self._pool.acquire()
try:
yield conn
finally:
self._pool.release(conn)
async def execute(self, query: str, args: tuple = ()) -> int:
"""
执行SQL语句INSERT, UPDATE, DELETE
Args:
query: SQL语句
args: 参数元组
Returns:
受影响的行数
"""
async with self.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, args)
return cursor.rowcount
async def fetch_one(self, query: str, args: tuple = ()) -> Optional[tuple]:
"""
查询单条记录
Args:
query: SQL语句
args: 参数元组
Returns:
查询结果元组如果没有结果返回None
"""
async with self.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, args)
return await cursor.fetchone()
async def fetch_all(self, query: str, args: tuple = ()) -> list:
"""
查询多条记录
Args:
query: SQL语句
args: 参数元组
Returns:
查询结果列表
"""
async with self.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, args)
return await cursor.fetchall()
async def fetch_dict(self, query: str, args: tuple = ()) -> Optional[dict]:
"""
查询单条记录并返回字典
Args:
query: SQL语句
args: 参数元组
Returns:
查询结果字典如果没有结果返回None
"""
async with self.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(query, args)
return await cursor.fetchone()
async def fetch_all_dict(self, query: str, args: tuple = ()) -> list:
"""
查询多条记录并返回字典列表
Args:
query: SQL语句
args: 参数元组
Returns:
查询结果字典列表
"""
async with self.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(query, args)
return await cursor.fetchall()
async def close(self) -> None:
"""关闭数据库连接池"""
if self._pool:
self._pool.close()
await self._pool.wait_closed()
self._pool = None
self._initialized = False
logger.info("Database connection pool closed")
@property
def is_initialized(self) -> bool:
"""检查数据库是否已初始化"""
return self._initialized
@property
def pool(self) -> Optional[aiomysql.Pool]:
"""获取连接池对象"""
return self._pool
# 全局数据库管理器实例
_db_manager: Optional[DatabaseManager] = None
async def get_database() -> DatabaseManager:
"""
获取全局数据库管理器实例
Returns:
DatabaseManager实例
"""
global _db_manager
if _db_manager is None:
_db_manager = DatabaseManager()
await _db_manager.initialize()
return _db_manager
async def close_database() -> None:
"""关闭全局数据库连接"""
global _db_manager
if _db_manager:
await _db_manager.close()
_db_manager = None

@ -0,0 +1,851 @@
# P2P Network Communication - Relay Server
"""
中转服务器模块
负责管理客户端连接用户注册消息转发和离线消息缓存
需求: 8.1, 8.2, 8.4, 8.5, 1.7, 2.2, 2.3, 2.5
"""
import asyncio
import logging
import time
import json
from asyncio import StreamReader, StreamWriter
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Set
from collections import deque
from shared.models import (
Message, MessageType, UserInfo, UserStatus
)
from shared.message_handler import (
MessageHandler, MessageValidationError, MessageSerializationError
)
from config import ServerConfig
# 设置日志
logger = logging.getLogger(__name__)
class ServerError(Exception):
"""服务器错误基类"""
pass
class ConnectionError(ServerError):
"""连接错误"""
pass
class UserRegistrationError(ServerError):
"""用户注册错误"""
pass
@dataclass
class ClientConnection:
"""客户端连接信息"""
user_id: str
reader: StreamReader
writer: StreamWriter
connected_at: datetime = field(default_factory=datetime.now)
last_heartbeat: float = field(default_factory=time.time)
ip_address: str = ""
port: int = 0
@property
def is_alive(self) -> bool:
"""检查连接是否存活(基于心跳)"""
return time.time() - self.last_heartbeat < 60 # 60秒超时
class RelayServer:
"""
中转服务器
负责:
- 监听指定端口等待客户端连接 (需求 8.1)
- 验证客户端身份并分配会话 (需求 8.2)
- 维护所有活跃连接的状态 (需求 8.3)
- 将消息路由到目标客户端 (需求 8.4)
- 缓存离线消息并在客户端上线后投递 (需求 8.5)
- 记录客户端信息并维护在线用户列表 (需求 1.7)
- 用户注册和注销 (需求 2.2, 2.5)
- 获取在线用户列表 (需求 2.3)
"""
def __init__(self, config: Optional[ServerConfig] = None):
"""
初始化服务器
Args:
config: 服务器配置如果为None则使用默认配置
"""
self.config = config or ServerConfig()
self.host = self.config.host
self.port = self.config.port
self.max_connections = self.config.max_connections
# 连接管理
self._connections: Dict[str, ClientConnection] = {} # user_id -> connection
self._users: Dict[str, UserInfo] = {} # user_id -> user_info
# 离线消息缓存 (user_id -> list of messages)
self._offline_messages: Dict[str, deque] = {}
self._max_offline_messages = 1000 # 每用户最大离线消息数
# 服务器状态
self._server: Optional[asyncio.AbstractServer] = None
self._running = False
self._lock = asyncio.Lock()
# 消息处理器
self._message_handler = MessageHandler()
self._setup_message_handlers()
# 心跳检查任务
self._heartbeat_task: Optional[asyncio.Task] = None
logger.info(f"RelayServer initialized: {self.host}:{self.port}")
def _setup_message_handlers(self) -> None:
"""设置消息处理器"""
self._message_handler.register_handler(
MessageType.USER_REGISTER, self._handle_user_register
)
self._message_handler.register_handler(
MessageType.USER_UNREGISTER, self._handle_user_unregister
)
self._message_handler.register_handler(
MessageType.USER_LIST_REQUEST, self._handle_user_list_request
)
self._message_handler.register_handler(
MessageType.HEARTBEAT, self._handle_heartbeat
)
# 设置默认处理器用于转发其他类型的消息
self._message_handler.set_default_handler(self._handle_relay_message)
async def start(self) -> None:
"""
启动服务器
监听指定端口等待客户端连接 (需求 8.1)
Raises:
ServerError: 服务器启动失败时抛出
"""
if self._running:
logger.warning("Server is already running")
return
try:
self._server = await asyncio.start_server(
self._handle_client,
self.host,
self.port
)
self._running = True
# 启动心跳检查任务
self._heartbeat_task = asyncio.create_task(self._heartbeat_checker())
addr = self._server.sockets[0].getsockname()
logger.info(f"Server started on {addr[0]}:{addr[1]}")
async with self._server:
await self._server.serve_forever()
except OSError as e:
raise ServerError(f"Failed to start server: {e}")
except Exception as e:
logger.error(f"Server error: {e}")
raise ServerError(f"Server error: {e}")
async def stop(self) -> None:
"""
停止服务器
关闭所有连接并释放资源
"""
if not self._running:
logger.warning("Server is not running")
return
self._running = False
# 取消心跳检查任务
if self._heartbeat_task:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
# 关闭所有客户端连接
async with self._lock:
for user_id, conn in list(self._connections.items()):
try:
conn.writer.close()
await conn.writer.wait_closed()
except Exception as e:
logger.error(f"Error closing connection for {user_id}: {e}")
self._connections.clear()
# 关闭服务器
if self._server:
self._server.close()
await self._server.wait_closed()
self._server = None
logger.info("Server stopped")
async def _handle_client(self, reader: StreamReader, writer: StreamWriter) -> None:
"""
处理客户端连接
验证客户端身份并分配会话 (需求 8.2)
Args:
reader: 异步读取器
writer: 异步写入器
"""
addr = writer.get_extra_info('peername')
ip_address = addr[0] if addr else "unknown"
port = addr[1] if addr else 0
logger.info(f"New connection from {ip_address}:{port}")
# 检查连接数限制
if len(self._connections) >= self.max_connections:
logger.warning(f"Max connections reached, rejecting {ip_address}:{port}")
error_msg = self._create_error_message(
"server", "unknown", "Server is at maximum capacity"
)
try:
await self._send_message(writer, error_msg)
except Exception:
pass
writer.close()
await writer.wait_closed()
return
user_id: Optional[str] = None
try:
while self._running:
# 读取消息
message = await self._read_message(reader)
if message is None:
break
# 如果是注册消息记录user_id
if message.msg_type == MessageType.USER_REGISTER:
user_id = message.sender_id
# 创建临时连接对象用于注册
temp_conn = ClientConnection(
user_id=user_id,
reader=reader,
writer=writer,
ip_address=ip_address,
port=port
)
await self._process_register(message, temp_conn)
elif user_id:
# 更新心跳时间
if user_id in self._connections:
self._connections[user_id].last_heartbeat = time.time()
# 处理消息
await self._process_message(message, user_id)
else:
# 未注册的连接发送非注册消息
logger.warning(f"Unregistered client sent message: {message.msg_type}")
error_msg = self._create_error_message(
"server", message.sender_id, "Please register first"
)
await self._send_message(writer, error_msg)
except asyncio.CancelledError:
logger.info(f"Connection cancelled for {ip_address}:{port}")
except Exception as e:
logger.error(f"Error handling client {ip_address}:{port}: {e}")
finally:
# 清理连接
if user_id:
await self._cleanup_connection(user_id)
else:
writer.close()
try:
await writer.wait_closed()
except Exception:
pass
logger.info(f"Connection closed for {ip_address}:{port}")
async def _read_message(self, reader: StreamReader) -> Optional[Message]:
"""
从流中读取消息
Args:
reader: 异步读取器
Returns:
解析后的消息对象如果连接关闭则返回None
"""
try:
# 读取消息头
header = await reader.readexactly(self._message_handler.HEADER_SIZE)
if not header:
return None
# 解析消息长度
import struct
payload_length, version = struct.unpack(
self._message_handler.HEADER_FORMAT, header
)
# 读取消息体
payload = await reader.readexactly(payload_length)
# 反序列化消息
full_data = header + payload
return self._message_handler.deserialize(full_data)
except asyncio.IncompleteReadError:
return None
except MessageSerializationError as e:
logger.error(f"Failed to deserialize message: {e}")
return None
except Exception as e:
logger.error(f"Error reading message: {e}")
return None
async def _send_message(self, writer: StreamWriter, message: Message) -> bool:
"""
发送消息到客户端
Args:
writer: 异步写入器
message: 要发送的消息
Returns:
发送成功返回True否则返回False
"""
try:
data = self._message_handler.serialize(message)
writer.write(data)
await writer.drain()
return True
except Exception as e:
logger.error(f"Failed to send message: {e}")
return False
async def _process_message(self, message: Message, sender_user_id: str) -> None:
"""
处理接收到的消息
Args:
message: 接收到的消息
sender_user_id: 发送者用户ID
"""
try:
# 验证消息
self._message_handler.validate(message)
# 根据消息类型处理
if message.msg_type == MessageType.USER_UNREGISTER:
await self._handle_user_unregister_async(message)
elif message.msg_type == MessageType.USER_LIST_REQUEST:
await self._handle_user_list_request_async(message)
elif message.msg_type == MessageType.HEARTBEAT:
await self._handle_heartbeat_async(message)
else:
# 转发消息
await self._relay_message_async(message)
except MessageValidationError as e:
logger.error(f"Message validation failed: {e}")
error_msg = self._create_error_message(
"server", sender_user_id, f"Invalid message: {e}"
)
if sender_user_id in self._connections:
await self._send_message(
self._connections[sender_user_id].writer, error_msg
)
async def _process_register(self, message: Message, conn: ClientConnection) -> None:
"""
处理用户注册
Args:
message: 注册消息
conn: 客户端连接
"""
user_id = message.sender_id
try:
# 解析用户信息
user_info = UserInfo.deserialize(message.payload)
user_info.status = UserStatus.ONLINE
user_info.ip_address = conn.ip_address
user_info.port = conn.port
user_info.last_seen = datetime.now()
# 注册用户
success = await self._register_user_async(user_id, user_info, conn)
if success:
# 发送注册成功响应
ack_msg = self._create_ack_message("server", user_id, "Registration successful")
await self._send_message(conn.writer, ack_msg)
# 投递离线消息
await self._deliver_offline_messages(user_id)
else:
error_msg = self._create_error_message(
"server", user_id, "Registration failed"
)
await self._send_message(conn.writer, error_msg)
except Exception as e:
logger.error(f"Error processing registration for {user_id}: {e}")
error_msg = self._create_error_message(
"server", user_id, f"Registration error: {e}"
)
await self._send_message(conn.writer, error_msg)
async def _cleanup_connection(self, user_id: str) -> None:
"""
清理用户连接
Args:
user_id: 用户ID
"""
async with self._lock:
if user_id in self._connections:
conn = self._connections[user_id]
try:
conn.writer.close()
await conn.writer.wait_closed()
except Exception:
pass
del self._connections[user_id]
# 更新用户状态为离线
if user_id in self._users:
self._users[user_id].status = UserStatus.OFFLINE
self._users[user_id].last_seen = datetime.now()
logger.info(f"User {user_id} disconnected")
async def _heartbeat_checker(self) -> None:
"""心跳检查任务,定期检查连接是否存活"""
while self._running:
try:
await asyncio.sleep(self.config.heartbeat_interval)
async with self._lock:
dead_connections = []
for user_id, conn in self._connections.items():
if not conn.is_alive:
dead_connections.append(user_id)
for user_id in dead_connections:
logger.info(f"Heartbeat timeout for user {user_id}")
await self._cleanup_connection(user_id)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in heartbeat checker: {e}")
# ==================== 用户管理功能 (需求 1.7, 2.2, 2.3, 2.5) ====================
def register_user(self, user_id: str, connection: ClientConnection) -> bool:
"""
注册用户同步版本用于消息处理器回调
记录客户端信息并维护在线用户列表 (需求 1.7)
允许用户注册 (需求 2.2)
Args:
user_id: 用户ID
connection: 客户端连接
Returns:
注册成功返回True否则返回False
"""
# 这是同步版本,实际使用异步版本
return False
async def _register_user_async(
self, user_id: str, user_info: UserInfo, connection: ClientConnection
) -> bool:
"""
注册用户异步版本
Args:
user_id: 用户ID
user_info: 用户信息
connection: 客户端连接
Returns:
注册成功返回True否则返回False
"""
async with self._lock:
# 检查是否已有连接(可能是重连)
if user_id in self._connections:
old_conn = self._connections[user_id]
try:
old_conn.writer.close()
await old_conn.writer.wait_closed()
except Exception:
pass
logger.info(f"User {user_id} reconnected, closing old connection")
# 注册新连接
self._connections[user_id] = connection
self._users[user_id] = user_info
logger.info(f"User {user_id} registered from {connection.ip_address}:{connection.port}")
return True
def unregister_user(self, user_id: str) -> None:
"""
注销用户同步版本用于消息处理器回调
通知服务器更新在线状态 (需求 2.5)
Args:
user_id: 用户ID
"""
# 这是同步版本,实际使用异步版本
pass
async def _unregister_user_async(self, user_id: str) -> None:
"""
注销用户异步版本
Args:
user_id: 用户ID
"""
await self._cleanup_connection(user_id)
def get_online_users(self) -> List[UserInfo]:
"""
获取在线用户列表
显示当前所有在线用户 (需求 2.3)
Returns:
在线用户信息列表
"""
online_users = []
for user_id, user_info in self._users.items():
if user_id in self._connections and user_info.status == UserStatus.ONLINE:
online_users.append(user_info)
return online_users
def get_user_info(self, user_id: str) -> Optional[UserInfo]:
"""
获取指定用户信息
Args:
user_id: 用户ID
Returns:
用户信息如果用户不存在则返回None
"""
return self._users.get(user_id)
def is_user_online(self, user_id: str) -> bool:
"""
检查用户是否在线
Args:
user_id: 用户ID
Returns:
用户在线返回True否则返回False
"""
return (
user_id in self._connections and
user_id in self._users and
self._users[user_id].status == UserStatus.ONLINE
)
# ==================== 消息处理器回调 ====================
def _handle_user_register(self, message: Message) -> None:
"""处理用户注册消息(同步回调)"""
# 实际处理在 _process_register 中完成
pass
def _handle_user_unregister(self, message: Message) -> None:
"""处理用户注销消息(同步回调)"""
# 实际处理在异步版本中完成
pass
async def _handle_user_unregister_async(self, message: Message) -> None:
"""处理用户注销消息(异步版本)"""
user_id = message.sender_id
await self._unregister_user_async(user_id)
logger.info(f"User {user_id} unregistered")
def _handle_user_list_request(self, message: Message) -> None:
"""处理用户列表请求(同步回调)"""
# 实际处理在异步版本中完成
pass
async def _handle_user_list_request_async(self, message: Message) -> None:
"""处理用户列表请求(异步版本)"""
user_id = message.sender_id
if user_id not in self._connections:
return
# 获取在线用户列表
online_users = self.get_online_users()
# 构建响应
users_data = [user.to_dict() for user in online_users]
payload = json.dumps(users_data).encode('utf-8')
response = Message(
msg_type=MessageType.USER_LIST_RESPONSE,
sender_id="server",
receiver_id=user_id,
timestamp=time.time(),
payload=payload
)
await self._send_message(self._connections[user_id].writer, response)
def _handle_heartbeat(self, message: Message) -> None:
"""处理心跳消息(同步回调)"""
# 心跳时间更新在 _process_message 中完成
pass
async def _handle_heartbeat_async(self, message: Message) -> None:
"""处理心跳消息(异步版本)"""
user_id = message.sender_id
if user_id in self._connections:
self._connections[user_id].last_heartbeat = time.time()
# 发送心跳响应
response = Message(
msg_type=MessageType.HEARTBEAT,
sender_id="server",
receiver_id=user_id,
timestamp=time.time(),
payload=b""
)
await self._send_message(self._connections[user_id].writer, response)
def _handle_relay_message(self, message: Message) -> None:
"""处理需要转发的消息(同步回调)"""
# 实际处理在异步版本中完成
pass
# ==================== 消息转发功能 (需求 8.4, 8.5) ====================
async def relay_message(self, message: Message) -> bool:
"""
转发消息到目标客户端
将消息路由到目标客户端 (需求 8.4)
Args:
message: 要转发的消息
Returns:
转发成功返回True否则返回False
"""
return await self._relay_message_async(message)
async def _relay_message_async(self, message: Message) -> bool:
"""
转发消息异步版本
Args:
message: 要转发的消息
Returns:
转发成功返回True否则返回False
"""
receiver_id = message.receiver_id
sender_id = message.sender_id
# 检查目标用户是否在线
if self.is_user_online(receiver_id):
# 在线,直接转发
conn = self._connections[receiver_id]
success = await self._send_message(conn.writer, message)
if success:
logger.debug(f"Message relayed from {sender_id} to {receiver_id}")
# 发送ACK给发送者
if sender_id in self._connections:
ack = self._create_ack_message(
"server", sender_id,
f"Message delivered to {receiver_id}"
)
await self._send_message(self._connections[sender_id].writer, ack)
return True
else:
logger.error(f"Failed to relay message to {receiver_id}")
return False
else:
# 离线,缓存消息
self.cache_offline_message(receiver_id, message)
logger.info(f"User {receiver_id} offline, message cached")
# 通知发送者消息已缓存
if sender_id in self._connections:
ack = self._create_ack_message(
"server", sender_id,
f"User {receiver_id} is offline, message cached"
)
await self._send_message(self._connections[sender_id].writer, ack)
return True
def cache_offline_message(self, user_id: str, message: Message) -> None:
"""
缓存离线消息
缓存消息并在客户端上线后投递 (需求 8.5)
Args:
user_id: 目标用户ID
message: 要缓存的消息
"""
if user_id not in self._offline_messages:
self._offline_messages[user_id] = deque(maxlen=self._max_offline_messages)
self._offline_messages[user_id].append(message)
logger.debug(f"Cached offline message for user {user_id}, "
f"total: {len(self._offline_messages[user_id])}")
async def _deliver_offline_messages(self, user_id: str) -> None:
"""
投递离线消息
当用户上线时投递所有缓存的消息 (需求 8.5)
Args:
user_id: 用户ID
"""
if user_id not in self._offline_messages:
return
if user_id not in self._connections:
return
messages = self._offline_messages.pop(user_id)
conn = self._connections[user_id]
delivered_count = 0
for message in messages:
success = await self._send_message(conn.writer, message)
if success:
delivered_count += 1
else:
# 如果发送失败,将剩余消息放回缓存
remaining = list(messages)[delivered_count:]
if remaining:
self._offline_messages[user_id] = deque(remaining, maxlen=self._max_offline_messages)
break
if delivered_count > 0:
logger.info(f"Delivered {delivered_count} offline messages to user {user_id}")
def get_offline_message_count(self, user_id: str) -> int:
"""
获取用户的离线消息数量
Args:
user_id: 用户ID
Returns:
离线消息数量
"""
if user_id in self._offline_messages:
return len(self._offline_messages[user_id])
return 0
def clear_offline_messages(self, user_id: str) -> None:
"""
清除用户的离线消息
Args:
user_id: 用户ID
"""
if user_id in self._offline_messages:
del self._offline_messages[user_id]
logger.debug(f"Cleared offline messages for user {user_id}")
# ==================== 辅助方法 ====================
def _create_error_message(self, sender_id: str, receiver_id: str, error: str) -> Message:
"""
创建错误消息
Args:
sender_id: 发送者ID
receiver_id: 接收者ID
error: 错误信息
Returns:
错误消息对象
"""
return Message(
msg_type=MessageType.ERROR,
sender_id=sender_id,
receiver_id=receiver_id,
timestamp=time.time(),
payload=error.encode('utf-8')
)
def _create_ack_message(self, sender_id: str, receiver_id: str, info: str) -> Message:
"""
创建确认消息
Args:
sender_id: 发送者ID
receiver_id: 接收者ID
info: 确认信息
Returns:
确认消息对象
"""
return Message(
msg_type=MessageType.ACK,
sender_id=sender_id,
receiver_id=receiver_id,
timestamp=time.time(),
payload=info.encode('utf-8')
)
# ==================== 属性访问 ====================
@property
def is_running(self) -> bool:
"""服务器是否正在运行"""
return self._running
@property
def connection_count(self) -> int:
"""当前连接数"""
return len(self._connections)
@property
def user_count(self) -> int:
"""注册用户数"""
return len(self._users)
@property
def online_user_count(self) -> int:
"""在线用户数"""
return len(self.get_online_users())

@ -0,0 +1,905 @@
# P2P Network Communication - Data Access Layer
"""
数据访问层提供用户消息文件传输记录的CRUD操作
"""
import json
import logging
from datetime import datetime
from typing import Optional, List
from shared.models import (
UserInfo, UserStatus, ChatMessage, MessageType,
FileTransferRecord, TransferStatus, Message
)
from server.database import DatabaseManager
logger = logging.getLogger(__name__)
class UserRepository:
"""用户数据访问层"""
def __init__(self, db: DatabaseManager):
"""
初始化用户仓库
Args:
db: 数据库管理器实例
"""
self.db = db
async def create(self, user: UserInfo, password_hash: str = "") -> bool:
"""
创建新用户
Args:
user: 用户信息对象
password_hash: 密码哈希值
Returns:
创建是否成功
"""
query = """
INSERT INTO users (user_id, username, display_name, password_hash,
public_key, status, ip_address, port, last_seen)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
try:
await self.db.execute(query, (
user.user_id,
user.username,
user.display_name,
password_hash,
user.public_key if user.public_key else None,
user.status.value,
user.ip_address,
user.port,
user.last_seen,
))
logger.info(f"User created: {user.user_id}")
return True
except Exception as e:
logger.error(f"Failed to create user {user.user_id}: {e}")
return False
async def get_by_id(self, user_id: str) -> Optional[UserInfo]:
"""
根据用户ID获取用户信息
Args:
user_id: 用户ID
Returns:
用户信息对象如果不存在返回None
"""
query = """
SELECT user_id, username, display_name, public_key, status,
ip_address, port, last_seen
FROM users WHERE user_id = %s
"""
row = await self.db.fetch_dict(query, (user_id,))
if row:
return self._row_to_user(row)
return None
async def get_by_username(self, username: str) -> Optional[UserInfo]:
"""
根据用户名获取用户信息
Args:
username: 用户名
Returns:
用户信息对象如果不存在返回None
"""
query = """
SELECT user_id, username, display_name, public_key, status,
ip_address, port, last_seen
FROM users WHERE username = %s
"""
row = await self.db.fetch_dict(query, (username,))
if row:
return self._row_to_user(row)
return None
async def update(self, user: UserInfo) -> bool:
"""
更新用户信息
Args:
user: 用户信息对象
Returns:
更新是否成功
"""
query = """
UPDATE users SET display_name = %s, public_key = %s, status = %s,
ip_address = %s, port = %s, last_seen = %s
WHERE user_id = %s
"""
try:
rows = await self.db.execute(query, (
user.display_name,
user.public_key if user.public_key else None,
user.status.value,
user.ip_address,
user.port,
user.last_seen,
user.user_id,
))
return rows > 0
except Exception as e:
logger.error(f"Failed to update user {user.user_id}: {e}")
return False
async def update_status(self, user_id: str, status: UserStatus,
ip_address: str = "", port: int = 0) -> bool:
"""
更新用户在线状态
Args:
user_id: 用户ID
status: 用户状态
ip_address: IP地址
port: 端口号
Returns:
更新是否成功
"""
query = """
UPDATE users SET status = %s, ip_address = %s, port = %s, last_seen = %s
WHERE user_id = %s
"""
try:
rows = await self.db.execute(query, (
status.value,
ip_address,
port,
datetime.now(),
user_id,
))
return rows > 0
except Exception as e:
logger.error(f"Failed to update user status {user_id}: {e}")
return False
async def delete(self, user_id: str) -> bool:
"""
删除用户
Args:
user_id: 用户ID
Returns:
删除是否成功
"""
query = "DELETE FROM users WHERE user_id = %s"
try:
rows = await self.db.execute(query, (user_id,))
return rows > 0
except Exception as e:
logger.error(f"Failed to delete user {user_id}: {e}")
return False
async def get_online_users(self) -> List[UserInfo]:
"""
获取所有在线用户
Returns:
在线用户列表
"""
query = """
SELECT user_id, username, display_name, public_key, status,
ip_address, port, last_seen
FROM users WHERE status = 'online'
"""
rows = await self.db.fetch_all_dict(query)
return [self._row_to_user(row) for row in rows]
async def get_all_users(self) -> List[UserInfo]:
"""
获取所有用户
Returns:
用户列表
"""
query = """
SELECT user_id, username, display_name, public_key, status,
ip_address, port, last_seen
FROM users ORDER BY username
"""
rows = await self.db.fetch_all_dict(query)
return [self._row_to_user(row) for row in rows]
async def exists(self, user_id: str) -> bool:
"""
检查用户是否存在
Args:
user_id: 用户ID
Returns:
用户是否存在
"""
query = "SELECT 1 FROM users WHERE user_id = %s"
row = await self.db.fetch_one(query, (user_id,))
return row is not None
async def username_exists(self, username: str) -> bool:
"""
检查用户名是否已存在
Args:
username: 用户名
Returns:
用户名是否已存在
"""
query = "SELECT 1 FROM users WHERE username = %s"
row = await self.db.fetch_one(query, (username,))
return row is not None
def _row_to_user(self, row: dict) -> UserInfo:
"""将数据库行转换为UserInfo对象"""
return UserInfo(
user_id=row['user_id'],
username=row['username'],
display_name=row['display_name'],
status=UserStatus(row['status']) if row['status'] else UserStatus.OFFLINE,
ip_address=row['ip_address'] or "",
port=row['port'] or 0,
last_seen=row['last_seen'],
public_key=row['public_key'] or bytes(),
)
class MessageRepository:
"""消息数据访问层"""
def __init__(self, db: DatabaseManager):
"""
初始化消息仓库
Args:
db: 数据库管理器实例
"""
self.db = db
async def create(self, message: ChatMessage) -> bool:
"""
保存聊天消息
Args:
message: 聊天消息对象
Returns:
保存是否成功
"""
query = """
INSERT INTO messages (message_id, sender_id, receiver_id, content_type,
content, timestamp, is_read, is_delivered)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
"""
try:
await self.db.execute(query, (
message.message_id,
message.sender_id,
message.receiver_id,
message.content_type.value,
message.content,
message.timestamp,
message.is_read,
message.is_sent,
))
logger.debug(f"Message saved: {message.message_id}")
return True
except Exception as e:
logger.error(f"Failed to save message {message.message_id}: {e}")
return False
async def get_by_id(self, message_id: str) -> Optional[ChatMessage]:
"""
根据消息ID获取消息
Args:
message_id: 消息ID
Returns:
聊天消息对象如果不存在返回None
"""
query = """
SELECT message_id, sender_id, receiver_id, content_type, content,
timestamp, is_read, is_delivered
FROM messages WHERE message_id = %s
"""
row = await self.db.fetch_dict(query, (message_id,))
if row:
return self._row_to_message(row)
return None
async def get_conversation(self, user1_id: str, user2_id: str,
limit: int = 100, offset: int = 0) -> List[ChatMessage]:
"""
获取两个用户之间的聊天历史
Args:
user1_id: 用户1的ID
user2_id: 用户2的ID
limit: 返回消息数量限制
offset: 偏移量
Returns:
聊天消息列表按时间升序排列
"""
query = """
SELECT message_id, sender_id, receiver_id, content_type, content,
timestamp, is_read, is_delivered
FROM messages
WHERE (sender_id = %s AND receiver_id = %s)
OR (sender_id = %s AND receiver_id = %s)
ORDER BY timestamp ASC
LIMIT %s OFFSET %s
"""
rows = await self.db.fetch_all_dict(query, (
user1_id, user2_id, user2_id, user1_id, limit, offset
))
return [self._row_to_message(row) for row in rows]
async def get_user_messages(self, user_id: str, limit: int = 100,
offset: int = 0) -> List[ChatMessage]:
"""
获取用户的所有消息发送和接收
Args:
user_id: 用户ID
limit: 返回消息数量限制
offset: 偏移量
Returns:
聊天消息列表按时间升序排列
"""
query = """
SELECT message_id, sender_id, receiver_id, content_type, content,
timestamp, is_read, is_delivered
FROM messages
WHERE sender_id = %s OR receiver_id = %s
ORDER BY timestamp ASC
LIMIT %s OFFSET %s
"""
rows = await self.db.fetch_all_dict(query, (user_id, user_id, limit, offset))
return [self._row_to_message(row) for row in rows]
async def get_unread_messages(self, user_id: str) -> List[ChatMessage]:
"""
获取用户的未读消息
Args:
user_id: 用户ID
Returns:
未读消息列表
"""
query = """
SELECT message_id, sender_id, receiver_id, content_type, content,
timestamp, is_read, is_delivered
FROM messages
WHERE receiver_id = %s AND is_read = FALSE
ORDER BY timestamp ASC
"""
rows = await self.db.fetch_all_dict(query, (user_id,))
return [self._row_to_message(row) for row in rows]
async def mark_as_read(self, message_id: str) -> bool:
"""
标记消息为已读
Args:
message_id: 消息ID
Returns:
更新是否成功
"""
query = "UPDATE messages SET is_read = TRUE WHERE message_id = %s"
try:
rows = await self.db.execute(query, (message_id,))
return rows > 0
except Exception as e:
logger.error(f"Failed to mark message as read {message_id}: {e}")
return False
async def mark_conversation_as_read(self, user_id: str, sender_id: str) -> int:
"""
标记与某用户的所有消息为已读
Args:
user_id: 接收者用户ID
sender_id: 发送者用户ID
Returns:
更新的消息数量
"""
query = """
UPDATE messages SET is_read = TRUE
WHERE receiver_id = %s AND sender_id = %s AND is_read = FALSE
"""
try:
return await self.db.execute(query, (user_id, sender_id))
except Exception as e:
logger.error(f"Failed to mark conversation as read: {e}")
return 0
async def mark_as_delivered(self, message_id: str) -> bool:
"""
标记消息为已投递
Args:
message_id: 消息ID
Returns:
更新是否成功
"""
query = "UPDATE messages SET is_delivered = TRUE WHERE message_id = %s"
try:
rows = await self.db.execute(query, (message_id,))
return rows > 0
except Exception as e:
logger.error(f"Failed to mark message as delivered {message_id}: {e}")
return False
async def delete(self, message_id: str) -> bool:
"""
删除消息
Args:
message_id: 消息ID
Returns:
删除是否成功
"""
query = "DELETE FROM messages WHERE message_id = %s"
try:
rows = await self.db.execute(query, (message_id,))
return rows > 0
except Exception as e:
logger.error(f"Failed to delete message {message_id}: {e}")
return False
async def delete_conversation(self, user1_id: str, user2_id: str) -> int:
"""
删除两个用户之间的所有消息
Args:
user1_id: 用户1的ID
user2_id: 用户2的ID
Returns:
删除的消息数量
"""
query = """
DELETE FROM messages
WHERE (sender_id = %s AND receiver_id = %s)
OR (sender_id = %s AND receiver_id = %s)
"""
try:
return await self.db.execute(query, (user1_id, user2_id, user2_id, user1_id))
except Exception as e:
logger.error(f"Failed to delete conversation: {e}")
return 0
def _row_to_message(self, row: dict) -> ChatMessage:
"""将数据库行转换为ChatMessage对象"""
return ChatMessage(
message_id=row['message_id'],
sender_id=row['sender_id'],
receiver_id=row['receiver_id'],
content_type=MessageType(row['content_type']),
content=row['content'] or "",
timestamp=row['timestamp'],
is_read=bool(row['is_read']),
is_sent=bool(row['is_delivered']),
)
class FileTransferRepository:
"""文件传输记录数据访问层"""
def __init__(self, db: DatabaseManager):
"""
初始化文件传输仓库
Args:
db: 数据库管理器实例
"""
self.db = db
async def create(self, record: FileTransferRecord) -> bool:
"""
创建文件传输记录
Args:
record: 文件传输记录对象
Returns:
创建是否成功
"""
query = """
INSERT INTO file_transfers (transfer_id, file_name, file_size, file_hash,
sender_id, receiver_id, status, progress,
start_time, end_time)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
try:
await self.db.execute(query, (
record.transfer_id,
record.file_name,
record.file_size,
record.file_hash,
record.sender_id,
record.receiver_id,
record.status.value,
record.progress,
record.start_time,
record.end_time,
))
logger.info(f"File transfer record created: {record.transfer_id}")
return True
except Exception as e:
logger.error(f"Failed to create file transfer record: {e}")
return False
async def get_by_id(self, transfer_id: str) -> Optional[FileTransferRecord]:
"""
根据传输ID获取记录
Args:
transfer_id: 传输ID
Returns:
文件传输记录对象如果不存在返回None
"""
query = """
SELECT transfer_id, file_name, file_size, file_hash, sender_id,
receiver_id, status, progress, start_time, end_time
FROM file_transfers WHERE transfer_id = %s
"""
row = await self.db.fetch_dict(query, (transfer_id,))
if row:
return self._row_to_record(row)
return None
async def update_status(self, transfer_id: str, status: TransferStatus,
progress: float = None) -> bool:
"""
更新传输状态
Args:
transfer_id: 传输ID
status: 新状态
progress: 进度可选
Returns:
更新是否成功
"""
if progress is not None:
query = """
UPDATE file_transfers SET status = %s, progress = %s
WHERE transfer_id = %s
"""
args = (status.value, progress, transfer_id)
else:
query = "UPDATE file_transfers SET status = %s WHERE transfer_id = %s"
args = (status.value, transfer_id)
try:
rows = await self.db.execute(query, args)
return rows > 0
except Exception as e:
logger.error(f"Failed to update transfer status: {e}")
return False
async def update_progress(self, transfer_id: str, progress: float) -> bool:
"""
更新传输进度
Args:
transfer_id: 传输ID
progress: 进度百分比0-100
Returns:
更新是否成功
"""
query = "UPDATE file_transfers SET progress = %s WHERE transfer_id = %s"
try:
rows = await self.db.execute(query, (progress, transfer_id))
return rows > 0
except Exception as e:
logger.error(f"Failed to update transfer progress: {e}")
return False
async def complete(self, transfer_id: str) -> bool:
"""
标记传输完成
Args:
transfer_id: 传输ID
Returns:
更新是否成功
"""
query = """
UPDATE file_transfers SET status = %s, progress = 100, end_time = %s
WHERE transfer_id = %s
"""
try:
rows = await self.db.execute(query, (
TransferStatus.COMPLETED.value, datetime.now(), transfer_id
))
return rows > 0
except Exception as e:
logger.error(f"Failed to complete transfer: {e}")
return False
async def fail(self, transfer_id: str) -> bool:
"""
标记传输失败
Args:
transfer_id: 传输ID
Returns:
更新是否成功
"""
query = """
UPDATE file_transfers SET status = %s, end_time = %s
WHERE transfer_id = %s
"""
try:
rows = await self.db.execute(query, (
TransferStatus.FAILED.value, datetime.now(), transfer_id
))
return rows > 0
except Exception as e:
logger.error(f"Failed to mark transfer as failed: {e}")
return False
async def get_user_transfers(self, user_id: str,
status: TransferStatus = None) -> List[FileTransferRecord]:
"""
获取用户的文件传输记录
Args:
user_id: 用户ID
status: 过滤状态可选
Returns:
文件传输记录列表
"""
if status:
query = """
SELECT transfer_id, file_name, file_size, file_hash, sender_id,
receiver_id, status, progress, start_time, end_time
FROM file_transfers
WHERE (sender_id = %s OR receiver_id = %s) AND status = %s
ORDER BY start_time DESC
"""
rows = await self.db.fetch_all_dict(query, (user_id, user_id, status.value))
else:
query = """
SELECT transfer_id, file_name, file_size, file_hash, sender_id,
receiver_id, status, progress, start_time, end_time
FROM file_transfers
WHERE sender_id = %s OR receiver_id = %s
ORDER BY start_time DESC
"""
rows = await self.db.fetch_all_dict(query, (user_id, user_id))
return [self._row_to_record(row) for row in rows]
async def get_pending_transfers(self, user_id: str) -> List[FileTransferRecord]:
"""
获取用户待处理的传输
Args:
user_id: 用户ID
Returns:
待处理的文件传输记录列表
"""
query = """
SELECT transfer_id, file_name, file_size, file_hash, sender_id,
receiver_id, status, progress, start_time, end_time
FROM file_transfers
WHERE (sender_id = %s OR receiver_id = %s)
AND status IN ('pending', 'in_progress', 'paused')
ORDER BY start_time DESC
"""
rows = await self.db.fetch_all_dict(query, (user_id, user_id))
return [self._row_to_record(row) for row in rows]
async def delete(self, transfer_id: str) -> bool:
"""
删除传输记录
Args:
transfer_id: 传输ID
Returns:
删除是否成功
"""
query = "DELETE FROM file_transfers WHERE transfer_id = %s"
try:
rows = await self.db.execute(query, (transfer_id,))
return rows > 0
except Exception as e:
logger.error(f"Failed to delete transfer record: {e}")
return False
def _row_to_record(self, row: dict) -> FileTransferRecord:
"""将数据库行转换为FileTransferRecord对象"""
return FileTransferRecord(
transfer_id=row['transfer_id'],
file_name=row['file_name'],
file_size=row['file_size'],
file_hash=row['file_hash'],
sender_id=row['sender_id'],
receiver_id=row['receiver_id'],
status=TransferStatus(row['status']),
progress=float(row['progress']),
start_time=row['start_time'],
end_time=row['end_time'],
)
class OfflineMessageRepository:
"""离线消息数据访问层"""
def __init__(self, db: DatabaseManager):
"""
初始化离线消息仓库
Args:
db: 数据库管理器实例
"""
self.db = db
async def cache(self, user_id: str, message: Message) -> bool:
"""
缓存离线消息
Args:
user_id: 目标用户ID
message: 消息对象
Returns:
缓存是否成功
"""
query = """
INSERT INTO offline_messages (user_id, message_data)
VALUES (%s, %s)
"""
try:
message_data = message.serialize()
await self.db.execute(query, (user_id, message_data))
logger.debug(f"Offline message cached for user: {user_id}")
return True
except Exception as e:
logger.error(f"Failed to cache offline message: {e}")
return False
async def get_messages(self, user_id: str) -> List[Message]:
"""
获取用户的所有离线消息
Args:
user_id: 用户ID
Returns:
离线消息列表按创建时间升序排列
"""
query = """
SELECT id, message_data FROM offline_messages
WHERE user_id = %s ORDER BY created_at ASC
"""
rows = await self.db.fetch_all_dict(query, (user_id,))
messages = []
for row in rows:
try:
message = Message.deserialize(row['message_data'])
messages.append(message)
except Exception as e:
logger.error(f"Failed to deserialize offline message: {e}")
return messages
async def get_messages_with_ids(self, user_id: str) -> List[tuple]:
"""
获取用户的所有离线消息及其ID
Args:
user_id: 用户ID
Returns:
(消息ID, 消息对象)元组列表
"""
query = """
SELECT id, message_data FROM offline_messages
WHERE user_id = %s ORDER BY created_at ASC
"""
rows = await self.db.fetch_all_dict(query, (user_id,))
results = []
for row in rows:
try:
message = Message.deserialize(row['message_data'])
results.append((row['id'], message))
except Exception as e:
logger.error(f"Failed to deserialize offline message: {e}")
return results
async def delete(self, message_id: int) -> bool:
"""
删除单条离线消息
Args:
message_id: 离线消息记录ID
Returns:
删除是否成功
"""
query = "DELETE FROM offline_messages WHERE id = %s"
try:
rows = await self.db.execute(query, (message_id,))
return rows > 0
except Exception as e:
logger.error(f"Failed to delete offline message: {e}")
return False
async def delete_all(self, user_id: str) -> int:
"""
删除用户的所有离线消息
Args:
user_id: 用户ID
Returns:
删除的消息数量
"""
query = "DELETE FROM offline_messages WHERE user_id = %s"
try:
return await self.db.execute(query, (user_id,))
except Exception as e:
logger.error(f"Failed to delete all offline messages: {e}")
return 0
async def count(self, user_id: str) -> int:
"""
获取用户的离线消息数量
Args:
user_id: 用户ID
Returns:
离线消息数量
"""
query = "SELECT COUNT(*) as count FROM offline_messages WHERE user_id = %s"
row = await self.db.fetch_dict(query, (user_id,))
return row['count'] if row else 0
async def cleanup_old_messages(self, max_age_seconds: int) -> int:
"""
清理过期的离线消息
Args:
max_age_seconds: 最大保留时间
Returns:
删除的消息数量
"""
query = """
DELETE FROM offline_messages
WHERE created_at < DATE_SUB(NOW(), INTERVAL %s SECOND)
"""
try:
return await self.db.execute(query, (max_age_seconds,))
except Exception as e:
logger.error(f"Failed to cleanup old offline messages: {e}")
return 0

@ -0,0 +1,365 @@
# P2P Network Communication - Relay Server Tests
"""
中转服务器单元测试
测试服务器核心功能用户管理和消息转发
"""
import asyncio
import pytest
import time
import json
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime
from server.relay_server import (
RelayServer, ClientConnection, ServerError
)
from shared.models import (
Message, MessageType, UserInfo, UserStatus
)
from config import ServerConfig
class TestRelayServerInit:
"""测试服务器初始化"""
def test_init_with_default_config(self):
"""测试使用默认配置初始化"""
server = RelayServer()
assert server.host == "0.0.0.0"
assert server.port == 8888
assert server.max_connections == 1000
assert not server.is_running
assert server.connection_count == 0
def test_init_with_custom_config(self):
"""测试使用自定义配置初始化"""
config = ServerConfig(
host="127.0.0.1",
port=9999,
max_connections=500
)
server = RelayServer(config)
assert server.host == "127.0.0.1"
assert server.port == 9999
assert server.max_connections == 500
class TestUserManagement:
"""测试用户管理功能"""
@pytest.fixture
def server(self):
"""创建服务器实例"""
return RelayServer()
@pytest.fixture
def mock_connection(self):
"""创建模拟连接"""
reader = AsyncMock()
writer = AsyncMock()
writer.close = MagicMock()
writer.wait_closed = AsyncMock()
return ClientConnection(
user_id="test_user",
reader=reader,
writer=writer,
ip_address="192.168.1.100",
port=12345
)
@pytest.fixture
def user_info(self):
"""创建用户信息"""
return UserInfo(
user_id="test_user",
username="testuser",
display_name="Test User",
status=UserStatus.ONLINE
)
@pytest.mark.asyncio
async def test_register_user(self, server, mock_connection, user_info):
"""测试用户注册"""
success = await server._register_user_async(
"test_user", user_info, mock_connection
)
assert success
assert server.is_user_online("test_user")
assert server.connection_count == 1
@pytest.mark.asyncio
async def test_register_user_reconnect(self, server, mock_connection, user_info):
"""测试用户重连(替换旧连接)"""
# 第一次注册
await server._register_user_async("test_user", user_info, mock_connection)
# 创建新连接
new_reader = AsyncMock()
new_writer = AsyncMock()
new_writer.close = MagicMock()
new_writer.wait_closed = AsyncMock()
new_connection = ClientConnection(
user_id="test_user",
reader=new_reader,
writer=new_writer,
ip_address="192.168.1.101",
port=12346
)
# 重新注册
success = await server._register_user_async("test_user", user_info, new_connection)
assert success
assert server.connection_count == 1 # 仍然只有一个连接
# 旧连接应该被关闭
mock_connection.writer.close.assert_called()
@pytest.mark.asyncio
async def test_unregister_user(self, server, mock_connection, user_info):
"""测试用户注销"""
await server._register_user_async("test_user", user_info, mock_connection)
assert server.is_user_online("test_user")
await server._unregister_user_async("test_user")
assert not server.is_user_online("test_user")
assert server.connection_count == 0
@pytest.mark.asyncio
async def test_get_online_users(self, server, mock_connection, user_info):
"""测试获取在线用户列表"""
# 初始为空
assert len(server.get_online_users()) == 0
# 注册用户
await server._register_user_async("test_user", user_info, mock_connection)
online_users = server.get_online_users()
assert len(online_users) == 1
assert online_users[0].user_id == "test_user"
@pytest.mark.asyncio
async def test_get_user_info(self, server, mock_connection, user_info):
"""测试获取用户信息"""
# 用户不存在
assert server.get_user_info("test_user") is None
# 注册用户
await server._register_user_async("test_user", user_info, mock_connection)
info = server.get_user_info("test_user")
assert info is not None
assert info.user_id == "test_user"
assert info.username == "testuser"
class TestMessageRelay:
"""测试消息转发功能"""
@pytest.fixture
def server(self):
"""创建服务器实例"""
return RelayServer()
@pytest.fixture
def create_mock_connection(self):
"""创建模拟连接的工厂函数"""
def _create(user_id):
reader = AsyncMock()
writer = AsyncMock()
writer.close = MagicMock()
writer.wait_closed = AsyncMock()
writer.write = MagicMock()
writer.drain = AsyncMock()
return ClientConnection(
user_id=user_id,
reader=reader,
writer=writer,
ip_address="192.168.1.100",
port=12345
)
return _create
@pytest.mark.asyncio
async def test_relay_message_to_online_user(self, server, create_mock_connection):
"""测试转发消息给在线用户"""
# 注册发送者和接收者
sender_conn = create_mock_connection("sender")
receiver_conn = create_mock_connection("receiver")
sender_info = UserInfo(
user_id="sender", username="sender",
display_name="Sender", status=UserStatus.ONLINE
)
receiver_info = UserInfo(
user_id="receiver", username="receiver",
display_name="Receiver", status=UserStatus.ONLINE
)
await server._register_user_async("sender", sender_info, sender_conn)
await server._register_user_async("receiver", receiver_info, receiver_conn)
# 创建消息
message = Message(
msg_type=MessageType.TEXT,
sender_id="sender",
receiver_id="receiver",
timestamp=time.time(),
payload=b"Hello, receiver!"
)
# 转发消息
success = await server.relay_message(message)
assert success
# 验证消息被发送到接收者
receiver_conn.writer.write.assert_called()
@pytest.mark.asyncio
async def test_relay_message_to_offline_user(self, server, create_mock_connection):
"""测试转发消息给离线用户(缓存)"""
# 只注册发送者
sender_conn = create_mock_connection("sender")
sender_info = UserInfo(
user_id="sender", username="sender",
display_name="Sender", status=UserStatus.ONLINE
)
await server._register_user_async("sender", sender_info, sender_conn)
# 创建消息
message = Message(
msg_type=MessageType.TEXT,
sender_id="sender",
receiver_id="offline_user",
timestamp=time.time(),
payload=b"Hello, offline user!"
)
# 转发消息(应该被缓存)
success = await server.relay_message(message)
assert success
# 验证消息被缓存
assert server.get_offline_message_count("offline_user") == 1
def test_cache_offline_message(self, server):
"""测试离线消息缓存"""
message = Message(
msg_type=MessageType.TEXT,
sender_id="sender",
receiver_id="receiver",
timestamp=time.time(),
payload=b"Test message"
)
server.cache_offline_message("receiver", message)
assert server.get_offline_message_count("receiver") == 1
# 缓存多条消息
server.cache_offline_message("receiver", message)
server.cache_offline_message("receiver", message)
assert server.get_offline_message_count("receiver") == 3
def test_clear_offline_messages(self, server):
"""测试清除离线消息"""
message = Message(
msg_type=MessageType.TEXT,
sender_id="sender",
receiver_id="receiver",
timestamp=time.time(),
payload=b"Test message"
)
server.cache_offline_message("receiver", message)
assert server.get_offline_message_count("receiver") == 1
server.clear_offline_messages("receiver")
assert server.get_offline_message_count("receiver") == 0
@pytest.mark.asyncio
async def test_deliver_offline_messages(self, server, create_mock_connection):
"""测试投递离线消息"""
# 缓存离线消息
message1 = Message(
msg_type=MessageType.TEXT,
sender_id="sender",
receiver_id="receiver",
timestamp=time.time(),
payload=b"Message 1"
)
message2 = Message(
msg_type=MessageType.TEXT,
sender_id="sender",
receiver_id="receiver",
timestamp=time.time() + 1,
payload=b"Message 2"
)
server.cache_offline_message("receiver", message1)
server.cache_offline_message("receiver", message2)
assert server.get_offline_message_count("receiver") == 2
# 用户上线
receiver_conn = create_mock_connection("receiver")
receiver_info = UserInfo(
user_id="receiver", username="receiver",
display_name="Receiver", status=UserStatus.ONLINE
)
await server._register_user_async("receiver", receiver_info, receiver_conn)
# 投递离线消息
await server._deliver_offline_messages("receiver")
# 验证消息被投递
assert receiver_conn.writer.write.call_count >= 2
assert server.get_offline_message_count("receiver") == 0
class TestHelperMethods:
"""测试辅助方法"""
@pytest.fixture
def server(self):
"""创建服务器实例"""
return RelayServer()
def test_create_error_message(self, server):
"""测试创建错误消息"""
error_msg = server._create_error_message(
"server", "user1", "Test error"
)
assert error_msg.msg_type == MessageType.ERROR
assert error_msg.sender_id == "server"
assert error_msg.receiver_id == "user1"
assert error_msg.payload == b"Test error"
def test_create_ack_message(self, server):
"""测试创建确认消息"""
ack_msg = server._create_ack_message(
"server", "user1", "Test ack"
)
assert ack_msg.msg_type == MessageType.ACK
assert ack_msg.sender_id == "server"
assert ack_msg.receiver_id == "user1"
assert ack_msg.payload == b"Test ack"
class TestClientConnection:
"""测试客户端连接类"""
def test_connection_is_alive(self):
"""测试连接存活检查"""
reader = AsyncMock()
writer = AsyncMock()
conn = ClientConnection(
user_id="test",
reader=reader,
writer=writer,
last_heartbeat=time.time()
)
assert conn.is_alive
# 模拟超时
conn.last_heartbeat = time.time() - 120 # 2分钟前
assert not conn.is_alive
Loading…
Cancel
Save