diff --git a/src/backend/app/api/v1/endpoints/reports.py b/src/backend/app/api/v1/endpoints/reports.py index 2071f3c..3377191 100644 --- a/src/backend/app/api/v1/endpoints/reports.py +++ b/src/backend/app/api/v1/endpoints/reports.py @@ -1,50 +1,93 @@ # backend/app/api/v1/endpoints/reports.py -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends, Query, Path, HTTPException from sqlalchemy.ext.asyncio import AsyncSession -from schema.report import ReportCreate, Report, HistoryQuery, UpdateChartType +from typing import List + +# 导入正确的 ReportUpdate Schema +from schema.report import ReportCreate, Report, HistoryQuery, ReportUpdate from api.v1.deps import get_db from service import report_service router = APIRouter() -# 4.1 获取报表列表 -@router.get("/reports", response_model=list[Report]) +# ------------------------------------------- +# 1. 获取报表列表 (Read) +# ------------------------------------------- +@router.get("/reports", response_model=List[Report]) async def read_reports( - projectId: int = Query(...), + projectId: int = Query(..., description="项目ID"), db: AsyncSession = Depends(get_db), ): + """ + 获取指定项目下的所有保存的报表配置。 + """ return await report_service.get_report_list(db, projectId) - -# 4.4 获取历史查询记录 -@router.get("/history-queries", response_model=list[HistoryQuery]) +# ------------------------------------------- +# 2. 获取历史查询记录 +# ------------------------------------------- +@router.get("/history-queries", response_model=List[HistoryQuery]) async def read_history( - projectId: int = Query(...), + projectId: int = Query(..., description="项目ID"), db: AsyncSession = Depends(get_db), ): + """ + 获取历史查询结果,用于作为创建新报表的数据源。 + """ return await report_service.get_history_queries_service(db, projectId) - -# 3.5.1 创建报表(实时渲染) +# ------------------------------------------- +# 3. 创建报表 (Create) +# ------------------------------------------- @router.post("/projects/{project_id}/reports", response_model=Report) async def create_report( project_id: int, body: ReportCreate, db: AsyncSession = Depends(get_db), ): + """ + 保存报表配置。 + """ return await report_service.create_report_service(db, project_id, body) +# ------------------------------------------- +# 4. 删除报表 (Delete) +# ------------------------------------------- +@router.delete("/reports/{report_id}", response_model=bool) +async def delete_report( + report_id: int = Path(..., description="报表ID"), + db: AsyncSession = Depends(get_db), +): + """ + 删除指定的报表配置。 + """ + success = await report_service.delete_report_service(db, report_id) + if not success: + raise HTTPException(status_code=404, detail="Report not found") + return True -# 3.5.2 修改图表类型(已删除该接口) - - +# ------------------------------------------- +# 5. 修改报表信息 (Update) - [已修复] +# ------------------------------------------- +@router.put("/reports/{report_id}", response_model=Report) +async def update_report( + body: ReportUpdate, # 注意:Pydantic模型通常放在 Depends 前面 + report_id: int = Path(..., description="报表ID"), + db: AsyncSession = Depends(get_db), +): + """ + 修改报表配置(例如:修改图表类型、修改标题)。 + """ + return await report_service.update_report_service(db, report_id, body) -# 3.5.3 导出报表 +# ------------------------------------------- +# 6. 导出报表 (Export) +# ------------------------------------------- @router.get("/reports/{report_id}/export", response_model=dict) async def export_report( report_id: int, - format: str = Query("png"), + format: str = Query("png", regex="^(png|jpeg|pdf)$"), db: AsyncSession = Depends(get_db) ): - return await report_service.export_report_service(db, report_id, format) + return await report_service.export_report_service(db, report_id, format) \ No newline at end of file diff --git a/src/backend/app/models/report.py b/src/backend/app/models/report.py index 0479597..6227ffd 100644 --- a/src/backend/app/models/report.py +++ b/src/backend/app/models/report.py @@ -1,16 +1,28 @@ -from sqlalchemy import Column, Integer, String, ForeignKey, DateTime, func +from sqlalchemy import Column, Integer, String, Text, ForeignKey, JSON, DateTime +from sqlalchemy.sql import func from sqlalchemy.orm import relationship from core.database import Base -class Report(Base): - __tablename__ = "report" +class AnalysisReport(Base): + """ + 报表配置表:用于持久化保存用户对某个查询结果的展示配置 + """ + __tablename__ = "analysis_report" report_id = Column(Integer, primary_key=True, index=True) - project_id = Column(Integer, ForeignKey("project.project_id"), nullable=False) - - report_name = Column(String(255), nullable=False) - query_id = Column(Integer, ForeignKey("query_history.query_id"), nullable=False) - chart_type = Column(String(50), nullable=False) - + project_id = Column(Integer, ForeignKey("project.project_id", ondelete="CASCADE"), nullable=False) + # 关联的数据源(查询结果) + result_id = Column(Integer, ForeignKey("query_result.result_id", ondelete="CASCADE"), nullable=False) + + # 用户自定义的配置 + name = Column(String(100), nullable=False) # 报表名称 + description = Column(Text, nullable=True) + chart_type = Column(String(50), default="table") # bar, line, pie, etc. + chart_config = Column(JSON, nullable=True) # 存储 {xAxisKey: "...", yAxisKey: "..."} + created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + + # 关联关系 + # project = relationship("Project", back_populates="reports") # 需在Project model中添加对应关系 + # query_result = relationship("QueryResult") # 需在QueryResult model中添加对应关系 \ No newline at end of file diff --git a/src/backend/app/schema/report.py b/src/backend/app/schema/report.py index 0d6d2b5..73d14a3 100644 --- a/src/backend/app/schema/report.py +++ b/src/backend/app/schema/report.py @@ -2,41 +2,49 @@ from pydantic import BaseModel, ConfigDict from typing import List, Optional, Any, Dict +from datetime import datetime # 图表配置 class ChartConfig(BaseModel): xAxisKey: str yAxisKey: str -# 创建报表 +# 1. 创建报表请求 class ReportCreate(BaseModel): report_name: str - query_id: int + query_id: int # 关联的 result_id chart_type: str = "table" description: Optional[str] = None - chartConfig: Optional[ChartConfig] = None # ← 新增 + chartConfig: Optional[ChartConfig] = None +# 2. 修改报表请求 (新增) +class ReportUpdate(BaseModel): + report_name: Optional[str] = None + chart_type: Optional[str] = None + description: Optional[str] = None + chartConfig: Optional[ChartConfig] = None -# 单个报表响应 +# 3. 报表响应 (调整为从 AnalysisReport 获取元数据,从 QueryResult 获取 Data) class Report(BaseModel): - id: str + id: str # report_id projectId: str - name: str - type: str + name: str # 用户自定义的名称 + type: str # 用户保存的 chart_type description: Optional[str] - data: List[Dict[str, Any]] - chartConfig: ChartConfig | None # ← 必须加回 + + # 以下数据来自关联的 QueryResult + data: List[Dict[str, Any]] sourceQueryText: Optional[str] + + chartConfig: Optional[ChartConfig] updatedAt: str + + model_config = ConfigDict(from_attributes=True) -# 历史查询 +# 4. 历史查询 (保持不变) class HistoryQuery(BaseModel): id: str projectId: str queryText: str timestamp: str - result: Optional[Any] = None - -class UpdateChartType(BaseModel): - chart_type: str - model_config = ConfigDict(from_attributes=True) + result: Optional[Any] = None \ No newline at end of file diff --git a/src/backend/app/service/chat_service.py b/src/backend/app/service/chat_service.py index d639fd9..8cbae38 100644 --- a/src/backend/app/service/chat_service.py +++ b/src/backend/app/service/chat_service.py @@ -1,25 +1,67 @@ import json import httpx import sqlparse +import asyncio +import random from fastapi import HTTPException from sqlalchemy.future import select from sqlalchemy.ext.asyncio import AsyncSession from crud.crud_message import crud_message from crud.crud_project import crud_project -from crud.crud_knowledge import crud_knowledge from models.session import Session as SessionModel from schema.chat import ChatResponse, MessageType from core.log import log - # ----------------------- # AI 配置 # ----------------------- AI_SERVICE_URL = "http://26.64.77.145:1234/v1/chat/completions" AI_MODEL = "codellama/CodeLlama-13b-Instruct-hf" -AI_API_KEY = "sk-YQjmNgkBJqRTsZCsr7r0zkHoLb6G0exL9u8gEkJTf5oZQXmE" +AI_API_KEY = "dummy-key" + +# 【重要】Mock 开关 +# True = 开启模拟模式(不联网,返回假数据,用于开发调试) +# False = 关闭模拟模式(尝试连接真实 AI) +MOCK_MODE = True +# ----------------------- +# 工具:将 Schema JSON 转为日志里的文本格式 +# ----------------------- +def format_schema_to_text(schema_data): + """ + 将 JSON 对象转换为模型习惯的文本格式: + Table: table_name, columns = [col1, col2, ...] + """ + if not schema_data: + return "" + + lines = [] + try: + # 如果数据库里存的是字符串,先转成对象 + if isinstance(schema_data, str): + schema_data = json.loads(schema_data) + + # 遍历表结构 + # 假设结构是: [{"table_name": "student", "columns": ["id", "name"]}] + for table in schema_data: + t_name = table.get("table_name", "unknown") + cols = table.get("columns", []) + + # 容错处理:确保 cols 是列表 + if isinstance(cols, str): + cols = [cols] + + # 构造日志里的核心格式 + col_str = ", ".join(str(c) for c in cols) + line = f"Table: {t_name}, columns = [{col_str}]" + lines.append(line) + + return "\n".join(lines) + except Exception as e: + log.error(f"Schema formatting error: {e}") + # 如果解析失败,为了不报错,返回原始字符串 + return str(schema_data) # ----------------------- # 获取 Session → Project @@ -29,68 +71,79 @@ async def get_project_id_by_session(db: AsyncSession, session_id: int) -> int: select(SessionModel).where(SessionModel.session_id == session_id) ) session = result.scalar_one_or_none() - if not session: raise HTTPException(404, "Session not found") - return session.project_id - # ----------------------- -# 获取 Schema +# 获取 Schema (已修改为返回特定文本格式) # ----------------------- -async def get_project_schema(db, project_id): +async def get_project_schema_text(db, project_id): project = await crud_project.get(db, project_id) - if not project or not project.schema_definition: return "No schema defined." - sd = project.schema_definition - return json.dumps(sd, ensure_ascii=False, indent=2) if isinstance(sd, (dict, list)) else str(sd) - + # 调用上面的工具函数进行转换 + return format_schema_to_text(project.schema_definition) # ----------------------- -# 获取术语库 +# Mock 逻辑 (模拟 AI) # ----------------------- -async def get_domain_knowledge(db, project_id): - items = await crud_knowledge.get_by_project(db, project_id) - - if not items: - return "" - - txt = "\n[业务术语]\n" - for t in items: - txt += f"- {t.term}: {t.definition}\n" - - return txt - +async def mock_ai_response(question: str): + """模拟 AI 的行为,根据关键词返回不同类型的 SQL""" + log.info(f"【MOCK模式】正在模拟 AI 回复... 问题: {question}") + + # 模拟 1.5 秒网络延迟,让前端 Loading 转一会儿 + await asyncio.sleep(1.5) + + q = question.lower() + + # 根据问题包含的词,返回不同的 SQL,测试前端展示效果 + if "删除" in q or "delete" in q: + return "DELETE FROM student WHERE id = 1001;" + + elif "修改" in q or "update" in q: + return "UPDATE course SET credit = 4 WHERE name = 'Software Engineering';" + + elif "插入" in q or "添加" in q or "insert" in q: + return "INSERT INTO student (id, name, age) VALUES (2024001, 'Test User', 20);" + + elif "平均" in q or "avg" in q: + return "SELECT AVG(score) FROM exam_results WHERE course_id = 'SE101';" + + else: + # 默认查询 + return "SELECT * FROM student WHERE major = 'Software Engineering' LIMIT 10;" # ----------------------- -# 调用 Claude Agent +# 调用 AI Agent # ----------------------- -async def call_ai_agent(schema, glossary, question): - system_prompt = f""" -你是 SQL 专家,请根据 Schema 与 业务术语生成 SQL。 - -[Schema] -{schema} - -[Domain Knowledge] -{glossary} - -要求: -1. 必须返回 JSON -2. JSON 必须包含字段 "sql" +async def call_ai_agent(schema_text, question): + + # 1. 如果开启了 Mock 模式,直接拦截并返回 + if MOCK_MODE: + return await mock_ai_response(question) + + # 2. 构造符合日志格式的 Prompt + # 注意:这里严格遵循了 "Table: ..., columns = [...]" 和 "### Response:" + user_prompt_content = f"""I want you to act as a SQL terminal in front of an database. +Here is the schema: +{schema_text} + +I want you to answer the following question. +### Question: {question} + +### Response: """ payload = { "model": AI_MODEL, "messages": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": question} + {"role": "system", "content": "You are a SQL expert."}, + {"role": "user", "content": user_prompt_content} ], - "temperature": 0.1, - "stream": False # 添加 stream 参数 + "temperature": 0.1, + "stream": False } headers = { @@ -99,40 +152,30 @@ async def call_ai_agent(schema, glossary, question): } try: - log.info(f"Payload: {payload}") + log.info(f"Payload sending to AI: {payload}") async with httpx.AsyncClient(timeout=60) as client: resp = await client.post(AI_SERVICE_URL, json=payload, headers=headers) - resp.raise_for_status() raw = resp.json() - - log.info(f"Raw Response: {raw}") - # 更健壮的响应解析 + content = "" if "choices" in raw and len(raw["choices"]) > 0: content = raw["choices"][0]["message"]["content"] - else: - # 尝试其他可能的响应格式 - content = raw.get("message", {}).get("content", "") or raw.get("content", "") # 清理内容 - clean = content.replace("```json", "").replace("```", "").strip() + clean_sql = content.strip() + if clean_sql.startswith("```sql"): + clean_sql = clean_sql.replace("```sql", "").replace("```", "") - # 尝试解析 JSON - try: - return json.loads(clean) - except json.JSONDecodeError: - # 如果不是 JSON,返回原始内容 - return {"sql": clean} + return clean_sql.strip() except Exception as e: - print("AI Error:", e) - return {"sql": f"-- AI Error: {str(e)}"} - + log.error(f"AI Connection Error: {e}") + return f"-- Error calling AI: {str(e)}" # ----------------------- -# 主流程(修改版) +# 主流程 # ----------------------- async def process_chat(db: AsyncSession, session_id: int, user_input: str, user_id: int): @@ -140,37 +183,38 @@ async def process_chat(db: AsyncSession, session_id: int, user_input: str, user_ user_msg = await crud_message.create_message( db, session_id, user_input, role="user" ) - log.info(f"User message stored: {user_msg}") # 2. 获取上下文 project_id = await get_project_id_by_session(db, session_id) - schema = await get_project_schema(db, project_id) - glossary = await get_domain_knowledge(db, project_id) - log.info(f"Schema: {schema}") - log.info(f"Glossary: {glossary}") - - # 3. 模型生成 SQL - ai_json = await call_ai_agent(schema, glossary, user_input) - sql_text = ai_json.get("sql", "-- no sql") + + # 获取转换成文本格式的 Schema (Change: 使用新函数) + schema_text = await get_project_schema_text(db, project_id) + + # 3. 模型生成 SQL + sql_text = await call_ai_agent(schema_text, user_input) # 4. SQL 类型判断 + sql_type = "UNKNOWN" try: - parsed = sqlparse.parse(sql_text) - sql_type = parsed[0].get_type().upper() if parsed else "UNKNOWN" - except: - sql_type = "UNKNOWN" + if sql_text and not sql_text.startswith("--"): + parsed = sqlparse.parse(sql_text) + if parsed: + sql_type = parsed[0].get_type().upper() + except Exception as e: + log.warning(f"SQL Parse warning: {e}") # 5. 创建 AI 回复消息 - reply_content = f"生成 SQL 类型:{sql_type}" + reply_content = f"已生成查询语句:\n{sql_text}" + ai_message = await crud_message.create_message( db, session_id, reply_content, role="assistant" ) - # 6. 返回 AI 回复给前端 + # 6. 返回结果 return ChatResponse( message_id=ai_message.message_id, - content=reply_content, # AI 回复的文本内容 - message_type=MessageType.ASSISTANT, # 标记为 AI 消息 + content=reply_content, + message_type=MessageType.ASSISTANT, sql_text=sql_text, sql_type=sql_type, requires_confirmation=sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER"], diff --git a/src/backend/app/service/chat_service_test.py b/src/backend/app/service/chat_service_test.py new file mode 100644 index 0000000..d639fd9 --- /dev/null +++ b/src/backend/app/service/chat_service_test.py @@ -0,0 +1,178 @@ +import json +import httpx +import sqlparse +from fastapi import HTTPException +from sqlalchemy.future import select +from sqlalchemy.ext.asyncio import AsyncSession + +from crud.crud_message import crud_message +from crud.crud_project import crud_project +from crud.crud_knowledge import crud_knowledge +from models.session import Session as SessionModel +from schema.chat import ChatResponse, MessageType +from core.log import log + + +# ----------------------- +# AI 配置 +# ----------------------- +AI_SERVICE_URL = "http://26.64.77.145:1234/v1/chat/completions" +AI_MODEL = "codellama/CodeLlama-13b-Instruct-hf" +AI_API_KEY = "sk-YQjmNgkBJqRTsZCsr7r0zkHoLb6G0exL9u8gEkJTf5oZQXmE" + + +# ----------------------- +# 获取 Session → Project +# ----------------------- +async def get_project_id_by_session(db: AsyncSession, session_id: int) -> int: + result = await db.execute( + select(SessionModel).where(SessionModel.session_id == session_id) + ) + session = result.scalar_one_or_none() + + if not session: + raise HTTPException(404, "Session not found") + + return session.project_id + + +# ----------------------- +# 获取 Schema +# ----------------------- +async def get_project_schema(db, project_id): + project = await crud_project.get(db, project_id) + + if not project or not project.schema_definition: + return "No schema defined." + + sd = project.schema_definition + return json.dumps(sd, ensure_ascii=False, indent=2) if isinstance(sd, (dict, list)) else str(sd) + + +# ----------------------- +# 获取术语库 +# ----------------------- +async def get_domain_knowledge(db, project_id): + items = await crud_knowledge.get_by_project(db, project_id) + + if not items: + return "" + + txt = "\n[业务术语]\n" + for t in items: + txt += f"- {t.term}: {t.definition}\n" + + return txt + + +# ----------------------- +# 调用 Claude Agent +# ----------------------- +async def call_ai_agent(schema, glossary, question): + system_prompt = f""" +你是 SQL 专家,请根据 Schema 与 业务术语生成 SQL。 + +[Schema] +{schema} + +[Domain Knowledge] +{glossary} + +要求: +1. 必须返回 JSON +2. JSON 必须包含字段 "sql" +""" + + payload = { + "model": AI_MODEL, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question} + ], + "temperature": 0.1, + "stream": False # 添加 stream 参数 + } + + headers = { + "Authorization": f"Bearer {AI_API_KEY}", + "Content-Type": "application/json" + } + + try: + log.info(f"Payload: {payload}") + async with httpx.AsyncClient(timeout=60) as client: + resp = await client.post(AI_SERVICE_URL, json=payload, headers=headers) + + + resp.raise_for_status() + raw = resp.json() + + log.info(f"Raw Response: {raw}") + + # 更健壮的响应解析 + if "choices" in raw and len(raw["choices"]) > 0: + content = raw["choices"][0]["message"]["content"] + else: + # 尝试其他可能的响应格式 + content = raw.get("message", {}).get("content", "") or raw.get("content", "") + + # 清理内容 + clean = content.replace("```json", "").replace("```", "").strip() + + # 尝试解析 JSON + try: + return json.loads(clean) + except json.JSONDecodeError: + # 如果不是 JSON,返回原始内容 + return {"sql": clean} + + except Exception as e: + print("AI Error:", e) + return {"sql": f"-- AI Error: {str(e)}"} + + +# ----------------------- +# 主流程(修改版) +# ----------------------- +async def process_chat(db: AsyncSession, session_id: int, user_input: str, user_id: int): + + # 1. 存用户消息 + user_msg = await crud_message.create_message( + db, session_id, user_input, role="user" + ) + log.info(f"User message stored: {user_msg}") + + # 2. 获取上下文 + project_id = await get_project_id_by_session(db, session_id) + schema = await get_project_schema(db, project_id) + glossary = await get_domain_knowledge(db, project_id) + log.info(f"Schema: {schema}") + log.info(f"Glossary: {glossary}") + + # 3. 模型生成 SQL + ai_json = await call_ai_agent(schema, glossary, user_input) + sql_text = ai_json.get("sql", "-- no sql") + + # 4. SQL 类型判断 + try: + parsed = sqlparse.parse(sql_text) + sql_type = parsed[0].get_type().upper() if parsed else "UNKNOWN" + except: + sql_type = "UNKNOWN" + + # 5. 创建 AI 回复消息 + reply_content = f"生成 SQL 类型:{sql_type}" + ai_message = await crud_message.create_message( + db, session_id, reply_content, role="assistant" + ) + + # 6. 返回 AI 回复给前端 + return ChatResponse( + message_id=ai_message.message_id, + content=reply_content, # AI 回复的文本内容 + message_type=MessageType.ASSISTANT, # 标记为 AI 消息 + sql_text=sql_text, + sql_type=sql_type, + requires_confirmation=sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER"], + data=None + ) \ No newline at end of file diff --git a/src/backend/app/service/report_service.py b/src/backend/app/service/report_service.py index 35b650e..27221ae 100644 --- a/src/backend/app/service/report_service.py +++ b/src/backend/app/service/report_service.py @@ -2,133 +2,225 @@ from typing import List, Dict, Any from sqlalchemy.ext.asyncio import AsyncSession as Session +from sqlalchemy import select, delete, update from fastapi import HTTPException -from crud.crud_report import crud_report from schema import report as schemas -from core.exceptions import DatabaseOperationFailedException, ItemNotFoundException +# 导入核心模型 +from models.report import AnalysisReport +from models.query_result import QueryResult +from models.ai_generated_statement import AIGeneratedStatement from models.project import Project -from sqlalchemy import select +# 注意:Message 和 Session 我们将在函数内部导入,或者你可以尝试在这里导入 +# 如果报错循环依赖,请保持函数内导入 # -------------------------- -# 1. 获取报表列表 +# 1. 获取报表列表 (Read) # -------------------------- async def get_report_list(db: Session, project_id: int) -> List[schemas.Report]: - - db_objs = await crud_report.get_by_project(db, project_id) + """ + 查询 AnalysisReport 表,并join QueryResult 获取数据 + """ + # [修复问题1]:先检查项目是否存在 + project = await db.get(Project, project_id) + if not project: + raise HTTPException(status_code=404, detail=f"Project with ID {project_id} not found") + + stmt = ( + select(AnalysisReport, QueryResult, AIGeneratedStatement) + .join(QueryResult, AnalysisReport.result_id == QueryResult.result_id) + .join(AIGeneratedStatement, QueryResult.statement_id == AIGeneratedStatement.statement_id) + .where(AnalysisReport.project_id == project_id) + .order_by(AnalysisReport.created_at.desc()) + ) + + result = await db.execute(stmt) + rows = result.all() + reports = [] - - for obj in db_objs: - - default_chart_config = schemas.ChartConfig( - xAxisKey="name", - yAxisKey="value" - ) - - report_data = obj.result_data if isinstance(obj.result_data, list) else [] - - report = schemas.Report( - id=str(obj.result_id), + for report_obj, result_obj, stmt_obj in rows: + # 处理 chart_config (从数据库JSON转为Pydantic对象) + c_config = None + if report_obj.chart_config: + # 兼容处理:确保 chart_config 是字典 + config_dict = report_obj.chart_config if isinstance(report_obj.chart_config, dict) else {} + # 只有当字典不为空且包含必要的键时才转换 + if config_dict.get('xAxisKey') and config_dict.get('yAxisKey'): + c_config = schemas.ChartConfig(**config_dict) + + reports.append(schemas.Report( + id=str(report_obj.report_id), projectId=str(project_id), - name=obj.data_summary[:20] if obj.data_summary else f"Report-{obj.result_id}", - type=obj.chart_type or "table", - description=obj.data_summary, - data=report_data, - chartConfig=default_chart_config, - sourceQueryText="SELECT * FROM ...", - updatedAt=obj.cached_at.isoformat() if obj.cached_at else "" - ) - reports.append(report) + name=report_obj.name, + type=report_obj.chart_type, + description=report_obj.description, + data=result_obj.result_data if isinstance(result_obj.result_data, list) else [], + chartConfig=c_config, + sourceQueryText=stmt_obj.sql_text, + updatedAt=report_obj.updated_at.isoformat() if report_obj.updated_at else report_obj.created_at.isoformat() + )) return reports # -------------------------- -# 2. 获取报表详情 -# -------------------------- -async def get_report_data_by_id_service(db: Session, query_id: int) -> Dict[str, Any]: - - raw = await crud_report.get_report_data_by_result_id(db, query_id) - - if not raw: - raise ItemNotFoundException(f"Query result with ID {query_id} not found.") - - query_result, statement = raw - - return { - "result_data": query_result.result_data, - "data_summary": query_result.data_summary, - "chart_type": query_result.chart_type, - "sql_text": statement.sql_text, - "cached_at": query_result.cached_at.isoformat() if query_result.cached_at else None - } - - -# -------------------------- -# 3. 获取历史查询记录 +# 2. 获取历史查询记录 (Source for creating reports) # -------------------------- async def get_history_queries_service(db: Session, project_id: int) -> List[schemas.HistoryQuery]: + """ + 获取历史查询结果。 + 修正逻辑:关联 Message 表,返回用户原始的自然语言问题。 + """ + # [修复问题1]:先检查项目是否存在 + project = await db.get(Project, project_id) + if not project: + raise HTTPException(status_code=404, detail=f"Project with ID {project_id} not found") + + # [修复问题2]:局部导入以避免 UnboundLocalError 和循环依赖 + from models.message import Message + from models.session import Session as SessionModel + + # 构造查询:Query Result -> Statement -> Message -> Session + # 我们需要 Message.content (用户问题) 和 QueryResult (数据) + stmt = ( + select(QueryResult, Message) + .join(AIGeneratedStatement, QueryResult.statement_id == AIGeneratedStatement.statement_id) + .join(Message, AIGeneratedStatement.message_id == Message.message_id) + .join(SessionModel, Message.session_id == SessionModel.session_id) + .where(SessionModel.project_id == project_id) + .where(Message.message_type == 'user') # 确保我们取的是用户发的消息(提问) + .order_by(QueryResult.cached_at.desc()) + ) - db_objs = await crud_report.get_history_by_project(db, project_id) - history = [] + result = await db.execute(stmt) + rows = result.all() - for obj in db_objs: + history = [] + for query_res, msg_obj in rows: history.append(schemas.HistoryQuery( - id=str(obj.result_id), - projectId=str(project_id), # ← 修复类型错误 - queryText=obj.data_summary or "N/A", - timestamp=obj.cached_at.isoformat() if obj.cached_at else 'N/A', - result=obj.result_data -)) - + id=str(query_res.result_id), + projectId=str(project_id), + # 这里取的是 message.content,即用户的原始提问 + queryText=msg_obj.content, + timestamp=query_res.cached_at.isoformat() if query_res.cached_at else 'N/A', + result=query_res.result_data + )) return history # -------------------------- -# 4. 创建报表(实时渲染,不入库) +# 3. 创建报表 (Create - 真正入库) # -------------------------- async def create_report_service(db: Session, project_id: int, payload: schemas.ReportCreate): - - project_exists = await db.execute(select(Project).where(Project.project_id == project_id)) - if not project_exists.scalars().first(): + # 1. 校验项目 + project_exists = await db.get(Project, project_id) + if not project_exists: raise HTTPException(status_code=404, detail="Project not found") - raw = await crud_report.get_report_data_by_result_id(db, payload.query_id) - if not raw: - raise HTTPException(status_code=404, detail="Query result not found") + + # 2. 校验数据源 (Query Result) 是否存在 + result_exists = await db.get(QueryResult, payload.query_id) + if not result_exists: + raise HTTPException(status_code=404, detail="Query result (Data Source) not found") + + # 3. 创建 AnalysisReport 对象 + new_report = AnalysisReport( + project_id=project_id, + result_id=payload.query_id, + name=payload.report_name, + chart_type=payload.chart_type, + description=payload.description, + chart_config=payload.chartConfig.model_dump() if payload.chartConfig else None + ) - query_result, statement = raw + db.add(new_report) + await db.commit() + await db.refresh(new_report) + + # 4. 获取关联 SQL 文本用于返回 + stmt_obj = await db.get(AIGeneratedStatement, result_exists.statement_id) return schemas.Report( - id=str(query_result.result_id), + id=str(new_report.report_id), projectId=str(project_id), - name=payload.report_name, - type=payload.chart_type, - description=query_result.data_summary, - data=query_result.result_data, + name=new_report.name, + type=new_report.chart_type, + description=new_report.description, + data=result_exists.result_data, chartConfig=payload.chartConfig, - sourceQueryText=statement.sql_text, - updatedAt=query_result.cached_at.isoformat() if query_result.cached_at else "" + sourceQueryText=stmt_obj.sql_text if stmt_obj else "", + updatedAt=new_report.created_at.isoformat() ) +# -------------------------- +# 4. 删除报表 (Delete) +# -------------------------- +async def delete_report_service(db: Session, report_id: int) -> bool: + stmt = delete(AnalysisReport).where(AnalysisReport.report_id == report_id) + result = await db.execute(stmt) + await db.commit() + + if result.rowcount == 0: + return False + return True + # -------------------------- -# 6. 导出报表 +# 5. 修改报表 (Update) # -------------------------- -async def export_report_service(db: Session, report_id: int, format: str): +async def update_report_service(db: Session, report_id: int, payload: schemas.ReportUpdate): + # 1. 检查是否存在 + report_obj = await db.get(AnalysisReport, report_id) + if not report_obj: + raise HTTPException(status_code=404, detail="Report not found") + + # 2. 更新字段 + if payload.report_name is not None: + report_obj.name = payload.report_name + if payload.chart_type is not None: + report_obj.chart_type = payload.chart_type + if payload.description is not None: + report_obj.description = payload.description + if payload.chartConfig is not None: + report_obj.chart_config = payload.chartConfig.model_dump() + + await db.commit() + await db.refresh(report_obj) + + # 3. 组装返回数据 + result_obj = await db.get(QueryResult, report_obj.result_id) + stmt_obj = await db.get(AIGeneratedStatement, result_obj.statement_id) + + c_config = None + if report_obj.chart_config: + config_dict = report_obj.chart_config if isinstance(report_obj.chart_config, dict) else {} + if config_dict.get('xAxisKey') and config_dict.get('yAxisKey'): + c_config = schemas.ChartConfig(**config_dict) + + return schemas.Report( + id=str(report_obj.report_id), + projectId=str(report_obj.project_id), + name=report_obj.name, + type=report_obj.chart_type, + description=report_obj.description, + data=result_obj.result_data, + chartConfig=c_config, + sourceQueryText=stmt_obj.sql_text, + updatedAt=report_obj.updated_at.isoformat() if report_obj.updated_at else "" + ) - if format not in ("png", "pdf"): - raise HTTPException(status_code=400, detail="format must be png or pdf") +# -------------------------- +# 6. 导出报表 (Export) +# -------------------------- +async def export_report_service(db: Session, report_id: int, format: str): # 检查 report 是否存在 - raw = await crud_report.get_report_data_by_result_id(db, report_id) - if not raw: + report_obj = await db.get(AnalysisReport, report_id) + if not report_obj: raise HTTPException(status_code=404, detail="Report not found") - # 正常返回 return { "download_url": f"https://fake-cdn.example.com/reports/{report_id}.{format}", "expires_at": "2025-12-01T15:00:00Z" - } - + } \ No newline at end of file