# P2P Network Communication - Database Module """ 数据库模块,提供MySQL数据库连接池管理和表初始化功能 使用aiomysql异步驱动实现高性能数据库操作 """ import asyncio import logging from contextlib import asynccontextmanager from typing import Optional, AsyncGenerator import aiomysql from config import ServerConfig, load_server_config logger = logging.getLogger(__name__) class DatabaseManager: """数据库管理器,负责连接池管理和数据库初始化""" def __init__(self, config: Optional[ServerConfig] = None): """ 初始化数据库管理器 Args: config: 服务器配置,如果为None则从环境变量加载 """ self.config = config or load_server_config() self._pool: Optional[aiomysql.Pool] = None self._initialized = False async def initialize(self) -> None: """初始化数据库连接池和表结构""" if self._initialized: return await self._create_pool() await self._create_tables() self._initialized = True logger.info("Database initialized successfully") async def _create_pool(self) -> None: """创建数据库连接池""" try: self._pool = await aiomysql.create_pool( host=self.config.db_host, port=self.config.db_port, user=self.config.db_user, password=self.config.db_password, db=self.config.db_name, minsize=1, maxsize=self.config.db_pool_size, charset='utf8mb4', autocommit=True, echo=False, ) logger.info(f"Database connection pool created: {self.config.db_host}:{self.config.db_port}") except Exception as e: logger.error(f"Failed to create database pool: {e}") raise async def _create_tables(self) -> None: """创建数据库表结构""" create_users_table = """ CREATE TABLE IF NOT EXISTS users ( user_id VARCHAR(64) PRIMARY KEY, username VARCHAR(100) UNIQUE NOT NULL, display_name VARCHAR(200) NOT NULL, password_hash VARCHAR(255) NOT NULL DEFAULT '', public_key BLOB, status VARCHAR(20) DEFAULT 'offline', ip_address VARCHAR(45) DEFAULT '', port INT DEFAULT 0, last_seen TIMESTAMP NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, INDEX idx_username (username), INDEX idx_status (status) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; """ create_messages_table = """ CREATE TABLE IF NOT EXISTS messages ( message_id VARCHAR(64) PRIMARY KEY, sender_id VARCHAR(64) NOT NULL, receiver_id VARCHAR(64) NOT NULL, content_type VARCHAR(50) NOT NULL, content TEXT, timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, is_read BOOLEAN DEFAULT FALSE, is_delivered BOOLEAN DEFAULT FALSE, INDEX idx_sender (sender_id), INDEX idx_receiver (receiver_id), INDEX idx_timestamp (timestamp), INDEX idx_conversation (sender_id, receiver_id, timestamp) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; """ create_file_transfers_table = """ CREATE TABLE IF NOT EXISTS file_transfers ( transfer_id VARCHAR(64) PRIMARY KEY, file_name VARCHAR(500) NOT NULL, file_size BIGINT NOT NULL, file_hash VARCHAR(128) NOT NULL, sender_id VARCHAR(64) NOT NULL, receiver_id VARCHAR(64) NOT NULL, status VARCHAR(20) NOT NULL DEFAULT 'pending', progress DECIMAL(5,2) DEFAULT 0, start_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, end_time TIMESTAMP NULL, INDEX idx_sender (sender_id), INDEX idx_receiver (receiver_id), INDEX idx_status (status) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; """ create_offline_messages_table = """ CREATE TABLE IF NOT EXISTS offline_messages ( id BIGINT AUTO_INCREMENT PRIMARY KEY, user_id VARCHAR(64) NOT NULL, message_data BLOB NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, INDEX idx_user_id (user_id), INDEX idx_created_at (created_at) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; """ tables = [ ("users", create_users_table), ("messages", create_messages_table), ("file_transfers", create_file_transfers_table), ("offline_messages", create_offline_messages_table), ] async with self.acquire() as conn: async with conn.cursor() as cursor: for table_name, create_sql in tables: try: await cursor.execute(create_sql) logger.info(f"Table '{table_name}' created or already exists") except Exception as e: logger.error(f"Failed to create table '{table_name}': {e}") raise @asynccontextmanager async def acquire(self) -> AsyncGenerator[aiomysql.Connection, None]: """ 获取数据库连接的上下文管理器 Yields: 数据库连接对象 """ if not self._pool: raise RuntimeError("Database pool not initialized. Call initialize() first.") conn = await self._pool.acquire() try: yield conn finally: self._pool.release(conn) async def execute(self, query: str, args: tuple = ()) -> int: """ 执行SQL语句(INSERT, UPDATE, DELETE) Args: query: SQL语句 args: 参数元组 Returns: 受影响的行数 """ async with self.acquire() as conn: async with conn.cursor() as cursor: await cursor.execute(query, args) return cursor.rowcount async def fetch_one(self, query: str, args: tuple = ()) -> Optional[tuple]: """ 查询单条记录 Args: query: SQL语句 args: 参数元组 Returns: 查询结果元组,如果没有结果返回None """ async with self.acquire() as conn: async with conn.cursor() as cursor: await cursor.execute(query, args) return await cursor.fetchone() async def fetch_all(self, query: str, args: tuple = ()) -> list: """ 查询多条记录 Args: query: SQL语句 args: 参数元组 Returns: 查询结果列表 """ async with self.acquire() as conn: async with conn.cursor() as cursor: await cursor.execute(query, args) return await cursor.fetchall() async def fetch_dict(self, query: str, args: tuple = ()) -> Optional[dict]: """ 查询单条记录并返回字典 Args: query: SQL语句 args: 参数元组 Returns: 查询结果字典,如果没有结果返回None """ async with self.acquire() as conn: async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(query, args) return await cursor.fetchone() async def fetch_all_dict(self, query: str, args: tuple = ()) -> list: """ 查询多条记录并返回字典列表 Args: query: SQL语句 args: 参数元组 Returns: 查询结果字典列表 """ async with self.acquire() as conn: async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(query, args) return await cursor.fetchall() async def close(self) -> None: """关闭数据库连接池""" if self._pool: self._pool.close() await self._pool.wait_closed() self._pool = None self._initialized = False logger.info("Database connection pool closed") @property def is_initialized(self) -> bool: """检查数据库是否已初始化""" return self._initialized @property def pool(self) -> Optional[aiomysql.Pool]: """获取连接池对象""" return self._pool # 全局数据库管理器实例 _db_manager: Optional[DatabaseManager] = None async def get_database() -> DatabaseManager: """ 获取全局数据库管理器实例 Returns: DatabaseManager实例 """ global _db_manager if _db_manager is None: _db_manager = DatabaseManager() await _db_manager.initialize() return _db_manager async def close_database() -> None: """关闭全局数据库连接""" global _db_manager if _db_manager: await _db_manager.close() _db_manager = None