parent
842aa282cc
commit
f9bd845a11
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,331 @@
|
||||
# P2P Network Communication - Connection Manager Tests
|
||||
"""
|
||||
测试客户端连接管理器的基本功能
|
||||
|
||||
需求: 1.1, 1.2, 1.3, 1.4, 1.5, 1.6
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
from client.connection_manager import (
|
||||
ConnectionManager,
|
||||
ConnectionState,
|
||||
ConnectionError,
|
||||
Connection,
|
||||
)
|
||||
from shared.models import (
|
||||
Message, MessageType, UserInfo, UserStatus,
|
||||
ConnectionMode, PeerInfo
|
||||
)
|
||||
from config import ClientConfig
|
||||
|
||||
|
||||
class TestConnectionManagerInit:
|
||||
"""测试连接管理器初始化"""
|
||||
|
||||
def test_init_with_default_config(self):
|
||||
"""测试使用默认配置初始化"""
|
||||
cm = ConnectionManager()
|
||||
|
||||
assert cm.config is not None
|
||||
assert cm.state == ConnectionState.DISCONNECTED
|
||||
assert cm.is_connected is False
|
||||
assert cm.user_id == ""
|
||||
|
||||
def test_init_with_custom_config(self):
|
||||
"""测试使用自定义配置初始化"""
|
||||
config = ClientConfig(
|
||||
server_host="192.168.1.100",
|
||||
server_port=9999,
|
||||
heartbeat_interval=60
|
||||
)
|
||||
cm = ConnectionManager(config)
|
||||
|
||||
assert cm.config.server_host == "192.168.1.100"
|
||||
assert cm.config.server_port == 9999
|
||||
assert cm.config.heartbeat_interval == 60
|
||||
|
||||
|
||||
class TestConnectionState:
|
||||
"""测试连接状态管理"""
|
||||
|
||||
def test_initial_state_is_disconnected(self):
|
||||
"""测试初始状态为断开"""
|
||||
cm = ConnectionManager()
|
||||
assert cm.state == ConnectionState.DISCONNECTED
|
||||
|
||||
def test_state_callback_registration(self):
|
||||
"""测试状态回调注册"""
|
||||
cm = ConnectionManager()
|
||||
callback_called = []
|
||||
|
||||
def state_callback(state, reason):
|
||||
callback_called.append((state, reason))
|
||||
|
||||
cm.add_state_callback(state_callback)
|
||||
cm._set_state(ConnectionState.CONNECTING, "Test")
|
||||
|
||||
assert len(callback_called) == 1
|
||||
assert callback_called[0][0] == ConnectionState.CONNECTING
|
||||
assert callback_called[0][1] == "Test"
|
||||
|
||||
def test_state_callback_removal(self):
|
||||
"""测试状态回调移除"""
|
||||
cm = ConnectionManager()
|
||||
callback_called = []
|
||||
|
||||
def state_callback(state, reason):
|
||||
callback_called.append((state, reason))
|
||||
|
||||
cm.add_state_callback(state_callback)
|
||||
cm.remove_state_callback(state_callback)
|
||||
cm._set_state(ConnectionState.CONNECTING, "Test")
|
||||
|
||||
assert len(callback_called) == 0
|
||||
|
||||
|
||||
class TestConnectionMode:
|
||||
"""测试连接模式选择 (需求 1.1, 1.2, 1.3)"""
|
||||
|
||||
def test_default_mode_is_relay(self):
|
||||
"""测试默认模式为中转"""
|
||||
cm = ConnectionManager()
|
||||
mode = cm.get_connection_mode("unknown_peer")
|
||||
|
||||
assert mode == ConnectionMode.RELAY
|
||||
|
||||
def test_mode_is_p2p_for_discovered_peer(self):
|
||||
"""测试已发现的LAN对等端使用P2P模式"""
|
||||
cm = ConnectionManager()
|
||||
|
||||
# 模拟发现对等端
|
||||
peer_info = PeerInfo(
|
||||
peer_id="peer1",
|
||||
username="test_peer",
|
||||
ip_address="192.168.1.100",
|
||||
port=8889
|
||||
)
|
||||
cm._discovered_peers["peer1"] = peer_info
|
||||
|
||||
mode = cm.get_connection_mode("peer1")
|
||||
assert mode == ConnectionMode.P2P
|
||||
|
||||
def test_mode_is_p2p_for_connected_peer(self):
|
||||
"""测试已连接的P2P对等端返回P2P模式"""
|
||||
cm = ConnectionManager()
|
||||
|
||||
# 模拟P2P连接
|
||||
conn = Connection(
|
||||
peer_id="peer1",
|
||||
reader=MagicMock(),
|
||||
writer=MagicMock(),
|
||||
mode=ConnectionMode.P2P,
|
||||
ip_address="192.168.1.100",
|
||||
port=8889
|
||||
)
|
||||
cm._peer_connections["peer1"] = conn
|
||||
|
||||
mode = cm.get_connection_mode("peer1")
|
||||
assert mode == ConnectionMode.P2P
|
||||
|
||||
|
||||
class TestReconnectMechanism:
|
||||
"""测试网络重连机制 (需求 1.6)"""
|
||||
|
||||
def test_reconnect_status_initial(self):
|
||||
"""测试初始重连状态"""
|
||||
cm = ConnectionManager()
|
||||
status = cm.get_reconnect_status()
|
||||
|
||||
assert status["is_reconnecting"] is False
|
||||
assert status["reconnect_attempts"] == 0
|
||||
assert status["auto_reconnect_enabled"] is True
|
||||
|
||||
def test_enable_disable_reconnect(self):
|
||||
"""测试启用/禁用自动重连"""
|
||||
cm = ConnectionManager()
|
||||
|
||||
cm.enable_reconnect(False)
|
||||
assert cm._should_reconnect is False
|
||||
|
||||
cm.enable_reconnect(True)
|
||||
assert cm._should_reconnect is True
|
||||
|
||||
def test_set_reconnect_config(self):
|
||||
"""测试配置重连参数"""
|
||||
cm = ConnectionManager()
|
||||
|
||||
cm.set_reconnect_config(max_attempts=5, base_delay=2.0)
|
||||
|
||||
assert cm.config.reconnect_attempts == 5
|
||||
assert cm.config.reconnect_delay == 2.0
|
||||
|
||||
|
||||
class TestMessageCallbacks:
|
||||
"""测试消息回调"""
|
||||
|
||||
def test_message_callback_registration(self):
|
||||
"""测试消息回调注册"""
|
||||
cm = ConnectionManager()
|
||||
messages_received = []
|
||||
|
||||
def msg_callback(msg):
|
||||
messages_received.append(msg)
|
||||
|
||||
cm.add_message_callback(msg_callback)
|
||||
|
||||
# 模拟接收消息
|
||||
test_msg = Message(
|
||||
msg_type=MessageType.TEXT,
|
||||
sender_id="user1",
|
||||
receiver_id="user2",
|
||||
timestamp=1234567890.0,
|
||||
payload=b"Test"
|
||||
)
|
||||
|
||||
# 直接调用回调
|
||||
for callback in cm._message_callbacks:
|
||||
callback(test_msg)
|
||||
|
||||
assert len(messages_received) == 1
|
||||
assert messages_received[0].payload == b"Test"
|
||||
|
||||
def test_message_callback_removal(self):
|
||||
"""测试消息回调移除"""
|
||||
cm = ConnectionManager()
|
||||
messages_received = []
|
||||
|
||||
def msg_callback(msg):
|
||||
messages_received.append(msg)
|
||||
|
||||
cm.add_message_callback(msg_callback)
|
||||
cm.remove_message_callback(msg_callback)
|
||||
|
||||
assert len(cm._message_callbacks) == 0
|
||||
|
||||
|
||||
class TestConnectionStats:
|
||||
"""测试连接统计"""
|
||||
|
||||
def test_connection_stats_initial(self):
|
||||
"""测试初始连接统计"""
|
||||
cm = ConnectionManager()
|
||||
stats = cm.get_connection_stats()
|
||||
|
||||
assert stats["server_connected"] is False
|
||||
assert stats["state"] == "disconnected"
|
||||
assert stats["total_peer_connections"] == 0
|
||||
assert stats["p2p_connections"] == 0
|
||||
assert stats["discovered_peers"] == 0
|
||||
|
||||
def test_connection_stats_with_peers(self):
|
||||
"""测试有对等端时的连接统计"""
|
||||
cm = ConnectionManager()
|
||||
|
||||
# 添加发现的对等端
|
||||
cm._discovered_peers["peer1"] = PeerInfo(
|
||||
peer_id="peer1",
|
||||
username="test",
|
||||
ip_address="192.168.1.100",
|
||||
port=8889
|
||||
)
|
||||
|
||||
# 添加P2P连接
|
||||
cm._peer_connections["peer2"] = Connection(
|
||||
peer_id="peer2",
|
||||
reader=MagicMock(),
|
||||
writer=MagicMock(),
|
||||
mode=ConnectionMode.P2P,
|
||||
ip_address="192.168.1.101",
|
||||
port=8889
|
||||
)
|
||||
|
||||
stats = cm.get_connection_stats()
|
||||
|
||||
assert stats["total_peer_connections"] == 1
|
||||
assert stats["p2p_connections"] == 1
|
||||
assert stats["discovered_peers"] == 1
|
||||
|
||||
|
||||
class TestPeerConnectionInfo:
|
||||
"""测试对等端连接信息"""
|
||||
|
||||
def test_get_peer_info_for_connected_peer(self):
|
||||
"""测试获取已连接对等端信息"""
|
||||
cm = ConnectionManager()
|
||||
|
||||
conn = Connection(
|
||||
peer_id="peer1",
|
||||
reader=MagicMock(),
|
||||
writer=MagicMock(),
|
||||
mode=ConnectionMode.P2P,
|
||||
ip_address="192.168.1.100",
|
||||
port=8889
|
||||
)
|
||||
cm._peer_connections["peer1"] = conn
|
||||
|
||||
info = cm.get_peer_connection_info("peer1")
|
||||
|
||||
assert info is not None
|
||||
assert info["peer_id"] == "peer1"
|
||||
assert info["mode"] == "p2p"
|
||||
assert info["ip_address"] == "192.168.1.100"
|
||||
|
||||
def test_get_peer_info_for_discovered_peer(self):
|
||||
"""测试获取已发现对等端信息"""
|
||||
cm = ConnectionManager()
|
||||
|
||||
cm._discovered_peers["peer1"] = PeerInfo(
|
||||
peer_id="peer1",
|
||||
username="test",
|
||||
ip_address="192.168.1.100",
|
||||
port=8889
|
||||
)
|
||||
|
||||
info = cm.get_peer_connection_info("peer1")
|
||||
|
||||
assert info is not None
|
||||
assert info["peer_id"] == "peer1"
|
||||
assert info["mode"] == "discovered"
|
||||
|
||||
def test_get_peer_info_for_unknown_peer(self):
|
||||
"""测试获取未知对等端信息"""
|
||||
cm = ConnectionManager()
|
||||
|
||||
info = cm.get_peer_connection_info("unknown")
|
||||
|
||||
assert info is None
|
||||
|
||||
|
||||
class TestAllConnectionModes:
|
||||
"""测试获取所有连接模式"""
|
||||
|
||||
def test_get_all_connection_modes(self):
|
||||
"""测试获取所有对等端的连接模式"""
|
||||
cm = ConnectionManager()
|
||||
|
||||
# 添加P2P连接
|
||||
cm._peer_connections["peer1"] = Connection(
|
||||
peer_id="peer1",
|
||||
reader=MagicMock(),
|
||||
writer=MagicMock(),
|
||||
mode=ConnectionMode.P2P,
|
||||
ip_address="192.168.1.100",
|
||||
port=8889
|
||||
)
|
||||
|
||||
# 添加发现的对等端
|
||||
cm._discovered_peers["peer2"] = PeerInfo(
|
||||
peer_id="peer2",
|
||||
username="test",
|
||||
ip_address="192.168.1.101",
|
||||
port=8889
|
||||
)
|
||||
|
||||
modes = cm.get_all_connection_modes()
|
||||
|
||||
assert modes["peer1"] == ConnectionMode.P2P
|
||||
assert modes["peer2"] == ConnectionMode.P2P
|
||||
@ -0,0 +1,580 @@
|
||||
# P2P Network Communication - File Transfer Module Tests
|
||||
"""
|
||||
文件传输模块测试
|
||||
测试文件分块、传输、断点续传和完整性校验功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from client.file_transfer import (
|
||||
FileTransferModule,
|
||||
FileTransferError,
|
||||
FileNotFoundError,
|
||||
FileIntegrityError,
|
||||
TransferCancelledError,
|
||||
TransferState,
|
||||
)
|
||||
from shared.models import (
|
||||
Message, MessageType, FileChunk, TransferProgress,
|
||||
TransferStatus
|
||||
)
|
||||
from config import ClientConfig
|
||||
|
||||
|
||||
class TestFileTransferModuleInit:
|
||||
"""文件传输模块初始化测试"""
|
||||
|
||||
def test_init_with_default_config(self):
|
||||
"""测试使用默认配置初始化"""
|
||||
module = FileTransferModule()
|
||||
|
||||
assert module.config is not None
|
||||
assert module.CHUNK_SIZE == 64 * 1024
|
||||
assert len(module._active_transfers) == 0
|
||||
|
||||
def test_init_with_custom_config(self):
|
||||
"""测试使用自定义配置初始化"""
|
||||
config = ClientConfig(chunk_size=32 * 1024)
|
||||
module = FileTransferModule(config=config)
|
||||
|
||||
assert module.config.chunk_size == 32 * 1024
|
||||
|
||||
def test_set_send_message_func(self):
|
||||
"""测试设置发送消息函数"""
|
||||
module = FileTransferModule()
|
||||
mock_func = AsyncMock()
|
||||
|
||||
module.set_send_message_func(mock_func)
|
||||
|
||||
assert module._send_message == mock_func
|
||||
|
||||
|
||||
class TestFileHashCalculation:
|
||||
"""文件哈希计算测试"""
|
||||
|
||||
def test_calculate_file_hash_sha256(self):
|
||||
"""测试SHA256哈希计算"""
|
||||
module = FileTransferModule()
|
||||
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
f.write(b"Hello, World!")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
hash_value = module.calculate_file_hash(temp_path, "sha256")
|
||||
|
||||
# SHA256 of "Hello, World!"
|
||||
expected = "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f"
|
||||
assert hash_value == expected
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_calculate_file_hash_md5(self):
|
||||
"""测试MD5哈希计算"""
|
||||
module = FileTransferModule()
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
f.write(b"Hello, World!")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
hash_value = module.calculate_file_hash(temp_path, "md5")
|
||||
|
||||
# MD5 of "Hello, World!"
|
||||
expected = "65a8e27d8879283831b664bd8b7f0ad4"
|
||||
assert hash_value == expected
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_calculate_file_hash_nonexistent_file(self):
|
||||
"""测试计算不存在文件的哈希"""
|
||||
module = FileTransferModule()
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
module.calculate_file_hash("/nonexistent/file.txt")
|
||||
|
||||
def test_calculate_chunk_hash(self):
|
||||
"""测试数据块哈希计算"""
|
||||
module = FileTransferModule()
|
||||
|
||||
data = b"Test chunk data"
|
||||
hash_value = module.calculate_chunk_hash(data)
|
||||
|
||||
import hashlib
|
||||
expected = hashlib.md5(data).hexdigest()
|
||||
assert hash_value == expected
|
||||
|
||||
def test_verify_file_integrity_success(self):
|
||||
"""测试文件完整性验证成功"""
|
||||
module = FileTransferModule()
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
f.write(b"Test content")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
expected_hash = module.calculate_file_hash(temp_path)
|
||||
result = module.verify_file_integrity(temp_path, expected_hash)
|
||||
|
||||
assert result is True
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_verify_file_integrity_failure(self):
|
||||
"""测试文件完整性验证失败"""
|
||||
module = FileTransferModule()
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
f.write(b"Test content")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
result = module.verify_file_integrity(temp_path, "wrong_hash")
|
||||
|
||||
assert result is False
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
|
||||
class TestFileChunking:
|
||||
"""文件分块测试"""
|
||||
|
||||
def test_get_total_chunks_small_file(self):
|
||||
"""测试小文件的块数计算"""
|
||||
module = FileTransferModule()
|
||||
|
||||
# 小于一个块
|
||||
assert module._get_total_chunks(1000) == 1
|
||||
|
||||
# 正好一个块
|
||||
assert module._get_total_chunks(64 * 1024) == 1
|
||||
|
||||
def test_get_total_chunks_large_file(self):
|
||||
"""测试大文件的块数计算"""
|
||||
module = FileTransferModule()
|
||||
|
||||
# 两个块
|
||||
assert module._get_total_chunks(64 * 1024 + 1) == 2
|
||||
|
||||
# 多个块
|
||||
assert module._get_total_chunks(256 * 1024) == 4
|
||||
|
||||
def test_read_chunk(self):
|
||||
"""测试读取文件块"""
|
||||
module = FileTransferModule()
|
||||
|
||||
# 创建测试文件
|
||||
test_data = b"A" * 100 + b"B" * 100
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
f.write(test_data)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
chunk = module._read_chunk(temp_path, 0)
|
||||
assert chunk == test_data
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_write_chunk(self):
|
||||
"""测试写入文件块"""
|
||||
module = FileTransferModule()
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
# 写入第一个块
|
||||
module._write_chunk(temp_path, 0, b"First chunk")
|
||||
|
||||
with open(temp_path, 'rb') as f:
|
||||
content = f.read()
|
||||
|
||||
assert content == b"First chunk"
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_split_file_to_chunks(self):
|
||||
"""测试文件分块"""
|
||||
module = FileTransferModule()
|
||||
|
||||
# 创建测试文件(大于一个块)
|
||||
test_data = b"X" * (64 * 1024 + 100)
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
f.write(test_data)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
chunks = module.split_file_to_chunks(temp_path)
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].chunk_index == 0
|
||||
assert chunks[1].chunk_index == 1
|
||||
assert chunks[0].total_chunks == 2
|
||||
assert len(chunks[0].data) == 64 * 1024
|
||||
assert len(chunks[1].data) == 100
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_split_nonexistent_file(self):
|
||||
"""测试分块不存在的文件"""
|
||||
module = FileTransferModule()
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
module.split_file_to_chunks("/nonexistent/file.txt")
|
||||
|
||||
|
||||
|
||||
class TestTransferState:
|
||||
"""传输状态测试"""
|
||||
|
||||
def test_transfer_state_creation(self):
|
||||
"""测试传输状态创建"""
|
||||
state = TransferState(
|
||||
file_id="test-id",
|
||||
file_path="/path/to/file.txt",
|
||||
file_name="file.txt",
|
||||
file_size=1000,
|
||||
file_hash="abc123",
|
||||
total_chunks=2
|
||||
)
|
||||
|
||||
assert state.file_id == "test-id"
|
||||
assert state.file_name == "file.txt"
|
||||
assert state.status == TransferStatus.PENDING
|
||||
assert len(state.completed_chunks) == 0
|
||||
|
||||
def test_transfer_state_progress_percent(self):
|
||||
"""测试进度百分比计算"""
|
||||
state = TransferState(
|
||||
file_id="test-id",
|
||||
file_path="/path/to/file.txt",
|
||||
file_name="file.txt",
|
||||
file_size=1000,
|
||||
file_hash="abc123",
|
||||
total_chunks=4,
|
||||
completed_chunks=[0, 1]
|
||||
)
|
||||
|
||||
assert state.progress_percent == 50.0
|
||||
|
||||
def test_transfer_state_to_dict_and_from_dict(self):
|
||||
"""测试状态序列化和反序列化"""
|
||||
state = TransferState(
|
||||
file_id="test-id",
|
||||
file_path="/path/to/file.txt",
|
||||
file_name="file.txt",
|
||||
file_size=1000,
|
||||
file_hash="abc123",
|
||||
total_chunks=2,
|
||||
completed_chunks=[0],
|
||||
status=TransferStatus.IN_PROGRESS
|
||||
)
|
||||
|
||||
state_dict = state.to_dict()
|
||||
restored = TransferState.from_dict(state_dict)
|
||||
|
||||
assert restored.file_id == state.file_id
|
||||
assert restored.file_name == state.file_name
|
||||
assert restored.status == state.status
|
||||
assert restored.completed_chunks == state.completed_chunks
|
||||
|
||||
|
||||
class TestTransferManagement:
|
||||
"""传输管理测试"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清理状态"""
|
||||
# 清理可能存在的状态文件
|
||||
import shutil
|
||||
state_dir = Path("data/transfer_states")
|
||||
if state_dir.exists():
|
||||
shutil.rmtree(state_dir)
|
||||
|
||||
def test_cancel_transfer(self):
|
||||
"""测试取消传输"""
|
||||
module = FileTransferModule()
|
||||
|
||||
# 创建一个传输状态
|
||||
state = TransferState(
|
||||
file_id="test-id",
|
||||
file_path="/path/to/file.txt",
|
||||
file_name="file.txt",
|
||||
file_size=1000,
|
||||
file_hash="abc123",
|
||||
total_chunks=2
|
||||
)
|
||||
module._active_transfers["test-id"] = state
|
||||
|
||||
module.cancel_transfer("test-id")
|
||||
|
||||
assert "test-id" in module._cancelled_transfers
|
||||
assert state.status == TransferStatus.CANCELLED
|
||||
|
||||
def test_pause_transfer(self):
|
||||
"""测试暂停传输"""
|
||||
module = FileTransferModule()
|
||||
|
||||
state = TransferState(
|
||||
file_id="test-id",
|
||||
file_path="/path/to/file.txt",
|
||||
file_name="file.txt",
|
||||
file_size=1000,
|
||||
file_hash="abc123",
|
||||
total_chunks=2,
|
||||
status=TransferStatus.IN_PROGRESS
|
||||
)
|
||||
module._active_transfers["test-id"] = state
|
||||
|
||||
result = module.pause_transfer("test-id")
|
||||
|
||||
assert result is True
|
||||
assert state.status == TransferStatus.PAUSED
|
||||
|
||||
def test_pause_transfer_not_in_progress(self):
|
||||
"""测试暂停非进行中的传输"""
|
||||
module = FileTransferModule()
|
||||
|
||||
state = TransferState(
|
||||
file_id="test-id",
|
||||
file_path="/path/to/file.txt",
|
||||
file_name="file.txt",
|
||||
file_size=1000,
|
||||
file_hash="abc123",
|
||||
total_chunks=2,
|
||||
status=TransferStatus.COMPLETED
|
||||
)
|
||||
module._active_transfers["test-id"] = state
|
||||
|
||||
result = module.pause_transfer("test-id")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_get_transfer_progress(self):
|
||||
"""测试获取传输进度"""
|
||||
module = FileTransferModule()
|
||||
|
||||
state = TransferState(
|
||||
file_id="test-id",
|
||||
file_path="/path/to/file.txt",
|
||||
file_name="file.txt",
|
||||
file_size=128 * 1024, # 2 chunks
|
||||
file_hash="abc123",
|
||||
total_chunks=2,
|
||||
completed_chunks=[0]
|
||||
)
|
||||
module._active_transfers["test-id"] = state
|
||||
|
||||
progress = module.get_transfer_progress("test-id")
|
||||
|
||||
assert progress is not None
|
||||
assert progress.file_id == "test-id"
|
||||
assert progress.file_name == "file.txt"
|
||||
assert progress.total_size == 128 * 1024
|
||||
|
||||
def test_get_transfer_progress_nonexistent(self):
|
||||
"""测试获取不存在的传输进度"""
|
||||
module = FileTransferModule()
|
||||
|
||||
progress = module.get_transfer_progress("nonexistent")
|
||||
|
||||
assert progress is None
|
||||
|
||||
def test_get_all_transfers(self):
|
||||
"""测试获取所有传输"""
|
||||
module = FileTransferModule()
|
||||
|
||||
state1 = TransferState(
|
||||
file_id="id1",
|
||||
file_path="/path/1.txt",
|
||||
file_name="1.txt",
|
||||
file_size=1000,
|
||||
file_hash="hash1",
|
||||
total_chunks=1
|
||||
)
|
||||
state2 = TransferState(
|
||||
file_id="id2",
|
||||
file_path="/path/2.txt",
|
||||
file_name="2.txt",
|
||||
file_size=2000,
|
||||
file_hash="hash2",
|
||||
total_chunks=1
|
||||
)
|
||||
|
||||
module._active_transfers["id1"] = state1
|
||||
module._active_transfers["id2"] = state2
|
||||
|
||||
transfers = module.get_all_transfers()
|
||||
|
||||
assert len(transfers) == 2
|
||||
|
||||
def test_get_pending_transfers(self):
|
||||
"""测试获取待恢复的传输"""
|
||||
module = FileTransferModule()
|
||||
|
||||
state1 = TransferState(
|
||||
file_id="id1",
|
||||
file_path="/path/1.txt",
|
||||
file_name="1.txt",
|
||||
file_size=1000,
|
||||
file_hash="hash1",
|
||||
total_chunks=1,
|
||||
status=TransferStatus.PAUSED
|
||||
)
|
||||
state2 = TransferState(
|
||||
file_id="id2",
|
||||
file_path="/path/2.txt",
|
||||
file_name="2.txt",
|
||||
file_size=2000,
|
||||
file_hash="hash2",
|
||||
total_chunks=1,
|
||||
status=TransferStatus.COMPLETED
|
||||
)
|
||||
|
||||
module._active_transfers["id1"] = state1
|
||||
module._active_transfers["id2"] = state2
|
||||
|
||||
pending = module.get_pending_transfers()
|
||||
|
||||
assert len(pending) == 1
|
||||
assert pending[0].file_id == "id1"
|
||||
|
||||
|
||||
class TestMessageHandling:
|
||||
"""消息处理测试"""
|
||||
|
||||
def test_handle_file_request(self):
|
||||
"""测试处理文件请求消息"""
|
||||
module = FileTransferModule()
|
||||
|
||||
import json
|
||||
payload = json.dumps({
|
||||
"file_id": "test-file-id",
|
||||
"file_name": "test.txt",
|
||||
"file_size": 1000,
|
||||
"file_hash": "abc123",
|
||||
"total_chunks": 2
|
||||
}).encode('utf-8')
|
||||
|
||||
message = Message(
|
||||
msg_type=MessageType.FILE_REQUEST,
|
||||
sender_id="sender",
|
||||
receiver_id="receiver",
|
||||
timestamp=1234567890.0,
|
||||
payload=payload
|
||||
)
|
||||
|
||||
file_id = module.handle_file_request(message)
|
||||
|
||||
assert file_id == "test-file-id"
|
||||
assert "test-file-id" in module._active_transfers
|
||||
assert "test-file-id" in module._receive_buffers
|
||||
|
||||
def test_handle_file_chunk(self):
|
||||
"""测试处理文件块消息"""
|
||||
module = FileTransferModule()
|
||||
|
||||
# 先创建传输状态
|
||||
state = TransferState(
|
||||
file_id="test-file-id",
|
||||
file_path="",
|
||||
file_name="test.txt",
|
||||
file_size=1000,
|
||||
file_hash="abc123",
|
||||
total_chunks=2
|
||||
)
|
||||
module._active_transfers["test-file-id"] = state
|
||||
module._receive_buffers["test-file-id"] = {}
|
||||
|
||||
import json
|
||||
chunk_data = b"Test chunk data"
|
||||
payload = json.dumps({
|
||||
"file_id": "test-file-id",
|
||||
"chunk_index": 0,
|
||||
"total_chunks": 2,
|
||||
"checksum": module.calculate_chunk_hash(chunk_data),
|
||||
"data": chunk_data.hex()
|
||||
}).encode('utf-8')
|
||||
|
||||
message = Message(
|
||||
msg_type=MessageType.FILE_CHUNK,
|
||||
sender_id="sender",
|
||||
receiver_id="receiver",
|
||||
timestamp=1234567890.0,
|
||||
payload=payload
|
||||
)
|
||||
|
||||
result = module.handle_file_chunk(message)
|
||||
|
||||
assert result is True
|
||||
assert 0 in module._receive_buffers["test-file-id"]
|
||||
assert module._receive_buffers["test-file-id"][0] == chunk_data
|
||||
assert 0 in state.completed_chunks
|
||||
|
||||
def test_handle_file_chunk_checksum_mismatch(self):
|
||||
"""测试处理校验和不匹配的文件块"""
|
||||
module = FileTransferModule()
|
||||
|
||||
state = TransferState(
|
||||
file_id="test-file-id",
|
||||
file_path="",
|
||||
file_name="test.txt",
|
||||
file_size=1000,
|
||||
file_hash="abc123",
|
||||
total_chunks=2
|
||||
)
|
||||
module._active_transfers["test-file-id"] = state
|
||||
module._receive_buffers["test-file-id"] = {}
|
||||
|
||||
import json
|
||||
chunk_data = b"Test chunk data"
|
||||
payload = json.dumps({
|
||||
"file_id": "test-file-id",
|
||||
"chunk_index": 0,
|
||||
"total_chunks": 2,
|
||||
"checksum": "wrong_checksum",
|
||||
"data": chunk_data.hex()
|
||||
}).encode('utf-8')
|
||||
|
||||
message = Message(
|
||||
msg_type=MessageType.FILE_CHUNK,
|
||||
sender_id="sender",
|
||||
receiver_id="receiver",
|
||||
timestamp=1234567890.0,
|
||||
payload=payload
|
||||
)
|
||||
|
||||
result = module.handle_file_chunk(message)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestCleanup:
|
||||
"""清理功能测试"""
|
||||
|
||||
def test_clear_all_transfers(self):
|
||||
"""测试清除所有传输"""
|
||||
module = FileTransferModule()
|
||||
|
||||
state = TransferState(
|
||||
file_id="test-id",
|
||||
file_path="/path/to/file.txt",
|
||||
file_name="file.txt",
|
||||
file_size=1000,
|
||||
file_hash="abc123",
|
||||
total_chunks=2
|
||||
)
|
||||
module._active_transfers["test-id"] = state
|
||||
module._receive_buffers["test-id"] = {}
|
||||
module._cancelled_transfers.add("test-id")
|
||||
|
||||
module.clear_all_transfers()
|
||||
|
||||
assert len(module._active_transfers) == 0
|
||||
assert len(module._receive_buffers) == 0
|
||||
assert len(module._cancelled_transfers) == 0
|
||||
Loading…
Reference in new issue