You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1290 lines
43 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# P2P Network Communication - File Transfer Module
"""
文件传输模块
负责文件的分块传输、断点续传和完整性校验
需求: 4.2, 4.4, 4.5
"""
import asyncio
import hashlib
import json
import logging
import os
import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Callable, Dict, List, Optional, Any
from shared.models import (
Message, MessageType, FileChunk, TransferProgress,
FileTransferRecord, TransferStatus
)
from shared.message_handler import MessageHandler
from config import ClientConfig
# 设置日志
logger = logging.getLogger(__name__)
class FileTransferError(Exception):
"""文件传输错误"""
pass
class FileNotFoundError(FileTransferError):
"""文件不存在错误"""
pass
class FileIntegrityError(FileTransferError):
"""文件完整性校验错误"""
pass
class TransferCancelledError(FileTransferError):
"""传输取消错误"""
pass
# 进度回调类型
ProgressCallback = Callable[[TransferProgress], None]
@dataclass
class TransferState:
"""传输状态(用于断点续传)"""
file_id: str
file_path: str
file_name: str
file_size: int
file_hash: str
total_chunks: int
completed_chunks: List[int] = field(default_factory=list)
status: TransferStatus = TransferStatus.PENDING
sender_id: str = ""
receiver_id: str = ""
start_time: datetime = field(default_factory=datetime.now)
last_update: datetime = field(default_factory=datetime.now)
save_path: str = "" # 接收方保存路径
def to_dict(self) -> dict:
"""转换为字典(用于持久化)"""
return {
"file_id": self.file_id,
"file_path": self.file_path,
"file_name": self.file_name,
"file_size": self.file_size,
"file_hash": self.file_hash,
"total_chunks": self.total_chunks,
"completed_chunks": self.completed_chunks,
"status": self.status.value,
"sender_id": self.sender_id,
"receiver_id": self.receiver_id,
"start_time": self.start_time.isoformat(),
"last_update": self.last_update.isoformat(),
"save_path": self.save_path,
}
@classmethod
def from_dict(cls, data: dict) -> "TransferState":
"""从字典创建TransferState对象"""
return cls(
file_id=data["file_id"],
file_path=data["file_path"],
file_name=data["file_name"],
file_size=data["file_size"],
file_hash=data["file_hash"],
total_chunks=data["total_chunks"],
completed_chunks=data.get("completed_chunks", []),
status=TransferStatus(data.get("status", "pending")),
sender_id=data.get("sender_id", ""),
receiver_id=data.get("receiver_id", ""),
start_time=datetime.fromisoformat(data["start_time"]),
last_update=datetime.fromisoformat(data["last_update"]),
save_path=data.get("save_path", ""),
)
@property
def progress_percent(self) -> float:
"""获取进度百分比"""
if self.total_chunks == 0:
return 0.0
return (len(self.completed_chunks) / self.total_chunks) * 100
@property
def transferred_size(self) -> int:
"""获取已传输大小"""
chunk_size = FileTransferModule.CHUNK_SIZE
full_chunks = len(self.completed_chunks)
if full_chunks == 0:
return 0
# 最后一个块可能不是完整大小
if self.total_chunks in self.completed_chunks:
last_chunk_size = self.file_size % chunk_size
if last_chunk_size == 0:
last_chunk_size = chunk_size
return (full_chunks - 1) * chunk_size + last_chunk_size
return full_chunks * chunk_size
class FileTransferModule:
"""
文件传输模块
负责:
- 文件分块传输 (需求 4.2)
- 传输进度回调 (需求 4.3)
- 文件完整性校验 (需求 4.4)
- 断点续传 (需求 4.5)
"""
# 每个块的大小: 64KB
CHUNK_SIZE = 64 * 1024
# 传输状态持久化目录
STATE_DIR = "transfer_states"
def __init__(self, config: Optional[ClientConfig] = None,
send_message_func: Optional[Callable] = None):
"""
初始化文件传输模块
Args:
config: 客户端配置
send_message_func: 发送消息的函数由ConnectionManager提供
"""
self.config = config or ClientConfig()
self._send_message = send_message_func
# 传输状态管理
self._active_transfers: Dict[str, TransferState] = {}
self._cancelled_transfers: set = set()
# 接收缓冲区: file_id -> {chunk_index: data}
self._receive_buffers: Dict[str, Dict[int, bytes]] = {}
# 进度回调
self._progress_callbacks: Dict[str, ProgressCallback] = {}
# 文件接收完成回调
self._file_received_callbacks: List[Callable[[str, str, str], None]] = [] # (sender_id, file_name, file_path)
# 消息处理器
self._message_handler = MessageHandler()
# 确保状态目录存在
self._state_dir = Path(self.config.data_dir) / self.STATE_DIR
self._state_dir.mkdir(parents=True, exist_ok=True)
# 确保下载目录存在
self._downloads_dir = Path(self.config.downloads_dir)
self._downloads_dir.mkdir(parents=True, exist_ok=True)
# 加载未完成的传输状态
self._load_transfer_states()
logger.info("FileTransferModule initialized")
def set_send_message_func(self, func: Callable) -> None:
"""
设置发送消息函数
Args:
func: 发送消息的异步函数
"""
self._send_message = func
def add_file_received_callback(self, callback: Callable[[str, str, str], None]) -> None:
"""
添加文件接收完成回调
Args:
callback: 回调函数 (sender_id, file_name, file_path)
"""
self._file_received_callbacks.append(callback)
def remove_file_received_callback(self, callback: Callable[[str, str, str], None]) -> None:
"""
移除文件接收完成回调
"""
if callback in self._file_received_callbacks:
self._file_received_callbacks.remove(callback)
# ==================== 文件哈希计算 (需求 4.4) ====================
def calculate_file_hash(self, file_path: str, algorithm: str = "sha256") -> str:
"""
计算文件哈希值
实现 MD5/SHA256 哈希计算 (需求 4.4)
Args:
file_path: 文件路径
algorithm: 哈希算法 ("md5""sha256")
Returns:
文件哈希值(十六进制字符串)
Raises:
FileNotFoundError: 文件不存在
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
if algorithm == "md5":
hasher = hashlib.md5()
else:
hasher = hashlib.sha256()
with open(file_path, 'rb') as f:
while True:
data = f.read(self.CHUNK_SIZE)
if not data:
break
hasher.update(data)
return hasher.hexdigest()
def calculate_chunk_hash(self, data: bytes) -> str:
"""
计算数据块哈希值
Args:
data: 数据块
Returns:
MD5哈希值
"""
return hashlib.md5(data).hexdigest()
def verify_file_integrity(self, file_path: str, expected_hash: str,
algorithm: str = "sha256") -> bool:
"""
验证文件完整性
实现传输完成后的校验逻辑 (需求 4.4)
Args:
file_path: 文件路径
expected_hash: 期望的哈希值
algorithm: 哈希算法
Returns:
校验通过返回True否则返回False
"""
try:
actual_hash = self.calculate_file_hash(file_path, algorithm)
return actual_hash == expected_hash
except Exception as e:
logger.error(f"File integrity verification failed: {e}")
return False
# ==================== 文件分块 (需求 4.2) ====================
def _get_total_chunks(self, file_size: int) -> int:
"""
计算文件总块数
Args:
file_size: 文件大小
Returns:
总块数
"""
return (file_size + self.CHUNK_SIZE - 1) // self.CHUNK_SIZE
def _read_chunk(self, file_path: str, chunk_index: int) -> bytes:
"""
读取指定的文件块
实现文件分块逻辑64KB per chunk(需求 4.2)
Args:
file_path: 文件路径
chunk_index: 块索引从0开始
Returns:
数据块
"""
offset = chunk_index * self.CHUNK_SIZE
with open(file_path, 'rb') as f:
f.seek(offset)
return f.read(self.CHUNK_SIZE)
def _write_chunk(self, file_path: str, chunk_index: int, data: bytes) -> None:
"""
写入指定的文件块
Args:
file_path: 文件路径
chunk_index: 块索引
data: 数据块
"""
offset = chunk_index * self.CHUNK_SIZE
# 确保文件存在
if not os.path.exists(file_path):
# 创建空文件
with open(file_path, 'wb') as f:
pass
with open(file_path, 'r+b') as f:
f.seek(offset)
f.write(data)
def split_file_to_chunks(self, file_path: str) -> List[FileChunk]:
"""
将文件分割成多个块
实现文件分块逻辑64KB per chunk(需求 4.2)
Args:
file_path: 文件路径
Returns:
文件块列表
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
file_size = os.path.getsize(file_path)
total_chunks = self._get_total_chunks(file_size)
file_id = str(uuid.uuid4())
chunks = []
for i in range(total_chunks):
data = self._read_chunk(file_path, i)
chunk = FileChunk(
file_id=file_id,
chunk_index=i,
total_chunks=total_chunks,
data=data
)
chunks.append(chunk)
return chunks
# ==================== 进度计算和回调 ====================
def _calculate_progress(self, state: TransferState,
start_time: float) -> TransferProgress:
"""
计算传输进度
实现传输进度回调 (需求 4.3)
Args:
state: 传输状态
start_time: 开始时间
Returns:
传输进度信息
"""
elapsed = time.time() - start_time
transferred = state.transferred_size
# 计算速度
speed = transferred / elapsed if elapsed > 0 else 0
# 计算预计剩余时间
remaining = state.file_size - transferred
eta = remaining / speed if speed > 0 else 0
return TransferProgress(
file_id=state.file_id,
file_name=state.file_name,
total_size=state.file_size,
transferred_size=transferred,
speed=speed,
eta=eta
)
def _notify_progress(self, file_id: str, progress: TransferProgress) -> None:
"""
通知进度回调
Args:
file_id: 文件ID
progress: 进度信息
"""
if file_id in self._progress_callbacks:
try:
self._progress_callbacks[file_id](progress)
except Exception as e:
logger.error(f"Progress callback error: {e}")
# ==================== 发送文件 (需求 4.2) ====================
async def send_file(self, peer_id: str, file_path: str,
progress_callback: Optional[ProgressCallback] = None) -> bool:
"""
发送文件到指定对等端
实现 send_file() 发送文件 (需求 4.2)
将文件分割成多个 Chunk 进行传输 (需求 4.2)
Args:
peer_id: 目标对等端ID
file_path: 文件路径
progress_callback: 进度回调函数
Returns:
发送成功返回True否则返回False
Raises:
FileNotFoundError: 文件不存在
FileTransferError: 传输错误
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
if not self._send_message:
raise FileTransferError("Send message function not set")
# 获取文件信息
file_name = os.path.basename(file_path)
file_size = os.path.getsize(file_path)
file_hash = self.calculate_file_hash(file_path)
total_chunks = self._get_total_chunks(file_size)
file_id = str(uuid.uuid4())
logger.info(f"Starting file transfer: {file_name} ({file_size} bytes, "
f"{total_chunks} chunks) to {peer_id}")
# 创建传输状态
state = TransferState(
file_id=file_id,
file_path=file_path,
file_name=file_name,
file_size=file_size,
file_hash=file_hash,
total_chunks=total_chunks,
status=TransferStatus.IN_PROGRESS,
receiver_id=peer_id
)
self._active_transfers[file_id] = state
if progress_callback:
self._progress_callbacks[file_id] = progress_callback
# 保存传输状态(用于断点续传)
self._save_transfer_state(state)
start_time = time.time()
try:
# 发送文件请求消息
file_request_payload = json.dumps({
"file_id": file_id,
"file_name": file_name,
"file_size": file_size,
"file_hash": file_hash,
"total_chunks": total_chunks
}).encode('utf-8')
request_msg = Message(
msg_type=MessageType.FILE_REQUEST,
sender_id="", # 由ConnectionManager填充
receiver_id=peer_id,
timestamp=time.time(),
payload=file_request_payload
)
await self._send_message(peer_id, request_msg)
# 发送所有数据块
for chunk_index in range(total_chunks):
# 检查是否取消
if file_id in self._cancelled_transfers:
raise TransferCancelledError("Transfer cancelled")
# 读取数据块
chunk_data = self._read_chunk(file_path, chunk_index)
chunk_hash = self.calculate_chunk_hash(chunk_data)
# 创建文件块消息
chunk_payload = json.dumps({
"file_id": file_id,
"chunk_index": chunk_index,
"total_chunks": total_chunks,
"checksum": chunk_hash,
"data": chunk_data.hex()
}).encode('utf-8')
chunk_msg = Message(
msg_type=MessageType.FILE_CHUNK,
sender_id="",
receiver_id=peer_id,
timestamp=time.time(),
payload=chunk_payload
)
await self._send_message(peer_id, chunk_msg)
# 更新状态
state.completed_chunks.append(chunk_index)
state.last_update = datetime.now()
# 通知进度
progress = self._calculate_progress(state, start_time)
self._notify_progress(file_id, progress)
# 定期保存状态
if chunk_index % 10 == 0:
self._save_transfer_state(state)
# 小延迟避免网络拥塞
await asyncio.sleep(0.001)
# 发送完成消息
complete_payload = json.dumps({
"file_id": file_id,
"file_name": file_name,
"file_hash": file_hash,
"total_chunks": total_chunks
}).encode('utf-8')
complete_msg = Message(
msg_type=MessageType.FILE_COMPLETE,
sender_id="",
receiver_id=peer_id,
timestamp=time.time(),
payload=complete_payload
)
await self._send_message(peer_id, complete_msg)
# 更新状态为完成
state.status = TransferStatus.COMPLETED
self._save_transfer_state(state)
# 最终进度通知
progress = self._calculate_progress(state, start_time)
self._notify_progress(file_id, progress)
logger.info(f"File transfer completed: {file_name}")
return True
except TransferCancelledError:
state.status = TransferStatus.CANCELLED
self._save_transfer_state(state)
logger.info(f"File transfer cancelled: {file_name}")
return False
except Exception as e:
state.status = TransferStatus.FAILED
self._save_transfer_state(state)
logger.error(f"File transfer failed: {e}")
raise FileTransferError(f"Transfer failed: {e}")
finally:
# 清理
if file_id in self._progress_callbacks:
del self._progress_callbacks[file_id]
# ==================== 接收文件 (需求 4.2) ====================
async def receive_file(self, file_id: str, save_path: str,
progress_callback: Optional[ProgressCallback] = None) -> bool:
"""
接收文件并保存
实现 receive_file() 接收文件 (需求 4.2)
Args:
file_id: 文件ID
save_path: 保存路径
progress_callback: 进度回调函数
Returns:
接收成功返回True否则返回False
"""
if file_id not in self._active_transfers:
logger.error(f"Unknown file transfer: {file_id}")
return False
state = self._active_transfers[file_id]
state.save_path = save_path
state.status = TransferStatus.IN_PROGRESS
if progress_callback:
self._progress_callbacks[file_id] = progress_callback
# 等待所有块接收完成
# 实际的块接收在 handle_file_chunk 中处理
logger.info(f"Receiving file: {state.file_name} to {save_path}")
return True
def handle_file_request(self, message: Message) -> Optional[str]:
"""
处理文件请求消息
Args:
message: 文件请求消息
Returns:
文件ID如果处理失败返回None
"""
try:
payload = json.loads(message.payload.decode('utf-8'))
file_id = payload["file_id"]
file_name = payload["file_name"]
file_size = payload["file_size"]
file_hash = payload["file_hash"]
total_chunks = payload["total_chunks"]
# 创建接收状态
state = TransferState(
file_id=file_id,
file_path="", # 接收方不知道原始路径
file_name=file_name,
file_size=file_size,
file_hash=file_hash,
total_chunks=total_chunks,
status=TransferStatus.PENDING,
sender_id=message.sender_id
)
self._active_transfers[file_id] = state
self._receive_buffers[file_id] = {}
logger.info(f"File request received: {file_name} ({file_size} bytes)")
return file_id
except Exception as e:
logger.error(f"Failed to handle file request: {e}")
return None
def handle_file_chunk(self, message: Message) -> bool:
"""
处理文件块消息
Args:
message: 文件块消息
Returns:
处理成功返回True否则返回False
"""
try:
payload = json.loads(message.payload.decode('utf-8'))
file_id = payload["file_id"]
chunk_index = payload["chunk_index"]
total_chunks = payload["total_chunks"]
checksum = payload["checksum"]
data = bytes.fromhex(payload["data"])
# 验证块校验和
if self.calculate_chunk_hash(data) != checksum:
logger.error(f"Chunk checksum mismatch: {file_id}[{chunk_index}]")
return False
# 存储块数据
if file_id not in self._receive_buffers:
self._receive_buffers[file_id] = {}
self._receive_buffers[file_id][chunk_index] = data
# 更新状态
if file_id in self._active_transfers:
state = self._active_transfers[file_id]
if chunk_index not in state.completed_chunks:
state.completed_chunks.append(chunk_index)
state.last_update = datetime.now()
# 通知进度
if file_id in self._progress_callbacks:
progress = TransferProgress(
file_id=file_id,
file_name=state.file_name,
total_size=state.file_size,
transferred_size=len(state.completed_chunks) * self.CHUNK_SIZE,
speed=0, # 简化处理
eta=0
)
self._notify_progress(file_id, progress)
logger.debug(f"Received chunk {chunk_index + 1}/{total_chunks} for {file_id}")
return True
except Exception as e:
logger.error(f"Failed to handle file chunk: {e}")
return False
def handle_file_complete(self, message: Message) -> bool:
"""
处理文件完成消息
实现传输完成后的校验逻辑 (需求 4.4)
Args:
message: 文件完成消息
Returns:
处理成功返回True否则返回False
"""
try:
payload = json.loads(message.payload.decode('utf-8'))
file_id = payload["file_id"]
file_hash = payload["file_hash"]
if file_id not in self._active_transfers:
logger.error(f"Unknown file transfer: {file_id}")
return False
state = self._active_transfers[file_id]
# 检查是否收到所有块
if len(state.completed_chunks) != state.total_chunks:
logger.error(f"Missing chunks: received {len(state.completed_chunks)}, "
f"expected {state.total_chunks}")
state.status = TransferStatus.FAILED
return False
# 如果没有设置保存路径,自动保存到 downloads 目录
if not state.save_path:
state.save_path = str(self._downloads_dir / state.file_name)
# 如果文件已存在,添加序号
base_path = state.save_path
counter = 1
while os.path.exists(state.save_path):
name, ext = os.path.splitext(base_path)
state.save_path = f"{name}_{counter}{ext}"
counter += 1
# 组装文件
success = self._assemble_file(file_id, state.save_path)
if success:
# 验证文件完整性
if self.verify_file_integrity(state.save_path, file_hash):
state.status = TransferStatus.COMPLETED
logger.info(f"File received and verified: {state.file_name} -> {state.save_path}")
# 通知回调
logger.info(f"Notifying {len(self._file_received_callbacks)} file received callbacks")
for callback in self._file_received_callbacks:
try:
logger.info(f"Calling callback with: {state.sender_id}, {state.file_name}, {state.save_path}")
callback(state.sender_id, state.file_name, state.save_path)
except Exception as e:
logger.error(f"File received callback error: {e}")
else:
state.status = TransferStatus.FAILED
logger.error(f"File integrity check failed: {state.file_name}")
return False
else:
state.status = TransferStatus.FAILED
return False
# 清理缓冲区
if file_id in self._receive_buffers:
del self._receive_buffers[file_id]
# 保存最终状态
self._save_transfer_state(state)
return True
except Exception as e:
logger.error(f"Failed to handle file complete: {e}")
return False
def _assemble_file(self, file_id: str, save_path: str) -> bool:
"""
组装接收到的文件块
Args:
file_id: 文件ID
save_path: 保存路径
Returns:
组装成功返回True否则返回False
"""
if file_id not in self._receive_buffers:
return False
if file_id not in self._active_transfers:
return False
state = self._active_transfers[file_id]
buffer = self._receive_buffers[file_id]
try:
# 确保目录存在
os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
# 按顺序写入所有块
with open(save_path, 'wb') as f:
for i in range(state.total_chunks):
if i not in buffer:
logger.error(f"Missing chunk {i} for file {file_id}")
return False
f.write(buffer[i])
logger.info(f"File assembled: {save_path}")
return True
except Exception as e:
logger.error(f"Failed to assemble file: {e}")
return False
# ==================== 断点续传 (需求 4.5) ====================
def _save_transfer_state(self, state: TransferState) -> None:
"""
保存传输状态到文件
实现传输状态持久化 (需求 4.5)
Args:
state: 传输状态
"""
try:
state_file = self._state_dir / f"{state.file_id}.json"
with open(state_file, 'w', encoding='utf-8') as f:
json.dump(state.to_dict(), f, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"Failed to save transfer state: {e}")
def _load_transfer_states(self) -> None:
"""
加载所有未完成的传输状态
实现传输状态持久化 (需求 4.5)
"""
try:
for state_file in self._state_dir.glob("*.json"):
try:
with open(state_file, 'r', encoding='utf-8') as f:
data = json.load(f)
state = TransferState.from_dict(data)
# 只加载未完成的传输
if state.status in [TransferStatus.PENDING,
TransferStatus.IN_PROGRESS,
TransferStatus.PAUSED]:
self._active_transfers[state.file_id] = state
logger.info(f"Loaded transfer state: {state.file_name} "
f"({state.progress_percent:.1f}%)")
except Exception as e:
logger.error(f"Failed to load state file {state_file}: {e}")
except Exception as e:
logger.error(f"Failed to load transfer states: {e}")
def _delete_transfer_state(self, file_id: str) -> None:
"""
删除传输状态文件
Args:
file_id: 文件ID
"""
try:
state_file = self._state_dir / f"{file_id}.json"
if state_file.exists():
state_file.unlink()
except Exception as e:
logger.error(f"Failed to delete transfer state: {e}")
async def resume_transfer(self, file_id: str) -> bool:
"""
恢复中断的传输
实现 resume_transfer() 恢复传输 (需求 4.5)
支持断点续传功能 (需求 4.5)
Args:
file_id: 文件ID
Returns:
恢复成功返回True否则返回False
"""
if file_id not in self._active_transfers:
logger.error(f"Transfer not found: {file_id}")
return False
state = self._active_transfers[file_id]
if state.status == TransferStatus.COMPLETED:
logger.info(f"Transfer already completed: {file_id}")
return True
if state.status == TransferStatus.CANCELLED:
logger.error(f"Transfer was cancelled: {file_id}")
return False
# 检查文件是否仍然存在(发送方)
if state.file_path and not os.path.exists(state.file_path):
logger.error(f"Source file no longer exists: {state.file_path}")
state.status = TransferStatus.FAILED
self._save_transfer_state(state)
return False
if not self._send_message:
raise FileTransferError("Send message function not set")
logger.info(f"Resuming transfer: {state.file_name} from chunk "
f"{len(state.completed_chunks)}/{state.total_chunks}")
# 更新状态
state.status = TransferStatus.IN_PROGRESS
state.last_update = datetime.now()
# 从取消列表中移除
self._cancelled_transfers.discard(file_id)
start_time = time.time()
try:
# 发送恢复请求(包含已完成的块列表)
resume_payload = json.dumps({
"file_id": file_id,
"file_name": state.file_name,
"file_size": state.file_size,
"file_hash": state.file_hash,
"total_chunks": state.total_chunks,
"completed_chunks": state.completed_chunks,
"resume": True
}).encode('utf-8')
resume_msg = Message(
msg_type=MessageType.FILE_REQUEST,
sender_id="",
receiver_id=state.receiver_id,
timestamp=time.time(),
payload=resume_payload
)
await self._send_message(state.receiver_id, resume_msg)
# 只发送未完成的块
completed_set = set(state.completed_chunks)
for chunk_index in range(state.total_chunks):
if chunk_index in completed_set:
continue
# 检查是否取消
if file_id in self._cancelled_transfers:
raise TransferCancelledError("Transfer cancelled")
# 读取并发送数据块
chunk_data = self._read_chunk(state.file_path, chunk_index)
chunk_hash = self.calculate_chunk_hash(chunk_data)
chunk_payload = json.dumps({
"file_id": file_id,
"chunk_index": chunk_index,
"total_chunks": state.total_chunks,
"checksum": chunk_hash,
"data": chunk_data.hex()
}).encode('utf-8')
chunk_msg = Message(
msg_type=MessageType.FILE_CHUNK,
sender_id="",
receiver_id=state.receiver_id,
timestamp=time.time(),
payload=chunk_payload
)
await self._send_message(state.receiver_id, chunk_msg)
# 更新状态
state.completed_chunks.append(chunk_index)
state.last_update = datetime.now()
# 通知进度
if file_id in self._progress_callbacks:
progress = self._calculate_progress(state, start_time)
self._notify_progress(file_id, progress)
# 定期保存状态
if chunk_index % 10 == 0:
self._save_transfer_state(state)
await asyncio.sleep(0.001)
# 发送完成消息
complete_payload = json.dumps({
"file_id": file_id,
"file_name": state.file_name,
"file_hash": state.file_hash,
"total_chunks": state.total_chunks
}).encode('utf-8')
complete_msg = Message(
msg_type=MessageType.FILE_COMPLETE,
sender_id="",
receiver_id=state.receiver_id,
timestamp=time.time(),
payload=complete_payload
)
await self._send_message(state.receiver_id, complete_msg)
state.status = TransferStatus.COMPLETED
self._save_transfer_state(state)
logger.info(f"Transfer resumed and completed: {state.file_name}")
return True
except TransferCancelledError:
state.status = TransferStatus.CANCELLED
self._save_transfer_state(state)
return False
except Exception as e:
state.status = TransferStatus.FAILED
self._save_transfer_state(state)
logger.error(f"Resume transfer failed: {e}")
return False
def cancel_transfer(self, file_id: str) -> None:
"""
取消传输
实现 cancel_transfer() 取消传输 (需求 4.5)
Args:
file_id: 文件ID
"""
self._cancelled_transfers.add(file_id)
if file_id in self._active_transfers:
state = self._active_transfers[file_id]
state.status = TransferStatus.CANCELLED
self._save_transfer_state(state)
logger.info(f"Transfer cancelled: {state.file_name}")
# 清理缓冲区
if file_id in self._receive_buffers:
del self._receive_buffers[file_id]
# 清理进度回调
if file_id in self._progress_callbacks:
del self._progress_callbacks[file_id]
def pause_transfer(self, file_id: str) -> bool:
"""
暂停传输
Args:
file_id: 文件ID
Returns:
暂停成功返回True否则返回False
"""
if file_id not in self._active_transfers:
return False
state = self._active_transfers[file_id]
if state.status != TransferStatus.IN_PROGRESS:
return False
state.status = TransferStatus.PAUSED
self._save_transfer_state(state)
logger.info(f"Transfer paused: {state.file_name}")
return True
# ==================== 传输状态查询 ====================
def get_transfer_progress(self, file_id: str) -> Optional[TransferProgress]:
"""
获取传输进度
Args:
file_id: 文件ID
Returns:
传输进度信息如果传输不存在返回None
"""
if file_id not in self._active_transfers:
return None
state = self._active_transfers[file_id]
return TransferProgress(
file_id=file_id,
file_name=state.file_name,
total_size=state.file_size,
transferred_size=state.transferred_size,
speed=0, # 需要实时计算
eta=0
)
def get_transfer_state(self, file_id: str) -> Optional[TransferState]:
"""
获取传输状态
Args:
file_id: 文件ID
Returns:
传输状态如果不存在返回None
"""
return self._active_transfers.get(file_id)
def get_all_transfers(self) -> List[TransferState]:
"""
获取所有传输状态
Returns:
传输状态列表
"""
return list(self._active_transfers.values())
def get_pending_transfers(self) -> List[TransferState]:
"""
获取所有待恢复的传输
Returns:
待恢复的传输状态列表
"""
return [
state for state in self._active_transfers.values()
if state.status in [TransferStatus.PENDING,
TransferStatus.IN_PROGRESS,
TransferStatus.PAUSED]
]
def get_transfer_record(self, file_id: str) -> Optional[FileTransferRecord]:
"""
获取传输记录
Args:
file_id: 文件ID
Returns:
传输记录如果不存在返回None
"""
if file_id not in self._active_transfers:
return None
state = self._active_transfers[file_id]
return FileTransferRecord(
transfer_id=state.file_id,
file_name=state.file_name,
file_size=state.file_size,
file_hash=state.file_hash,
sender_id=state.sender_id,
receiver_id=state.receiver_id,
status=state.status,
progress=state.progress_percent,
start_time=state.start_time,
end_time=datetime.now() if state.status == TransferStatus.COMPLETED else None
)
# ==================== 清理 ====================
def cleanup_completed_transfers(self, max_age_hours: int = 24) -> int:
"""
清理已完成的传输记录
Args:
max_age_hours: 最大保留时间(小时)
Returns:
清理的记录数
"""
cleaned = 0
cutoff = datetime.now()
for file_id in list(self._active_transfers.keys()):
state = self._active_transfers[file_id]
if state.status in [TransferStatus.COMPLETED,
TransferStatus.CANCELLED,
TransferStatus.FAILED]:
age_hours = (cutoff - state.last_update).total_seconds() / 3600
if age_hours > max_age_hours:
del self._active_transfers[file_id]
self._delete_transfer_state(file_id)
cleaned += 1
if cleaned > 0:
logger.info(f"Cleaned {cleaned} old transfer records")
return cleaned
def clear_all_transfers(self) -> None:
"""
清除所有传输记录
"""
for file_id in list(self._active_transfers.keys()):
self._delete_transfer_state(file_id)
self._active_transfers.clear()
self._receive_buffers.clear()
self._progress_callbacks.clear()
self._cancelled_transfers.clear()
logger.info("All transfer records cleared")
# ==================== 保存接收的文件 ====================
def save_received_file(self, file_id: str, save_path: str) -> bool:
"""
保存接收到的文件
Args:
file_id: 文件ID
save_path: 保存路径
Returns:
保存成功返回True否则返回False
"""
if file_id not in self._active_transfers:
logger.error(f"Transfer not found: {file_id}")
return False
state = self._active_transfers[file_id]
if state.status != TransferStatus.COMPLETED:
logger.error(f"Transfer not completed: {file_id}")
return False
# 如果数据还在缓冲区,组装文件
if file_id in self._receive_buffers:
success = self._assemble_file(file_id, save_path)
if success:
# 验证完整性
if self.verify_file_integrity(save_path, state.file_hash):
state.save_path = save_path
self._save_transfer_state(state)
# 清理缓冲区
del self._receive_buffers[file_id]
logger.info(f"File saved: {save_path}")
return True
else:
logger.error(f"File integrity check failed after save")
return False
else:
return False
logger.error(f"No data in buffer for file: {file_id}")
return False