You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

581 lines
17 KiB

# P2P Network Communication - File Transfer Module Tests
"""
文件传输模块测试
测试文件分块、传输、断点续传和完整性校验功能
"""
import asyncio
import os
import tempfile
import pytest
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from client.file_transfer import (
FileTransferModule,
FileTransferError,
FileNotFoundError,
FileIntegrityError,
TransferCancelledError,
TransferState,
)
from shared.models import (
Message, MessageType, FileChunk, TransferProgress,
TransferStatus
)
from config import ClientConfig
class TestFileTransferModuleInit:
"""文件传输模块初始化测试"""
def test_init_with_default_config(self):
"""测试使用默认配置初始化"""
module = FileTransferModule()
assert module.config is not None
assert module.CHUNK_SIZE == 64 * 1024
assert len(module._active_transfers) == 0
def test_init_with_custom_config(self):
"""测试使用自定义配置初始化"""
config = ClientConfig(chunk_size=32 * 1024)
module = FileTransferModule(config=config)
assert module.config.chunk_size == 32 * 1024
def test_set_send_message_func(self):
"""测试设置发送消息函数"""
module = FileTransferModule()
mock_func = AsyncMock()
module.set_send_message_func(mock_func)
assert module._send_message == mock_func
class TestFileHashCalculation:
"""文件哈希计算测试"""
def test_calculate_file_hash_sha256(self):
"""测试SHA256哈希计算"""
module = FileTransferModule()
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(b"Hello, World!")
temp_path = f.name
try:
hash_value = module.calculate_file_hash(temp_path, "sha256")
# SHA256 of "Hello, World!"
expected = "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f"
assert hash_value == expected
finally:
os.unlink(temp_path)
def test_calculate_file_hash_md5(self):
"""测试MD5哈希计算"""
module = FileTransferModule()
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(b"Hello, World!")
temp_path = f.name
try:
hash_value = module.calculate_file_hash(temp_path, "md5")
# MD5 of "Hello, World!"
expected = "65a8e27d8879283831b664bd8b7f0ad4"
assert hash_value == expected
finally:
os.unlink(temp_path)
def test_calculate_file_hash_nonexistent_file(self):
"""测试计算不存在文件的哈希"""
module = FileTransferModule()
with pytest.raises(FileNotFoundError):
module.calculate_file_hash("/nonexistent/file.txt")
def test_calculate_chunk_hash(self):
"""测试数据块哈希计算"""
module = FileTransferModule()
data = b"Test chunk data"
hash_value = module.calculate_chunk_hash(data)
import hashlib
expected = hashlib.md5(data).hexdigest()
assert hash_value == expected
def test_verify_file_integrity_success(self):
"""测试文件完整性验证成功"""
module = FileTransferModule()
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(b"Test content")
temp_path = f.name
try:
expected_hash = module.calculate_file_hash(temp_path)
result = module.verify_file_integrity(temp_path, expected_hash)
assert result is True
finally:
os.unlink(temp_path)
def test_verify_file_integrity_failure(self):
"""测试文件完整性验证失败"""
module = FileTransferModule()
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(b"Test content")
temp_path = f.name
try:
result = module.verify_file_integrity(temp_path, "wrong_hash")
assert result is False
finally:
os.unlink(temp_path)
class TestFileChunking:
"""文件分块测试"""
def test_get_total_chunks_small_file(self):
"""测试小文件的块数计算"""
module = FileTransferModule()
# 小于一个块
assert module._get_total_chunks(1000) == 1
# 正好一个块
assert module._get_total_chunks(64 * 1024) == 1
def test_get_total_chunks_large_file(self):
"""测试大文件的块数计算"""
module = FileTransferModule()
# 两个块
assert module._get_total_chunks(64 * 1024 + 1) == 2
# 多个块
assert module._get_total_chunks(256 * 1024) == 4
def test_read_chunk(self):
"""测试读取文件块"""
module = FileTransferModule()
# 创建测试文件
test_data = b"A" * 100 + b"B" * 100
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(test_data)
temp_path = f.name
try:
chunk = module._read_chunk(temp_path, 0)
assert chunk == test_data
finally:
os.unlink(temp_path)
def test_write_chunk(self):
"""测试写入文件块"""
module = FileTransferModule()
with tempfile.NamedTemporaryFile(delete=False) as f:
temp_path = f.name
try:
# 写入第一个块
module._write_chunk(temp_path, 0, b"First chunk")
with open(temp_path, 'rb') as f:
content = f.read()
assert content == b"First chunk"
finally:
os.unlink(temp_path)
def test_split_file_to_chunks(self):
"""测试文件分块"""
module = FileTransferModule()
# 创建测试文件(大于一个块)
test_data = b"X" * (64 * 1024 + 100)
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(test_data)
temp_path = f.name
try:
chunks = module.split_file_to_chunks(temp_path)
assert len(chunks) == 2
assert chunks[0].chunk_index == 0
assert chunks[1].chunk_index == 1
assert chunks[0].total_chunks == 2
assert len(chunks[0].data) == 64 * 1024
assert len(chunks[1].data) == 100
finally:
os.unlink(temp_path)
def test_split_nonexistent_file(self):
"""测试分块不存在的文件"""
module = FileTransferModule()
with pytest.raises(FileNotFoundError):
module.split_file_to_chunks("/nonexistent/file.txt")
class TestTransferState:
"""传输状态测试"""
def test_transfer_state_creation(self):
"""测试传输状态创建"""
state = TransferState(
file_id="test-id",
file_path="/path/to/file.txt",
file_name="file.txt",
file_size=1000,
file_hash="abc123",
total_chunks=2
)
assert state.file_id == "test-id"
assert state.file_name == "file.txt"
assert state.status == TransferStatus.PENDING
assert len(state.completed_chunks) == 0
def test_transfer_state_progress_percent(self):
"""测试进度百分比计算"""
state = TransferState(
file_id="test-id",
file_path="/path/to/file.txt",
file_name="file.txt",
file_size=1000,
file_hash="abc123",
total_chunks=4,
completed_chunks=[0, 1]
)
assert state.progress_percent == 50.0
def test_transfer_state_to_dict_and_from_dict(self):
"""测试状态序列化和反序列化"""
state = TransferState(
file_id="test-id",
file_path="/path/to/file.txt",
file_name="file.txt",
file_size=1000,
file_hash="abc123",
total_chunks=2,
completed_chunks=[0],
status=TransferStatus.IN_PROGRESS
)
state_dict = state.to_dict()
restored = TransferState.from_dict(state_dict)
assert restored.file_id == state.file_id
assert restored.file_name == state.file_name
assert restored.status == state.status
assert restored.completed_chunks == state.completed_chunks
class TestTransferManagement:
"""传输管理测试"""
def setup_method(self):
"""每个测试前清理状态"""
# 清理可能存在的状态文件
import shutil
state_dir = Path("data/transfer_states")
if state_dir.exists():
shutil.rmtree(state_dir)
def test_cancel_transfer(self):
"""测试取消传输"""
module = FileTransferModule()
# 创建一个传输状态
state = TransferState(
file_id="test-id",
file_path="/path/to/file.txt",
file_name="file.txt",
file_size=1000,
file_hash="abc123",
total_chunks=2
)
module._active_transfers["test-id"] = state
module.cancel_transfer("test-id")
assert "test-id" in module._cancelled_transfers
assert state.status == TransferStatus.CANCELLED
def test_pause_transfer(self):
"""测试暂停传输"""
module = FileTransferModule()
state = TransferState(
file_id="test-id",
file_path="/path/to/file.txt",
file_name="file.txt",
file_size=1000,
file_hash="abc123",
total_chunks=2,
status=TransferStatus.IN_PROGRESS
)
module._active_transfers["test-id"] = state
result = module.pause_transfer("test-id")
assert result is True
assert state.status == TransferStatus.PAUSED
def test_pause_transfer_not_in_progress(self):
"""测试暂停非进行中的传输"""
module = FileTransferModule()
state = TransferState(
file_id="test-id",
file_path="/path/to/file.txt",
file_name="file.txt",
file_size=1000,
file_hash="abc123",
total_chunks=2,
status=TransferStatus.COMPLETED
)
module._active_transfers["test-id"] = state
result = module.pause_transfer("test-id")
assert result is False
def test_get_transfer_progress(self):
"""测试获取传输进度"""
module = FileTransferModule()
state = TransferState(
file_id="test-id",
file_path="/path/to/file.txt",
file_name="file.txt",
file_size=128 * 1024, # 2 chunks
file_hash="abc123",
total_chunks=2,
completed_chunks=[0]
)
module._active_transfers["test-id"] = state
progress = module.get_transfer_progress("test-id")
assert progress is not None
assert progress.file_id == "test-id"
assert progress.file_name == "file.txt"
assert progress.total_size == 128 * 1024
def test_get_transfer_progress_nonexistent(self):
"""测试获取不存在的传输进度"""
module = FileTransferModule()
progress = module.get_transfer_progress("nonexistent")
assert progress is None
def test_get_all_transfers(self):
"""测试获取所有传输"""
module = FileTransferModule()
state1 = TransferState(
file_id="id1",
file_path="/path/1.txt",
file_name="1.txt",
file_size=1000,
file_hash="hash1",
total_chunks=1
)
state2 = TransferState(
file_id="id2",
file_path="/path/2.txt",
file_name="2.txt",
file_size=2000,
file_hash="hash2",
total_chunks=1
)
module._active_transfers["id1"] = state1
module._active_transfers["id2"] = state2
transfers = module.get_all_transfers()
assert len(transfers) == 2
def test_get_pending_transfers(self):
"""测试获取待恢复的传输"""
module = FileTransferModule()
state1 = TransferState(
file_id="id1",
file_path="/path/1.txt",
file_name="1.txt",
file_size=1000,
file_hash="hash1",
total_chunks=1,
status=TransferStatus.PAUSED
)
state2 = TransferState(
file_id="id2",
file_path="/path/2.txt",
file_name="2.txt",
file_size=2000,
file_hash="hash2",
total_chunks=1,
status=TransferStatus.COMPLETED
)
module._active_transfers["id1"] = state1
module._active_transfers["id2"] = state2
pending = module.get_pending_transfers()
assert len(pending) == 1
assert pending[0].file_id == "id1"
class TestMessageHandling:
"""消息处理测试"""
def test_handle_file_request(self):
"""测试处理文件请求消息"""
module = FileTransferModule()
import json
payload = json.dumps({
"file_id": "test-file-id",
"file_name": "test.txt",
"file_size": 1000,
"file_hash": "abc123",
"total_chunks": 2
}).encode('utf-8')
message = Message(
msg_type=MessageType.FILE_REQUEST,
sender_id="sender",
receiver_id="receiver",
timestamp=1234567890.0,
payload=payload
)
file_id = module.handle_file_request(message)
assert file_id == "test-file-id"
assert "test-file-id" in module._active_transfers
assert "test-file-id" in module._receive_buffers
def test_handle_file_chunk(self):
"""测试处理文件块消息"""
module = FileTransferModule()
# 先创建传输状态
state = TransferState(
file_id="test-file-id",
file_path="",
file_name="test.txt",
file_size=1000,
file_hash="abc123",
total_chunks=2
)
module._active_transfers["test-file-id"] = state
module._receive_buffers["test-file-id"] = {}
import json
chunk_data = b"Test chunk data"
payload = json.dumps({
"file_id": "test-file-id",
"chunk_index": 0,
"total_chunks": 2,
"checksum": module.calculate_chunk_hash(chunk_data),
"data": chunk_data.hex()
}).encode('utf-8')
message = Message(
msg_type=MessageType.FILE_CHUNK,
sender_id="sender",
receiver_id="receiver",
timestamp=1234567890.0,
payload=payload
)
result = module.handle_file_chunk(message)
assert result is True
assert 0 in module._receive_buffers["test-file-id"]
assert module._receive_buffers["test-file-id"][0] == chunk_data
assert 0 in state.completed_chunks
def test_handle_file_chunk_checksum_mismatch(self):
"""测试处理校验和不匹配的文件块"""
module = FileTransferModule()
state = TransferState(
file_id="test-file-id",
file_path="",
file_name="test.txt",
file_size=1000,
file_hash="abc123",
total_chunks=2
)
module._active_transfers["test-file-id"] = state
module._receive_buffers["test-file-id"] = {}
import json
chunk_data = b"Test chunk data"
payload = json.dumps({
"file_id": "test-file-id",
"chunk_index": 0,
"total_chunks": 2,
"checksum": "wrong_checksum",
"data": chunk_data.hex()
}).encode('utf-8')
message = Message(
msg_type=MessageType.FILE_CHUNK,
sender_id="sender",
receiver_id="receiver",
timestamp=1234567890.0,
payload=payload
)
result = module.handle_file_chunk(message)
assert result is False
class TestCleanup:
"""清理功能测试"""
def test_clear_all_transfers(self):
"""测试清除所有传输"""
module = FileTransferModule()
state = TransferState(
file_id="test-id",
file_path="/path/to/file.txt",
file_name="file.txt",
file_size=1000,
file_hash="abc123",
total_chunks=2
)
module._active_transfers["test-id"] = state
module._receive_buffers["test-id"] = {}
module._cancelled_transfers.add("test-id")
module.clear_all_transfers()
assert len(module._active_transfers) == 0
assert len(module._receive_buffers) == 0
assert len(module._cancelled_transfers) == 0