feat: 修复报表功能和对话功能 #28

Merged
hnu202326010328 merged 1 commits from wanglirong_branch into develop 2 months ago

@ -1,50 +1,93 @@
# backend/app/api/v1/endpoints/reports.py
from fastapi import APIRouter, Depends, Query
from fastapi import APIRouter, Depends, Query, Path, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from schema.report import ReportCreate, Report, HistoryQuery, UpdateChartType
from typing import List
# 导入正确的 ReportUpdate Schema
from schema.report import ReportCreate, Report, HistoryQuery, ReportUpdate
from api.v1.deps import get_db
from service import report_service
router = APIRouter()
# 4.1 获取报表列表
@router.get("/reports", response_model=list[Report])
# -------------------------------------------
# 1. 获取报表列表 (Read)
# -------------------------------------------
@router.get("/reports", response_model=List[Report])
async def read_reports(
projectId: int = Query(...),
projectId: int = Query(..., description="项目ID"),
db: AsyncSession = Depends(get_db),
):
"""
获取指定项目下的所有保存的报表配置
"""
return await report_service.get_report_list(db, projectId)
# 4.4 获取历史查询记录
@router.get("/history-queries", response_model=list[HistoryQuery])
# -------------------------------------------
# 2. 获取历史查询记录
# -------------------------------------------
@router.get("/history-queries", response_model=List[HistoryQuery])
async def read_history(
projectId: int = Query(...),
projectId: int = Query(..., description="项目ID"),
db: AsyncSession = Depends(get_db),
):
"""
获取历史查询结果用于作为创建新报表的数据源
"""
return await report_service.get_history_queries_service(db, projectId)
# 3.5.1 创建报表(实时渲染)
# -------------------------------------------
# 3. 创建报表 (Create)
# -------------------------------------------
@router.post("/projects/{project_id}/reports", response_model=Report)
async def create_report(
project_id: int,
body: ReportCreate,
db: AsyncSession = Depends(get_db),
):
"""
保存报表配置
"""
return await report_service.create_report_service(db, project_id, body)
# -------------------------------------------
# 4. 删除报表 (Delete)
# -------------------------------------------
@router.delete("/reports/{report_id}", response_model=bool)
async def delete_report(
report_id: int = Path(..., description="报表ID"),
db: AsyncSession = Depends(get_db),
):
"""
删除指定的报表配置
"""
success = await report_service.delete_report_service(db, report_id)
if not success:
raise HTTPException(status_code=404, detail="Report not found")
return True
# 3.5.2 修改图表类型(已删除该接口)
# -------------------------------------------
# 5. 修改报表信息 (Update) - [已修复]
# -------------------------------------------
@router.put("/reports/{report_id}", response_model=Report)
async def update_report(
body: ReportUpdate, # 注意Pydantic模型通常放在 Depends 前面
report_id: int = Path(..., description="报表ID"),
db: AsyncSession = Depends(get_db),
):
"""
修改报表配置例如修改图表类型修改标题
"""
return await report_service.update_report_service(db, report_id, body)
# 3.5.3 导出报表
# -------------------------------------------
# 6. 导出报表 (Export)
# -------------------------------------------
@router.get("/reports/{report_id}/export", response_model=dict)
async def export_report(
report_id: int,
format: str = Query("png"),
format: str = Query("png", regex="^(png|jpeg|pdf)$"),
db: AsyncSession = Depends(get_db)
):
return await report_service.export_report_service(db, report_id, format)
return await report_service.export_report_service(db, report_id, format)

@ -1,16 +1,28 @@
from sqlalchemy import Column, Integer, String, ForeignKey, DateTime, func
from sqlalchemy import Column, Integer, String, Text, ForeignKey, JSON, DateTime
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from core.database import Base
class Report(Base):
__tablename__ = "report"
class AnalysisReport(Base):
"""
报表配置表用于持久化保存用户对某个查询结果的展示配置
"""
__tablename__ = "analysis_report"
report_id = Column(Integer, primary_key=True, index=True)
project_id = Column(Integer, ForeignKey("project.project_id"), nullable=False)
report_name = Column(String(255), nullable=False)
query_id = Column(Integer, ForeignKey("query_history.query_id"), nullable=False)
chart_type = Column(String(50), nullable=False)
project_id = Column(Integer, ForeignKey("project.project_id", ondelete="CASCADE"), nullable=False)
# 关联的数据源(查询结果)
result_id = Column(Integer, ForeignKey("query_result.result_id", ondelete="CASCADE"), nullable=False)
# 用户自定义的配置
name = Column(String(100), nullable=False) # 报表名称
description = Column(Text, nullable=True)
chart_type = Column(String(50), default="table") # bar, line, pie, etc.
chart_config = Column(JSON, nullable=True) # 存储 {xAxisKey: "...", yAxisKey: "..."}
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
# 关联关系
# project = relationship("Project", back_populates="reports") # 需在Project model中添加对应关系
# query_result = relationship("QueryResult") # 需在QueryResult model中添加对应关系

@ -2,41 +2,49 @@
from pydantic import BaseModel, ConfigDict
from typing import List, Optional, Any, Dict
from datetime import datetime
# 图表配置
class ChartConfig(BaseModel):
xAxisKey: str
yAxisKey: str
# 创建报表
# 1. 创建报表请求
class ReportCreate(BaseModel):
report_name: str
query_id: int
query_id: int # 关联的 result_id
chart_type: str = "table"
description: Optional[str] = None
chartConfig: Optional[ChartConfig] = None # ← 新增
chartConfig: Optional[ChartConfig] = None
# 2. 修改报表请求 (新增)
class ReportUpdate(BaseModel):
report_name: Optional[str] = None
chart_type: Optional[str] = None
description: Optional[str] = None
chartConfig: Optional[ChartConfig] = None
# 单个报表响应
# 3. 报表响应 (调整为从 AnalysisReport 获取元数据,从 QueryResult 获取 Data)
class Report(BaseModel):
id: str
id: str # report_id
projectId: str
name: str
type: str
name: str # 用户自定义的名称
type: str # 用户保存的 chart_type
description: Optional[str]
data: List[Dict[str, Any]]
chartConfig: ChartConfig | None # ← 必须加回
# 以下数据来自关联的 QueryResult
data: List[Dict[str, Any]]
sourceQueryText: Optional[str]
chartConfig: Optional[ChartConfig]
updatedAt: str
model_config = ConfigDict(from_attributes=True)
# 历史查询
# 4. 历史查询 (保持不变)
class HistoryQuery(BaseModel):
id: str
projectId: str
queryText: str
timestamp: str
result: Optional[Any] = None
class UpdateChartType(BaseModel):
chart_type: str
model_config = ConfigDict(from_attributes=True)
result: Optional[Any] = None

@ -1,25 +1,67 @@
import json
import httpx
import sqlparse
import asyncio
import random
from fastapi import HTTPException
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession
from crud.crud_message import crud_message
from crud.crud_project import crud_project
from crud.crud_knowledge import crud_knowledge
from models.session import Session as SessionModel
from schema.chat import ChatResponse, MessageType
from core.log import log
# -----------------------
# AI 配置
# -----------------------
AI_SERVICE_URL = "http://26.64.77.145:1234/v1/chat/completions"
AI_MODEL = "codellama/CodeLlama-13b-Instruct-hf"
AI_API_KEY = "sk-YQjmNgkBJqRTsZCsr7r0zkHoLb6G0exL9u8gEkJTf5oZQXmE"
AI_API_KEY = "dummy-key"
# 【重要】Mock 开关
# True = 开启模拟模式(不联网,返回假数据,用于开发调试)
# False = 关闭模拟模式(尝试连接真实 AI
MOCK_MODE = True
# -----------------------
# 工具:将 Schema JSON 转为日志里的文本格式
# -----------------------
def format_schema_to_text(schema_data):
"""
JSON 对象转换为模型习惯的文本格式
Table: table_name, columns = [col1, col2, ...]
"""
if not schema_data:
return ""
lines = []
try:
# 如果数据库里存的是字符串,先转成对象
if isinstance(schema_data, str):
schema_data = json.loads(schema_data)
# 遍历表结构
# 假设结构是: [{"table_name": "student", "columns": ["id", "name"]}]
for table in schema_data:
t_name = table.get("table_name", "unknown")
cols = table.get("columns", [])
# 容错处理:确保 cols 是列表
if isinstance(cols, str):
cols = [cols]
# 构造日志里的核心格式
col_str = ", ".join(str(c) for c in cols)
line = f"Table: {t_name}, columns = [{col_str}]"
lines.append(line)
return "\n".join(lines)
except Exception as e:
log.error(f"Schema formatting error: {e}")
# 如果解析失败,为了不报错,返回原始字符串
return str(schema_data)
# -----------------------
# 获取 Session → Project
@ -29,68 +71,79 @@ async def get_project_id_by_session(db: AsyncSession, session_id: int) -> int:
select(SessionModel).where(SessionModel.session_id == session_id)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(404, "Session not found")
return session.project_id
# -----------------------
# 获取 Schema
# 获取 Schema (已修改为返回特定文本格式)
# -----------------------
async def get_project_schema(db, project_id):
async def get_project_schema_text(db, project_id):
project = await crud_project.get(db, project_id)
if not project or not project.schema_definition:
return "No schema defined."
sd = project.schema_definition
return json.dumps(sd, ensure_ascii=False, indent=2) if isinstance(sd, (dict, list)) else str(sd)
# 调用上面的工具函数进行转换
return format_schema_to_text(project.schema_definition)
# -----------------------
# 获取术语库
# Mock 逻辑 (模拟 AI)
# -----------------------
async def get_domain_knowledge(db, project_id):
items = await crud_knowledge.get_by_project(db, project_id)
if not items:
return ""
txt = "\n[业务术语]\n"
for t in items:
txt += f"- {t.term}: {t.definition}\n"
return txt
async def mock_ai_response(question: str):
"""模拟 AI 的行为,根据关键词返回不同类型的 SQL"""
log.info(f"【MOCK模式】正在模拟 AI 回复... 问题: {question}")
# 模拟 1.5 秒网络延迟,让前端 Loading 转一会儿
await asyncio.sleep(1.5)
q = question.lower()
# 根据问题包含的词,返回不同的 SQL测试前端展示效果
if "删除" in q or "delete" in q:
return "DELETE FROM student WHERE id = 1001;"
elif "修改" in q or "update" in q:
return "UPDATE course SET credit = 4 WHERE name = 'Software Engineering';"
elif "插入" in q or "添加" in q or "insert" in q:
return "INSERT INTO student (id, name, age) VALUES (2024001, 'Test User', 20);"
elif "平均" in q or "avg" in q:
return "SELECT AVG(score) FROM exam_results WHERE course_id = 'SE101';"
else:
# 默认查询
return "SELECT * FROM student WHERE major = 'Software Engineering' LIMIT 10;"
# -----------------------
# 调用 Claude Agent
# 调用 AI Agent
# -----------------------
async def call_ai_agent(schema, glossary, question):
system_prompt = f"""
你是 SQL 专家请根据 Schema 业务术语生成 SQL
[Schema]
{schema}
[Domain Knowledge]
{glossary}
要求
1. 必须返回 JSON
2. JSON 必须包含字段 "sql"
async def call_ai_agent(schema_text, question):
# 1. 如果开启了 Mock 模式,直接拦截并返回
if MOCK_MODE:
return await mock_ai_response(question)
# 2. 构造符合日志格式的 Prompt
# 注意:这里严格遵循了 "Table: ..., columns = [...]" 和 "### Response:"
user_prompt_content = f"""I want you to act as a SQL terminal in front of an database.
Here is the schema:
{schema_text}
I want you to answer the following question.
### Question: {question}
### Response:
"""
payload = {
"model": AI_MODEL,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question}
{"role": "system", "content": "You are a SQL expert."},
{"role": "user", "content": user_prompt_content}
],
"temperature": 0.1,
"stream": False # 添加 stream 参数
"temperature": 0.1,
"stream": False
}
headers = {
@ -99,40 +152,30 @@ async def call_ai_agent(schema, glossary, question):
}
try:
log.info(f"Payload: {payload}")
log.info(f"Payload sending to AI: {payload}")
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(AI_SERVICE_URL, json=payload, headers=headers)
resp.raise_for_status()
raw = resp.json()
log.info(f"Raw Response: {raw}")
# 更健壮的响应解析
content = ""
if "choices" in raw and len(raw["choices"]) > 0:
content = raw["choices"][0]["message"]["content"]
else:
# 尝试其他可能的响应格式
content = raw.get("message", {}).get("content", "") or raw.get("content", "")
# 清理内容
clean = content.replace("```json", "").replace("```", "").strip()
clean_sql = content.strip()
if clean_sql.startswith("```sql"):
clean_sql = clean_sql.replace("```sql", "").replace("```", "")
# 尝试解析 JSON
try:
return json.loads(clean)
except json.JSONDecodeError:
# 如果不是 JSON返回原始内容
return {"sql": clean}
return clean_sql.strip()
except Exception as e:
print("AI Error:", e)
return {"sql": f"-- AI Error: {str(e)}"}
log.error(f"AI Connection Error: {e}")
return f"-- Error calling AI: {str(e)}"
# -----------------------
# 主流程(修改版)
# 主流程
# -----------------------
async def process_chat(db: AsyncSession, session_id: int, user_input: str, user_id: int):
@ -140,37 +183,38 @@ async def process_chat(db: AsyncSession, session_id: int, user_input: str, user_
user_msg = await crud_message.create_message(
db, session_id, user_input, role="user"
)
log.info(f"User message stored: {user_msg}")
# 2. 获取上下文
project_id = await get_project_id_by_session(db, session_id)
schema = await get_project_schema(db, project_id)
glossary = await get_domain_knowledge(db, project_id)
log.info(f"Schema: {schema}")
log.info(f"Glossary: {glossary}")
# 3. 模型生成 SQL
ai_json = await call_ai_agent(schema, glossary, user_input)
sql_text = ai_json.get("sql", "-- no sql")
# 获取转换成文本格式的 Schema (Change: 使用新函数)
schema_text = await get_project_schema_text(db, project_id)
# 3. 模型生成 SQL
sql_text = await call_ai_agent(schema_text, user_input)
# 4. SQL 类型判断
sql_type = "UNKNOWN"
try:
parsed = sqlparse.parse(sql_text)
sql_type = parsed[0].get_type().upper() if parsed else "UNKNOWN"
except:
sql_type = "UNKNOWN"
if sql_text and not sql_text.startswith("--"):
parsed = sqlparse.parse(sql_text)
if parsed:
sql_type = parsed[0].get_type().upper()
except Exception as e:
log.warning(f"SQL Parse warning: {e}")
# 5. 创建 AI 回复消息
reply_content = f"生成 SQL 类型:{sql_type}"
reply_content = f"已生成查询语句:\n{sql_text}"
ai_message = await crud_message.create_message(
db, session_id, reply_content, role="assistant"
)
# 6. 返回 AI 回复给前端
# 6. 返回结果
return ChatResponse(
message_id=ai_message.message_id,
content=reply_content, # AI 回复的文本内容
message_type=MessageType.ASSISTANT, # 标记为 AI 消息
content=reply_content,
message_type=MessageType.ASSISTANT,
sql_text=sql_text,
sql_type=sql_type,
requires_confirmation=sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER"],

@ -0,0 +1,178 @@
import json
import httpx
import sqlparse
from fastapi import HTTPException
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession
from crud.crud_message import crud_message
from crud.crud_project import crud_project
from crud.crud_knowledge import crud_knowledge
from models.session import Session as SessionModel
from schema.chat import ChatResponse, MessageType
from core.log import log
# -----------------------
# AI 配置
# -----------------------
AI_SERVICE_URL = "http://26.64.77.145:1234/v1/chat/completions"
AI_MODEL = "codellama/CodeLlama-13b-Instruct-hf"
AI_API_KEY = "sk-YQjmNgkBJqRTsZCsr7r0zkHoLb6G0exL9u8gEkJTf5oZQXmE"
# -----------------------
# 获取 Session → Project
# -----------------------
async def get_project_id_by_session(db: AsyncSession, session_id: int) -> int:
result = await db.execute(
select(SessionModel).where(SessionModel.session_id == session_id)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(404, "Session not found")
return session.project_id
# -----------------------
# 获取 Schema
# -----------------------
async def get_project_schema(db, project_id):
project = await crud_project.get(db, project_id)
if not project or not project.schema_definition:
return "No schema defined."
sd = project.schema_definition
return json.dumps(sd, ensure_ascii=False, indent=2) if isinstance(sd, (dict, list)) else str(sd)
# -----------------------
# 获取术语库
# -----------------------
async def get_domain_knowledge(db, project_id):
items = await crud_knowledge.get_by_project(db, project_id)
if not items:
return ""
txt = "\n[业务术语]\n"
for t in items:
txt += f"- {t.term}: {t.definition}\n"
return txt
# -----------------------
# 调用 Claude Agent
# -----------------------
async def call_ai_agent(schema, glossary, question):
system_prompt = f"""
你是 SQL 专家请根据 Schema 业务术语生成 SQL
[Schema]
{schema}
[Domain Knowledge]
{glossary}
要求
1. 必须返回 JSON
2. JSON 必须包含字段 "sql"
"""
payload = {
"model": AI_MODEL,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question}
],
"temperature": 0.1,
"stream": False # 添加 stream 参数
}
headers = {
"Authorization": f"Bearer {AI_API_KEY}",
"Content-Type": "application/json"
}
try:
log.info(f"Payload: {payload}")
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(AI_SERVICE_URL, json=payload, headers=headers)
resp.raise_for_status()
raw = resp.json()
log.info(f"Raw Response: {raw}")
# 更健壮的响应解析
if "choices" in raw and len(raw["choices"]) > 0:
content = raw["choices"][0]["message"]["content"]
else:
# 尝试其他可能的响应格式
content = raw.get("message", {}).get("content", "") or raw.get("content", "")
# 清理内容
clean = content.replace("```json", "").replace("```", "").strip()
# 尝试解析 JSON
try:
return json.loads(clean)
except json.JSONDecodeError:
# 如果不是 JSON返回原始内容
return {"sql": clean}
except Exception as e:
print("AI Error:", e)
return {"sql": f"-- AI Error: {str(e)}"}
# -----------------------
# 主流程(修改版)
# -----------------------
async def process_chat(db: AsyncSession, session_id: int, user_input: str, user_id: int):
# 1. 存用户消息
user_msg = await crud_message.create_message(
db, session_id, user_input, role="user"
)
log.info(f"User message stored: {user_msg}")
# 2. 获取上下文
project_id = await get_project_id_by_session(db, session_id)
schema = await get_project_schema(db, project_id)
glossary = await get_domain_knowledge(db, project_id)
log.info(f"Schema: {schema}")
log.info(f"Glossary: {glossary}")
# 3. 模型生成 SQL
ai_json = await call_ai_agent(schema, glossary, user_input)
sql_text = ai_json.get("sql", "-- no sql")
# 4. SQL 类型判断
try:
parsed = sqlparse.parse(sql_text)
sql_type = parsed[0].get_type().upper() if parsed else "UNKNOWN"
except:
sql_type = "UNKNOWN"
# 5. 创建 AI 回复消息
reply_content = f"生成 SQL 类型:{sql_type}"
ai_message = await crud_message.create_message(
db, session_id, reply_content, role="assistant"
)
# 6. 返回 AI 回复给前端
return ChatResponse(
message_id=ai_message.message_id,
content=reply_content, # AI 回复的文本内容
message_type=MessageType.ASSISTANT, # 标记为 AI 消息
sql_text=sql_text,
sql_type=sql_type,
requires_confirmation=sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER"],
data=None
)

@ -2,133 +2,225 @@
from typing import List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession as Session
from sqlalchemy import select, delete, update
from fastapi import HTTPException
from crud.crud_report import crud_report
from schema import report as schemas
from core.exceptions import DatabaseOperationFailedException, ItemNotFoundException
# 导入核心模型
from models.report import AnalysisReport
from models.query_result import QueryResult
from models.ai_generated_statement import AIGeneratedStatement
from models.project import Project
from sqlalchemy import select
# 注意Message 和 Session 我们将在函数内部导入,或者你可以尝试在这里导入
# 如果报错循环依赖,请保持函数内导入
# --------------------------
# 1. 获取报表列表
# 1. 获取报表列表 (Read)
# --------------------------
async def get_report_list(db: Session, project_id: int) -> List[schemas.Report]:
db_objs = await crud_report.get_by_project(db, project_id)
"""
查询 AnalysisReport 并join QueryResult 获取数据
"""
# [修复问题1]:先检查项目是否存在
project = await db.get(Project, project_id)
if not project:
raise HTTPException(status_code=404, detail=f"Project with ID {project_id} not found")
stmt = (
select(AnalysisReport, QueryResult, AIGeneratedStatement)
.join(QueryResult, AnalysisReport.result_id == QueryResult.result_id)
.join(AIGeneratedStatement, QueryResult.statement_id == AIGeneratedStatement.statement_id)
.where(AnalysisReport.project_id == project_id)
.order_by(AnalysisReport.created_at.desc())
)
result = await db.execute(stmt)
rows = result.all()
reports = []
for obj in db_objs:
default_chart_config = schemas.ChartConfig(
xAxisKey="name",
yAxisKey="value"
)
report_data = obj.result_data if isinstance(obj.result_data, list) else []
report = schemas.Report(
id=str(obj.result_id),
for report_obj, result_obj, stmt_obj in rows:
# 处理 chart_config (从数据库JSON转为Pydantic对象)
c_config = None
if report_obj.chart_config:
# 兼容处理:确保 chart_config 是字典
config_dict = report_obj.chart_config if isinstance(report_obj.chart_config, dict) else {}
# 只有当字典不为空且包含必要的键时才转换
if config_dict.get('xAxisKey') and config_dict.get('yAxisKey'):
c_config = schemas.ChartConfig(**config_dict)
reports.append(schemas.Report(
id=str(report_obj.report_id),
projectId=str(project_id),
name=obj.data_summary[:20] if obj.data_summary else f"Report-{obj.result_id}",
type=obj.chart_type or "table",
description=obj.data_summary,
data=report_data,
chartConfig=default_chart_config,
sourceQueryText="SELECT * FROM ...",
updatedAt=obj.cached_at.isoformat() if obj.cached_at else ""
)
reports.append(report)
name=report_obj.name,
type=report_obj.chart_type,
description=report_obj.description,
data=result_obj.result_data if isinstance(result_obj.result_data, list) else [],
chartConfig=c_config,
sourceQueryText=stmt_obj.sql_text,
updatedAt=report_obj.updated_at.isoformat() if report_obj.updated_at else report_obj.created_at.isoformat()
))
return reports
# --------------------------
# 2. 获取报表详情
# --------------------------
async def get_report_data_by_id_service(db: Session, query_id: int) -> Dict[str, Any]:
raw = await crud_report.get_report_data_by_result_id(db, query_id)
if not raw:
raise ItemNotFoundException(f"Query result with ID {query_id} not found.")
query_result, statement = raw
return {
"result_data": query_result.result_data,
"data_summary": query_result.data_summary,
"chart_type": query_result.chart_type,
"sql_text": statement.sql_text,
"cached_at": query_result.cached_at.isoformat() if query_result.cached_at else None
}
# --------------------------
# 3. 获取历史查询记录
# 2. 获取历史查询记录 (Source for creating reports)
# --------------------------
async def get_history_queries_service(db: Session, project_id: int) -> List[schemas.HistoryQuery]:
"""
获取历史查询结果
修正逻辑关联 Message 返回用户原始的自然语言问题
"""
# [修复问题1]:先检查项目是否存在
project = await db.get(Project, project_id)
if not project:
raise HTTPException(status_code=404, detail=f"Project with ID {project_id} not found")
# [修复问题2]:局部导入以避免 UnboundLocalError 和循环依赖
from models.message import Message
from models.session import Session as SessionModel
# 构造查询Query Result -> Statement -> Message -> Session
# 我们需要 Message.content (用户问题) 和 QueryResult (数据)
stmt = (
select(QueryResult, Message)
.join(AIGeneratedStatement, QueryResult.statement_id == AIGeneratedStatement.statement_id)
.join(Message, AIGeneratedStatement.message_id == Message.message_id)
.join(SessionModel, Message.session_id == SessionModel.session_id)
.where(SessionModel.project_id == project_id)
.where(Message.message_type == 'user') # 确保我们取的是用户发的消息(提问)
.order_by(QueryResult.cached_at.desc())
)
db_objs = await crud_report.get_history_by_project(db, project_id)
history = []
result = await db.execute(stmt)
rows = result.all()
for obj in db_objs:
history = []
for query_res, msg_obj in rows:
history.append(schemas.HistoryQuery(
id=str(obj.result_id),
projectId=str(project_id), # ← 修复类型错误
queryText=obj.data_summary or "N/A",
timestamp=obj.cached_at.isoformat() if obj.cached_at else 'N/A',
result=obj.result_data
))
id=str(query_res.result_id),
projectId=str(project_id),
# 这里取的是 message.content即用户的原始提问
queryText=msg_obj.content,
timestamp=query_res.cached_at.isoformat() if query_res.cached_at else 'N/A',
result=query_res.result_data
))
return history
# --------------------------
# 4. 创建报表(实时渲染,不入库)
# 3. 创建报表 (Create - 真正入库)
# --------------------------
async def create_report_service(db: Session, project_id: int, payload: schemas.ReportCreate):
project_exists = await db.execute(select(Project).where(Project.project_id == project_id))
if not project_exists.scalars().first():
# 1. 校验项目
project_exists = await db.get(Project, project_id)
if not project_exists:
raise HTTPException(status_code=404, detail="Project not found")
raw = await crud_report.get_report_data_by_result_id(db, payload.query_id)
if not raw:
raise HTTPException(status_code=404, detail="Query result not found")
# 2. 校验数据源 (Query Result) 是否存在
result_exists = await db.get(QueryResult, payload.query_id)
if not result_exists:
raise HTTPException(status_code=404, detail="Query result (Data Source) not found")
# 3. 创建 AnalysisReport 对象
new_report = AnalysisReport(
project_id=project_id,
result_id=payload.query_id,
name=payload.report_name,
chart_type=payload.chart_type,
description=payload.description,
chart_config=payload.chartConfig.model_dump() if payload.chartConfig else None
)
query_result, statement = raw
db.add(new_report)
await db.commit()
await db.refresh(new_report)
# 4. 获取关联 SQL 文本用于返回
stmt_obj = await db.get(AIGeneratedStatement, result_exists.statement_id)
return schemas.Report(
id=str(query_result.result_id),
id=str(new_report.report_id),
projectId=str(project_id),
name=payload.report_name,
type=payload.chart_type,
description=query_result.data_summary,
data=query_result.result_data,
name=new_report.name,
type=new_report.chart_type,
description=new_report.description,
data=result_exists.result_data,
chartConfig=payload.chartConfig,
sourceQueryText=statement.sql_text,
updatedAt=query_result.cached_at.isoformat() if query_result.cached_at else ""
sourceQueryText=stmt_obj.sql_text if stmt_obj else "",
updatedAt=new_report.created_at.isoformat()
)
# --------------------------
# 4. 删除报表 (Delete)
# --------------------------
async def delete_report_service(db: Session, report_id: int) -> bool:
stmt = delete(AnalysisReport).where(AnalysisReport.report_id == report_id)
result = await db.execute(stmt)
await db.commit()
if result.rowcount == 0:
return False
return True
# --------------------------
# 6. 导出报表
# 5. 修改报表 (Update)
# --------------------------
async def export_report_service(db: Session, report_id: int, format: str):
async def update_report_service(db: Session, report_id: int, payload: schemas.ReportUpdate):
# 1. 检查是否存在
report_obj = await db.get(AnalysisReport, report_id)
if not report_obj:
raise HTTPException(status_code=404, detail="Report not found")
# 2. 更新字段
if payload.report_name is not None:
report_obj.name = payload.report_name
if payload.chart_type is not None:
report_obj.chart_type = payload.chart_type
if payload.description is not None:
report_obj.description = payload.description
if payload.chartConfig is not None:
report_obj.chart_config = payload.chartConfig.model_dump()
await db.commit()
await db.refresh(report_obj)
# 3. 组装返回数据
result_obj = await db.get(QueryResult, report_obj.result_id)
stmt_obj = await db.get(AIGeneratedStatement, result_obj.statement_id)
c_config = None
if report_obj.chart_config:
config_dict = report_obj.chart_config if isinstance(report_obj.chart_config, dict) else {}
if config_dict.get('xAxisKey') and config_dict.get('yAxisKey'):
c_config = schemas.ChartConfig(**config_dict)
return schemas.Report(
id=str(report_obj.report_id),
projectId=str(report_obj.project_id),
name=report_obj.name,
type=report_obj.chart_type,
description=report_obj.description,
data=result_obj.result_data,
chartConfig=c_config,
sourceQueryText=stmt_obj.sql_text,
updatedAt=report_obj.updated_at.isoformat() if report_obj.updated_at else ""
)
if format not in ("png", "pdf"):
raise HTTPException(status_code=400, detail="format must be png or pdf")
# --------------------------
# 6. 导出报表 (Export)
# --------------------------
async def export_report_service(db: Session, report_id: int, format: str):
# 检查 report 是否存在
raw = await crud_report.get_report_data_by_result_id(db, report_id)
if not raw:
report_obj = await db.get(AnalysisReport, report_id)
if not report_obj:
raise HTTPException(status_code=404, detail="Report not found")
# 正常返回
return {
"download_url": f"https://fake-cdn.example.com/reports/{report_id}.{format}",
"expires_at": "2025-12-01T15:00:00Z"
}
}
Loading…
Cancel
Save