You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

906 lines
28 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# P2P Network Communication - Data Access Layer
"""
数据访问层提供用户、消息、文件传输记录的CRUD操作
"""
import json
import logging
from datetime import datetime
from typing import Optional, List
from shared.models import (
UserInfo, UserStatus, ChatMessage, MessageType,
FileTransferRecord, TransferStatus, Message
)
from server.database import DatabaseManager
logger = logging.getLogger(__name__)
class UserRepository:
"""用户数据访问层"""
def __init__(self, db: DatabaseManager):
"""
初始化用户仓库
Args:
db: 数据库管理器实例
"""
self.db = db
async def create(self, user: UserInfo, password_hash: str = "") -> bool:
"""
创建新用户
Args:
user: 用户信息对象
password_hash: 密码哈希值
Returns:
创建是否成功
"""
query = """
INSERT INTO users (user_id, username, display_name, password_hash,
public_key, status, ip_address, port, last_seen)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
try:
await self.db.execute(query, (
user.user_id,
user.username,
user.display_name,
password_hash,
user.public_key if user.public_key else None,
user.status.value,
user.ip_address,
user.port,
user.last_seen,
))
logger.info(f"User created: {user.user_id}")
return True
except Exception as e:
logger.error(f"Failed to create user {user.user_id}: {e}")
return False
async def get_by_id(self, user_id: str) -> Optional[UserInfo]:
"""
根据用户ID获取用户信息
Args:
user_id: 用户ID
Returns:
用户信息对象如果不存在返回None
"""
query = """
SELECT user_id, username, display_name, public_key, status,
ip_address, port, last_seen
FROM users WHERE user_id = %s
"""
row = await self.db.fetch_dict(query, (user_id,))
if row:
return self._row_to_user(row)
return None
async def get_by_username(self, username: str) -> Optional[UserInfo]:
"""
根据用户名获取用户信息
Args:
username: 用户名
Returns:
用户信息对象如果不存在返回None
"""
query = """
SELECT user_id, username, display_name, public_key, status,
ip_address, port, last_seen
FROM users WHERE username = %s
"""
row = await self.db.fetch_dict(query, (username,))
if row:
return self._row_to_user(row)
return None
async def update(self, user: UserInfo) -> bool:
"""
更新用户信息
Args:
user: 用户信息对象
Returns:
更新是否成功
"""
query = """
UPDATE users SET display_name = %s, public_key = %s, status = %s,
ip_address = %s, port = %s, last_seen = %s
WHERE user_id = %s
"""
try:
rows = await self.db.execute(query, (
user.display_name,
user.public_key if user.public_key else None,
user.status.value,
user.ip_address,
user.port,
user.last_seen,
user.user_id,
))
return rows > 0
except Exception as e:
logger.error(f"Failed to update user {user.user_id}: {e}")
return False
async def update_status(self, user_id: str, status: UserStatus,
ip_address: str = "", port: int = 0) -> bool:
"""
更新用户在线状态
Args:
user_id: 用户ID
status: 用户状态
ip_address: IP地址
port: 端口号
Returns:
更新是否成功
"""
query = """
UPDATE users SET status = %s, ip_address = %s, port = %s, last_seen = %s
WHERE user_id = %s
"""
try:
rows = await self.db.execute(query, (
status.value,
ip_address,
port,
datetime.now(),
user_id,
))
return rows > 0
except Exception as e:
logger.error(f"Failed to update user status {user_id}: {e}")
return False
async def delete(self, user_id: str) -> bool:
"""
删除用户
Args:
user_id: 用户ID
Returns:
删除是否成功
"""
query = "DELETE FROM users WHERE user_id = %s"
try:
rows = await self.db.execute(query, (user_id,))
return rows > 0
except Exception as e:
logger.error(f"Failed to delete user {user_id}: {e}")
return False
async def get_online_users(self) -> List[UserInfo]:
"""
获取所有在线用户
Returns:
在线用户列表
"""
query = """
SELECT user_id, username, display_name, public_key, status,
ip_address, port, last_seen
FROM users WHERE status = 'online'
"""
rows = await self.db.fetch_all_dict(query)
return [self._row_to_user(row) for row in rows]
async def get_all_users(self) -> List[UserInfo]:
"""
获取所有用户
Returns:
用户列表
"""
query = """
SELECT user_id, username, display_name, public_key, status,
ip_address, port, last_seen
FROM users ORDER BY username
"""
rows = await self.db.fetch_all_dict(query)
return [self._row_to_user(row) for row in rows]
async def exists(self, user_id: str) -> bool:
"""
检查用户是否存在
Args:
user_id: 用户ID
Returns:
用户是否存在
"""
query = "SELECT 1 FROM users WHERE user_id = %s"
row = await self.db.fetch_one(query, (user_id,))
return row is not None
async def username_exists(self, username: str) -> bool:
"""
检查用户名是否已存在
Args:
username: 用户名
Returns:
用户名是否已存在
"""
query = "SELECT 1 FROM users WHERE username = %s"
row = await self.db.fetch_one(query, (username,))
return row is not None
def _row_to_user(self, row: dict) -> UserInfo:
"""将数据库行转换为UserInfo对象"""
return UserInfo(
user_id=row['user_id'],
username=row['username'],
display_name=row['display_name'],
status=UserStatus(row['status']) if row['status'] else UserStatus.OFFLINE,
ip_address=row['ip_address'] or "",
port=row['port'] or 0,
last_seen=row['last_seen'],
public_key=row['public_key'] or bytes(),
)
class MessageRepository:
"""消息数据访问层"""
def __init__(self, db: DatabaseManager):
"""
初始化消息仓库
Args:
db: 数据库管理器实例
"""
self.db = db
async def create(self, message: ChatMessage) -> bool:
"""
保存聊天消息
Args:
message: 聊天消息对象
Returns:
保存是否成功
"""
query = """
INSERT INTO messages (message_id, sender_id, receiver_id, content_type,
content, timestamp, is_read, is_delivered)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
"""
try:
await self.db.execute(query, (
message.message_id,
message.sender_id,
message.receiver_id,
message.content_type.value,
message.content,
message.timestamp,
message.is_read,
message.is_sent,
))
logger.debug(f"Message saved: {message.message_id}")
return True
except Exception as e:
logger.error(f"Failed to save message {message.message_id}: {e}")
return False
async def get_by_id(self, message_id: str) -> Optional[ChatMessage]:
"""
根据消息ID获取消息
Args:
message_id: 消息ID
Returns:
聊天消息对象如果不存在返回None
"""
query = """
SELECT message_id, sender_id, receiver_id, content_type, content,
timestamp, is_read, is_delivered
FROM messages WHERE message_id = %s
"""
row = await self.db.fetch_dict(query, (message_id,))
if row:
return self._row_to_message(row)
return None
async def get_conversation(self, user1_id: str, user2_id: str,
limit: int = 100, offset: int = 0) -> List[ChatMessage]:
"""
获取两个用户之间的聊天历史
Args:
user1_id: 用户1的ID
user2_id: 用户2的ID
limit: 返回消息数量限制
offset: 偏移量
Returns:
聊天消息列表,按时间升序排列
"""
query = """
SELECT message_id, sender_id, receiver_id, content_type, content,
timestamp, is_read, is_delivered
FROM messages
WHERE (sender_id = %s AND receiver_id = %s)
OR (sender_id = %s AND receiver_id = %s)
ORDER BY timestamp ASC
LIMIT %s OFFSET %s
"""
rows = await self.db.fetch_all_dict(query, (
user1_id, user2_id, user2_id, user1_id, limit, offset
))
return [self._row_to_message(row) for row in rows]
async def get_user_messages(self, user_id: str, limit: int = 100,
offset: int = 0) -> List[ChatMessage]:
"""
获取用户的所有消息(发送和接收)
Args:
user_id: 用户ID
limit: 返回消息数量限制
offset: 偏移量
Returns:
聊天消息列表,按时间升序排列
"""
query = """
SELECT message_id, sender_id, receiver_id, content_type, content,
timestamp, is_read, is_delivered
FROM messages
WHERE sender_id = %s OR receiver_id = %s
ORDER BY timestamp ASC
LIMIT %s OFFSET %s
"""
rows = await self.db.fetch_all_dict(query, (user_id, user_id, limit, offset))
return [self._row_to_message(row) for row in rows]
async def get_unread_messages(self, user_id: str) -> List[ChatMessage]:
"""
获取用户的未读消息
Args:
user_id: 用户ID
Returns:
未读消息列表
"""
query = """
SELECT message_id, sender_id, receiver_id, content_type, content,
timestamp, is_read, is_delivered
FROM messages
WHERE receiver_id = %s AND is_read = FALSE
ORDER BY timestamp ASC
"""
rows = await self.db.fetch_all_dict(query, (user_id,))
return [self._row_to_message(row) for row in rows]
async def mark_as_read(self, message_id: str) -> bool:
"""
标记消息为已读
Args:
message_id: 消息ID
Returns:
更新是否成功
"""
query = "UPDATE messages SET is_read = TRUE WHERE message_id = %s"
try:
rows = await self.db.execute(query, (message_id,))
return rows > 0
except Exception as e:
logger.error(f"Failed to mark message as read {message_id}: {e}")
return False
async def mark_conversation_as_read(self, user_id: str, sender_id: str) -> int:
"""
标记与某用户的所有消息为已读
Args:
user_id: 接收者用户ID
sender_id: 发送者用户ID
Returns:
更新的消息数量
"""
query = """
UPDATE messages SET is_read = TRUE
WHERE receiver_id = %s AND sender_id = %s AND is_read = FALSE
"""
try:
return await self.db.execute(query, (user_id, sender_id))
except Exception as e:
logger.error(f"Failed to mark conversation as read: {e}")
return 0
async def mark_as_delivered(self, message_id: str) -> bool:
"""
标记消息为已投递
Args:
message_id: 消息ID
Returns:
更新是否成功
"""
query = "UPDATE messages SET is_delivered = TRUE WHERE message_id = %s"
try:
rows = await self.db.execute(query, (message_id,))
return rows > 0
except Exception as e:
logger.error(f"Failed to mark message as delivered {message_id}: {e}")
return False
async def delete(self, message_id: str) -> bool:
"""
删除消息
Args:
message_id: 消息ID
Returns:
删除是否成功
"""
query = "DELETE FROM messages WHERE message_id = %s"
try:
rows = await self.db.execute(query, (message_id,))
return rows > 0
except Exception as e:
logger.error(f"Failed to delete message {message_id}: {e}")
return False
async def delete_conversation(self, user1_id: str, user2_id: str) -> int:
"""
删除两个用户之间的所有消息
Args:
user1_id: 用户1的ID
user2_id: 用户2的ID
Returns:
删除的消息数量
"""
query = """
DELETE FROM messages
WHERE (sender_id = %s AND receiver_id = %s)
OR (sender_id = %s AND receiver_id = %s)
"""
try:
return await self.db.execute(query, (user1_id, user2_id, user2_id, user1_id))
except Exception as e:
logger.error(f"Failed to delete conversation: {e}")
return 0
def _row_to_message(self, row: dict) -> ChatMessage:
"""将数据库行转换为ChatMessage对象"""
return ChatMessage(
message_id=row['message_id'],
sender_id=row['sender_id'],
receiver_id=row['receiver_id'],
content_type=MessageType(row['content_type']),
content=row['content'] or "",
timestamp=row['timestamp'],
is_read=bool(row['is_read']),
is_sent=bool(row['is_delivered']),
)
class FileTransferRepository:
"""文件传输记录数据访问层"""
def __init__(self, db: DatabaseManager):
"""
初始化文件传输仓库
Args:
db: 数据库管理器实例
"""
self.db = db
async def create(self, record: FileTransferRecord) -> bool:
"""
创建文件传输记录
Args:
record: 文件传输记录对象
Returns:
创建是否成功
"""
query = """
INSERT INTO file_transfers (transfer_id, file_name, file_size, file_hash,
sender_id, receiver_id, status, progress,
start_time, end_time)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
try:
await self.db.execute(query, (
record.transfer_id,
record.file_name,
record.file_size,
record.file_hash,
record.sender_id,
record.receiver_id,
record.status.value,
record.progress,
record.start_time,
record.end_time,
))
logger.info(f"File transfer record created: {record.transfer_id}")
return True
except Exception as e:
logger.error(f"Failed to create file transfer record: {e}")
return False
async def get_by_id(self, transfer_id: str) -> Optional[FileTransferRecord]:
"""
根据传输ID获取记录
Args:
transfer_id: 传输ID
Returns:
文件传输记录对象如果不存在返回None
"""
query = """
SELECT transfer_id, file_name, file_size, file_hash, sender_id,
receiver_id, status, progress, start_time, end_time
FROM file_transfers WHERE transfer_id = %s
"""
row = await self.db.fetch_dict(query, (transfer_id,))
if row:
return self._row_to_record(row)
return None
async def update_status(self, transfer_id: str, status: TransferStatus,
progress: float = None) -> bool:
"""
更新传输状态
Args:
transfer_id: 传输ID
status: 新状态
progress: 进度(可选)
Returns:
更新是否成功
"""
if progress is not None:
query = """
UPDATE file_transfers SET status = %s, progress = %s
WHERE transfer_id = %s
"""
args = (status.value, progress, transfer_id)
else:
query = "UPDATE file_transfers SET status = %s WHERE transfer_id = %s"
args = (status.value, transfer_id)
try:
rows = await self.db.execute(query, args)
return rows > 0
except Exception as e:
logger.error(f"Failed to update transfer status: {e}")
return False
async def update_progress(self, transfer_id: str, progress: float) -> bool:
"""
更新传输进度
Args:
transfer_id: 传输ID
progress: 进度百分比0-100
Returns:
更新是否成功
"""
query = "UPDATE file_transfers SET progress = %s WHERE transfer_id = %s"
try:
rows = await self.db.execute(query, (progress, transfer_id))
return rows > 0
except Exception as e:
logger.error(f"Failed to update transfer progress: {e}")
return False
async def complete(self, transfer_id: str) -> bool:
"""
标记传输完成
Args:
transfer_id: 传输ID
Returns:
更新是否成功
"""
query = """
UPDATE file_transfers SET status = %s, progress = 100, end_time = %s
WHERE transfer_id = %s
"""
try:
rows = await self.db.execute(query, (
TransferStatus.COMPLETED.value, datetime.now(), transfer_id
))
return rows > 0
except Exception as e:
logger.error(f"Failed to complete transfer: {e}")
return False
async def fail(self, transfer_id: str) -> bool:
"""
标记传输失败
Args:
transfer_id: 传输ID
Returns:
更新是否成功
"""
query = """
UPDATE file_transfers SET status = %s, end_time = %s
WHERE transfer_id = %s
"""
try:
rows = await self.db.execute(query, (
TransferStatus.FAILED.value, datetime.now(), transfer_id
))
return rows > 0
except Exception as e:
logger.error(f"Failed to mark transfer as failed: {e}")
return False
async def get_user_transfers(self, user_id: str,
status: TransferStatus = None) -> List[FileTransferRecord]:
"""
获取用户的文件传输记录
Args:
user_id: 用户ID
status: 过滤状态(可选)
Returns:
文件传输记录列表
"""
if status:
query = """
SELECT transfer_id, file_name, file_size, file_hash, sender_id,
receiver_id, status, progress, start_time, end_time
FROM file_transfers
WHERE (sender_id = %s OR receiver_id = %s) AND status = %s
ORDER BY start_time DESC
"""
rows = await self.db.fetch_all_dict(query, (user_id, user_id, status.value))
else:
query = """
SELECT transfer_id, file_name, file_size, file_hash, sender_id,
receiver_id, status, progress, start_time, end_time
FROM file_transfers
WHERE sender_id = %s OR receiver_id = %s
ORDER BY start_time DESC
"""
rows = await self.db.fetch_all_dict(query, (user_id, user_id))
return [self._row_to_record(row) for row in rows]
async def get_pending_transfers(self, user_id: str) -> List[FileTransferRecord]:
"""
获取用户待处理的传输
Args:
user_id: 用户ID
Returns:
待处理的文件传输记录列表
"""
query = """
SELECT transfer_id, file_name, file_size, file_hash, sender_id,
receiver_id, status, progress, start_time, end_time
FROM file_transfers
WHERE (sender_id = %s OR receiver_id = %s)
AND status IN ('pending', 'in_progress', 'paused')
ORDER BY start_time DESC
"""
rows = await self.db.fetch_all_dict(query, (user_id, user_id))
return [self._row_to_record(row) for row in rows]
async def delete(self, transfer_id: str) -> bool:
"""
删除传输记录
Args:
transfer_id: 传输ID
Returns:
删除是否成功
"""
query = "DELETE FROM file_transfers WHERE transfer_id = %s"
try:
rows = await self.db.execute(query, (transfer_id,))
return rows > 0
except Exception as e:
logger.error(f"Failed to delete transfer record: {e}")
return False
def _row_to_record(self, row: dict) -> FileTransferRecord:
"""将数据库行转换为FileTransferRecord对象"""
return FileTransferRecord(
transfer_id=row['transfer_id'],
file_name=row['file_name'],
file_size=row['file_size'],
file_hash=row['file_hash'],
sender_id=row['sender_id'],
receiver_id=row['receiver_id'],
status=TransferStatus(row['status']),
progress=float(row['progress']),
start_time=row['start_time'],
end_time=row['end_time'],
)
class OfflineMessageRepository:
"""离线消息数据访问层"""
def __init__(self, db: DatabaseManager):
"""
初始化离线消息仓库
Args:
db: 数据库管理器实例
"""
self.db = db
async def cache(self, user_id: str, message: Message) -> bool:
"""
缓存离线消息
Args:
user_id: 目标用户ID
message: 消息对象
Returns:
缓存是否成功
"""
query = """
INSERT INTO offline_messages (user_id, message_data)
VALUES (%s, %s)
"""
try:
message_data = message.serialize()
await self.db.execute(query, (user_id, message_data))
logger.debug(f"Offline message cached for user: {user_id}")
return True
except Exception as e:
logger.error(f"Failed to cache offline message: {e}")
return False
async def get_messages(self, user_id: str) -> List[Message]:
"""
获取用户的所有离线消息
Args:
user_id: 用户ID
Returns:
离线消息列表,按创建时间升序排列
"""
query = """
SELECT id, message_data FROM offline_messages
WHERE user_id = %s ORDER BY created_at ASC
"""
rows = await self.db.fetch_all_dict(query, (user_id,))
messages = []
for row in rows:
try:
message = Message.deserialize(row['message_data'])
messages.append(message)
except Exception as e:
logger.error(f"Failed to deserialize offline message: {e}")
return messages
async def get_messages_with_ids(self, user_id: str) -> List[tuple]:
"""
获取用户的所有离线消息及其ID
Args:
user_id: 用户ID
Returns:
(消息ID, 消息对象)元组列表
"""
query = """
SELECT id, message_data FROM offline_messages
WHERE user_id = %s ORDER BY created_at ASC
"""
rows = await self.db.fetch_all_dict(query, (user_id,))
results = []
for row in rows:
try:
message = Message.deserialize(row['message_data'])
results.append((row['id'], message))
except Exception as e:
logger.error(f"Failed to deserialize offline message: {e}")
return results
async def delete(self, message_id: int) -> bool:
"""
删除单条离线消息
Args:
message_id: 离线消息记录ID
Returns:
删除是否成功
"""
query = "DELETE FROM offline_messages WHERE id = %s"
try:
rows = await self.db.execute(query, (message_id,))
return rows > 0
except Exception as e:
logger.error(f"Failed to delete offline message: {e}")
return False
async def delete_all(self, user_id: str) -> int:
"""
删除用户的所有离线消息
Args:
user_id: 用户ID
Returns:
删除的消息数量
"""
query = "DELETE FROM offline_messages WHERE user_id = %s"
try:
return await self.db.execute(query, (user_id,))
except Exception as e:
logger.error(f"Failed to delete all offline messages: {e}")
return 0
async def count(self, user_id: str) -> int:
"""
获取用户的离线消息数量
Args:
user_id: 用户ID
Returns:
离线消息数量
"""
query = "SELECT COUNT(*) as count FROM offline_messages WHERE user_id = %s"
row = await self.db.fetch_dict(query, (user_id,))
return row['count'] if row else 0
async def cleanup_old_messages(self, max_age_seconds: int) -> int:
"""
清理过期的离线消息
Args:
max_age_seconds: 最大保留时间(秒)
Returns:
删除的消息数量
"""
query = """
DELETE FROM offline_messages
WHERE created_at < DATE_SUB(NOW(), INTERVAL %s SECOND)
"""
try:
return await self.db.execute(query, (max_age_seconds,))
except Exception as e:
logger.error(f"Failed to cleanup old offline messages: {e}")
return 0