确保基础模块测试通过

main
杨博文 4 months ago
parent cab4f6cbde
commit d6b92b67dc

5
.gitignore vendored

@ -0,0 +1,5 @@
__pycache__/
*.pyc
.kiro/*

@ -0,0 +1,7 @@
# P2P Network Communication - Client Module
"""
客户端模块运行在Windows系统上
提供用户界面和P2P通信功能
"""
__version__ = "0.1.0"

@ -0,0 +1,106 @@
# P2P Network Communication - Configuration Module
"""
配置模块包含服务器和客户端的配置参数
"""
import os
from dataclasses import dataclass
from typing import Optional
@dataclass
class ServerConfig:
"""服务器配置"""
host: str = "0.0.0.0"
port: int = 8888
max_connections: int = 1000
heartbeat_interval: int = 30 # seconds
offline_message_ttl: int = 7 * 24 * 3600 # 7 days in seconds
# Database configuration
db_host: str = "localhost"
db_port: int = 3306
db_name: str = "p2p_chat"
db_user: str = "root"
db_password: str = ""
db_pool_size: int = 10
@dataclass
class ClientConfig:
"""客户端配置"""
server_host: str = "127.0.0.1"
server_port: int = 8888
# LAN discovery
lan_broadcast_port: int = 8889
lan_discovery_timeout: float = 2.0 # seconds
# Connection settings
connection_timeout: float = 10.0 # seconds
reconnect_attempts: int = 3
reconnect_delay: float = 1.0 # seconds (will be multiplied: 1s, 2s, 4s)
heartbeat_interval: int = 30 # seconds
# File transfer
chunk_size: int = 64 * 1024 # 64KB per chunk
max_file_size: int = 2 * 1024 * 1024 * 1024 # 2GB max
# Voice chat
audio_sample_rate: int = 16000
audio_channels: int = 1
audio_chunk_duration: float = 0.02 # 20ms per chunk
max_voice_latency: float = 0.3 # 300ms max latency
# UI settings
window_width: int = 1024
window_height: int = 768
thumbnail_size: tuple = (200, 200)
# Local storage
data_dir: str = "data"
cache_dir: str = "cache"
@dataclass
class SecurityConfig:
"""安全配置"""
use_tls: bool = True
cert_file: Optional[str] = None
key_file: Optional[str] = None
# Encryption
encryption_algorithm: str = "AES-256-GCM"
key_derivation: str = "PBKDF2"
key_iterations: int = 100000
def load_server_config() -> ServerConfig:
"""从环境变量加载服务器配置"""
return ServerConfig(
host=os.getenv("P2P_SERVER_HOST", "0.0.0.0"),
port=int(os.getenv("P2P_SERVER_PORT", "8888")),
max_connections=int(os.getenv("P2P_MAX_CONNECTIONS", "1000")),
db_host=os.getenv("P2P_DB_HOST", "localhost"),
db_port=int(os.getenv("P2P_DB_PORT", "3306")),
db_name=os.getenv("P2P_DB_NAME", "p2p_chat"),
db_user=os.getenv("P2P_DB_USER", "root"),
db_password=os.getenv("P2P_DB_PASSWORD", ""),
)
def load_client_config() -> ClientConfig:
"""从环境变量加载客户端配置"""
return ClientConfig(
server_host=os.getenv("P2P_SERVER_HOST", "127.0.0.1"),
server_port=int(os.getenv("P2P_SERVER_PORT", "8888")),
)
def load_security_config() -> SecurityConfig:
"""从环境变量加载安全配置"""
return SecurityConfig(
use_tls=os.getenv("P2P_USE_TLS", "true").lower() == "true",
cert_file=os.getenv("P2P_CERT_FILE"),
key_file=os.getenv("P2P_KEY_FILE"),
)

@ -0,0 +1,28 @@
# P2P Network Communication Application Dependencies
# Async networking
asyncio-mqtt>=0.16.0
aiomysql>=0.2.0
# GUI Framework
PyQt6>=6.5.0
# Audio/Video Processing
PyAudio>=0.2.13
opencv-python>=4.8.0
ffmpeg-python>=0.2.0
# Encryption
cryptography>=41.0.0
pycryptodome>=3.19.0
# Database
mysql-connector-python>=8.2.0
# Testing
pytest>=7.4.0
pytest-asyncio>=0.21.0
hypothesis>=6.88.0
# Utilities
python-dotenv>=1.0.0

@ -0,0 +1,7 @@
# P2P Network Communication - Server Module
"""
服务器端模块部署在OpenEuler系统上
负责用户管理消息中转和离线消息缓存
"""
__version__ = "0.1.0"

@ -0,0 +1,47 @@
# P2P Network Communication - Shared Module
"""
共享模块包含客户端和服务器共用的数据模型和枚举类型
"""
__version__ = "0.1.0"
from shared.models import (
MessageType,
UserStatus,
TransferStatus,
ConnectionMode,
NetworkQuality,
Message,
UserInfo,
ChatMessage,
FileChunk,
TransferProgress,
FileTransferRecord,
PeerInfo,
)
from shared.message_handler import (
MessageHandler,
MessageValidationError,
MessageSerializationError,
MessageRoutingError,
)
__all__ = [
"MessageType",
"UserStatus",
"TransferStatus",
"ConnectionMode",
"NetworkQuality",
"Message",
"UserInfo",
"ChatMessage",
"FileChunk",
"TransferProgress",
"FileTransferRecord",
"PeerInfo",
"MessageHandler",
"MessageValidationError",
"MessageSerializationError",
"MessageRoutingError",
]

@ -0,0 +1,272 @@
# P2P Network Communication - Message Handler
"""
消息处理器模块
负责消息的序列化反序列化校验和路由
"""
import json
import struct
import hashlib
import logging
from typing import Callable, Dict, Optional, Any
from dataclasses import dataclass
from shared.models import Message, MessageType
# 设置日志
logger = logging.getLogger(__name__)
class MessageValidationError(Exception):
"""消息验证错误"""
pass
class MessageSerializationError(Exception):
"""消息序列化错误"""
pass
class MessageRoutingError(Exception):
"""消息路由错误"""
pass
# 消息处理回调类型
MessageCallback = Callable[[Message], None]
class MessageHandler:
"""
消息处理器
负责处理和路由各类消息包括
- 消息序列化/反序列化
- 消息校验
- 消息路由到对应处理器
需求: 3.1, 3.2
"""
# 消息头格式: 4字节长度 + 4字节版本
HEADER_FORMAT = "!II" # Network byte order, 2 unsigned ints
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
PROTOCOL_VERSION = 1
def __init__(self):
"""初始化消息处理器"""
self._handlers: Dict[MessageType, MessageCallback] = {}
self._default_handler: Optional[MessageCallback] = None
def serialize(self, message: Message) -> bytes:
"""
序列化消息为字节流
格式: [4字节长度][4字节版本][JSON数据]
Args:
message: 要序列化的消息对象
Returns:
序列化后的字节流
Raises:
MessageSerializationError: 序列化失败时抛出
"""
try:
# 将消息转换为JSON字节流
json_data = json.dumps(message.to_dict(), ensure_ascii=False).encode('utf-8')
# 构建消息头
header = struct.pack(
self.HEADER_FORMAT,
len(json_data),
self.PROTOCOL_VERSION
)
return header + json_data
except (TypeError, ValueError, json.JSONEncodeError) as e:
raise MessageSerializationError(f"Failed to serialize message: {e}")
def deserialize(self, data: bytes) -> Message:
"""
反序列化字节流为消息
Args:
data: 包含消息头和消息体的字节流
Returns:
反序列化后的消息对象
Raises:
MessageSerializationError: 反序列化失败时抛出
"""
try:
if len(data) < self.HEADER_SIZE:
raise MessageSerializationError(
f"Data too short: expected at least {self.HEADER_SIZE} bytes, got {len(data)}"
)
# 解析消息头
header = data[:self.HEADER_SIZE]
payload_length, version = struct.unpack(self.HEADER_FORMAT, header)
# 版本检查
if version != self.PROTOCOL_VERSION:
raise MessageSerializationError(
f"Unsupported protocol version: {version}, expected {self.PROTOCOL_VERSION}"
)
# 检查数据长度
expected_length = self.HEADER_SIZE + payload_length
if len(data) < expected_length:
raise MessageSerializationError(
f"Incomplete message: expected {expected_length} bytes, got {len(data)}"
)
# 解析JSON数据
json_data = data[self.HEADER_SIZE:self.HEADER_SIZE + payload_length]
message_dict = json.loads(json_data.decode('utf-8'))
return Message.from_dict(message_dict)
except json.JSONDecodeError as e:
raise MessageSerializationError(f"Failed to parse JSON: {e}")
except (KeyError, ValueError) as e:
raise MessageSerializationError(f"Invalid message format: {e}")
except struct.error as e:
raise MessageSerializationError(f"Failed to unpack header: {e}")
def validate(self, message: Message) -> bool:
"""
验证消息完整性
检查项:
- 校验和是否正确
- 必要字段是否存在
- 消息类型是否有效
Args:
message: 要验证的消息对象
Returns:
验证通过返回True否则返回False
Raises:
MessageValidationError: 验证失败时抛出包含详细错误信息
"""
errors = []
# 检查必要字段
if not message.sender_id:
errors.append("sender_id is required")
if not message.receiver_id:
errors.append("receiver_id is required")
if message.timestamp <= 0:
errors.append("timestamp must be positive")
# 检查消息类型
if not isinstance(message.msg_type, MessageType):
errors.append(f"Invalid message type: {message.msg_type}")
# 验证校验和
if not message.verify_checksum():
errors.append("Checksum verification failed")
if errors:
raise MessageValidationError("; ".join(errors))
return True
def register_handler(self, msg_type: MessageType, handler: MessageCallback) -> None:
"""
注册消息处理器
Args:
msg_type: 消息类型
handler: 处理回调函数
"""
self._handlers[msg_type] = handler
logger.debug(f"Registered handler for message type: {msg_type.value}")
def unregister_handler(self, msg_type: MessageType) -> None:
"""
注销消息处理器
Args:
msg_type: 消息类型
"""
if msg_type in self._handlers:
del self._handlers[msg_type]
logger.debug(f"Unregistered handler for message type: {msg_type.value}")
def set_default_handler(self, handler: Optional[MessageCallback]) -> None:
"""
设置默认处理器用于未注册类型的消息
Args:
handler: 默认处理回调函数None表示清除
"""
self._default_handler = handler
def route(self, message: Message) -> None:
"""
路由消息到对应处理器
根据消息类型将消息分发到已注册的处理器
如果没有找到对应的处理器则使用默认处理器
如果没有默认处理器则抛出异常
Args:
message: 要路由的消息
Raises:
MessageRoutingError: 找不到处理器时抛出
MessageValidationError: 消息验证失败时抛出
"""
# 先验证消息
self.validate(message)
# 查找处理器
handler = self._handlers.get(message.msg_type)
if handler is None:
handler = self._default_handler
if handler is None:
raise MessageRoutingError(
f"No handler registered for message type: {message.msg_type.value}"
)
# 调用处理器
try:
handler(message)
logger.debug(f"Routed message {message.message_id} to handler for {message.msg_type.value}")
except Exception as e:
logger.error(f"Handler error for message {message.message_id}: {e}")
raise
def has_handler(self, msg_type: MessageType) -> bool:
"""
检查是否有指定类型的处理器
Args:
msg_type: 消息类型
Returns:
如果有处理器返回True否则返回False
"""
return msg_type in self._handlers or self._default_handler is not None
def get_registered_types(self) -> list:
"""
获取所有已注册处理器的消息类型
Returns:
消息类型列表
"""
return list(self._handlers.keys())

@ -0,0 +1,394 @@
# P2P Network Communication - Data Models
"""
数据模型和枚举类型定义
包含消息类型用户状态传输状态等核心数据结构
"""
import hashlib
import json
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
from typing import Optional
class MessageType(Enum):
"""消息类型枚举"""
TEXT = "text"
FILE_REQUEST = "file_request"
FILE_CHUNK = "file_chunk"
FILE_COMPLETE = "file_complete"
IMAGE = "image"
AUDIO_STREAM = "audio_stream"
VIDEO_STREAM = "video_stream"
VOICE_CALL_REQUEST = "voice_call_request"
VOICE_CALL_ACCEPT = "voice_call_accept"
VOICE_CALL_REJECT = "voice_call_reject"
VOICE_CALL_END = "voice_call_end"
VOICE_DATA = "voice_data"
HEARTBEAT = "heartbeat"
USER_REGISTER = "user_register"
USER_UNREGISTER = "user_unregister"
USER_LIST_REQUEST = "user_list_request"
USER_LIST_RESPONSE = "user_list_response"
ACK = "ack"
ERROR = "error"
class UserStatus(Enum):
"""用户状态枚举"""
ONLINE = "online"
OFFLINE = "offline"
BUSY = "busy"
AWAY = "away"
class TransferStatus(Enum):
"""传输状态枚举"""
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
PAUSED = "paused"
class ConnectionMode(Enum):
"""连接模式枚举"""
P2P = "p2p" # 局域网直连
RELAY = "relay" # 服务器中转
UNKNOWN = "unknown" # 未知
class NetworkQuality(Enum):
"""网络质量枚举"""
EXCELLENT = "excellent" # 延迟 < 50ms
GOOD = "good" # 延迟 50-100ms
FAIR = "fair" # 延迟 100-200ms
POOR = "poor" # 延迟 200-300ms
BAD = "bad" # 延迟 > 300ms
@dataclass
class Message:
"""消息数据结构"""
msg_type: MessageType
sender_id: str
receiver_id: str
timestamp: float
payload: bytes
checksum: str = field(default="")
message_id: str = field(default="")
def __post_init__(self):
"""初始化后计算校验和和消息ID"""
if not self.checksum:
self.checksum = self._calculate_checksum()
if not self.message_id:
self.message_id = self._generate_message_id()
def _calculate_checksum(self) -> str:
"""计算消息校验和"""
data = f"{self.msg_type.value}{self.sender_id}{self.receiver_id}{self.timestamp}"
data_bytes = data.encode('utf-8') + self.payload
return hashlib.md5(data_bytes).hexdigest()
def _generate_message_id(self) -> str:
"""生成消息ID"""
data = f"{self.sender_id}{self.receiver_id}{self.timestamp}{self.checksum}"
return hashlib.sha256(data.encode('utf-8')).hexdigest()[:32]
def verify_checksum(self) -> bool:
"""验证消息校验和"""
return self.checksum == self._calculate_checksum()
def to_dict(self) -> dict:
"""转换为字典"""
return {
"msg_type": self.msg_type.value,
"sender_id": self.sender_id,
"receiver_id": self.receiver_id,
"timestamp": self.timestamp,
"payload": self.payload.hex(), # bytes转hex字符串
"checksum": self.checksum,
"message_id": self.message_id,
}
@classmethod
def from_dict(cls, data: dict) -> "Message":
"""从字典创建Message对象"""
return cls(
msg_type=MessageType(data["msg_type"]),
sender_id=data["sender_id"],
receiver_id=data["receiver_id"],
timestamp=data["timestamp"],
payload=bytes.fromhex(data["payload"]),
checksum=data.get("checksum", ""),
message_id=data.get("message_id", ""),
)
def serialize(self) -> bytes:
"""序列化消息为字节流"""
return json.dumps(self.to_dict()).encode('utf-8')
@classmethod
def deserialize(cls, data: bytes) -> "Message":
"""反序列化字节流为消息"""
return cls.from_dict(json.loads(data.decode('utf-8')))
@dataclass
class UserInfo:
"""用户信息"""
user_id: str
username: str
display_name: str
status: UserStatus = UserStatus.OFFLINE
ip_address: str = ""
port: int = 0
last_seen: Optional[datetime] = None
public_key: bytes = field(default_factory=bytes)
def to_dict(self) -> dict:
"""转换为字典"""
return {
"user_id": self.user_id,
"username": self.username,
"display_name": self.display_name,
"status": self.status.value,
"ip_address": self.ip_address,
"port": self.port,
"last_seen": self.last_seen.isoformat() if self.last_seen else None,
"public_key": self.public_key.hex() if self.public_key else "",
}
@classmethod
def from_dict(cls, data: dict) -> "UserInfo":
"""从字典创建UserInfo对象"""
last_seen = None
if data.get("last_seen"):
last_seen = datetime.fromisoformat(data["last_seen"])
public_key = bytes()
if data.get("public_key"):
public_key = bytes.fromhex(data["public_key"])
return cls(
user_id=data["user_id"],
username=data["username"],
display_name=data["display_name"],
status=UserStatus(data.get("status", "offline")),
ip_address=data.get("ip_address", ""),
port=data.get("port", 0),
last_seen=last_seen,
public_key=public_key,
)
def serialize(self) -> bytes:
"""序列化为字节流"""
return json.dumps(self.to_dict()).encode('utf-8')
@classmethod
def deserialize(cls, data: bytes) -> "UserInfo":
"""反序列化字节流"""
return cls.from_dict(json.loads(data.decode('utf-8')))
@dataclass
class ChatMessage:
"""聊天消息记录"""
message_id: str
sender_id: str
receiver_id: str
content_type: MessageType
content: str # 文本内容或文件路径
timestamp: datetime = field(default_factory=datetime.now)
is_read: bool = False
is_sent: bool = False
def to_dict(self) -> dict:
"""转换为字典"""
return {
"message_id": self.message_id,
"sender_id": self.sender_id,
"receiver_id": self.receiver_id,
"content_type": self.content_type.value,
"content": self.content,
"timestamp": self.timestamp.isoformat(),
"is_read": self.is_read,
"is_sent": self.is_sent,
}
@classmethod
def from_dict(cls, data: dict) -> "ChatMessage":
"""从字典创建ChatMessage对象"""
return cls(
message_id=data["message_id"],
sender_id=data["sender_id"],
receiver_id=data["receiver_id"],
content_type=MessageType(data["content_type"]),
content=data["content"],
timestamp=datetime.fromisoformat(data["timestamp"]),
is_read=data.get("is_read", False),
is_sent=data.get("is_sent", False),
)
def serialize(self) -> bytes:
"""序列化为字节流"""
return json.dumps(self.to_dict()).encode('utf-8')
@classmethod
def deserialize(cls, data: bytes) -> "ChatMessage":
"""反序列化字节流"""
return cls.from_dict(json.loads(data.decode('utf-8')))
@dataclass
class FileChunk:
"""文件块数据结构"""
file_id: str
chunk_index: int
total_chunks: int
data: bytes
checksum: str = field(default="")
def __post_init__(self):
"""初始化后计算校验和"""
if not self.checksum:
self.checksum = hashlib.md5(self.data).hexdigest()
def verify_checksum(self) -> bool:
"""验证数据块校验和"""
return self.checksum == hashlib.md5(self.data).hexdigest()
def to_dict(self) -> dict:
"""转换为字典"""
return {
"file_id": self.file_id,
"chunk_index": self.chunk_index,
"total_chunks": self.total_chunks,
"data": self.data.hex(),
"checksum": self.checksum,
}
@classmethod
def from_dict(cls, data: dict) -> "FileChunk":
"""从字典创建FileChunk对象"""
return cls(
file_id=data["file_id"],
chunk_index=data["chunk_index"],
total_chunks=data["total_chunks"],
data=bytes.fromhex(data["data"]),
checksum=data.get("checksum", ""),
)
@dataclass
class TransferProgress:
"""传输进度信息"""
file_id: str
file_name: str
total_size: int
transferred_size: int
speed: float = 0.0 # bytes per second
eta: float = 0.0 # estimated time remaining in seconds
@property
def progress_percent(self) -> float:
"""获取进度百分比"""
if self.total_size == 0:
return 0.0
return (self.transferred_size / self.total_size) * 100
def to_dict(self) -> dict:
"""转换为字典"""
return {
"file_id": self.file_id,
"file_name": self.file_name,
"total_size": self.total_size,
"transferred_size": self.transferred_size,
"speed": self.speed,
"eta": self.eta,
"progress_percent": self.progress_percent,
}
@dataclass
class FileTransferRecord:
"""文件传输记录"""
transfer_id: str
file_name: str
file_size: int
file_hash: str
sender_id: str
receiver_id: str
status: TransferStatus = TransferStatus.PENDING
progress: float = 0.0
start_time: datetime = field(default_factory=datetime.now)
end_time: Optional[datetime] = None
def to_dict(self) -> dict:
"""转换为字典"""
return {
"transfer_id": self.transfer_id,
"file_name": self.file_name,
"file_size": self.file_size,
"file_hash": self.file_hash,
"sender_id": self.sender_id,
"receiver_id": self.receiver_id,
"status": self.status.value,
"progress": self.progress,
"start_time": self.start_time.isoformat(),
"end_time": self.end_time.isoformat() if self.end_time else None,
}
@classmethod
def from_dict(cls, data: dict) -> "FileTransferRecord":
"""从字典创建FileTransferRecord对象"""
end_time = None
if data.get("end_time"):
end_time = datetime.fromisoformat(data["end_time"])
return cls(
transfer_id=data["transfer_id"],
file_name=data["file_name"],
file_size=data["file_size"],
file_hash=data["file_hash"],
sender_id=data["sender_id"],
receiver_id=data["receiver_id"],
status=TransferStatus(data.get("status", "pending")),
progress=data.get("progress", 0.0),
start_time=datetime.fromisoformat(data["start_time"]),
end_time=end_time,
)
@dataclass
class PeerInfo:
"""对等端信息(用于局域网发现)"""
peer_id: str
username: str
ip_address: str
port: int
discovered_at: datetime = field(default_factory=datetime.now)
def to_dict(self) -> dict:
"""转换为字典"""
return {
"peer_id": self.peer_id,
"username": self.username,
"ip_address": self.ip_address,
"port": self.port,
"discovered_at": self.discovered_at.isoformat(),
}
@classmethod
def from_dict(cls, data: dict) -> "PeerInfo":
"""从字典创建PeerInfo对象"""
return cls(
peer_id=data["peer_id"],
username=data["username"],
ip_address=data["ip_address"],
port=data["port"],
discovered_at=datetime.fromisoformat(data["discovered_at"]),
)

@ -0,0 +1 @@
# P2P Network Communication - Tests Package

@ -0,0 +1,211 @@
# P2P Network Communication - Message Handler Tests
"""
测试消息处理器的基本功能
"""
import pytest
from shared.models import Message, MessageType
from shared.message_handler import (
MessageHandler,
MessageValidationError,
MessageSerializationError,
MessageRoutingError
)
class TestMessageHandlerSerialization:
"""测试消息处理器的序列化功能"""
def setup_method(self):
"""每个测试方法前初始化"""
self.handler = MessageHandler()
def test_serialize_text_message(self):
"""测试文本消息序列化"""
msg = Message(
msg_type=MessageType.TEXT,
sender_id="user1",
receiver_id="user2",
timestamp=1234567890.0,
payload=b"Hello, World!"
)
serialized = self.handler.serialize(msg)
assert isinstance(serialized, bytes)
assert len(serialized) > MessageHandler.HEADER_SIZE
def test_deserialize_text_message(self):
"""测试文本消息反序列化"""
original = Message(
msg_type=MessageType.TEXT,
sender_id="sender",
receiver_id="receiver",
timestamp=9876543210.0,
payload=b"Test payload"
)
serialized = self.handler.serialize(original)
restored = self.handler.deserialize(serialized)
assert restored.msg_type == original.msg_type
assert restored.sender_id == original.sender_id
assert restored.receiver_id == original.receiver_id
assert restored.timestamp == original.timestamp
assert restored.payload == original.payload
def test_deserialize_short_data_raises_error(self):
"""测试数据过短时抛出错误"""
with pytest.raises(MessageSerializationError):
self.handler.deserialize(b"short")
def test_deserialize_invalid_json_raises_error(self):
"""测试无效JSON时抛出错误"""
import struct
# 创建有效头部但无效JSON的数据
invalid_json = b"not valid json"
header = struct.pack("!II", len(invalid_json), 1)
with pytest.raises(MessageSerializationError):
self.handler.deserialize(header + invalid_json)
class TestMessageHandlerValidation:
"""测试消息处理器的验证功能"""
def setup_method(self):
"""每个测试方法前初始化"""
self.handler = MessageHandler()
def test_validate_valid_message(self):
"""测试验证有效消息"""
msg = Message(
msg_type=MessageType.TEXT,
sender_id="user1",
receiver_id="user2",
timestamp=1234567890.0,
payload=b"Valid message"
)
assert self.handler.validate(msg) is True
def test_validate_empty_sender_raises_error(self):
"""测试空发送者ID时抛出错误"""
msg = Message(
msg_type=MessageType.TEXT,
sender_id="",
receiver_id="user2",
timestamp=1234567890.0,
payload=b"Test"
)
with pytest.raises(MessageValidationError):
self.handler.validate(msg)
def test_validate_empty_receiver_raises_error(self):
"""测试空接收者ID时抛出错误"""
msg = Message(
msg_type=MessageType.TEXT,
sender_id="user1",
receiver_id="",
timestamp=1234567890.0,
payload=b"Test"
)
with pytest.raises(MessageValidationError):
self.handler.validate(msg)
def test_validate_invalid_timestamp_raises_error(self):
"""测试无效时间戳时抛出错误"""
msg = Message(
msg_type=MessageType.TEXT,
sender_id="user1",
receiver_id="user2",
timestamp=-1.0,
payload=b"Test"
)
with pytest.raises(MessageValidationError):
self.handler.validate(msg)
class TestMessageHandlerRouting:
"""测试消息处理器的路由功能"""
def setup_method(self):
"""每个测试方法前初始化"""
self.handler = MessageHandler()
self.received_messages = []
def test_register_and_route_handler(self):
"""测试注册和路由处理器"""
def text_handler(msg):
self.received_messages.append(msg)
self.handler.register_handler(MessageType.TEXT, text_handler)
msg = Message(
msg_type=MessageType.TEXT,
sender_id="user1",
receiver_id="user2",
timestamp=1234567890.0,
payload=b"Routed message"
)
self.handler.route(msg)
assert len(self.received_messages) == 1
assert self.received_messages[0].payload == b"Routed message"
def test_route_without_handler_raises_error(self):
"""测试没有处理器时路由抛出错误"""
msg = Message(
msg_type=MessageType.FILE_REQUEST,
sender_id="user1",
receiver_id="user2",
timestamp=1234567890.0,
payload=b"File request"
)
with pytest.raises(MessageRoutingError):
self.handler.route(msg)
def test_default_handler(self):
"""测试默认处理器"""
def default_handler(msg):
self.received_messages.append(msg)
self.handler.set_default_handler(default_handler)
msg = Message(
msg_type=MessageType.HEARTBEAT,
sender_id="user1",
receiver_id="user2",
timestamp=1234567890.0,
payload=b""
)
self.handler.route(msg)
assert len(self.received_messages) == 1
def test_unregister_handler(self):
"""测试注销处理器"""
def handler(msg):
self.received_messages.append(msg)
self.handler.register_handler(MessageType.TEXT, handler)
self.handler.unregister_handler(MessageType.TEXT)
assert not self.handler.has_handler(MessageType.TEXT)
def test_get_registered_types(self):
"""测试获取已注册的消息类型"""
self.handler.register_handler(MessageType.TEXT, lambda m: None)
self.handler.register_handler(MessageType.FILE_REQUEST, lambda m: None)
registered = self.handler.get_registered_types()
assert MessageType.TEXT in registered
assert MessageType.FILE_REQUEST in registered
assert len(registered) == 2

@ -0,0 +1,250 @@
# P2P Network Communication - Data Models Tests
"""
测试数据模型的基本功能
"""
import pytest
from datetime import datetime
from shared.models import (
MessageType, UserStatus, TransferStatus, ConnectionMode, NetworkQuality,
Message, UserInfo, ChatMessage, FileChunk, TransferProgress, FileTransferRecord, PeerInfo
)
class TestMessageType:
"""测试消息类型枚举"""
def test_message_type_values(self):
"""测试消息类型枚举值"""
assert MessageType.TEXT.value == "text"
assert MessageType.FILE_REQUEST.value == "file_request"
assert MessageType.HEARTBEAT.value == "heartbeat"
assert MessageType.VOICE_DATA.value == "voice_data"
class TestUserStatus:
"""测试用户状态枚举"""
def test_user_status_values(self):
"""测试用户状态枚举值"""
assert UserStatus.ONLINE.value == "online"
assert UserStatus.OFFLINE.value == "offline"
assert UserStatus.BUSY.value == "busy"
assert UserStatus.AWAY.value == "away"
class TestTransferStatus:
"""测试传输状态枚举"""
def test_transfer_status_values(self):
"""测试传输状态枚举值"""
assert TransferStatus.PENDING.value == "pending"
assert TransferStatus.IN_PROGRESS.value == "in_progress"
assert TransferStatus.COMPLETED.value == "completed"
assert TransferStatus.FAILED.value == "failed"
assert TransferStatus.CANCELLED.value == "cancelled"
class TestMessage:
"""测试Message数据类"""
def test_message_creation(self):
"""测试消息创建"""
msg = Message(
msg_type=MessageType.TEXT,
sender_id="user1",
receiver_id="user2",
timestamp=1234567890.0,
payload=b"Hello, World!"
)
assert msg.msg_type == MessageType.TEXT
assert msg.sender_id == "user1"
assert msg.receiver_id == "user2"
assert msg.timestamp == 1234567890.0
assert msg.payload == b"Hello, World!"
assert msg.checksum != ""
assert msg.message_id != ""
def test_message_checksum_verification(self):
"""测试消息校验和验证"""
msg = Message(
msg_type=MessageType.TEXT,
sender_id="user1",
receiver_id="user2",
timestamp=1234567890.0,
payload=b"Test message"
)
assert msg.verify_checksum() is True
def test_message_to_dict_and_from_dict(self):
"""测试消息字典转换"""
original = Message(
msg_type=MessageType.TEXT,
sender_id="user1",
receiver_id="user2",
timestamp=1234567890.0,
payload=b"Test"
)
msg_dict = original.to_dict()
restored = Message.from_dict(msg_dict)
assert restored.msg_type == original.msg_type
assert restored.sender_id == original.sender_id
assert restored.receiver_id == original.receiver_id
assert restored.timestamp == original.timestamp
assert restored.payload == original.payload
assert restored.checksum == original.checksum
assert restored.message_id == original.message_id
def test_message_serialize_deserialize(self):
"""测试消息序列化和反序列化"""
original = Message(
msg_type=MessageType.FILE_REQUEST,
sender_id="sender",
receiver_id="receiver",
timestamp=9876543210.0,
payload=b"file_data_here"
)
serialized = original.serialize()
restored = Message.deserialize(serialized)
assert restored.msg_type == original.msg_type
assert restored.sender_id == original.sender_id
assert restored.receiver_id == original.receiver_id
assert restored.payload == original.payload
class TestUserInfo:
"""测试UserInfo数据类"""
def test_user_info_creation(self):
"""测试用户信息创建"""
user = UserInfo(
user_id="uid123",
username="testuser",
display_name="Test User"
)
assert user.user_id == "uid123"
assert user.username == "testuser"
assert user.display_name == "Test User"
assert user.status == UserStatus.OFFLINE
def test_user_info_serialize_deserialize(self):
"""测试用户信息序列化和反序列化"""
original = UserInfo(
user_id="uid456",
username="user456",
display_name="User 456",
status=UserStatus.ONLINE,
ip_address="192.168.1.100",
port=8888
)
serialized = original.serialize()
restored = UserInfo.deserialize(serialized)
assert restored.user_id == original.user_id
assert restored.username == original.username
assert restored.status == original.status
assert restored.ip_address == original.ip_address
assert restored.port == original.port
class TestChatMessage:
"""测试ChatMessage数据类"""
def test_chat_message_creation(self):
"""测试聊天消息创建"""
chat_msg = ChatMessage(
message_id="msg001",
sender_id="user1",
receiver_id="user2",
content_type=MessageType.TEXT,
content="Hello!"
)
assert chat_msg.message_id == "msg001"
assert chat_msg.content == "Hello!"
assert chat_msg.is_read is False
assert chat_msg.is_sent is False
def test_chat_message_serialize_deserialize(self):
"""测试聊天消息序列化和反序列化"""
original = ChatMessage(
message_id="msg002",
sender_id="sender",
receiver_id="receiver",
content_type=MessageType.IMAGE,
content="/path/to/image.png",
is_read=True,
is_sent=True
)
serialized = original.serialize()
restored = ChatMessage.deserialize(serialized)
assert restored.message_id == original.message_id
assert restored.content_type == original.content_type
assert restored.content == original.content
assert restored.is_read == original.is_read
assert restored.is_sent == original.is_sent
class TestFileChunk:
"""测试FileChunk数据类"""
def test_file_chunk_creation(self):
"""测试文件块创建"""
chunk = FileChunk(
file_id="file001",
chunk_index=0,
total_chunks=10,
data=b"chunk_data_here"
)
assert chunk.file_id == "file001"
assert chunk.chunk_index == 0
assert chunk.total_chunks == 10
assert chunk.checksum != ""
def test_file_chunk_checksum_verification(self):
"""测试文件块校验和验证"""
chunk = FileChunk(
file_id="file002",
chunk_index=1,
total_chunks=5,
data=b"test_chunk_data"
)
assert chunk.verify_checksum() is True
class TestTransferProgress:
"""测试TransferProgress数据类"""
def test_transfer_progress_percent(self):
"""测试传输进度百分比计算"""
progress = TransferProgress(
file_id="file001",
file_name="test.txt",
total_size=1000,
transferred_size=500
)
assert progress.progress_percent == 50.0
def test_transfer_progress_zero_total(self):
"""测试总大小为0时的进度计算"""
progress = TransferProgress(
file_id="file002",
file_name="empty.txt",
total_size=0,
transferred_size=0
)
assert progress.progress_percent == 0.0
Loading…
Cancel
Save