# P2P Network Communication - File Transfer Module Tests """ 文件传输模块测试 测试文件分块、传输、断点续传和完整性校验功能 """ import asyncio import os import tempfile import pytest from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch from client.file_transfer import ( FileTransferModule, FileTransferError, FileNotFoundError, FileIntegrityError, TransferCancelledError, TransferState, ) from shared.models import ( Message, MessageType, FileChunk, TransferProgress, TransferStatus ) from config import ClientConfig class TestFileTransferModuleInit: """文件传输模块初始化测试""" def test_init_with_default_config(self): """测试使用默认配置初始化""" module = FileTransferModule() assert module.config is not None assert module.CHUNK_SIZE == 64 * 1024 assert len(module._active_transfers) == 0 def test_init_with_custom_config(self): """测试使用自定义配置初始化""" config = ClientConfig(chunk_size=32 * 1024) module = FileTransferModule(config=config) assert module.config.chunk_size == 32 * 1024 def test_set_send_message_func(self): """测试设置发送消息函数""" module = FileTransferModule() mock_func = AsyncMock() module.set_send_message_func(mock_func) assert module._send_message == mock_func class TestFileHashCalculation: """文件哈希计算测试""" def test_calculate_file_hash_sha256(self): """测试SHA256哈希计算""" module = FileTransferModule() # 创建临时文件 with tempfile.NamedTemporaryFile(delete=False) as f: f.write(b"Hello, World!") temp_path = f.name try: hash_value = module.calculate_file_hash(temp_path, "sha256") # SHA256 of "Hello, World!" expected = "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f" assert hash_value == expected finally: os.unlink(temp_path) def test_calculate_file_hash_md5(self): """测试MD5哈希计算""" module = FileTransferModule() with tempfile.NamedTemporaryFile(delete=False) as f: f.write(b"Hello, World!") temp_path = f.name try: hash_value = module.calculate_file_hash(temp_path, "md5") # MD5 of "Hello, World!" expected = "65a8e27d8879283831b664bd8b7f0ad4" assert hash_value == expected finally: os.unlink(temp_path) def test_calculate_file_hash_nonexistent_file(self): """测试计算不存在文件的哈希""" module = FileTransferModule() with pytest.raises(FileNotFoundError): module.calculate_file_hash("/nonexistent/file.txt") def test_calculate_chunk_hash(self): """测试数据块哈希计算""" module = FileTransferModule() data = b"Test chunk data" hash_value = module.calculate_chunk_hash(data) import hashlib expected = hashlib.md5(data).hexdigest() assert hash_value == expected def test_verify_file_integrity_success(self): """测试文件完整性验证成功""" module = FileTransferModule() with tempfile.NamedTemporaryFile(delete=False) as f: f.write(b"Test content") temp_path = f.name try: expected_hash = module.calculate_file_hash(temp_path) result = module.verify_file_integrity(temp_path, expected_hash) assert result is True finally: os.unlink(temp_path) def test_verify_file_integrity_failure(self): """测试文件完整性验证失败""" module = FileTransferModule() with tempfile.NamedTemporaryFile(delete=False) as f: f.write(b"Test content") temp_path = f.name try: result = module.verify_file_integrity(temp_path, "wrong_hash") assert result is False finally: os.unlink(temp_path) class TestFileChunking: """文件分块测试""" def test_get_total_chunks_small_file(self): """测试小文件的块数计算""" module = FileTransferModule() # 小于一个块 assert module._get_total_chunks(1000) == 1 # 正好一个块 assert module._get_total_chunks(64 * 1024) == 1 def test_get_total_chunks_large_file(self): """测试大文件的块数计算""" module = FileTransferModule() # 两个块 assert module._get_total_chunks(64 * 1024 + 1) == 2 # 多个块 assert module._get_total_chunks(256 * 1024) == 4 def test_read_chunk(self): """测试读取文件块""" module = FileTransferModule() # 创建测试文件 test_data = b"A" * 100 + b"B" * 100 with tempfile.NamedTemporaryFile(delete=False) as f: f.write(test_data) temp_path = f.name try: chunk = module._read_chunk(temp_path, 0) assert chunk == test_data finally: os.unlink(temp_path) def test_write_chunk(self): """测试写入文件块""" module = FileTransferModule() with tempfile.NamedTemporaryFile(delete=False) as f: temp_path = f.name try: # 写入第一个块 module._write_chunk(temp_path, 0, b"First chunk") with open(temp_path, 'rb') as f: content = f.read() assert content == b"First chunk" finally: os.unlink(temp_path) def test_split_file_to_chunks(self): """测试文件分块""" module = FileTransferModule() # 创建测试文件(大于一个块) test_data = b"X" * (64 * 1024 + 100) with tempfile.NamedTemporaryFile(delete=False) as f: f.write(test_data) temp_path = f.name try: chunks = module.split_file_to_chunks(temp_path) assert len(chunks) == 2 assert chunks[0].chunk_index == 0 assert chunks[1].chunk_index == 1 assert chunks[0].total_chunks == 2 assert len(chunks[0].data) == 64 * 1024 assert len(chunks[1].data) == 100 finally: os.unlink(temp_path) def test_split_nonexistent_file(self): """测试分块不存在的文件""" module = FileTransferModule() with pytest.raises(FileNotFoundError): module.split_file_to_chunks("/nonexistent/file.txt") class TestTransferState: """传输状态测试""" def test_transfer_state_creation(self): """测试传输状态创建""" state = TransferState( file_id="test-id", file_path="/path/to/file.txt", file_name="file.txt", file_size=1000, file_hash="abc123", total_chunks=2 ) assert state.file_id == "test-id" assert state.file_name == "file.txt" assert state.status == TransferStatus.PENDING assert len(state.completed_chunks) == 0 def test_transfer_state_progress_percent(self): """测试进度百分比计算""" state = TransferState( file_id="test-id", file_path="/path/to/file.txt", file_name="file.txt", file_size=1000, file_hash="abc123", total_chunks=4, completed_chunks=[0, 1] ) assert state.progress_percent == 50.0 def test_transfer_state_to_dict_and_from_dict(self): """测试状态序列化和反序列化""" state = TransferState( file_id="test-id", file_path="/path/to/file.txt", file_name="file.txt", file_size=1000, file_hash="abc123", total_chunks=2, completed_chunks=[0], status=TransferStatus.IN_PROGRESS ) state_dict = state.to_dict() restored = TransferState.from_dict(state_dict) assert restored.file_id == state.file_id assert restored.file_name == state.file_name assert restored.status == state.status assert restored.completed_chunks == state.completed_chunks class TestTransferManagement: """传输管理测试""" def setup_method(self): """每个测试前清理状态""" # 清理可能存在的状态文件 import shutil state_dir = Path("data/transfer_states") if state_dir.exists(): shutil.rmtree(state_dir) def test_cancel_transfer(self): """测试取消传输""" module = FileTransferModule() # 创建一个传输状态 state = TransferState( file_id="test-id", file_path="/path/to/file.txt", file_name="file.txt", file_size=1000, file_hash="abc123", total_chunks=2 ) module._active_transfers["test-id"] = state module.cancel_transfer("test-id") assert "test-id" in module._cancelled_transfers assert state.status == TransferStatus.CANCELLED def test_pause_transfer(self): """测试暂停传输""" module = FileTransferModule() state = TransferState( file_id="test-id", file_path="/path/to/file.txt", file_name="file.txt", file_size=1000, file_hash="abc123", total_chunks=2, status=TransferStatus.IN_PROGRESS ) module._active_transfers["test-id"] = state result = module.pause_transfer("test-id") assert result is True assert state.status == TransferStatus.PAUSED def test_pause_transfer_not_in_progress(self): """测试暂停非进行中的传输""" module = FileTransferModule() state = TransferState( file_id="test-id", file_path="/path/to/file.txt", file_name="file.txt", file_size=1000, file_hash="abc123", total_chunks=2, status=TransferStatus.COMPLETED ) module._active_transfers["test-id"] = state result = module.pause_transfer("test-id") assert result is False def test_get_transfer_progress(self): """测试获取传输进度""" module = FileTransferModule() state = TransferState( file_id="test-id", file_path="/path/to/file.txt", file_name="file.txt", file_size=128 * 1024, # 2 chunks file_hash="abc123", total_chunks=2, completed_chunks=[0] ) module._active_transfers["test-id"] = state progress = module.get_transfer_progress("test-id") assert progress is not None assert progress.file_id == "test-id" assert progress.file_name == "file.txt" assert progress.total_size == 128 * 1024 def test_get_transfer_progress_nonexistent(self): """测试获取不存在的传输进度""" module = FileTransferModule() progress = module.get_transfer_progress("nonexistent") assert progress is None def test_get_all_transfers(self): """测试获取所有传输""" module = FileTransferModule() state1 = TransferState( file_id="id1", file_path="/path/1.txt", file_name="1.txt", file_size=1000, file_hash="hash1", total_chunks=1 ) state2 = TransferState( file_id="id2", file_path="/path/2.txt", file_name="2.txt", file_size=2000, file_hash="hash2", total_chunks=1 ) module._active_transfers["id1"] = state1 module._active_transfers["id2"] = state2 transfers = module.get_all_transfers() assert len(transfers) == 2 def test_get_pending_transfers(self): """测试获取待恢复的传输""" module = FileTransferModule() state1 = TransferState( file_id="id1", file_path="/path/1.txt", file_name="1.txt", file_size=1000, file_hash="hash1", total_chunks=1, status=TransferStatus.PAUSED ) state2 = TransferState( file_id="id2", file_path="/path/2.txt", file_name="2.txt", file_size=2000, file_hash="hash2", total_chunks=1, status=TransferStatus.COMPLETED ) module._active_transfers["id1"] = state1 module._active_transfers["id2"] = state2 pending = module.get_pending_transfers() assert len(pending) == 1 assert pending[0].file_id == "id1" class TestMessageHandling: """消息处理测试""" def test_handle_file_request(self): """测试处理文件请求消息""" module = FileTransferModule() import json payload = json.dumps({ "file_id": "test-file-id", "file_name": "test.txt", "file_size": 1000, "file_hash": "abc123", "total_chunks": 2 }).encode('utf-8') message = Message( msg_type=MessageType.FILE_REQUEST, sender_id="sender", receiver_id="receiver", timestamp=1234567890.0, payload=payload ) file_id = module.handle_file_request(message) assert file_id == "test-file-id" assert "test-file-id" in module._active_transfers assert "test-file-id" in module._receive_buffers def test_handle_file_chunk(self): """测试处理文件块消息""" module = FileTransferModule() # 先创建传输状态 state = TransferState( file_id="test-file-id", file_path="", file_name="test.txt", file_size=1000, file_hash="abc123", total_chunks=2 ) module._active_transfers["test-file-id"] = state module._receive_buffers["test-file-id"] = {} import json chunk_data = b"Test chunk data" payload = json.dumps({ "file_id": "test-file-id", "chunk_index": 0, "total_chunks": 2, "checksum": module.calculate_chunk_hash(chunk_data), "data": chunk_data.hex() }).encode('utf-8') message = Message( msg_type=MessageType.FILE_CHUNK, sender_id="sender", receiver_id="receiver", timestamp=1234567890.0, payload=payload ) result = module.handle_file_chunk(message) assert result is True assert 0 in module._receive_buffers["test-file-id"] assert module._receive_buffers["test-file-id"][0] == chunk_data assert 0 in state.completed_chunks def test_handle_file_chunk_checksum_mismatch(self): """测试处理校验和不匹配的文件块""" module = FileTransferModule() state = TransferState( file_id="test-file-id", file_path="", file_name="test.txt", file_size=1000, file_hash="abc123", total_chunks=2 ) module._active_transfers["test-file-id"] = state module._receive_buffers["test-file-id"] = {} import json chunk_data = b"Test chunk data" payload = json.dumps({ "file_id": "test-file-id", "chunk_index": 0, "total_chunks": 2, "checksum": "wrong_checksum", "data": chunk_data.hex() }).encode('utf-8') message = Message( msg_type=MessageType.FILE_CHUNK, sender_id="sender", receiver_id="receiver", timestamp=1234567890.0, payload=payload ) result = module.handle_file_chunk(message) assert result is False class TestCleanup: """清理功能测试""" def test_clear_all_transfers(self): """测试清除所有传输""" module = FileTransferModule() state = TransferState( file_id="test-id", file_path="/path/to/file.txt", file_name="file.txt", file_size=1000, file_hash="abc123", total_chunks=2 ) module._active_transfers["test-id"] = state module._receive_buffers["test-id"] = {} module._cancelled_transfers.add("test-id") module.clear_all_transfers() assert len(module._active_transfers) == 0 assert len(module._receive_buffers) == 0 assert len(module._cancelled_transfers) == 0