You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

286 lines
9.3 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# P2P Network Communication - Database Module
"""
数据库模块提供MySQL数据库连接池管理和表初始化功能
使用aiomysql异步驱动实现高性能数据库操作
"""
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import Optional, AsyncGenerator
import aiomysql
from config import ServerConfig, load_server_config
logger = logging.getLogger(__name__)
class DatabaseManager:
"""数据库管理器,负责连接池管理和数据库初始化"""
def __init__(self, config: Optional[ServerConfig] = None):
"""
初始化数据库管理器
Args:
config: 服务器配置如果为None则从环境变量加载
"""
self.config = config or load_server_config()
self._pool: Optional[aiomysql.Pool] = None
self._initialized = False
async def initialize(self) -> None:
"""初始化数据库连接池和表结构"""
if self._initialized:
return
await self._create_pool()
await self._create_tables()
self._initialized = True
logger.info("Database initialized successfully")
async def _create_pool(self) -> None:
"""创建数据库连接池"""
try:
self._pool = await aiomysql.create_pool(
host=self.config.db_host,
port=self.config.db_port,
user=self.config.db_user,
password=self.config.db_password,
db=self.config.db_name,
minsize=1,
maxsize=self.config.db_pool_size,
charset='utf8mb4',
autocommit=True,
echo=False,
)
logger.info(f"Database connection pool created: {self.config.db_host}:{self.config.db_port}")
except Exception as e:
logger.error(f"Failed to create database pool: {e}")
raise
async def _create_tables(self) -> None:
"""创建数据库表结构"""
create_users_table = """
CREATE TABLE IF NOT EXISTS users (
user_id VARCHAR(64) PRIMARY KEY,
username VARCHAR(100) UNIQUE NOT NULL,
display_name VARCHAR(200) NOT NULL,
password_hash VARCHAR(255) NOT NULL DEFAULT '',
public_key BLOB,
status VARCHAR(20) DEFAULT 'offline',
ip_address VARCHAR(45) DEFAULT '',
port INT DEFAULT 0,
last_seen TIMESTAMP NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_username (username),
INDEX idx_status (status)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
"""
create_messages_table = """
CREATE TABLE IF NOT EXISTS messages (
message_id VARCHAR(64) PRIMARY KEY,
sender_id VARCHAR(64) NOT NULL,
receiver_id VARCHAR(64) NOT NULL,
content_type VARCHAR(50) NOT NULL,
content TEXT,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_read BOOLEAN DEFAULT FALSE,
is_delivered BOOLEAN DEFAULT FALSE,
INDEX idx_sender (sender_id),
INDEX idx_receiver (receiver_id),
INDEX idx_timestamp (timestamp),
INDEX idx_conversation (sender_id, receiver_id, timestamp)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
"""
create_file_transfers_table = """
CREATE TABLE IF NOT EXISTS file_transfers (
transfer_id VARCHAR(64) PRIMARY KEY,
file_name VARCHAR(500) NOT NULL,
file_size BIGINT NOT NULL,
file_hash VARCHAR(128) NOT NULL,
sender_id VARCHAR(64) NOT NULL,
receiver_id VARCHAR(64) NOT NULL,
status VARCHAR(20) NOT NULL DEFAULT 'pending',
progress DECIMAL(5,2) DEFAULT 0,
start_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
end_time TIMESTAMP NULL,
INDEX idx_sender (sender_id),
INDEX idx_receiver (receiver_id),
INDEX idx_status (status)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
"""
create_offline_messages_table = """
CREATE TABLE IF NOT EXISTS offline_messages (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
message_data BLOB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_user_id (user_id),
INDEX idx_created_at (created_at)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
"""
tables = [
("users", create_users_table),
("messages", create_messages_table),
("file_transfers", create_file_transfers_table),
("offline_messages", create_offline_messages_table),
]
async with self.acquire() as conn:
async with conn.cursor() as cursor:
for table_name, create_sql in tables:
try:
await cursor.execute(create_sql)
logger.info(f"Table '{table_name}' created or already exists")
except Exception as e:
logger.error(f"Failed to create table '{table_name}': {e}")
raise
@asynccontextmanager
async def acquire(self) -> AsyncGenerator[aiomysql.Connection, None]:
"""
获取数据库连接的上下文管理器
Yields:
数据库连接对象
"""
if not self._pool:
raise RuntimeError("Database pool not initialized. Call initialize() first.")
conn = await self._pool.acquire()
try:
yield conn
finally:
self._pool.release(conn)
async def execute(self, query: str, args: tuple = ()) -> int:
"""
执行SQL语句INSERT, UPDATE, DELETE
Args:
query: SQL语句
args: 参数元组
Returns:
受影响的行数
"""
async with self.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, args)
return cursor.rowcount
async def fetch_one(self, query: str, args: tuple = ()) -> Optional[tuple]:
"""
查询单条记录
Args:
query: SQL语句
args: 参数元组
Returns:
查询结果元组如果没有结果返回None
"""
async with self.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, args)
return await cursor.fetchone()
async def fetch_all(self, query: str, args: tuple = ()) -> list:
"""
查询多条记录
Args:
query: SQL语句
args: 参数元组
Returns:
查询结果列表
"""
async with self.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, args)
return await cursor.fetchall()
async def fetch_dict(self, query: str, args: tuple = ()) -> Optional[dict]:
"""
查询单条记录并返回字典
Args:
query: SQL语句
args: 参数元组
Returns:
查询结果字典如果没有结果返回None
"""
async with self.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(query, args)
return await cursor.fetchone()
async def fetch_all_dict(self, query: str, args: tuple = ()) -> list:
"""
查询多条记录并返回字典列表
Args:
query: SQL语句
args: 参数元组
Returns:
查询结果字典列表
"""
async with self.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(query, args)
return await cursor.fetchall()
async def close(self) -> None:
"""关闭数据库连接池"""
if self._pool:
self._pool.close()
await self._pool.wait_closed()
self._pool = None
self._initialized = False
logger.info("Database connection pool closed")
@property
def is_initialized(self) -> bool:
"""检查数据库是否已初始化"""
return self._initialized
@property
def pool(self) -> Optional[aiomysql.Pool]:
"""获取连接池对象"""
return self._pool
# 全局数据库管理器实例
_db_manager: Optional[DatabaseManager] = None
async def get_database() -> DatabaseManager:
"""
获取全局数据库管理器实例
Returns:
DatabaseManager实例
"""
global _db_manager
if _db_manager is None:
_db_manager = DatabaseManager()
await _db_manager.initialize()
return _db_manager
async def close_database() -> None:
"""关闭全局数据库连接"""
global _db_manager
if _db_manager:
await _db_manager.close()
_db_manager = None