测试对话管理与会话管理接口 #21

Merged
hnu202326010318 merged 2 commits from liguolin_branch into develop 1 month ago

@ -14,7 +14,7 @@ router = APIRouter()
# 4.1 获取报表列表
@router.get("/reports", response_model=List[schemas.Report])
async def read_reports(
projectId: str = Query(..., description="项目ID"),
projectId: int = Query(..., description="项目ID"),
db: Session = Depends(get_db),
current_user = Depends(get_current_active_user),
) -> Any:
@ -25,7 +25,7 @@ async def read_reports(
# 4.4 获取历史查询记录
@router.get("/history-queries", response_model=List[schemas.HistoryQuery])
async def read_history(
projectId: str = Query(..., description="项目ID"),
projectId: int = Query(..., description="项目ID"),
db: Session = Depends(get_db),
current_user = Depends(get_current_active_user),
) -> Any:

@ -1,62 +1,173 @@
# backend/app/crud/crud_admin_data.py
from typing import List, Literal, Tuple, Dict, Any, Optional
from datetime import datetime, date
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import func, or_
from sqlalchemy import func, or_, desc, cast, Date
from core.exceptions import DatabaseOperationFailedException
from datetime import datetime # 【新增】添加这一行导入
from sqlalchemy import desc
# 假设导入 ORM 模型
# 导入所有需要的模型
from models.user_account import UserAccount
from models.project import Project
from models.violation_log import ViolationLog
from models.user_login_history import UserLoginHistory
from models.ai_generated_statement import AIGeneratedStatement
class CRUDAdminData:
# ------------------------------------------------------------------
# 4.1.1. 获取用户列表
# 4.1.1. 获取用户列表 (带统计数据)
# ------------------------------------------------------------------
@staticmethod
async def get_user_list_with_stats(db: AsyncSession, page: int, page_size: int, search: Optional[str] = None,
status: Literal["normal", "banned", "all"] = "all") -> Tuple[
List[Dict[str, Any]], int]:
"""获取用户列表,包含项目统计和额度。"""
# 实际代码应实现复杂的 JOIN 和筛选逻辑 (如前分析)
# Mocking the data structure return:
items = [{
"user_id": 1002, "username": "admin_user", "email": "admin@example.com", "status": "banned",
"max_databases": 10, "project_count": 5, "last_login_at": datetime.now()
}]
return items, 100
# ------------------------------------------------------------------
async def get_user_list_with_stats(
db: AsyncSession,
page: int,
page_size: int,
search: Optional[str] = None,
status: Literal["normal", "banned", "all"] = "all"
) -> Tuple[List[Dict[str, Any]], int]:
"""
获取用户列表包含项目统计和额度
实现逻辑查询 UserAccount Left Join Project 表计算 count
"""
try:
# 1. 构建基础查询:选择用户列 + 项目计数
# SELECT u.*, count(p.project_id) FROM user_account u LEFT JOIN project p ON ...
stmt = (
select(
UserAccount,
func.count(Project.project_id).label("project_count")
)
.outerjoin(Project, UserAccount.user_id == Project.user_id)
.group_by(UserAccount.user_id)
)
# 2. 构建 Count 查询 (用于分页总数)
count_stmt = select(func.count(UserAccount.user_id))
# 3. 应用筛选条件
filters = []
if status != "all":
filters.append(UserAccount.status == status)
if search:
search_term = f"%{search}%"
filters.append(
or_(
UserAccount.username.ilike(search_term),
UserAccount.email.ilike(search_term)
)
)
if filters:
stmt = stmt.where(*filters)
count_stmt = count_stmt.where(*filters)
# 4. 执行总数查询
total = await db.scalar(count_stmt) or 0
# 5. 执行分页查询 (按注册时间倒序)
stmt = stmt.order_by(desc(UserAccount.created_at)).offset((page - 1) * page_size).limit(page_size)
result = await db.execute(stmt)
rows = result.all()
# 6. 组装返回数据 (将 ORM 对象转为字典,并合并且它的统计数据)
items = []
for user, project_count in rows:
items.append({
"user_id": user.user_id,
"username": user.username,
"email": user.email,
"status": user.status,
"max_databases": user.max_databases,
"last_login_at": user.last_login_at,
"project_count": project_count # 聚合查询出来的字段
})
return items, total
except Exception as e:
raise DatabaseOperationFailedException("fetch user list with stats") from e
# ------------------------------------------------------------------
# 4.3.1. 获取管理员列表
# ------------------------------------------------------------------
@staticmethod
async def get_admin_list(db: AsyncSession, page: int, page_size: int) -> Tuple[List[UserAccount], int]:
"""获取所有管理员列表"""
# 实际代码应 SELECT UserAccount.is_admin == True
# Mocking the data structure return:
return [UserAccount(user_id=1, username='admin', email='a@e.com', is_admin=True)], 1
try:
# 1. 筛选条件
query = select(UserAccount).where(UserAccount.is_admin == True)
count_query = select(func.count(UserAccount.user_id)).where(UserAccount.is_admin == True)
# 2. 获取总数
total = await db.scalar(count_query) or 0
# 3. 分页查询
query = query.order_by(desc(UserAccount.created_at)).offset((page - 1) * page_size).limit(page_size)
result = await db.execute(query)
admins = result.scalars().all()
return admins, total
except Exception as e:
raise DatabaseOperationFailedException("fetch admin list") from e
# ------------------------------------------------------------------
# 4.4.1. 获取系统统计
# ------------------------------------------------------------------
@staticmethod
async def get_system_stats(db: AsyncSession) -> Dict[str, Any]:
"""获取系统统计看板数据"""
# 实际代码应包含多表聚合查询
return {
"active_users_today": 150, "total_projects": 500, "query_count_today": 2500,
"high_risk_operations_today": 15, "system_health": "good"
}
# 添加这个缺失的方法
"""获取系统统计看板数据 (真实数据库聚合)"""
try:
today = datetime.now().date()
# 1. 今日活跃用户 (今日有登录记录的去重用户数)
# 注意SQLAlchemy 中 date 类型转换通常用 cast(col, Date)
active_users_query = select(func.count(func.distinct(UserLoginHistory.user_id))) \
.where(cast(UserLoginHistory.login_time, Date) == today)
active_users_today = await db.scalar(active_users_query) or 0
# 2. 项目总数
total_projects_query = select(func.count(Project.project_id))
total_projects = await db.scalar(total_projects_query) or 0
# 3. 今日查询数 (AI 生成语句的数量作为近似值)
query_count_query = select(func.count(AIGeneratedStatement.statement_id)) \
.where(cast(AIGeneratedStatement.created_at, Date) == today)
query_count_today = await db.scalar(query_count_query) or 0
# 4. 今日高风险操作 (ViolationLog 中 risk_level 为 HIGH 或 CRITICAL)
high_risk_query = select(func.count(ViolationLog.violation_id)) \
.where(
cast(ViolationLog.created_at, Date) == today,
ViolationLog.risk_level.in_(['HIGH', 'CRITICAL'])
)
high_risk_ops = await db.scalar(high_risk_query) or 0
# 5. 计算系统健康度 (简单逻辑)
system_health = "good"
if high_risk_ops > 10:
system_health = "critical"
elif high_risk_ops > 0:
system_health = "warning"
return {
"active_users_today": active_users_today,
"total_projects": total_projects,
"query_count_today": query_count_today,
"high_risk_operations_today": high_risk_ops,
"system_health": system_health
}
except Exception as e:
# 统计失败不应阻塞主流程可以返回全0或抛出异常这里选择抛出
raise DatabaseOperationFailedException("fetch system stats") from e
# ------------------------------------------------------------------
# 4.4.2. 获取违规日志 (复用之前修复的代码)
# ------------------------------------------------------------------
@staticmethod
async def get_violation_logs(
db: AsyncSession,
@ -68,48 +179,50 @@ class CRUDAdminData:
"""
获取违规记录列表 (支持分页筛选并关联用户名)
"""
# 1. 构建基础查询:关联 UserAccount 表以获取 username
query = select(ViolationLog, UserAccount.username) \
.join(UserAccount, ViolationLog.user_id == UserAccount.user_id)
count_query = select(func.count(ViolationLog.violation_id))
# 2. 动态添加筛选条件
filters = []
if risk_level:
filters.append(ViolationLog.risk_level == risk_level)
if resolution_status:
filters.append(ViolationLog.resolution_status == resolution_status)
if filters:
query = query.where(*filters)
count_query = count_query.where(*filters)
# 3. 执行总数查询
total_result = await db.execute(count_query)
total = total_result.scalar_one()
# 4. 执行分页查询 (按时间倒序)
query = query.order_by(desc(ViolationLog.created_at)) \
.offset((page - 1) * page_size) \
.limit(page_size)
result = await db.execute(query)
rows = result.all()
# 5. 格式化返回数据
items = []
for log_obj, username in rows:
items.append({
"violation_id": log_obj.violation_id,
"user_id": log_obj.user_id,
"username": username, # 前端需要展示用户名
"risk_level": log_obj.risk_level,
"resolution_status": log_obj.resolution_status,
"created_at": log_obj.created_at
})
return items, total
try:
# 1. 构建基础查询:关联 UserAccount 表以获取 username
query = select(ViolationLog, UserAccount.username) \
.join(UserAccount, ViolationLog.user_id == UserAccount.user_id)
count_query = select(func.count(ViolationLog.violation_id))
# 2. 动态添加筛选条件
filters = []
if risk_level:
filters.append(ViolationLog.risk_level == risk_level)
if resolution_status:
filters.append(ViolationLog.resolution_status == resolution_status)
if filters:
query = query.where(*filters)
count_query = count_query.where(*filters)
# 3. 执行总数查询
total = await db.scalar(count_query) or 0
# 4. 执行分页查询 (按时间倒序)
query = query.order_by(desc(ViolationLog.created_at)) \
.offset((page - 1) * page_size) \
.limit(page_size)
result = await db.execute(query)
rows = result.all()
# 5. 格式化返回数据
items = []
for log_obj, username in rows:
items.append({
"violation_id": log_obj.violation_id,
"user_id": log_obj.user_id,
"username": username,
"risk_level": log_obj.risk_level,
"resolution_status": log_obj.resolution_status,
"created_at": log_obj.created_at
})
return items, total
except Exception as e:
raise DatabaseOperationFailedException("fetch violation logs") from e
curd_admin_data = CRUDAdminData()

@ -3,12 +3,15 @@
from typing import List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession as Session
from sqlalchemy.future import select
from sqlalchemy import join
from sqlalchemy import desc
from sqlalchemy.exc import SQLAlchemyError
# 隐式绝对导入
from models.query_result import QueryResult # 缓存结果表
# 导入所有相关模型以建立连接路径
from models.query_result import QueryResult
from models.ai_generated_statement import AIGeneratedStatement
from models.message import Message
from models.session import Session as SessionModel # 别名避免与 DB Session 混淆
from core.exceptions import DatabaseOperationFailedException
@ -35,7 +38,7 @@ class CRUDReport:
raise DatabaseOperationFailedException("get report data by result_id") from e
@staticmethod
async def get_history_by_project(db: Session, project_id: str) -> List[QueryResult]:
async def get_history_by_project(db: Session, project_id: int) -> List[QueryResult]:
"""
CRUD: 根据项目ID获取历史查询结果列表用于列表接口
"""
@ -48,5 +51,29 @@ class CRUDReport:
except SQLAlchemyError as e:
raise DatabaseOperationFailedException("get history queries by project") from e
@staticmethod
async def get_by_project(db: Session, project_id: int) -> List[QueryResult]:
"""
CRUD: 获取项目报表列表 (修正版增加 Project 筛选)
路径: QueryResult -> Statement -> Message -> Session -> Project
"""
try:
query = (
select(QueryResult)
.join(AIGeneratedStatement, QueryResult.statement_id == AIGeneratedStatement.statement_id)
.join(Message, AIGeneratedStatement.message_id == Message.message_id)
.join(SessionModel, Message.session_id == SessionModel.session_id)
.where(SessionModel.project_id == project_id) # 核心筛选条件
.order_by(desc(QueryResult.cached_at))
)
result = await db.execute(query)
return result.scalars().all()
except SQLAlchemyError as e:
raise DatabaseOperationFailedException("get reports by project") from e
# 兼容旧代码调用
get_history_by_project = get_by_project
crud_report = CRUDReport()

@ -10,12 +10,42 @@ from schema import report as schemas
from core.exceptions import DatabaseOperationFailedException, ItemNotFoundException
# 1. 获取报表列表
async def get_report_list(db: Session, project_id: str) -> List[schemas.Report]:
async def get_report_list(db: Session, project_id: int) -> List[schemas.Report]:
"""
业务逻辑获取报表列表
需要将 DB 中的 QueryResult 对象转换为 Schema 中的 Report 对象
"""
# 1. 获取数据库对象列表
db_objs = await crud_report.get_by_project(db, project_id)
return [schemas.Report.model_validate(obj) for obj in db_objs]
reports = []
for obj in db_objs:
# 2. 手动构造 Report 对象 (字段映射)
# 构造默认图表配置 (因为目前 DB 里没有存详细配置)
default_chart_config = schemas.ChartConfig(
xAxisKey="name",
yAxisKey="value"
)
# 确保 data 是列表格式
report_data = obj.result_data if isinstance(obj.result_data, list) else []
# 映射字段: ORM -> Schema
report = schemas.Report(
id=str(obj.result_id), # result_id -> id
projectId=str(project_id), # 传入的 project_id
name=obj.data_summary[:20] if obj.data_summary else f"Report-{obj.result_id}", # 使用摘要做标题
type=obj.chart_type or "table", # chart_type -> type
description=obj.data_summary, # data_summary -> description
data=report_data, # result_data -> data
chartConfig=default_chart_config, # 填充必填项
sourceQueryText="SELECT * FROM ...", # 暂无 SQL 文本,填占位符
updatedAt=obj.cached_at.isoformat() if obj.cached_at else "" # cached_at -> updatedAt
)
reports.append(report)
return reports
# 2. 获取报表详情数据 (从 project_service.py 移过来的)
async def get_report_data_by_id_service(db: Session, query_id: int) -> Dict[str, Any]:

Loading…
Cancel
Save