|
|
# P2P Network Communication - Database Module
|
|
|
"""
|
|
|
数据库模块,提供MySQL数据库连接池管理和表初始化功能
|
|
|
使用aiomysql异步驱动实现高性能数据库操作
|
|
|
"""
|
|
|
|
|
|
import asyncio
|
|
|
import logging
|
|
|
from contextlib import asynccontextmanager
|
|
|
from typing import Optional, AsyncGenerator
|
|
|
|
|
|
import aiomysql
|
|
|
|
|
|
from config import ServerConfig, load_server_config
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class DatabaseManager:
|
|
|
"""数据库管理器,负责连接池管理和数据库初始化"""
|
|
|
|
|
|
def __init__(self, config: Optional[ServerConfig] = None):
|
|
|
"""
|
|
|
初始化数据库管理器
|
|
|
|
|
|
Args:
|
|
|
config: 服务器配置,如果为None则从环境变量加载
|
|
|
"""
|
|
|
self.config = config or load_server_config()
|
|
|
self._pool: Optional[aiomysql.Pool] = None
|
|
|
self._initialized = False
|
|
|
|
|
|
async def initialize(self) -> None:
|
|
|
"""初始化数据库连接池和表结构"""
|
|
|
if self._initialized:
|
|
|
return
|
|
|
|
|
|
await self._create_pool()
|
|
|
await self._create_tables()
|
|
|
self._initialized = True
|
|
|
logger.info("Database initialized successfully")
|
|
|
|
|
|
async def _create_pool(self) -> None:
|
|
|
"""创建数据库连接池"""
|
|
|
try:
|
|
|
self._pool = await aiomysql.create_pool(
|
|
|
host=self.config.db_host,
|
|
|
port=self.config.db_port,
|
|
|
user=self.config.db_user,
|
|
|
password=self.config.db_password,
|
|
|
db=self.config.db_name,
|
|
|
minsize=1,
|
|
|
maxsize=self.config.db_pool_size,
|
|
|
charset='utf8mb4',
|
|
|
autocommit=True,
|
|
|
echo=False,
|
|
|
)
|
|
|
logger.info(f"Database connection pool created: {self.config.db_host}:{self.config.db_port}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to create database pool: {e}")
|
|
|
raise
|
|
|
|
|
|
async def _create_tables(self) -> None:
|
|
|
"""创建数据库表结构"""
|
|
|
create_users_table = """
|
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
|
user_id VARCHAR(64) PRIMARY KEY,
|
|
|
username VARCHAR(100) UNIQUE NOT NULL,
|
|
|
display_name VARCHAR(200) NOT NULL,
|
|
|
password_hash VARCHAR(255) NOT NULL DEFAULT '',
|
|
|
public_key BLOB,
|
|
|
status VARCHAR(20) DEFAULT 'offline',
|
|
|
ip_address VARCHAR(45) DEFAULT '',
|
|
|
port INT DEFAULT 0,
|
|
|
last_seen TIMESTAMP NULL,
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
INDEX idx_username (username),
|
|
|
INDEX idx_status (status)
|
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
|
|
"""
|
|
|
|
|
|
create_messages_table = """
|
|
|
CREATE TABLE IF NOT EXISTS messages (
|
|
|
message_id VARCHAR(64) PRIMARY KEY,
|
|
|
sender_id VARCHAR(64) NOT NULL,
|
|
|
receiver_id VARCHAR(64) NOT NULL,
|
|
|
content_type VARCHAR(50) NOT NULL,
|
|
|
content TEXT,
|
|
|
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
is_read BOOLEAN DEFAULT FALSE,
|
|
|
is_delivered BOOLEAN DEFAULT FALSE,
|
|
|
INDEX idx_sender (sender_id),
|
|
|
INDEX idx_receiver (receiver_id),
|
|
|
INDEX idx_timestamp (timestamp),
|
|
|
INDEX idx_conversation (sender_id, receiver_id, timestamp)
|
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
|
|
"""
|
|
|
|
|
|
create_file_transfers_table = """
|
|
|
CREATE TABLE IF NOT EXISTS file_transfers (
|
|
|
transfer_id VARCHAR(64) PRIMARY KEY,
|
|
|
file_name VARCHAR(500) NOT NULL,
|
|
|
file_size BIGINT NOT NULL,
|
|
|
file_hash VARCHAR(128) NOT NULL,
|
|
|
sender_id VARCHAR(64) NOT NULL,
|
|
|
receiver_id VARCHAR(64) NOT NULL,
|
|
|
status VARCHAR(20) NOT NULL DEFAULT 'pending',
|
|
|
progress DECIMAL(5,2) DEFAULT 0,
|
|
|
start_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
end_time TIMESTAMP NULL,
|
|
|
INDEX idx_sender (sender_id),
|
|
|
INDEX idx_receiver (receiver_id),
|
|
|
INDEX idx_status (status)
|
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
|
|
"""
|
|
|
|
|
|
create_offline_messages_table = """
|
|
|
CREATE TABLE IF NOT EXISTS offline_messages (
|
|
|
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
|
|
user_id VARCHAR(64) NOT NULL,
|
|
|
message_data BLOB NOT NULL,
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
INDEX idx_user_id (user_id),
|
|
|
INDEX idx_created_at (created_at)
|
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
|
|
"""
|
|
|
|
|
|
tables = [
|
|
|
("users", create_users_table),
|
|
|
("messages", create_messages_table),
|
|
|
("file_transfers", create_file_transfers_table),
|
|
|
("offline_messages", create_offline_messages_table),
|
|
|
]
|
|
|
|
|
|
async with self.acquire() as conn:
|
|
|
async with conn.cursor() as cursor:
|
|
|
for table_name, create_sql in tables:
|
|
|
try:
|
|
|
await cursor.execute(create_sql)
|
|
|
logger.info(f"Table '{table_name}' created or already exists")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to create table '{table_name}': {e}")
|
|
|
raise
|
|
|
|
|
|
@asynccontextmanager
|
|
|
async def acquire(self) -> AsyncGenerator[aiomysql.Connection, None]:
|
|
|
"""
|
|
|
获取数据库连接的上下文管理器
|
|
|
|
|
|
Yields:
|
|
|
数据库连接对象
|
|
|
"""
|
|
|
if not self._pool:
|
|
|
raise RuntimeError("Database pool not initialized. Call initialize() first.")
|
|
|
|
|
|
conn = await self._pool.acquire()
|
|
|
try:
|
|
|
yield conn
|
|
|
finally:
|
|
|
self._pool.release(conn)
|
|
|
|
|
|
async def execute(self, query: str, args: tuple = ()) -> int:
|
|
|
"""
|
|
|
执行SQL语句(INSERT, UPDATE, DELETE)
|
|
|
|
|
|
Args:
|
|
|
query: SQL语句
|
|
|
args: 参数元组
|
|
|
|
|
|
Returns:
|
|
|
受影响的行数
|
|
|
"""
|
|
|
async with self.acquire() as conn:
|
|
|
async with conn.cursor() as cursor:
|
|
|
await cursor.execute(query, args)
|
|
|
return cursor.rowcount
|
|
|
|
|
|
async def fetch_one(self, query: str, args: tuple = ()) -> Optional[tuple]:
|
|
|
"""
|
|
|
查询单条记录
|
|
|
|
|
|
Args:
|
|
|
query: SQL语句
|
|
|
args: 参数元组
|
|
|
|
|
|
Returns:
|
|
|
查询结果元组,如果没有结果返回None
|
|
|
"""
|
|
|
async with self.acquire() as conn:
|
|
|
async with conn.cursor() as cursor:
|
|
|
await cursor.execute(query, args)
|
|
|
return await cursor.fetchone()
|
|
|
|
|
|
async def fetch_all(self, query: str, args: tuple = ()) -> list:
|
|
|
"""
|
|
|
查询多条记录
|
|
|
|
|
|
Args:
|
|
|
query: SQL语句
|
|
|
args: 参数元组
|
|
|
|
|
|
Returns:
|
|
|
查询结果列表
|
|
|
"""
|
|
|
async with self.acquire() as conn:
|
|
|
async with conn.cursor() as cursor:
|
|
|
await cursor.execute(query, args)
|
|
|
return await cursor.fetchall()
|
|
|
|
|
|
async def fetch_dict(self, query: str, args: tuple = ()) -> Optional[dict]:
|
|
|
"""
|
|
|
查询单条记录并返回字典
|
|
|
|
|
|
Args:
|
|
|
query: SQL语句
|
|
|
args: 参数元组
|
|
|
|
|
|
Returns:
|
|
|
查询结果字典,如果没有结果返回None
|
|
|
"""
|
|
|
async with self.acquire() as conn:
|
|
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|
|
await cursor.execute(query, args)
|
|
|
return await cursor.fetchone()
|
|
|
|
|
|
async def fetch_all_dict(self, query: str, args: tuple = ()) -> list:
|
|
|
"""
|
|
|
查询多条记录并返回字典列表
|
|
|
|
|
|
Args:
|
|
|
query: SQL语句
|
|
|
args: 参数元组
|
|
|
|
|
|
Returns:
|
|
|
查询结果字典列表
|
|
|
"""
|
|
|
async with self.acquire() as conn:
|
|
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|
|
await cursor.execute(query, args)
|
|
|
return await cursor.fetchall()
|
|
|
|
|
|
async def close(self) -> None:
|
|
|
"""关闭数据库连接池"""
|
|
|
if self._pool:
|
|
|
self._pool.close()
|
|
|
await self._pool.wait_closed()
|
|
|
self._pool = None
|
|
|
self._initialized = False
|
|
|
logger.info("Database connection pool closed")
|
|
|
|
|
|
@property
|
|
|
def is_initialized(self) -> bool:
|
|
|
"""检查数据库是否已初始化"""
|
|
|
return self._initialized
|
|
|
|
|
|
@property
|
|
|
def pool(self) -> Optional[aiomysql.Pool]:
|
|
|
"""获取连接池对象"""
|
|
|
return self._pool
|
|
|
|
|
|
|
|
|
# 全局数据库管理器实例
|
|
|
_db_manager: Optional[DatabaseManager] = None
|
|
|
|
|
|
|
|
|
async def get_database() -> DatabaseManager:
|
|
|
"""
|
|
|
获取全局数据库管理器实例
|
|
|
|
|
|
Returns:
|
|
|
DatabaseManager实例
|
|
|
"""
|
|
|
global _db_manager
|
|
|
if _db_manager is None:
|
|
|
_db_manager = DatabaseManager()
|
|
|
await _db_manager.initialize()
|
|
|
return _db_manager
|
|
|
|
|
|
|
|
|
async def close_database() -> None:
|
|
|
"""关闭全局数据库连接"""
|
|
|
global _db_manager
|
|
|
if _db_manager:
|
|
|
await _db_manager.close()
|
|
|
_db_manager = None
|