|
|
# P2P Network Communication - Data Access Layer
|
|
|
"""
|
|
|
数据访问层,提供用户、消息、文件传输记录的CRUD操作
|
|
|
"""
|
|
|
|
|
|
import json
|
|
|
import logging
|
|
|
from datetime import datetime
|
|
|
from typing import Optional, List
|
|
|
|
|
|
from shared.models import (
|
|
|
UserInfo, UserStatus, ChatMessage, MessageType,
|
|
|
FileTransferRecord, TransferStatus, Message
|
|
|
)
|
|
|
from server.database import DatabaseManager
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class UserRepository:
|
|
|
"""用户数据访问层"""
|
|
|
|
|
|
def __init__(self, db: DatabaseManager):
|
|
|
"""
|
|
|
初始化用户仓库
|
|
|
|
|
|
Args:
|
|
|
db: 数据库管理器实例
|
|
|
"""
|
|
|
self.db = db
|
|
|
|
|
|
async def create(self, user: UserInfo, password_hash: str = "") -> bool:
|
|
|
"""
|
|
|
创建新用户
|
|
|
|
|
|
Args:
|
|
|
user: 用户信息对象
|
|
|
password_hash: 密码哈希值
|
|
|
|
|
|
Returns:
|
|
|
创建是否成功
|
|
|
"""
|
|
|
query = """
|
|
|
INSERT INTO users (user_id, username, display_name, password_hash,
|
|
|
public_key, status, ip_address, port, last_seen)
|
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|
|
"""
|
|
|
try:
|
|
|
await self.db.execute(query, (
|
|
|
user.user_id,
|
|
|
user.username,
|
|
|
user.display_name,
|
|
|
password_hash,
|
|
|
user.public_key if user.public_key else None,
|
|
|
user.status.value,
|
|
|
user.ip_address,
|
|
|
user.port,
|
|
|
user.last_seen,
|
|
|
))
|
|
|
logger.info(f"User created: {user.user_id}")
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to create user {user.user_id}: {e}")
|
|
|
return False
|
|
|
|
|
|
async def get_by_id(self, user_id: str) -> Optional[UserInfo]:
|
|
|
"""
|
|
|
根据用户ID获取用户信息
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
|
|
|
Returns:
|
|
|
用户信息对象,如果不存在返回None
|
|
|
"""
|
|
|
query = """
|
|
|
SELECT user_id, username, display_name, public_key, status,
|
|
|
ip_address, port, last_seen
|
|
|
FROM users WHERE user_id = %s
|
|
|
"""
|
|
|
row = await self.db.fetch_dict(query, (user_id,))
|
|
|
if row:
|
|
|
return self._row_to_user(row)
|
|
|
return None
|
|
|
|
|
|
async def get_by_username(self, username: str) -> Optional[UserInfo]:
|
|
|
"""
|
|
|
根据用户名获取用户信息
|
|
|
|
|
|
Args:
|
|
|
username: 用户名
|
|
|
|
|
|
Returns:
|
|
|
用户信息对象,如果不存在返回None
|
|
|
"""
|
|
|
query = """
|
|
|
SELECT user_id, username, display_name, public_key, status,
|
|
|
ip_address, port, last_seen
|
|
|
FROM users WHERE username = %s
|
|
|
"""
|
|
|
row = await self.db.fetch_dict(query, (username,))
|
|
|
if row:
|
|
|
return self._row_to_user(row)
|
|
|
return None
|
|
|
|
|
|
async def update(self, user: UserInfo) -> bool:
|
|
|
"""
|
|
|
更新用户信息
|
|
|
|
|
|
Args:
|
|
|
user: 用户信息对象
|
|
|
|
|
|
Returns:
|
|
|
更新是否成功
|
|
|
"""
|
|
|
query = """
|
|
|
UPDATE users SET display_name = %s, public_key = %s, status = %s,
|
|
|
ip_address = %s, port = %s, last_seen = %s
|
|
|
WHERE user_id = %s
|
|
|
"""
|
|
|
try:
|
|
|
rows = await self.db.execute(query, (
|
|
|
user.display_name,
|
|
|
user.public_key if user.public_key else None,
|
|
|
user.status.value,
|
|
|
user.ip_address,
|
|
|
user.port,
|
|
|
user.last_seen,
|
|
|
user.user_id,
|
|
|
))
|
|
|
return rows > 0
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to update user {user.user_id}: {e}")
|
|
|
return False
|
|
|
|
|
|
async def update_status(self, user_id: str, status: UserStatus,
|
|
|
ip_address: str = "", port: int = 0) -> bool:
|
|
|
"""
|
|
|
更新用户在线状态
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
status: 用户状态
|
|
|
ip_address: IP地址
|
|
|
port: 端口号
|
|
|
|
|
|
Returns:
|
|
|
更新是否成功
|
|
|
"""
|
|
|
query = """
|
|
|
UPDATE users SET status = %s, ip_address = %s, port = %s, last_seen = %s
|
|
|
WHERE user_id = %s
|
|
|
"""
|
|
|
try:
|
|
|
rows = await self.db.execute(query, (
|
|
|
status.value,
|
|
|
ip_address,
|
|
|
port,
|
|
|
datetime.now(),
|
|
|
user_id,
|
|
|
))
|
|
|
return rows > 0
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to update user status {user_id}: {e}")
|
|
|
return False
|
|
|
|
|
|
async def delete(self, user_id: str) -> bool:
|
|
|
"""
|
|
|
删除用户
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
|
|
|
Returns:
|
|
|
删除是否成功
|
|
|
"""
|
|
|
query = "DELETE FROM users WHERE user_id = %s"
|
|
|
try:
|
|
|
rows = await self.db.execute(query, (user_id,))
|
|
|
return rows > 0
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to delete user {user_id}: {e}")
|
|
|
return False
|
|
|
|
|
|
async def get_online_users(self) -> List[UserInfo]:
|
|
|
"""
|
|
|
获取所有在线用户
|
|
|
|
|
|
Returns:
|
|
|
在线用户列表
|
|
|
"""
|
|
|
query = """
|
|
|
SELECT user_id, username, display_name, public_key, status,
|
|
|
ip_address, port, last_seen
|
|
|
FROM users WHERE status = 'online'
|
|
|
"""
|
|
|
rows = await self.db.fetch_all_dict(query)
|
|
|
return [self._row_to_user(row) for row in rows]
|
|
|
|
|
|
async def get_all_users(self) -> List[UserInfo]:
|
|
|
"""
|
|
|
获取所有用户
|
|
|
|
|
|
Returns:
|
|
|
用户列表
|
|
|
"""
|
|
|
query = """
|
|
|
SELECT user_id, username, display_name, public_key, status,
|
|
|
ip_address, port, last_seen
|
|
|
FROM users ORDER BY username
|
|
|
"""
|
|
|
rows = await self.db.fetch_all_dict(query)
|
|
|
return [self._row_to_user(row) for row in rows]
|
|
|
|
|
|
async def exists(self, user_id: str) -> bool:
|
|
|
"""
|
|
|
检查用户是否存在
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
|
|
|
Returns:
|
|
|
用户是否存在
|
|
|
"""
|
|
|
query = "SELECT 1 FROM users WHERE user_id = %s"
|
|
|
row = await self.db.fetch_one(query, (user_id,))
|
|
|
return row is not None
|
|
|
|
|
|
async def username_exists(self, username: str) -> bool:
|
|
|
"""
|
|
|
检查用户名是否已存在
|
|
|
|
|
|
Args:
|
|
|
username: 用户名
|
|
|
|
|
|
Returns:
|
|
|
用户名是否已存在
|
|
|
"""
|
|
|
query = "SELECT 1 FROM users WHERE username = %s"
|
|
|
row = await self.db.fetch_one(query, (username,))
|
|
|
return row is not None
|
|
|
|
|
|
def _row_to_user(self, row: dict) -> UserInfo:
|
|
|
"""将数据库行转换为UserInfo对象"""
|
|
|
return UserInfo(
|
|
|
user_id=row['user_id'],
|
|
|
username=row['username'],
|
|
|
display_name=row['display_name'],
|
|
|
status=UserStatus(row['status']) if row['status'] else UserStatus.OFFLINE,
|
|
|
ip_address=row['ip_address'] or "",
|
|
|
port=row['port'] or 0,
|
|
|
last_seen=row['last_seen'],
|
|
|
public_key=row['public_key'] or bytes(),
|
|
|
)
|
|
|
|
|
|
|
|
|
class MessageRepository:
|
|
|
"""消息数据访问层"""
|
|
|
|
|
|
def __init__(self, db: DatabaseManager):
|
|
|
"""
|
|
|
初始化消息仓库
|
|
|
|
|
|
Args:
|
|
|
db: 数据库管理器实例
|
|
|
"""
|
|
|
self.db = db
|
|
|
|
|
|
async def create(self, message: ChatMessage) -> bool:
|
|
|
"""
|
|
|
保存聊天消息
|
|
|
|
|
|
Args:
|
|
|
message: 聊天消息对象
|
|
|
|
|
|
Returns:
|
|
|
保存是否成功
|
|
|
"""
|
|
|
query = """
|
|
|
INSERT INTO messages (message_id, sender_id, receiver_id, content_type,
|
|
|
content, timestamp, is_read, is_delivered)
|
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
|
|
"""
|
|
|
try:
|
|
|
await self.db.execute(query, (
|
|
|
message.message_id,
|
|
|
message.sender_id,
|
|
|
message.receiver_id,
|
|
|
message.content_type.value,
|
|
|
message.content,
|
|
|
message.timestamp,
|
|
|
message.is_read,
|
|
|
message.is_sent,
|
|
|
))
|
|
|
logger.debug(f"Message saved: {message.message_id}")
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to save message {message.message_id}: {e}")
|
|
|
return False
|
|
|
|
|
|
async def get_by_id(self, message_id: str) -> Optional[ChatMessage]:
|
|
|
"""
|
|
|
根据消息ID获取消息
|
|
|
|
|
|
Args:
|
|
|
message_id: 消息ID
|
|
|
|
|
|
Returns:
|
|
|
聊天消息对象,如果不存在返回None
|
|
|
"""
|
|
|
query = """
|
|
|
SELECT message_id, sender_id, receiver_id, content_type, content,
|
|
|
timestamp, is_read, is_delivered
|
|
|
FROM messages WHERE message_id = %s
|
|
|
"""
|
|
|
row = await self.db.fetch_dict(query, (message_id,))
|
|
|
if row:
|
|
|
return self._row_to_message(row)
|
|
|
return None
|
|
|
|
|
|
async def get_conversation(self, user1_id: str, user2_id: str,
|
|
|
limit: int = 100, offset: int = 0) -> List[ChatMessage]:
|
|
|
"""
|
|
|
获取两个用户之间的聊天历史
|
|
|
|
|
|
Args:
|
|
|
user1_id: 用户1的ID
|
|
|
user2_id: 用户2的ID
|
|
|
limit: 返回消息数量限制
|
|
|
offset: 偏移量
|
|
|
|
|
|
Returns:
|
|
|
聊天消息列表,按时间升序排列
|
|
|
"""
|
|
|
query = """
|
|
|
SELECT message_id, sender_id, receiver_id, content_type, content,
|
|
|
timestamp, is_read, is_delivered
|
|
|
FROM messages
|
|
|
WHERE (sender_id = %s AND receiver_id = %s)
|
|
|
OR (sender_id = %s AND receiver_id = %s)
|
|
|
ORDER BY timestamp ASC
|
|
|
LIMIT %s OFFSET %s
|
|
|
"""
|
|
|
rows = await self.db.fetch_all_dict(query, (
|
|
|
user1_id, user2_id, user2_id, user1_id, limit, offset
|
|
|
))
|
|
|
return [self._row_to_message(row) for row in rows]
|
|
|
|
|
|
async def get_user_messages(self, user_id: str, limit: int = 100,
|
|
|
offset: int = 0) -> List[ChatMessage]:
|
|
|
"""
|
|
|
获取用户的所有消息(发送和接收)
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
limit: 返回消息数量限制
|
|
|
offset: 偏移量
|
|
|
|
|
|
Returns:
|
|
|
聊天消息列表,按时间升序排列
|
|
|
"""
|
|
|
query = """
|
|
|
SELECT message_id, sender_id, receiver_id, content_type, content,
|
|
|
timestamp, is_read, is_delivered
|
|
|
FROM messages
|
|
|
WHERE sender_id = %s OR receiver_id = %s
|
|
|
ORDER BY timestamp ASC
|
|
|
LIMIT %s OFFSET %s
|
|
|
"""
|
|
|
rows = await self.db.fetch_all_dict(query, (user_id, user_id, limit, offset))
|
|
|
return [self._row_to_message(row) for row in rows]
|
|
|
|
|
|
async def get_unread_messages(self, user_id: str) -> List[ChatMessage]:
|
|
|
"""
|
|
|
获取用户的未读消息
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
|
|
|
Returns:
|
|
|
未读消息列表
|
|
|
"""
|
|
|
query = """
|
|
|
SELECT message_id, sender_id, receiver_id, content_type, content,
|
|
|
timestamp, is_read, is_delivered
|
|
|
FROM messages
|
|
|
WHERE receiver_id = %s AND is_read = FALSE
|
|
|
ORDER BY timestamp ASC
|
|
|
"""
|
|
|
rows = await self.db.fetch_all_dict(query, (user_id,))
|
|
|
return [self._row_to_message(row) for row in rows]
|
|
|
|
|
|
async def mark_as_read(self, message_id: str) -> bool:
|
|
|
"""
|
|
|
标记消息为已读
|
|
|
|
|
|
Args:
|
|
|
message_id: 消息ID
|
|
|
|
|
|
Returns:
|
|
|
更新是否成功
|
|
|
"""
|
|
|
query = "UPDATE messages SET is_read = TRUE WHERE message_id = %s"
|
|
|
try:
|
|
|
rows = await self.db.execute(query, (message_id,))
|
|
|
return rows > 0
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to mark message as read {message_id}: {e}")
|
|
|
return False
|
|
|
|
|
|
async def mark_conversation_as_read(self, user_id: str, sender_id: str) -> int:
|
|
|
"""
|
|
|
标记与某用户的所有消息为已读
|
|
|
|
|
|
Args:
|
|
|
user_id: 接收者用户ID
|
|
|
sender_id: 发送者用户ID
|
|
|
|
|
|
Returns:
|
|
|
更新的消息数量
|
|
|
"""
|
|
|
query = """
|
|
|
UPDATE messages SET is_read = TRUE
|
|
|
WHERE receiver_id = %s AND sender_id = %s AND is_read = FALSE
|
|
|
"""
|
|
|
try:
|
|
|
return await self.db.execute(query, (user_id, sender_id))
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to mark conversation as read: {e}")
|
|
|
return 0
|
|
|
|
|
|
async def mark_as_delivered(self, message_id: str) -> bool:
|
|
|
"""
|
|
|
标记消息为已投递
|
|
|
|
|
|
Args:
|
|
|
message_id: 消息ID
|
|
|
|
|
|
Returns:
|
|
|
更新是否成功
|
|
|
"""
|
|
|
query = "UPDATE messages SET is_delivered = TRUE WHERE message_id = %s"
|
|
|
try:
|
|
|
rows = await self.db.execute(query, (message_id,))
|
|
|
return rows > 0
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to mark message as delivered {message_id}: {e}")
|
|
|
return False
|
|
|
|
|
|
async def delete(self, message_id: str) -> bool:
|
|
|
"""
|
|
|
删除消息
|
|
|
|
|
|
Args:
|
|
|
message_id: 消息ID
|
|
|
|
|
|
Returns:
|
|
|
删除是否成功
|
|
|
"""
|
|
|
query = "DELETE FROM messages WHERE message_id = %s"
|
|
|
try:
|
|
|
rows = await self.db.execute(query, (message_id,))
|
|
|
return rows > 0
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to delete message {message_id}: {e}")
|
|
|
return False
|
|
|
|
|
|
async def delete_conversation(self, user1_id: str, user2_id: str) -> int:
|
|
|
"""
|
|
|
删除两个用户之间的所有消息
|
|
|
|
|
|
Args:
|
|
|
user1_id: 用户1的ID
|
|
|
user2_id: 用户2的ID
|
|
|
|
|
|
Returns:
|
|
|
删除的消息数量
|
|
|
"""
|
|
|
query = """
|
|
|
DELETE FROM messages
|
|
|
WHERE (sender_id = %s AND receiver_id = %s)
|
|
|
OR (sender_id = %s AND receiver_id = %s)
|
|
|
"""
|
|
|
try:
|
|
|
return await self.db.execute(query, (user1_id, user2_id, user2_id, user1_id))
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to delete conversation: {e}")
|
|
|
return 0
|
|
|
|
|
|
def _row_to_message(self, row: dict) -> ChatMessage:
|
|
|
"""将数据库行转换为ChatMessage对象"""
|
|
|
return ChatMessage(
|
|
|
message_id=row['message_id'],
|
|
|
sender_id=row['sender_id'],
|
|
|
receiver_id=row['receiver_id'],
|
|
|
content_type=MessageType(row['content_type']),
|
|
|
content=row['content'] or "",
|
|
|
timestamp=row['timestamp'],
|
|
|
is_read=bool(row['is_read']),
|
|
|
is_sent=bool(row['is_delivered']),
|
|
|
)
|
|
|
|
|
|
|
|
|
class FileTransferRepository:
|
|
|
"""文件传输记录数据访问层"""
|
|
|
|
|
|
def __init__(self, db: DatabaseManager):
|
|
|
"""
|
|
|
初始化文件传输仓库
|
|
|
|
|
|
Args:
|
|
|
db: 数据库管理器实例
|
|
|
"""
|
|
|
self.db = db
|
|
|
|
|
|
async def create(self, record: FileTransferRecord) -> bool:
|
|
|
"""
|
|
|
创建文件传输记录
|
|
|
|
|
|
Args:
|
|
|
record: 文件传输记录对象
|
|
|
|
|
|
Returns:
|
|
|
创建是否成功
|
|
|
"""
|
|
|
query = """
|
|
|
INSERT INTO file_transfers (transfer_id, file_name, file_size, file_hash,
|
|
|
sender_id, receiver_id, status, progress,
|
|
|
start_time, end_time)
|
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|
|
"""
|
|
|
try:
|
|
|
await self.db.execute(query, (
|
|
|
record.transfer_id,
|
|
|
record.file_name,
|
|
|
record.file_size,
|
|
|
record.file_hash,
|
|
|
record.sender_id,
|
|
|
record.receiver_id,
|
|
|
record.status.value,
|
|
|
record.progress,
|
|
|
record.start_time,
|
|
|
record.end_time,
|
|
|
))
|
|
|
logger.info(f"File transfer record created: {record.transfer_id}")
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to create file transfer record: {e}")
|
|
|
return False
|
|
|
|
|
|
async def get_by_id(self, transfer_id: str) -> Optional[FileTransferRecord]:
|
|
|
"""
|
|
|
根据传输ID获取记录
|
|
|
|
|
|
Args:
|
|
|
transfer_id: 传输ID
|
|
|
|
|
|
Returns:
|
|
|
文件传输记录对象,如果不存在返回None
|
|
|
"""
|
|
|
query = """
|
|
|
SELECT transfer_id, file_name, file_size, file_hash, sender_id,
|
|
|
receiver_id, status, progress, start_time, end_time
|
|
|
FROM file_transfers WHERE transfer_id = %s
|
|
|
"""
|
|
|
row = await self.db.fetch_dict(query, (transfer_id,))
|
|
|
if row:
|
|
|
return self._row_to_record(row)
|
|
|
return None
|
|
|
|
|
|
async def update_status(self, transfer_id: str, status: TransferStatus,
|
|
|
progress: float = None) -> bool:
|
|
|
"""
|
|
|
更新传输状态
|
|
|
|
|
|
Args:
|
|
|
transfer_id: 传输ID
|
|
|
status: 新状态
|
|
|
progress: 进度(可选)
|
|
|
|
|
|
Returns:
|
|
|
更新是否成功
|
|
|
"""
|
|
|
if progress is not None:
|
|
|
query = """
|
|
|
UPDATE file_transfers SET status = %s, progress = %s
|
|
|
WHERE transfer_id = %s
|
|
|
"""
|
|
|
args = (status.value, progress, transfer_id)
|
|
|
else:
|
|
|
query = "UPDATE file_transfers SET status = %s WHERE transfer_id = %s"
|
|
|
args = (status.value, transfer_id)
|
|
|
|
|
|
try:
|
|
|
rows = await self.db.execute(query, args)
|
|
|
return rows > 0
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to update transfer status: {e}")
|
|
|
return False
|
|
|
|
|
|
async def update_progress(self, transfer_id: str, progress: float) -> bool:
|
|
|
"""
|
|
|
更新传输进度
|
|
|
|
|
|
Args:
|
|
|
transfer_id: 传输ID
|
|
|
progress: 进度百分比(0-100)
|
|
|
|
|
|
Returns:
|
|
|
更新是否成功
|
|
|
"""
|
|
|
query = "UPDATE file_transfers SET progress = %s WHERE transfer_id = %s"
|
|
|
try:
|
|
|
rows = await self.db.execute(query, (progress, transfer_id))
|
|
|
return rows > 0
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to update transfer progress: {e}")
|
|
|
return False
|
|
|
|
|
|
async def complete(self, transfer_id: str) -> bool:
|
|
|
"""
|
|
|
标记传输完成
|
|
|
|
|
|
Args:
|
|
|
transfer_id: 传输ID
|
|
|
|
|
|
Returns:
|
|
|
更新是否成功
|
|
|
"""
|
|
|
query = """
|
|
|
UPDATE file_transfers SET status = %s, progress = 100, end_time = %s
|
|
|
WHERE transfer_id = %s
|
|
|
"""
|
|
|
try:
|
|
|
rows = await self.db.execute(query, (
|
|
|
TransferStatus.COMPLETED.value, datetime.now(), transfer_id
|
|
|
))
|
|
|
return rows > 0
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to complete transfer: {e}")
|
|
|
return False
|
|
|
|
|
|
async def fail(self, transfer_id: str) -> bool:
|
|
|
"""
|
|
|
标记传输失败
|
|
|
|
|
|
Args:
|
|
|
transfer_id: 传输ID
|
|
|
|
|
|
Returns:
|
|
|
更新是否成功
|
|
|
"""
|
|
|
query = """
|
|
|
UPDATE file_transfers SET status = %s, end_time = %s
|
|
|
WHERE transfer_id = %s
|
|
|
"""
|
|
|
try:
|
|
|
rows = await self.db.execute(query, (
|
|
|
TransferStatus.FAILED.value, datetime.now(), transfer_id
|
|
|
))
|
|
|
return rows > 0
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to mark transfer as failed: {e}")
|
|
|
return False
|
|
|
|
|
|
async def get_user_transfers(self, user_id: str,
|
|
|
status: TransferStatus = None) -> List[FileTransferRecord]:
|
|
|
"""
|
|
|
获取用户的文件传输记录
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
status: 过滤状态(可选)
|
|
|
|
|
|
Returns:
|
|
|
文件传输记录列表
|
|
|
"""
|
|
|
if status:
|
|
|
query = """
|
|
|
SELECT transfer_id, file_name, file_size, file_hash, sender_id,
|
|
|
receiver_id, status, progress, start_time, end_time
|
|
|
FROM file_transfers
|
|
|
WHERE (sender_id = %s OR receiver_id = %s) AND status = %s
|
|
|
ORDER BY start_time DESC
|
|
|
"""
|
|
|
rows = await self.db.fetch_all_dict(query, (user_id, user_id, status.value))
|
|
|
else:
|
|
|
query = """
|
|
|
SELECT transfer_id, file_name, file_size, file_hash, sender_id,
|
|
|
receiver_id, status, progress, start_time, end_time
|
|
|
FROM file_transfers
|
|
|
WHERE sender_id = %s OR receiver_id = %s
|
|
|
ORDER BY start_time DESC
|
|
|
"""
|
|
|
rows = await self.db.fetch_all_dict(query, (user_id, user_id))
|
|
|
|
|
|
return [self._row_to_record(row) for row in rows]
|
|
|
|
|
|
async def get_pending_transfers(self, user_id: str) -> List[FileTransferRecord]:
|
|
|
"""
|
|
|
获取用户待处理的传输
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
|
|
|
Returns:
|
|
|
待处理的文件传输记录列表
|
|
|
"""
|
|
|
query = """
|
|
|
SELECT transfer_id, file_name, file_size, file_hash, sender_id,
|
|
|
receiver_id, status, progress, start_time, end_time
|
|
|
FROM file_transfers
|
|
|
WHERE (sender_id = %s OR receiver_id = %s)
|
|
|
AND status IN ('pending', 'in_progress', 'paused')
|
|
|
ORDER BY start_time DESC
|
|
|
"""
|
|
|
rows = await self.db.fetch_all_dict(query, (user_id, user_id))
|
|
|
return [self._row_to_record(row) for row in rows]
|
|
|
|
|
|
async def delete(self, transfer_id: str) -> bool:
|
|
|
"""
|
|
|
删除传输记录
|
|
|
|
|
|
Args:
|
|
|
transfer_id: 传输ID
|
|
|
|
|
|
Returns:
|
|
|
删除是否成功
|
|
|
"""
|
|
|
query = "DELETE FROM file_transfers WHERE transfer_id = %s"
|
|
|
try:
|
|
|
rows = await self.db.execute(query, (transfer_id,))
|
|
|
return rows > 0
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to delete transfer record: {e}")
|
|
|
return False
|
|
|
|
|
|
def _row_to_record(self, row: dict) -> FileTransferRecord:
|
|
|
"""将数据库行转换为FileTransferRecord对象"""
|
|
|
return FileTransferRecord(
|
|
|
transfer_id=row['transfer_id'],
|
|
|
file_name=row['file_name'],
|
|
|
file_size=row['file_size'],
|
|
|
file_hash=row['file_hash'],
|
|
|
sender_id=row['sender_id'],
|
|
|
receiver_id=row['receiver_id'],
|
|
|
status=TransferStatus(row['status']),
|
|
|
progress=float(row['progress']),
|
|
|
start_time=row['start_time'],
|
|
|
end_time=row['end_time'],
|
|
|
)
|
|
|
|
|
|
|
|
|
class OfflineMessageRepository:
|
|
|
"""离线消息数据访问层"""
|
|
|
|
|
|
def __init__(self, db: DatabaseManager):
|
|
|
"""
|
|
|
初始化离线消息仓库
|
|
|
|
|
|
Args:
|
|
|
db: 数据库管理器实例
|
|
|
"""
|
|
|
self.db = db
|
|
|
|
|
|
async def cache(self, user_id: str, message: Message) -> bool:
|
|
|
"""
|
|
|
缓存离线消息
|
|
|
|
|
|
Args:
|
|
|
user_id: 目标用户ID
|
|
|
message: 消息对象
|
|
|
|
|
|
Returns:
|
|
|
缓存是否成功
|
|
|
"""
|
|
|
query = """
|
|
|
INSERT INTO offline_messages (user_id, message_data)
|
|
|
VALUES (%s, %s)
|
|
|
"""
|
|
|
try:
|
|
|
message_data = message.serialize()
|
|
|
await self.db.execute(query, (user_id, message_data))
|
|
|
logger.debug(f"Offline message cached for user: {user_id}")
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to cache offline message: {e}")
|
|
|
return False
|
|
|
|
|
|
async def get_messages(self, user_id: str) -> List[Message]:
|
|
|
"""
|
|
|
获取用户的所有离线消息
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
|
|
|
Returns:
|
|
|
离线消息列表,按创建时间升序排列
|
|
|
"""
|
|
|
query = """
|
|
|
SELECT id, message_data FROM offline_messages
|
|
|
WHERE user_id = %s ORDER BY created_at ASC
|
|
|
"""
|
|
|
rows = await self.db.fetch_all_dict(query, (user_id,))
|
|
|
messages = []
|
|
|
for row in rows:
|
|
|
try:
|
|
|
message = Message.deserialize(row['message_data'])
|
|
|
messages.append(message)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to deserialize offline message: {e}")
|
|
|
return messages
|
|
|
|
|
|
async def get_messages_with_ids(self, user_id: str) -> List[tuple]:
|
|
|
"""
|
|
|
获取用户的所有离线消息及其ID
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
|
|
|
Returns:
|
|
|
(消息ID, 消息对象)元组列表
|
|
|
"""
|
|
|
query = """
|
|
|
SELECT id, message_data FROM offline_messages
|
|
|
WHERE user_id = %s ORDER BY created_at ASC
|
|
|
"""
|
|
|
rows = await self.db.fetch_all_dict(query, (user_id,))
|
|
|
results = []
|
|
|
for row in rows:
|
|
|
try:
|
|
|
message = Message.deserialize(row['message_data'])
|
|
|
results.append((row['id'], message))
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to deserialize offline message: {e}")
|
|
|
return results
|
|
|
|
|
|
async def delete(self, message_id: int) -> bool:
|
|
|
"""
|
|
|
删除单条离线消息
|
|
|
|
|
|
Args:
|
|
|
message_id: 离线消息记录ID
|
|
|
|
|
|
Returns:
|
|
|
删除是否成功
|
|
|
"""
|
|
|
query = "DELETE FROM offline_messages WHERE id = %s"
|
|
|
try:
|
|
|
rows = await self.db.execute(query, (message_id,))
|
|
|
return rows > 0
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to delete offline message: {e}")
|
|
|
return False
|
|
|
|
|
|
async def delete_all(self, user_id: str) -> int:
|
|
|
"""
|
|
|
删除用户的所有离线消息
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
|
|
|
Returns:
|
|
|
删除的消息数量
|
|
|
"""
|
|
|
query = "DELETE FROM offline_messages WHERE user_id = %s"
|
|
|
try:
|
|
|
return await self.db.execute(query, (user_id,))
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to delete all offline messages: {e}")
|
|
|
return 0
|
|
|
|
|
|
async def count(self, user_id: str) -> int:
|
|
|
"""
|
|
|
获取用户的离线消息数量
|
|
|
|
|
|
Args:
|
|
|
user_id: 用户ID
|
|
|
|
|
|
Returns:
|
|
|
离线消息数量
|
|
|
"""
|
|
|
query = "SELECT COUNT(*) as count FROM offline_messages WHERE user_id = %s"
|
|
|
row = await self.db.fetch_dict(query, (user_id,))
|
|
|
return row['count'] if row else 0
|
|
|
|
|
|
async def cleanup_old_messages(self, max_age_seconds: int) -> int:
|
|
|
"""
|
|
|
清理过期的离线消息
|
|
|
|
|
|
Args:
|
|
|
max_age_seconds: 最大保留时间(秒)
|
|
|
|
|
|
Returns:
|
|
|
删除的消息数量
|
|
|
"""
|
|
|
query = """
|
|
|
DELETE FROM offline_messages
|
|
|
WHERE created_at < DATE_SUB(NOW(), INTERVAL %s SECOND)
|
|
|
"""
|
|
|
try:
|
|
|
return await self.db.execute(query, (max_age_seconds,))
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to cleanup old offline messages: {e}")
|
|
|
return 0
|