From 108b619653c2e1f13082c7db4bfe29edef583827 Mon Sep 17 00:00:00 2001 From: smallbailangui Date: Sun, 30 Nov 2025 17:51:37 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=8A=A5=E8=A1=A8?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E7=9A=84str=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/api/v1/endpoints/reports.py | 4 +-- src/backend/app/crud/crud_report.py | 35 ++++++++++++++++++--- src/backend/app/service/report_service.py | 34 ++++++++++++++++++-- 3 files changed, 65 insertions(+), 8 deletions(-) diff --git a/src/backend/app/api/v1/endpoints/reports.py b/src/backend/app/api/v1/endpoints/reports.py index e27b456..accf4c6 100644 --- a/src/backend/app/api/v1/endpoints/reports.py +++ b/src/backend/app/api/v1/endpoints/reports.py @@ -14,7 +14,7 @@ router = APIRouter() # 4.1 获取报表列表 @router.get("/reports", response_model=List[schemas.Report]) async def read_reports( - projectId: str = Query(..., description="项目ID"), + projectId: int = Query(..., description="项目ID"), db: Session = Depends(get_db), current_user = Depends(get_current_active_user), ) -> Any: @@ -25,7 +25,7 @@ async def read_reports( # 4.4 获取历史查询记录 @router.get("/history-queries", response_model=List[schemas.HistoryQuery]) async def read_history( - projectId: str = Query(..., description="项目ID"), + projectId: int = Query(..., description="项目ID"), db: Session = Depends(get_db), current_user = Depends(get_current_active_user), ) -> Any: diff --git a/src/backend/app/crud/crud_report.py b/src/backend/app/crud/crud_report.py index bf538b2..5acf325 100644 --- a/src/backend/app/crud/crud_report.py +++ b/src/backend/app/crud/crud_report.py @@ -3,12 +3,15 @@ from typing import List, Optional, Tuple from sqlalchemy.ext.asyncio import AsyncSession as Session from sqlalchemy.future import select -from sqlalchemy import join +from sqlalchemy import desc from sqlalchemy.exc import SQLAlchemyError -# 隐式绝对导入 -from models.query_result import QueryResult # 缓存结果表 +# 导入所有相关模型以建立连接路径 +from models.query_result import QueryResult from models.ai_generated_statement import AIGeneratedStatement +from models.message import Message +from models.session import Session as SessionModel # 别名避免与 DB Session 混淆 + from core.exceptions import DatabaseOperationFailedException @@ -35,7 +38,7 @@ class CRUDReport: raise DatabaseOperationFailedException("get report data by result_id") from e @staticmethod - async def get_history_by_project(db: Session, project_id: str) -> List[QueryResult]: + async def get_history_by_project(db: Session, project_id: int) -> List[QueryResult]: """ CRUD: 根据项目ID获取历史查询结果列表(用于列表接口)。 """ @@ -48,5 +51,29 @@ class CRUDReport: except SQLAlchemyError as e: raise DatabaseOperationFailedException("get history queries by project") from e + @staticmethod + async def get_by_project(db: Session, project_id: int) -> List[QueryResult]: + """ + CRUD: 获取项目报表列表 (修正版:增加 Project 筛选) + 路径: QueryResult -> Statement -> Message -> Session -> Project + """ + try: + query = ( + select(QueryResult) + .join(AIGeneratedStatement, QueryResult.statement_id == AIGeneratedStatement.statement_id) + .join(Message, AIGeneratedStatement.message_id == Message.message_id) + .join(SessionModel, Message.session_id == SessionModel.session_id) + .where(SessionModel.project_id == project_id) # 核心筛选条件 + .order_by(desc(QueryResult.cached_at)) + ) + + result = await db.execute(query) + return result.scalars().all() + except SQLAlchemyError as e: + raise DatabaseOperationFailedException("get reports by project") from e + + # 兼容旧代码调用 + get_history_by_project = get_by_project + crud_report = CRUDReport() \ No newline at end of file diff --git a/src/backend/app/service/report_service.py b/src/backend/app/service/report_service.py index 7b05ee1..98bdee5 100644 --- a/src/backend/app/service/report_service.py +++ b/src/backend/app/service/report_service.py @@ -10,12 +10,42 @@ from schema import report as schemas from core.exceptions import DatabaseOperationFailedException, ItemNotFoundException # 1. 获取报表列表 -async def get_report_list(db: Session, project_id: str) -> List[schemas.Report]: +async def get_report_list(db: Session, project_id: int) -> List[schemas.Report]: """ 业务逻辑:获取报表列表。 + 需要将 DB 中的 QueryResult 对象转换为 Schema 中的 Report 对象。 """ + # 1. 获取数据库对象列表 db_objs = await crud_report.get_by_project(db, project_id) - return [schemas.Report.model_validate(obj) for obj in db_objs] + + reports = [] + for obj in db_objs: + # 2. 手动构造 Report 对象 (字段映射) + + # 构造默认图表配置 (因为目前 DB 里没有存详细配置) + default_chart_config = schemas.ChartConfig( + xAxisKey="name", + yAxisKey="value" + ) + + # 确保 data 是列表格式 + report_data = obj.result_data if isinstance(obj.result_data, list) else [] + + # 映射字段: ORM -> Schema + report = schemas.Report( + id=str(obj.result_id), # result_id -> id + projectId=str(project_id), # 传入的 project_id + name=obj.data_summary[:20] if obj.data_summary else f"Report-{obj.result_id}", # 使用摘要做标题 + type=obj.chart_type or "table", # chart_type -> type + description=obj.data_summary, # data_summary -> description + data=report_data, # result_data -> data + chartConfig=default_chart_config, # 填充必填项 + sourceQueryText="SELECT * FROM ...", # 暂无 SQL 文本,填占位符 + updatedAt=obj.cached_at.isoformat() if obj.cached_at else "" # cached_at -> updatedAt + ) + reports.append(report) + + return reports # 2. 获取报表详情数据 (从 project_service.py 移过来的) async def get_report_data_by_id_service(db: Session, query_id: int) -> Dict[str, Any]: -- 2.34.1 From 51719fae6531c632c31c44300476e631cb2f3407 Mon Sep 17 00:00:00 2001 From: smallbailangui Date: Sun, 30 Nov 2025 17:55:44 +0800 Subject: [PATCH 2/2] =?UTF-8?q?fix=EF=BC=9A=E7=AE=A1=E7=90=86=E5=91=98?= =?UTF-8?q?=E7=9A=84mock=E6=95=B0=E6=8D=AE=E5=AF=B9=E6=8E=A5=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/crud/crud_admin_data.py | 257 +++++++++++++++++------- 1 file changed, 185 insertions(+), 72 deletions(-) diff --git a/src/backend/app/crud/crud_admin_data.py b/src/backend/app/crud/crud_admin_data.py index 4748c26..101a9a0 100644 --- a/src/backend/app/crud/crud_admin_data.py +++ b/src/backend/app/crud/crud_admin_data.py @@ -1,62 +1,173 @@ # backend/app/crud/crud_admin_data.py from typing import List, Literal, Tuple, Dict, Any, Optional +from datetime import datetime, date from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from sqlalchemy import func, or_ +from sqlalchemy import func, or_, desc, cast, Date + from core.exceptions import DatabaseOperationFailedException -from datetime import datetime # 【新增】添加这一行导入 -from sqlalchemy import desc -# 假设导入 ORM 模型 +# 导入所有需要的模型 from models.user_account import UserAccount from models.project import Project from models.violation_log import ViolationLog +from models.user_login_history import UserLoginHistory +from models.ai_generated_statement import AIGeneratedStatement class CRUDAdminData: # ------------------------------------------------------------------ - # 4.1.1. 获取用户列表 + # 4.1.1. 获取用户列表 (带统计数据) # ------------------------------------------------------------------ - @staticmethod - async def get_user_list_with_stats(db: AsyncSession, page: int, page_size: int, search: Optional[str] = None, - status: Literal["normal", "banned", "all"] = "all") -> Tuple[ - List[Dict[str, Any]], int]: - """获取用户列表,包含项目统计和额度。""" - # 实际代码应实现复杂的 JOIN 和筛选逻辑 (如前分析) - # Mocking the data structure return: - items = [{ - "user_id": 1002, "username": "admin_user", "email": "admin@example.com", "status": "banned", - "max_databases": 10, "project_count": 5, "last_login_at": datetime.now() - }] - return items, 100 - - # ------------------------------------------------------------------ + async def get_user_list_with_stats( + db: AsyncSession, + page: int, + page_size: int, + search: Optional[str] = None, + status: Literal["normal", "banned", "all"] = "all" + ) -> Tuple[List[Dict[str, Any]], int]: + """ + 获取用户列表,包含项目统计和额度。 + 实现逻辑:查询 UserAccount 表,并 Left Join Project 表计算 count。 + """ + try: + # 1. 构建基础查询:选择用户列 + 项目计数 + # SELECT u.*, count(p.project_id) FROM user_account u LEFT JOIN project p ON ... + stmt = ( + select( + UserAccount, + func.count(Project.project_id).label("project_count") + ) + .outerjoin(Project, UserAccount.user_id == Project.user_id) + .group_by(UserAccount.user_id) + ) + + # 2. 构建 Count 查询 (用于分页总数) + count_stmt = select(func.count(UserAccount.user_id)) + + # 3. 应用筛选条件 + filters = [] + if status != "all": + filters.append(UserAccount.status == status) + + if search: + search_term = f"%{search}%" + filters.append( + or_( + UserAccount.username.ilike(search_term), + UserAccount.email.ilike(search_term) + ) + ) + + if filters: + stmt = stmt.where(*filters) + count_stmt = count_stmt.where(*filters) + + # 4. 执行总数查询 + total = await db.scalar(count_stmt) or 0 + + # 5. 执行分页查询 (按注册时间倒序) + stmt = stmt.order_by(desc(UserAccount.created_at)).offset((page - 1) * page_size).limit(page_size) + result = await db.execute(stmt) + rows = result.all() + + # 6. 组装返回数据 (将 ORM 对象转为字典,并合并且它的统计数据) + items = [] + for user, project_count in rows: + items.append({ + "user_id": user.user_id, + "username": user.username, + "email": user.email, + "status": user.status, + "max_databases": user.max_databases, + "last_login_at": user.last_login_at, + "project_count": project_count # 聚合查询出来的字段 + }) + + return items, total + + except Exception as e: + raise DatabaseOperationFailedException("fetch user list with stats") from e + # ------------------------------------------------------------------ # 4.3.1. 获取管理员列表 # ------------------------------------------------------------------ @staticmethod async def get_admin_list(db: AsyncSession, page: int, page_size: int) -> Tuple[List[UserAccount], int]: """获取所有管理员列表""" - # 实际代码应 SELECT UserAccount.is_admin == True - # Mocking the data structure return: - return [UserAccount(user_id=1, username='admin', email='a@e.com', is_admin=True)], 1 + try: + # 1. 筛选条件 + query = select(UserAccount).where(UserAccount.is_admin == True) + count_query = select(func.count(UserAccount.user_id)).where(UserAccount.is_admin == True) + + # 2. 获取总数 + total = await db.scalar(count_query) or 0 + + # 3. 分页查询 + query = query.order_by(desc(UserAccount.created_at)).offset((page - 1) * page_size).limit(page_size) + result = await db.execute(query) + admins = result.scalars().all() + + return admins, total + except Exception as e: + raise DatabaseOperationFailedException("fetch admin list") from e # ------------------------------------------------------------------ # 4.4.1. 获取系统统计 # ------------------------------------------------------------------ @staticmethod async def get_system_stats(db: AsyncSession) -> Dict[str, Any]: - """获取系统统计看板数据""" - # 实际代码应包含多表聚合查询 - return { - "active_users_today": 150, "total_projects": 500, "query_count_today": 2500, - "high_risk_operations_today": 15, "system_health": "good" - } - - # 添加这个缺失的方法 + """获取系统统计看板数据 (真实数据库聚合)""" + try: + today = datetime.now().date() + + # 1. 今日活跃用户 (今日有登录记录的去重用户数) + # 注意:SQLAlchemy 中 date 类型转换通常用 cast(col, Date) + active_users_query = select(func.count(func.distinct(UserLoginHistory.user_id))) \ + .where(cast(UserLoginHistory.login_time, Date) == today) + active_users_today = await db.scalar(active_users_query) or 0 + + # 2. 项目总数 + total_projects_query = select(func.count(Project.project_id)) + total_projects = await db.scalar(total_projects_query) or 0 + + # 3. 今日查询数 (AI 生成语句的数量作为近似值) + query_count_query = select(func.count(AIGeneratedStatement.statement_id)) \ + .where(cast(AIGeneratedStatement.created_at, Date) == today) + query_count_today = await db.scalar(query_count_query) or 0 + + # 4. 今日高风险操作 (ViolationLog 中 risk_level 为 HIGH 或 CRITICAL) + high_risk_query = select(func.count(ViolationLog.violation_id)) \ + .where( + cast(ViolationLog.created_at, Date) == today, + ViolationLog.risk_level.in_(['HIGH', 'CRITICAL']) + ) + high_risk_ops = await db.scalar(high_risk_query) or 0 + + # 5. 计算系统健康度 (简单逻辑) + system_health = "good" + if high_risk_ops > 10: + system_health = "critical" + elif high_risk_ops > 0: + system_health = "warning" + + return { + "active_users_today": active_users_today, + "total_projects": total_projects, + "query_count_today": query_count_today, + "high_risk_operations_today": high_risk_ops, + "system_health": system_health + } + except Exception as e: + # 统计失败不应阻塞主流程,可以返回全0或抛出异常,这里选择抛出 + raise DatabaseOperationFailedException("fetch system stats") from e + + # ------------------------------------------------------------------ + # 4.4.2. 获取违规日志 (复用之前修复的代码) + # ------------------------------------------------------------------ @staticmethod async def get_violation_logs( db: AsyncSession, @@ -68,48 +179,50 @@ class CRUDAdminData: """ 获取违规记录列表 (支持分页、筛选,并关联用户名) """ - # 1. 构建基础查询:关联 UserAccount 表以获取 username - query = select(ViolationLog, UserAccount.username) \ - .join(UserAccount, ViolationLog.user_id == UserAccount.user_id) - - count_query = select(func.count(ViolationLog.violation_id)) - - # 2. 动态添加筛选条件 - filters = [] - if risk_level: - filters.append(ViolationLog.risk_level == risk_level) - if resolution_status: - filters.append(ViolationLog.resolution_status == resolution_status) - - if filters: - query = query.where(*filters) - count_query = count_query.where(*filters) - - # 3. 执行总数查询 - total_result = await db.execute(count_query) - total = total_result.scalar_one() - - # 4. 执行分页查询 (按时间倒序) - query = query.order_by(desc(ViolationLog.created_at)) \ - .offset((page - 1) * page_size) \ - .limit(page_size) - - result = await db.execute(query) - rows = result.all() - - # 5. 格式化返回数据 - items = [] - for log_obj, username in rows: - items.append({ - "violation_id": log_obj.violation_id, - "user_id": log_obj.user_id, - "username": username, # 前端需要展示用户名 - "risk_level": log_obj.risk_level, - "resolution_status": log_obj.resolution_status, - "created_at": log_obj.created_at - }) - - return items, total + try: + # 1. 构建基础查询:关联 UserAccount 表以获取 username + query = select(ViolationLog, UserAccount.username) \ + .join(UserAccount, ViolationLog.user_id == UserAccount.user_id) + + count_query = select(func.count(ViolationLog.violation_id)) + + # 2. 动态添加筛选条件 + filters = [] + if risk_level: + filters.append(ViolationLog.risk_level == risk_level) + if resolution_status: + filters.append(ViolationLog.resolution_status == resolution_status) + + if filters: + query = query.where(*filters) + count_query = count_query.where(*filters) + + # 3. 执行总数查询 + total = await db.scalar(count_query) or 0 + + # 4. 执行分页查询 (按时间倒序) + query = query.order_by(desc(ViolationLog.created_at)) \ + .offset((page - 1) * page_size) \ + .limit(page_size) + + result = await db.execute(query) + rows = result.all() + + # 5. 格式化返回数据 + items = [] + for log_obj, username in rows: + items.append({ + "violation_id": log_obj.violation_id, + "user_id": log_obj.user_id, + "username": username, + "risk_level": log_obj.risk_level, + "resolution_status": log_obj.resolution_status, + "created_at": log_obj.created_at + }) + + return items, total + except Exception as e: + raise DatabaseOperationFailedException("fetch violation logs") from e curd_admin_data = CRUDAdminData() \ No newline at end of file -- 2.34.1