parent
cab4f6cbde
commit
d6b92b67dc
@ -0,0 +1,5 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
||||
.kiro/*
|
||||
|
||||
@ -0,0 +1,28 @@
|
||||
# P2P Network Communication Application Dependencies
|
||||
|
||||
# Async networking
|
||||
asyncio-mqtt>=0.16.0
|
||||
aiomysql>=0.2.0
|
||||
|
||||
# GUI Framework
|
||||
PyQt6>=6.5.0
|
||||
|
||||
# Audio/Video Processing
|
||||
PyAudio>=0.2.13
|
||||
opencv-python>=4.8.0
|
||||
ffmpeg-python>=0.2.0
|
||||
|
||||
# Encryption
|
||||
cryptography>=41.0.0
|
||||
pycryptodome>=3.19.0
|
||||
|
||||
# Database
|
||||
mysql-connector-python>=8.2.0
|
||||
|
||||
# Testing
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.21.0
|
||||
hypothesis>=6.88.0
|
||||
|
||||
# Utilities
|
||||
python-dotenv>=1.0.0
|
||||
@ -0,0 +1,394 @@
|
||||
# P2P Network Communication - Data Models
|
||||
"""
|
||||
数据模型和枚举类型定义
|
||||
包含消息类型、用户状态、传输状态等核心数据结构
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""消息类型枚举"""
|
||||
TEXT = "text"
|
||||
FILE_REQUEST = "file_request"
|
||||
FILE_CHUNK = "file_chunk"
|
||||
FILE_COMPLETE = "file_complete"
|
||||
IMAGE = "image"
|
||||
AUDIO_STREAM = "audio_stream"
|
||||
VIDEO_STREAM = "video_stream"
|
||||
VOICE_CALL_REQUEST = "voice_call_request"
|
||||
VOICE_CALL_ACCEPT = "voice_call_accept"
|
||||
VOICE_CALL_REJECT = "voice_call_reject"
|
||||
VOICE_CALL_END = "voice_call_end"
|
||||
VOICE_DATA = "voice_data"
|
||||
HEARTBEAT = "heartbeat"
|
||||
USER_REGISTER = "user_register"
|
||||
USER_UNREGISTER = "user_unregister"
|
||||
USER_LIST_REQUEST = "user_list_request"
|
||||
USER_LIST_RESPONSE = "user_list_response"
|
||||
ACK = "ack"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class UserStatus(Enum):
|
||||
"""用户状态枚举"""
|
||||
ONLINE = "online"
|
||||
OFFLINE = "offline"
|
||||
BUSY = "busy"
|
||||
AWAY = "away"
|
||||
|
||||
|
||||
class TransferStatus(Enum):
|
||||
"""传输状态枚举"""
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class ConnectionMode(Enum):
|
||||
"""连接模式枚举"""
|
||||
P2P = "p2p" # 局域网直连
|
||||
RELAY = "relay" # 服务器中转
|
||||
UNKNOWN = "unknown" # 未知
|
||||
|
||||
|
||||
class NetworkQuality(Enum):
|
||||
"""网络质量枚举"""
|
||||
EXCELLENT = "excellent" # 延迟 < 50ms
|
||||
GOOD = "good" # 延迟 50-100ms
|
||||
FAIR = "fair" # 延迟 100-200ms
|
||||
POOR = "poor" # 延迟 200-300ms
|
||||
BAD = "bad" # 延迟 > 300ms
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""消息数据结构"""
|
||||
msg_type: MessageType
|
||||
sender_id: str
|
||||
receiver_id: str
|
||||
timestamp: float
|
||||
payload: bytes
|
||||
checksum: str = field(default="")
|
||||
message_id: str = field(default="")
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后计算校验和和消息ID"""
|
||||
if not self.checksum:
|
||||
self.checksum = self._calculate_checksum()
|
||||
if not self.message_id:
|
||||
self.message_id = self._generate_message_id()
|
||||
|
||||
def _calculate_checksum(self) -> str:
|
||||
"""计算消息校验和"""
|
||||
data = f"{self.msg_type.value}{self.sender_id}{self.receiver_id}{self.timestamp}"
|
||||
data_bytes = data.encode('utf-8') + self.payload
|
||||
return hashlib.md5(data_bytes).hexdigest()
|
||||
|
||||
def _generate_message_id(self) -> str:
|
||||
"""生成消息ID"""
|
||||
data = f"{self.sender_id}{self.receiver_id}{self.timestamp}{self.checksum}"
|
||||
return hashlib.sha256(data.encode('utf-8')).hexdigest()[:32]
|
||||
|
||||
def verify_checksum(self) -> bool:
|
||||
"""验证消息校验和"""
|
||||
return self.checksum == self._calculate_checksum()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"msg_type": self.msg_type.value,
|
||||
"sender_id": self.sender_id,
|
||||
"receiver_id": self.receiver_id,
|
||||
"timestamp": self.timestamp,
|
||||
"payload": self.payload.hex(), # bytes转hex字符串
|
||||
"checksum": self.checksum,
|
||||
"message_id": self.message_id,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "Message":
|
||||
"""从字典创建Message对象"""
|
||||
return cls(
|
||||
msg_type=MessageType(data["msg_type"]),
|
||||
sender_id=data["sender_id"],
|
||||
receiver_id=data["receiver_id"],
|
||||
timestamp=data["timestamp"],
|
||||
payload=bytes.fromhex(data["payload"]),
|
||||
checksum=data.get("checksum", ""),
|
||||
message_id=data.get("message_id", ""),
|
||||
)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""序列化消息为字节流"""
|
||||
return json.dumps(self.to_dict()).encode('utf-8')
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, data: bytes) -> "Message":
|
||||
"""反序列化字节流为消息"""
|
||||
return cls.from_dict(json.loads(data.decode('utf-8')))
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInfo:
|
||||
"""用户信息"""
|
||||
user_id: str
|
||||
username: str
|
||||
display_name: str
|
||||
status: UserStatus = UserStatus.OFFLINE
|
||||
ip_address: str = ""
|
||||
port: int = 0
|
||||
last_seen: Optional[datetime] = None
|
||||
public_key: bytes = field(default_factory=bytes)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"username": self.username,
|
||||
"display_name": self.display_name,
|
||||
"status": self.status.value,
|
||||
"ip_address": self.ip_address,
|
||||
"port": self.port,
|
||||
"last_seen": self.last_seen.isoformat() if self.last_seen else None,
|
||||
"public_key": self.public_key.hex() if self.public_key else "",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "UserInfo":
|
||||
"""从字典创建UserInfo对象"""
|
||||
last_seen = None
|
||||
if data.get("last_seen"):
|
||||
last_seen = datetime.fromisoformat(data["last_seen"])
|
||||
|
||||
public_key = bytes()
|
||||
if data.get("public_key"):
|
||||
public_key = bytes.fromhex(data["public_key"])
|
||||
|
||||
return cls(
|
||||
user_id=data["user_id"],
|
||||
username=data["username"],
|
||||
display_name=data["display_name"],
|
||||
status=UserStatus(data.get("status", "offline")),
|
||||
ip_address=data.get("ip_address", ""),
|
||||
port=data.get("port", 0),
|
||||
last_seen=last_seen,
|
||||
public_key=public_key,
|
||||
)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""序列化为字节流"""
|
||||
return json.dumps(self.to_dict()).encode('utf-8')
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, data: bytes) -> "UserInfo":
|
||||
"""反序列化字节流"""
|
||||
return cls.from_dict(json.loads(data.decode('utf-8')))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""聊天消息记录"""
|
||||
message_id: str
|
||||
sender_id: str
|
||||
receiver_id: str
|
||||
content_type: MessageType
|
||||
content: str # 文本内容或文件路径
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
is_read: bool = False
|
||||
is_sent: bool = False
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"message_id": self.message_id,
|
||||
"sender_id": self.sender_id,
|
||||
"receiver_id": self.receiver_id,
|
||||
"content_type": self.content_type.value,
|
||||
"content": self.content,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"is_read": self.is_read,
|
||||
"is_sent": self.is_sent,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "ChatMessage":
|
||||
"""从字典创建ChatMessage对象"""
|
||||
return cls(
|
||||
message_id=data["message_id"],
|
||||
sender_id=data["sender_id"],
|
||||
receiver_id=data["receiver_id"],
|
||||
content_type=MessageType(data["content_type"]),
|
||||
content=data["content"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]),
|
||||
is_read=data.get("is_read", False),
|
||||
is_sent=data.get("is_sent", False),
|
||||
)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""序列化为字节流"""
|
||||
return json.dumps(self.to_dict()).encode('utf-8')
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, data: bytes) -> "ChatMessage":
|
||||
"""反序列化字节流"""
|
||||
return cls.from_dict(json.loads(data.decode('utf-8')))
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileChunk:
|
||||
"""文件块数据结构"""
|
||||
file_id: str
|
||||
chunk_index: int
|
||||
total_chunks: int
|
||||
data: bytes
|
||||
checksum: str = field(default="")
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后计算校验和"""
|
||||
if not self.checksum:
|
||||
self.checksum = hashlib.md5(self.data).hexdigest()
|
||||
|
||||
def verify_checksum(self) -> bool:
|
||||
"""验证数据块校验和"""
|
||||
return self.checksum == hashlib.md5(self.data).hexdigest()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"file_id": self.file_id,
|
||||
"chunk_index": self.chunk_index,
|
||||
"total_chunks": self.total_chunks,
|
||||
"data": self.data.hex(),
|
||||
"checksum": self.checksum,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "FileChunk":
|
||||
"""从字典创建FileChunk对象"""
|
||||
return cls(
|
||||
file_id=data["file_id"],
|
||||
chunk_index=data["chunk_index"],
|
||||
total_chunks=data["total_chunks"],
|
||||
data=bytes.fromhex(data["data"]),
|
||||
checksum=data.get("checksum", ""),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferProgress:
|
||||
"""传输进度信息"""
|
||||
file_id: str
|
||||
file_name: str
|
||||
total_size: int
|
||||
transferred_size: int
|
||||
speed: float = 0.0 # bytes per second
|
||||
eta: float = 0.0 # estimated time remaining in seconds
|
||||
|
||||
@property
|
||||
def progress_percent(self) -> float:
|
||||
"""获取进度百分比"""
|
||||
if self.total_size == 0:
|
||||
return 0.0
|
||||
return (self.transferred_size / self.total_size) * 100
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"file_id": self.file_id,
|
||||
"file_name": self.file_name,
|
||||
"total_size": self.total_size,
|
||||
"transferred_size": self.transferred_size,
|
||||
"speed": self.speed,
|
||||
"eta": self.eta,
|
||||
"progress_percent": self.progress_percent,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileTransferRecord:
|
||||
"""文件传输记录"""
|
||||
transfer_id: str
|
||||
file_name: str
|
||||
file_size: int
|
||||
file_hash: str
|
||||
sender_id: str
|
||||
receiver_id: str
|
||||
status: TransferStatus = TransferStatus.PENDING
|
||||
progress: float = 0.0
|
||||
start_time: datetime = field(default_factory=datetime.now)
|
||||
end_time: Optional[datetime] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"transfer_id": self.transfer_id,
|
||||
"file_name": self.file_name,
|
||||
"file_size": self.file_size,
|
||||
"file_hash": self.file_hash,
|
||||
"sender_id": self.sender_id,
|
||||
"receiver_id": self.receiver_id,
|
||||
"status": self.status.value,
|
||||
"progress": self.progress,
|
||||
"start_time": self.start_time.isoformat(),
|
||||
"end_time": self.end_time.isoformat() if self.end_time else None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "FileTransferRecord":
|
||||
"""从字典创建FileTransferRecord对象"""
|
||||
end_time = None
|
||||
if data.get("end_time"):
|
||||
end_time = datetime.fromisoformat(data["end_time"])
|
||||
|
||||
return cls(
|
||||
transfer_id=data["transfer_id"],
|
||||
file_name=data["file_name"],
|
||||
file_size=data["file_size"],
|
||||
file_hash=data["file_hash"],
|
||||
sender_id=data["sender_id"],
|
||||
receiver_id=data["receiver_id"],
|
||||
status=TransferStatus(data.get("status", "pending")),
|
||||
progress=data.get("progress", 0.0),
|
||||
start_time=datetime.fromisoformat(data["start_time"]),
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeerInfo:
|
||||
"""对等端信息(用于局域网发现)"""
|
||||
peer_id: str
|
||||
username: str
|
||||
ip_address: str
|
||||
port: int
|
||||
discovered_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"peer_id": self.peer_id,
|
||||
"username": self.username,
|
||||
"ip_address": self.ip_address,
|
||||
"port": self.port,
|
||||
"discovered_at": self.discovered_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "PeerInfo":
|
||||
"""从字典创建PeerInfo对象"""
|
||||
return cls(
|
||||
peer_id=data["peer_id"],
|
||||
username=data["username"],
|
||||
ip_address=data["ip_address"],
|
||||
port=data["port"],
|
||||
discovered_at=datetime.fromisoformat(data["discovered_at"]),
|
||||
)
|
||||
@ -0,0 +1 @@
|
||||
# P2P Network Communication - Tests Package
|
||||
@ -0,0 +1,211 @@
|
||||
# P2P Network Communication - Message Handler Tests
|
||||
"""
|
||||
测试消息处理器的基本功能
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from shared.models import Message, MessageType
|
||||
from shared.message_handler import (
|
||||
MessageHandler,
|
||||
MessageValidationError,
|
||||
MessageSerializationError,
|
||||
MessageRoutingError
|
||||
)
|
||||
|
||||
|
||||
class TestMessageHandlerSerialization:
|
||||
"""测试消息处理器的序列化功能"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试方法前初始化"""
|
||||
self.handler = MessageHandler()
|
||||
|
||||
def test_serialize_text_message(self):
|
||||
"""测试文本消息序列化"""
|
||||
msg = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="user1",
|
||||
receiver_id="user2",
|
||||
timestamp=1234567890.0,
|
||||
payload=b"Hello, World!"
|
||||
)
|
||||
|
||||
serialized = self.handler.serialize(msg)
|
||||
|
||||
assert isinstance(serialized, bytes)
|
||||
assert len(serialized) > MessageHandler.HEADER_SIZE
|
||||
|
||||
def test_deserialize_text_message(self):
|
||||
"""测试文本消息反序列化"""
|
||||
original = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="sender",
|
||||
receiver_id="receiver",
|
||||
timestamp=9876543210.0,
|
||||
payload=b"Test payload"
|
||||
)
|
||||
|
||||
serialized = self.handler.serialize(original)
|
||||
restored = self.handler.deserialize(serialized)
|
||||
|
||||
assert restored.msg_type == original.msg_type
|
||||
assert restored.sender_id == original.sender_id
|
||||
assert restored.receiver_id == original.receiver_id
|
||||
assert restored.timestamp == original.timestamp
|
||||
assert restored.payload == original.payload
|
||||
|
||||
def test_deserialize_short_data_raises_error(self):
|
||||
"""测试数据过短时抛出错误"""
|
||||
with pytest.raises(MessageSerializationError):
|
||||
self.handler.deserialize(b"short")
|
||||
|
||||
def test_deserialize_invalid_json_raises_error(self):
|
||||
"""测试无效JSON时抛出错误"""
|
||||
import struct
|
||||
# 创建有效头部但无效JSON的数据
|
||||
invalid_json = b"not valid json"
|
||||
header = struct.pack("!II", len(invalid_json), 1)
|
||||
|
||||
with pytest.raises(MessageSerializationError):
|
||||
self.handler.deserialize(header + invalid_json)
|
||||
|
||||
|
||||
class TestMessageHandlerValidation:
|
||||
"""测试消息处理器的验证功能"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试方法前初始化"""
|
||||
self.handler = MessageHandler()
|
||||
|
||||
def test_validate_valid_message(self):
|
||||
"""测试验证有效消息"""
|
||||
msg = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="user1",
|
||||
receiver_id="user2",
|
||||
timestamp=1234567890.0,
|
||||
payload=b"Valid message"
|
||||
)
|
||||
|
||||
assert self.handler.validate(msg) is True
|
||||
|
||||
def test_validate_empty_sender_raises_error(self):
|
||||
"""测试空发送者ID时抛出错误"""
|
||||
msg = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="",
|
||||
receiver_id="user2",
|
||||
timestamp=1234567890.0,
|
||||
payload=b"Test"
|
||||
)
|
||||
|
||||
with pytest.raises(MessageValidationError):
|
||||
self.handler.validate(msg)
|
||||
|
||||
def test_validate_empty_receiver_raises_error(self):
|
||||
"""测试空接收者ID时抛出错误"""
|
||||
msg = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="user1",
|
||||
receiver_id="",
|
||||
timestamp=1234567890.0,
|
||||
payload=b"Test"
|
||||
)
|
||||
|
||||
with pytest.raises(MessageValidationError):
|
||||
self.handler.validate(msg)
|
||||
|
||||
def test_validate_invalid_timestamp_raises_error(self):
|
||||
"""测试无效时间戳时抛出错误"""
|
||||
msg = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="user1",
|
||||
receiver_id="user2",
|
||||
timestamp=-1.0,
|
||||
payload=b"Test"
|
||||
)
|
||||
|
||||
with pytest.raises(MessageValidationError):
|
||||
self.handler.validate(msg)
|
||||
|
||||
|
||||
class TestMessageHandlerRouting:
|
||||
"""测试消息处理器的路由功能"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试方法前初始化"""
|
||||
self.handler = MessageHandler()
|
||||
self.received_messages = []
|
||||
|
||||
def test_register_and_route_handler(self):
|
||||
"""测试注册和路由处理器"""
|
||||
def text_handler(msg):
|
||||
self.received_messages.append(msg)
|
||||
|
||||
self.handler.register_handler(MessageType.TEXT, text_handler)
|
||||
|
||||
msg = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="user1",
|
||||
receiver_id="user2",
|
||||
timestamp=1234567890.0,
|
||||
payload=b"Routed message"
|
||||
)
|
||||
|
||||
self.handler.route(msg)
|
||||
|
||||
assert len(self.received_messages) == 1
|
||||
assert self.received_messages[0].payload == b"Routed message"
|
||||
|
||||
def test_route_without_handler_raises_error(self):
|
||||
"""测试没有处理器时路由抛出错误"""
|
||||
msg = Message(
|
||||
msg_type=MessageType.FILE_REQUEST,
|
||||
sender_id="user1",
|
||||
receiver_id="user2",
|
||||
timestamp=1234567890.0,
|
||||
payload=b"File request"
|
||||
)
|
||||
|
||||
with pytest.raises(MessageRoutingError):
|
||||
self.handler.route(msg)
|
||||
|
||||
def test_default_handler(self):
|
||||
"""测试默认处理器"""
|
||||
def default_handler(msg):
|
||||
self.received_messages.append(msg)
|
||||
|
||||
self.handler.set_default_handler(default_handler)
|
||||
|
||||
msg = Message(
|
||||
msg_type=MessageType.HEARTBEAT,
|
||||
sender_id="user1",
|
||||
receiver_id="user2",
|
||||
timestamp=1234567890.0,
|
||||
payload=b""
|
||||
)
|
||||
|
||||
self.handler.route(msg)
|
||||
|
||||
assert len(self.received_messages) == 1
|
||||
|
||||
def test_unregister_handler(self):
|
||||
"""测试注销处理器"""
|
||||
def handler(msg):
|
||||
self.received_messages.append(msg)
|
||||
|
||||
self.handler.register_handler(MessageType.TEXT, handler)
|
||||
self.handler.unregister_handler(MessageType.TEXT)
|
||||
|
||||
assert not self.handler.has_handler(MessageType.TEXT)
|
||||
|
||||
def test_get_registered_types(self):
|
||||
"""测试获取已注册的消息类型"""
|
||||
self.handler.register_handler(MessageType.TEXT, lambda m: None)
|
||||
self.handler.register_handler(MessageType.FILE_REQUEST, lambda m: None)
|
||||
|
||||
registered = self.handler.get_registered_types()
|
||||
|
||||
assert MessageType.TEXT in registered
|
||||
assert MessageType.FILE_REQUEST in registered
|
||||
assert len(registered) == 2
|
||||
@ -0,0 +1,250 @@
|
||||
# P2P Network Communication - Data Models Tests
|
||||
"""
|
||||
测试数据模型的基本功能
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from shared.models import (
|
||||
MessageType, UserStatus, TransferStatus, ConnectionMode, NetworkQuality,
|
||||
Message, UserInfo, ChatMessage, FileChunk, TransferProgress, FileTransferRecord, PeerInfo
|
||||
)
|
||||
|
||||
|
||||
class TestMessageType:
|
||||
"""测试消息类型枚举"""
|
||||
|
||||
def test_message_type_values(self):
|
||||
"""测试消息类型枚举值"""
|
||||
assert MessageType.TEXT.value == "text"
|
||||
assert MessageType.FILE_REQUEST.value == "file_request"
|
||||
assert MessageType.HEARTBEAT.value == "heartbeat"
|
||||
assert MessageType.VOICE_DATA.value == "voice_data"
|
||||
|
||||
|
||||
class TestUserStatus:
|
||||
"""测试用户状态枚举"""
|
||||
|
||||
def test_user_status_values(self):
|
||||
"""测试用户状态枚举值"""
|
||||
assert UserStatus.ONLINE.value == "online"
|
||||
assert UserStatus.OFFLINE.value == "offline"
|
||||
assert UserStatus.BUSY.value == "busy"
|
||||
assert UserStatus.AWAY.value == "away"
|
||||
|
||||
|
||||
class TestTransferStatus:
|
||||
"""测试传输状态枚举"""
|
||||
|
||||
def test_transfer_status_values(self):
|
||||
"""测试传输状态枚举值"""
|
||||
assert TransferStatus.PENDING.value == "pending"
|
||||
assert TransferStatus.IN_PROGRESS.value == "in_progress"
|
||||
assert TransferStatus.COMPLETED.value == "completed"
|
||||
assert TransferStatus.FAILED.value == "failed"
|
||||
assert TransferStatus.CANCELLED.value == "cancelled"
|
||||
|
||||
|
||||
class TestMessage:
|
||||
"""测试Message数据类"""
|
||||
|
||||
def test_message_creation(self):
|
||||
"""测试消息创建"""
|
||||
msg = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="user1",
|
||||
receiver_id="user2",
|
||||
timestamp=1234567890.0,
|
||||
payload=b"Hello, World!"
|
||||
)
|
||||
|
||||
assert msg.msg_type == MessageType.TEXT
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.receiver_id == "user2"
|
||||
assert msg.timestamp == 1234567890.0
|
||||
assert msg.payload == b"Hello, World!"
|
||||
assert msg.checksum != ""
|
||||
assert msg.message_id != ""
|
||||
|
||||
def test_message_checksum_verification(self):
|
||||
"""测试消息校验和验证"""
|
||||
msg = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="user1",
|
||||
receiver_id="user2",
|
||||
timestamp=1234567890.0,
|
||||
payload=b"Test message"
|
||||
)
|
||||
|
||||
assert msg.verify_checksum() is True
|
||||
|
||||
def test_message_to_dict_and_from_dict(self):
|
||||
"""测试消息字典转换"""
|
||||
original = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="user1",
|
||||
receiver_id="user2",
|
||||
timestamp=1234567890.0,
|
||||
payload=b"Test"
|
||||
)
|
||||
|
||||
msg_dict = original.to_dict()
|
||||
restored = Message.from_dict(msg_dict)
|
||||
|
||||
assert restored.msg_type == original.msg_type
|
||||
assert restored.sender_id == original.sender_id
|
||||
assert restored.receiver_id == original.receiver_id
|
||||
assert restored.timestamp == original.timestamp
|
||||
assert restored.payload == original.payload
|
||||
assert restored.checksum == original.checksum
|
||||
assert restored.message_id == original.message_id
|
||||
|
||||
def test_message_serialize_deserialize(self):
|
||||
"""测试消息序列化和反序列化"""
|
||||
original = Message(
|
||||
msg_type=MessageType.FILE_REQUEST,
|
||||
sender_id="sender",
|
||||
receiver_id="receiver",
|
||||
timestamp=9876543210.0,
|
||||
payload=b"file_data_here"
|
||||
)
|
||||
|
||||
serialized = original.serialize()
|
||||
restored = Message.deserialize(serialized)
|
||||
|
||||
assert restored.msg_type == original.msg_type
|
||||
assert restored.sender_id == original.sender_id
|
||||
assert restored.receiver_id == original.receiver_id
|
||||
assert restored.payload == original.payload
|
||||
|
||||
|
||||
class TestUserInfo:
|
||||
"""测试UserInfo数据类"""
|
||||
|
||||
def test_user_info_creation(self):
|
||||
"""测试用户信息创建"""
|
||||
user = UserInfo(
|
||||
user_id="uid123",
|
||||
username="testuser",
|
||||
display_name="Test User"
|
||||
)
|
||||
|
||||
assert user.user_id == "uid123"
|
||||
assert user.username == "testuser"
|
||||
assert user.display_name == "Test User"
|
||||
assert user.status == UserStatus.OFFLINE
|
||||
|
||||
def test_user_info_serialize_deserialize(self):
|
||||
"""测试用户信息序列化和反序列化"""
|
||||
original = UserInfo(
|
||||
user_id="uid456",
|
||||
username="user456",
|
||||
display_name="User 456",
|
||||
status=UserStatus.ONLINE,
|
||||
ip_address="192.168.1.100",
|
||||
port=8888
|
||||
)
|
||||
|
||||
serialized = original.serialize()
|
||||
restored = UserInfo.deserialize(serialized)
|
||||
|
||||
assert restored.user_id == original.user_id
|
||||
assert restored.username == original.username
|
||||
assert restored.status == original.status
|
||||
assert restored.ip_address == original.ip_address
|
||||
assert restored.port == original.port
|
||||
|
||||
|
||||
class TestChatMessage:
|
||||
"""测试ChatMessage数据类"""
|
||||
|
||||
def test_chat_message_creation(self):
|
||||
"""测试聊天消息创建"""
|
||||
chat_msg = ChatMessage(
|
||||
message_id="msg001",
|
||||
sender_id="user1",
|
||||
receiver_id="user2",
|
||||
content_type=MessageType.TEXT,
|
||||
content="Hello!"
|
||||
)
|
||||
|
||||
assert chat_msg.message_id == "msg001"
|
||||
assert chat_msg.content == "Hello!"
|
||||
assert chat_msg.is_read is False
|
||||
assert chat_msg.is_sent is False
|
||||
|
||||
def test_chat_message_serialize_deserialize(self):
|
||||
"""测试聊天消息序列化和反序列化"""
|
||||
original = ChatMessage(
|
||||
message_id="msg002",
|
||||
sender_id="sender",
|
||||
receiver_id="receiver",
|
||||
content_type=MessageType.IMAGE,
|
||||
content="/path/to/image.png",
|
||||
is_read=True,
|
||||
is_sent=True
|
||||
)
|
||||
|
||||
serialized = original.serialize()
|
||||
restored = ChatMessage.deserialize(serialized)
|
||||
|
||||
assert restored.message_id == original.message_id
|
||||
assert restored.content_type == original.content_type
|
||||
assert restored.content == original.content
|
||||
assert restored.is_read == original.is_read
|
||||
assert restored.is_sent == original.is_sent
|
||||
|
||||
|
||||
class TestFileChunk:
|
||||
"""测试FileChunk数据类"""
|
||||
|
||||
def test_file_chunk_creation(self):
|
||||
"""测试文件块创建"""
|
||||
chunk = FileChunk(
|
||||
file_id="file001",
|
||||
chunk_index=0,
|
||||
total_chunks=10,
|
||||
data=b"chunk_data_here"
|
||||
)
|
||||
|
||||
assert chunk.file_id == "file001"
|
||||
assert chunk.chunk_index == 0
|
||||
assert chunk.total_chunks == 10
|
||||
assert chunk.checksum != ""
|
||||
|
||||
def test_file_chunk_checksum_verification(self):
|
||||
"""测试文件块校验和验证"""
|
||||
chunk = FileChunk(
|
||||
file_id="file002",
|
||||
chunk_index=1,
|
||||
total_chunks=5,
|
||||
data=b"test_chunk_data"
|
||||
)
|
||||
|
||||
assert chunk.verify_checksum() is True
|
||||
|
||||
|
||||
class TestTransferProgress:
|
||||
"""测试TransferProgress数据类"""
|
||||
|
||||
def test_transfer_progress_percent(self):
|
||||
"""测试传输进度百分比计算"""
|
||||
progress = TransferProgress(
|
||||
file_id="file001",
|
||||
file_name="test.txt",
|
||||
total_size=1000,
|
||||
transferred_size=500
|
||||
)
|
||||
|
||||
assert progress.progress_percent == 50.0
|
||||
|
||||
def test_transfer_progress_zero_total(self):
|
||||
"""测试总大小为0时的进度计算"""
|
||||
progress = TransferProgress(
|
||||
file_id="file002",
|
||||
file_name="empty.txt",
|
||||
total_size=0,
|
||||
transferred_size=0
|
||||
)
|
||||
|
||||
assert progress.progress_percent == 0.0
|
||||
Loading…
Reference in new issue