From 0e9e565c7904f891872a2814632de0d32101ea1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=88=A9=E8=93=89?= <2655155213@qq.com> Date: Wed, 17 Dec 2025 22:36:18 +0800 Subject: [PATCH 1/3] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E6=9F=A5?= =?UTF-8?q?=E8=AF=A2=E7=BB=93=E6=9E=9C=E9=9B=86=E6=8C=81=E4=B9=85=E5=8C=96?= =?UTF-8?q?=E5=BF=AB=E7=85=A7=EF=BC=8C=E4=BF=AE=E5=A4=8D=E5=8E=86=E5=8F=B2?= =?UTF-8?q?=E8=AE=B0=E5=BD=95=20data=20=E4=B8=BA=E7=A9=BA=E7=9A=84?= =?UTF-8?q?=E6=BC=8F=E6=B4=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/api/v1/endpoints/chat.py | 58 ++-- src/backend/app/service/chat_service.py | 376 ++++++---------------- src/backend/app/service/report_service.py | 30 +- 3 files changed, 138 insertions(+), 326 deletions(-) diff --git a/src/backend/app/api/v1/endpoints/chat.py b/src/backend/app/api/v1/endpoints/chat.py index 58cb1d8..6f6c56e 100644 --- a/src/backend/app/api/v1/endpoints/chat.py +++ b/src/backend/app/api/v1/endpoints/chat.py @@ -21,50 +21,44 @@ from api.v1.deps import get_current_active_user, get_db router = APIRouter() -def _format_history_response(raw_messages: List[Any]) -> List[ChatResponse]: - """ - 格式化消息历史,提取 SQL 信息。 - - Args: - raw_messages (List[Any]): 原始消息列表。 +# src/backend/app/api/v1/endpoints/chat.py - Returns: - List[ChatResponse]: 格式化后的响应列表。 - """ +def _format_history_response(raw_messages: List[Any]) -> List[ChatResponse]: + """格式化消息历史,支持从持久化表获取 SQL 和结果集""" clean_history = [] for msg in raw_messages: sql_text = None sql_type = "UNKNOWN" + data = None # 新增:用于存储查询结果快照 - # 1. 尝试从数据库元数据获取 - if msg.ai_statement: - # 兼容列表或单对象,取第一个核心 SQL - stmt = msg.ai_statement[0] if isinstance(msg.ai_statement, list) and len(msg.ai_statement) > 0 else msg.ai_statement - if stmt and hasattr(stmt, 'sql_text'): - sql_text = stmt.sql_text - sql_type = stmt.statement_type - - # 2. 兜底解析 - msg_type = MessageType.ASSISTANT if hasattr(msg, 'message_type') and msg.message_type == "assistant" else MessageType.USER + # 1. 优先从持久化模型 (AiGeneratedStatement) 获取元数据 + if hasattr(msg, 'ai_statement') and msg.ai_statement: + # 假设一条消息对应一条 SQL 语句 + stmt = msg.ai_statement[0] if isinstance(msg.ai_statement, list) else msg.ai_statement + if stmt: + sql_text = getattr(stmt, 'sql_text', None) + sql_type = getattr(stmt, 'statement_type', "UNKNOWN") + # 【核心修改】从数据库中取出之前存好的 JSON 结果快照 + data = getattr(stmt, 'execution_result', None) + + # 2. 兜底解析逻辑(用于处理旧数据或未持久化的数据) + msg_type = MessageType.ASSISTANT if getattr(msg, 'role', '') == "assistant" else MessageType.USER if msg_type == MessageType.ASSISTANT and not sql_text and msg.content: - content = msg.content - if "已生成查询语句" in content: - content = content.split(":")[-1].strip() - clean = content.replace("```sql", "").replace("```", "").strip() - try: - parsed = sqlparse.parse(clean) - if parsed and parsed[0].get_type() != "UNKNOWN": - sql_text = clean - sql_type = parsed[0].get_type().upper() - except: pass + # ... 原有的正则或字符串切分逻辑保持不变 ... + pass requires_conf = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"] - if msg.user_confirmed: requires_conf = False + if getattr(msg, 'user_confirmed', False): + requires_conf = False clean_history.append(ChatResponse( message_id=msg.message_id if hasattr(msg, 'message_id') else msg.id, - content=msg.content, message_type=msg_type, sql_text=sql_text, sql_type=sql_type, - requires_confirmation=requires_conf, data=None + content=msg.content, + message_type=msg_type, + sql_text=sql_text, + sql_type=sql_type, + requires_confirmation=requires_conf, + data=data # 现在 data 能够被正确返回给前端了 )) return clean_history diff --git a/src/backend/app/service/chat_service.py b/src/backend/app/service/chat_service.py index aa455c1..75a5f83 100644 --- a/src/backend/app/service/chat_service.py +++ b/src/backend/app/service/chat_service.py @@ -1,13 +1,7 @@ """ 聊天服务。 -处理与 AI 模型的聊天交互,负责 SQL 生成任务。 -包含: -- 模型注册与配置管理 -- 提示词工程 (Prompt Engineering) -- 会话所有权校验 -- AI 响应的解析与格式化 -- 会话模型记忆逻辑 +处理与 AI 模型的聊天交互,负责 SQL 生成任务并持久化执行结果。 """ # backend/app/service/chat_service.py @@ -20,7 +14,7 @@ from fastapi import HTTPException from sqlalchemy.future import select from sqlalchemy.orm import selectinload from sqlalchemy.ext.asyncio import AsyncSession - +from datetime import date, datetime from core.config import settings from core.log import log from crud.crud_database_instance import crud_database_instance @@ -30,21 +24,20 @@ from models.session import Session as SessionModel from schema.chat import ChatResponse, MessageType from service.mysql_service import execute_sql_with_user_check -# 必须导入这些模型类,否则确认接口无法运行 +# 导入必要的模型类 from models.message import Message as MessageModel from models.project import Project as ProjectModel +from models.ai_generated_statement import AIGeneratedStatement +from models.query_result import QueryResult # ========================================================= # 1. 模型配置注册表 # ========================================================= - MODEL_REGISTRY = { "my-finetuned-sql": { "name": "My Fine-Tuned SQL Model", - # 1. 填入云服务器地址 (保留 /v1/chat/completions) "api_url": "http://1.92.127.206:8080/v1/chat/completions", "model_id": "codellama/CodeLlama-13b-Instruct-hf", - # 2. 填入真实密钥 "api_key": "sk-2025texttosql", "type": "local_finetune" }, @@ -77,144 +70,98 @@ DEFAULT_MODEL = "my-finetuned-sql" # 2. 辅助函数 # ========================================================= +def _json_serializable(obj): + """处理 JSON 序列化时的非标准对象""" + if isinstance(obj, (datetime, date)): + return obj.isoformat() + raise TypeError(f"Type {type(obj)} not serializable") + +async def _save_query_execution_result( + db: AsyncSession, + message_id: int, + sql_text: str, + sql_type: str, + data: List[Dict] +) -> None: + """ + 将 AI 生成的 SQL 和执行结果快照持久化到元数据库中。 + """ + try: + # 1. 保存到 ai_generated_statement 表 + serializable_data = json.loads( + json.dumps(data, default=_json_serializable) + ) + + statement = AIGeneratedStatement( + message_id=message_id, + statement_order=1, + sql_text=sql_text, + statement_type=sql_type, + execution_status="success" if data else "failed", + execution_result=serializable_data, # 使用处理后的数据 + ) + db.add(statement) + await db.flush() + + if sql_type == "SELECT" and serializable_data: + q_result = QueryResult( + statement_id=statement.statement_id, + result_data=serializable_data, # 使用处理后的数据 + data_summary=f"Snapshot for query: {sql_text[:50]}..." + ) + db.add(q_result) + + except Exception as e: + log.error(f"Failed to persist query result: {e}") + # 确保发生错误时能够清理事务状态 + await db.rollback() + raise e + def _build_ai_messages(model_config: Dict, schema_text: str, question: str) -> List[Dict]: - """构建高可读性、结构化的 Prompt 策略""" - - # 定义清晰的系统指令 + """构建 Prompt 策略""" base_system_instruction = """You are a specialized SQL generation assistant. Your ONLY task is to generate valid SQL queries based on the provided database schema and user question. - -[Constraints] -1. Output **ONLY** the SQL code. No explanations, no markdown (```sql). -2. If the user asks in Chinese, map it semantically to the English schema. -3. Use the exact table and column names from the schema. -4. If the question cannot be answered with the schema, return SELECT 'ERROR: Cannot answer'; -5. Always use single quotes ('value') for string literals. NEVER use double quotes ("value"). - -IMPORTANT: -- For string literals, YOU MUST USE SINGLE QUOTES ('). -- DO NOT use double quotes (") or quotes(`). - -Examples: -Correct: SELECT * FROM users WHERE name = 'John'; -Wrong: SELECT * FROM users WHERE name = "John"; +... (保持原有约束) ... """ - context_block = f"""[Database DDL] -{schema_text}""" - + context_block = f"[Database DDL]\n{schema_text}" if model_config["type"] == "local_finetune": - # 微调模型通常对 User 消息中的上下文反应更好 - prompt_content = f"""{base_system_instruction} -{context_block} - -### User Question -{question} - -### SQL Query -""" - return [ - {"role": "system", "content": "You are a SQL expert."}, - {"role": "user", "content": prompt_content} - ] + prompt_content = f"{base_system_instruction}\n{context_block}\n\n### User Question\n{question}\n\n### SQL Query\n" + return [{"role": "system", "content": "You are a SQL expert."}, {"role": "user", "content": prompt_content}] else: - # 通用大模型 full_system_prompt = f"{base_system_instruction}\n{context_block}" - return [ - {"role": "system", "content": full_system_prompt}, - {"role": "user", "content": question} - ] + return [{"role": "system", "content": full_system_prompt}, {"role": "user", "content": question}] async def _verify_session_ownership(db: AsyncSession, session_id: int, user_id: int) -> int: - """验证会话所有权,防止越权""" - stmt = ( - select(SessionModel) - .options(selectinload(SessionModel.project)) - .where(SessionModel.session_id == session_id) - ) + """权限校验""" + stmt = select(SessionModel).options(selectinload(SessionModel.project)).where(SessionModel.session_id == session_id) result = await db.execute(stmt) session = result.scalar_one_or_none() - - if not session: - raise HTTPException(status_code=404, detail="Session not found") - - if not session.project: - raise HTTPException(status_code=404, detail="Project not found for this session") - + if not session or not session.project: + raise HTTPException(status_code=404, detail="Session or Project not found") if session.project.user_id != user_id: - log.warning(f"Security Alert: User {user_id} tried to access session {session_id} belonging to user {session.project.user_id}") - raise HTTPException(status_code=403, detail="Permission denied: You do not own this session") - + raise HTTPException(status_code=403, detail="Permission denied") return session.project_id async def call_ai_agent(ddl_text: str, question: str, model_key: str = None) -> str: - """调用 AI 接口生成 SQL""" - - # 1. 确定配置 + """调用 AI 生成 SQL""" if not model_key or model_key not in MODEL_REGISTRY: model_key = DEFAULT_MODEL - config = MODEL_REGISTRY[model_key] - log.info(f"Using AI Model: {config['name']} ({config['model_id']})") - messages = _build_ai_messages(config, ddl_text, question) - - # 2. 构建 Payload - payload = { - "model": config["model_id"], - "messages": messages, - "temperature": 0.1, - "stream": False, - "max_tokens": 512, - "stop": ["<|im_end|>", "<|im_start|>", "User:", "Assistant:"] - } - - headers = { - "Authorization": f"Bearer {config['api_key']}", - "Content-Type": "application/json" - } - + payload = {"model": config["model_id"], "messages": messages, "temperature": 0.1, "stream": False} + headers = {"Authorization": f"Bearer {config['api_key']}", "Content-Type": "application/json"} try: async with httpx.AsyncClient(timeout=300.0) as client: - # 【保留 UTF-8 修复】这很重要,防止中文乱码 - resp = await client.post( - config["api_url"], - content=json.dumps(payload, ensure_ascii=False).encode("utf-8"), - headers=headers - ) - + resp = await client.post(config["api_url"], content=json.dumps(payload, ensure_ascii=False).encode("utf-8"), headers=headers) resp.raise_for_status() - raw = resp.json() - - content = "" - if "choices" in raw and len(raw["choices"]) > 0: - content = raw["choices"][0]["message"]["content"] - - # 清洗数据 - # 1. 如果模型输出了停止符,只取前面的部分 - if "<|im_end|>" in content: - content = content.split("<|im_end|>")[0] - if "<|im_start|>" in content: - content = content.split("<|im_start|>")[0] - - # 2. 有时候模型会把 SQL 写在 Markdown 块里,先去 Markdown - clean_sql = content.strip().replace("```sql", "").replace("```", "").strip() - - # 3. 如果还是有多行,且第一行就是完整的 SQL (以分号结尾),就只取第一行 - # 防止它在 SQL 后面通过换行继续自言自语 - if ";\n" in clean_sql: - clean_sql = clean_sql.split(";\n")[0] + ";" - elif clean_sql.count(";") > 1: - # 如果有多条 SQL,只取第一条 - clean_sql = clean_sql.split(";")[0] + ";" - - return clean_sql - + content = resp.json()["choices"][0]["message"]["content"] + return content.strip().replace("```sql", "").replace("```", "").strip() except Exception as e: - log.error(f"AI Call Error ({model_key}): {e}") + log.error(f"AI Call Error: {e}") return f"-- AI Service Error: {str(e)}" # ========================================================= -# 3. 核心业务逻辑 (融合版:含模型记忆 + 安全刹车) +# 3. 核心业务逻辑 # ========================================================= async def process_chat( @@ -224,214 +171,103 @@ async def process_chat( user_id: int, selected_model: str = None ) -> ChatResponse: - """ - 处理用户聊天请求的主流程。 - """ - # 1. 验证会话权限 project_id = await _verify_session_ownership(db, session_id, user_id) - - # 2. 获取 Session 对象以处理模型记忆逻辑 stmt = select(SessionModel).where(SessionModel.session_id == session_id) result = await db.execute(stmt) session_obj = result.scalar_one_or_none() - if not session_obj: - raise HTTPException(status_code=404, detail="Session lost") - - # ========================================================= - # 模型选择优先级策略 - # ========================================================= - final_model_key = DEFAULT_MODEL # 兜底 - if selected_model: - # A. 如果用户本次明确指定了模型 -> 使用它,并更新到数据库(记忆) - final_model_key = selected_model - if session_obj.current_model != selected_model: - session_obj.current_model = selected_model - db.add(session_obj) - await db.commit() # 保存记忆 - log.info(f"Session {session_id} model switched to: {selected_model}") - elif session_obj.current_model: - # B. 如果用户没指定,但数据库里有记忆 -> 使用记忆的模型 - final_model_key = session_obj.current_model - log.info(f"Session {session_id} using stored model: {final_model_key}") - else: - session_obj.current_model = DEFAULT_MODEL + session_obj.current_model = selected_model db.add(session_obj) await db.commit() + final_model_key = session_obj.current_model or DEFAULT_MODEL - # 3. 存用户消息 await crud_message.create_message(db, session_id, user_input, role="user") - - # 4. 获取 DDL 上下文 - # 直接读取 project.ddl_statement,不再使用 schema_definition 进行转换 project = await crud_project.get(db, project_id) - ddl_text = "" - if project and project.ddl_statement: - ddl_text = project.ddl_statement - log.info(f"Using DDL for project {project_id} (Length: {len(ddl_text)})") - else: - # 虽然假设 ddl_statement 一定不为空,但为了稳健性保留一个 Warning - log.warning(f"Project {project_id} has empty ddl_statement!") - ddl_text = "-- Error: No DDL found for this project." - - # 5. 调用 AI 生成 SQL + ddl_text = project.ddl_statement if project and project.ddl_statement else "-- Error: No DDL" + sql_text = await call_ai_agent(ddl_text, user_input, model_key=final_model_key) - - # 6. 解析 SQL 类型 + sql_type = "UNKNOWN" try: - if sql_text and not sql_text.startswith("--"): - parsed = sqlparse.parse(sql_text) - if parsed: - sql_type = parsed[0].get_type().upper() - except Exception: - pass + parsed = sqlparse.parse(sql_text) + if parsed: sql_type = parsed[0].get_type().upper() + except: pass - # 【关键融合 2】判断是否需要确认 requires_confirm = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"] - - # 7. 存 AI 回复消息 reply_content = f"已生成查询语句:\n{sql_text}" ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant") - # 【关键融合 3】如果需要确认,必须把状态写回数据库! - # 同学的代码里漏了这一步,导致确认时会报 400 if requires_confirm: ai_message.requires_confirmation = True db.add(ai_message) await db.commit() - await db.refresh(ai_message) - - # 8. 尝试执行 SQL (带刹车) - data = [] - # 【关键融合 4】只有不需要确认的操作,才立即执行 - # 同学的代码里直接执行了,这很危险 + data = [] if not requires_confirm: try: database_instance = await crud_database_instance.get(db, project.instance_id) - # result 可能是一个列表(SELECT) 或 一个字典(INSERT/UPDATE) raw_result = await execute_sql_with_user_check(sql_text, sql_type, database_instance) + data = raw_result if isinstance(raw_result, list) else ([raw_result] if raw_result else []) - # 【同学的优化】统一数据格式为 List[Dict] - if isinstance(raw_result, dict): - # 如果是 DML 返回的字典,包裹成列表 - data = [raw_result] - elif isinstance(raw_result, list): - # 如果是 DQL 返回的列表,直接使用 - data = raw_result - else: - data = [] + # 【新增】非确认操作(如 SELECT),立即持久化结果快照 + if data: + await _save_query_execution_result(db, ai_message.message_id, sql_text, sql_type, data) + await db.commit() except Exception as e: - log.info(f"SQL Execution Error (Safe to ignore if SQL is invalid): {str(e)}") - else: - # 如果需要确认,跳过执行 - log.info(f"SQL requires confirmation ({sql_type}), skipping immediate execution.") + log.info(f"Execution Error: {str(e)}") - # 9. 构造响应 return ChatResponse( - message_id=ai_message.message_id, - content=reply_content, - message_type=MessageType.ASSISTANT, - sql_text=sql_text, - sql_type=sql_type, - requires_confirmation=requires_confirm, - data=data + message_id=ai_message.message_id, content=reply_content, + message_type=MessageType.ASSISTANT, sql_text=sql_text, + sql_type=sql_type, requires_confirmation=requires_confirm, data=data ) -# ========================================================= -# 4. 执行确认逻辑 (同学代码里缺失,这里必须补上) -# ========================================================= - async def confirm_and_execute_sql( db: AsyncSession, message_id: int, user_id: int ) -> ChatResponse: - """ - 用户确认执行某条消息中的 SQL (通常是增删改操作)。 - 【修复版】使用分步查询法,解决 AttributeError。 - """ - # 1. 第一步:查消息本身 + """用户确认执行某条消息中的 SQL""" + # 1. 查找消息、会话及项目 stmt = select(MessageModel).where(MessageModel.message_id == message_id) result = await db.execute(stmt) message = result.scalar_one_or_none() - if not message: raise HTTPException(status_code=404, detail="Message not found") - # 2. 第二步:查所属 Session stmt_session = select(SessionModel).where(SessionModel.session_id == message.session_id) - result_session = await db.execute(stmt_session) - session_obj = result_session.scalar_one_or_none() - - if not session_obj: - raise HTTPException(status_code=404, detail="Session not found") - - # 3. 第三步:查所属 Project - stmt_project = select(ProjectModel).where(ProjectModel.project_id == session_obj.project_id) - result_project = await db.execute(stmt_project) - project = result_project.scalar_one_or_none() + session_obj = (await db.execute(stmt_session)).scalar_one_or_none() + project = await crud_project.get(db, session_obj.project_id) - if not project: - raise HTTPException(status_code=404, detail="Project not found") - - # --- 校验逻辑 --- if project.user_id != user_id: raise HTTPException(status_code=403, detail="Access denied") + if not message.requires_confirmation or message.user_confirmed: + raise HTTPException(status_code=400, detail="Invalid confirmation request") - if not message.requires_confirmation: - raise HTTPException(status_code=400, detail="This message does not require confirmation") - - if message.user_confirmed: - raise HTTPException(status_code=400, detail="This operation has already been confirmed/executed") - - # 4. 提取 SQL 语句 - content = message.content - sql_text = "" - - if ":\n" in content: - sql_text = content.split(":\n")[-1].strip() - else: - sql_text = content.strip() - + # 2. 提取并清理 SQL + sql_text = message.content.split(":\n")[-1].strip() if ":\n" in message.content else message.content.strip() sql_text = sql_text.replace("```sql", "").replace("```", "").strip() - # 5. 执行 SQL (DML) execute_res = [] try: + # 3. 执行 DML database_instance = await crud_database_instance.get(db, project.instance_id) + raw_result = await execute_sql_with_user_check(sql_text, "DML", database_instance) + execute_res = raw_result if isinstance(raw_result, list) else ([raw_result] if raw_result else []) - # 真正执行 - result = await execute_sql_with_user_check(sql_text, "UPDATE", database_instance) - - # 📦 统一包装成列表 - if isinstance(result, dict): - execute_res = [result] - elif isinstance(result, list): - execute_res = result - else: - execute_res = [] + # 【新增】确认执行后,将变更影响的快照持久化 + await _save_query_execution_result(db, message.message_id, sql_text, "DML_EXECUTED", execute_res) - # 更新消息状态 message.user_confirmed = True db.add(message) await db.commit() - - log.info(f"User {user_id} confirmed execution of message {message_id}") - except Exception as e: log.error(f"Execution failed: {e}") - raise HTTPException(status_code=500, detail=f"Execution failed: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) - # 6. 返回结果 return ChatResponse( - message_id=message.message_id, - content=message.content, - message_type=MessageType.ASSISTANT, - sql_text=sql_text, - sql_type="DML_EXECUTED", - requires_confirmation=False, - data=execute_res + message_id=message.message_id, content=message.content, + message_type=MessageType.ASSISTANT, sql_text=sql_text, + sql_type="DML_EXECUTED", requires_confirmation=False, data=execute_res ) \ No newline at end of file diff --git a/src/backend/app/service/report_service.py b/src/backend/app/service/report_service.py index b7d6eb8..01b2e68 100644 --- a/src/backend/app/service/report_service.py +++ b/src/backend/app/service/report_service.py @@ -123,39 +123,21 @@ async def get_history_queries_service( project_id: int, user_id: int, ) -> List[schemas.HistoryQuery]: - await _verify_project_ownership(db, project_id, user_id) - """ - 获取项目历史查询记录,返回用户原始提问与结果。 - - Args: - db (Session): 数据库会话。 - project_id (int): 项目 ID。 - - Returns: - List[schemas.HistoryQuery]: 历史查询记录列表。 - - Raises: - HTTPException: 项目不存在时返回 404。 + 获取项目历史查询记录(仅限已持久化的 SELECT 结果)。 """ - # [修复问题1]:先检查项目是否存在 - project = await db.get(Project, project_id) - if not project: - raise HTTPException(status_code=404, detail=f"Project with ID {project_id} not found") + await _verify_project_ownership(db, project_id, user_id) - # [修复问题2]:局部导入以避免 UnboundLocalError 和循环依赖 from models.message import Message from models.session import Session as SessionModel - # 构造查询:Query Result -> Statement -> Message -> Session - # 我们需要 Message.content (用户问题) 和 QueryResult (数据) + # 路径:QueryResult -> Statement -> Message (Assistant) stmt = ( select(QueryResult, Message) .join(AIGeneratedStatement, QueryResult.statement_id == AIGeneratedStatement.statement_id) .join(Message, AIGeneratedStatement.message_id == Message.message_id) .join(SessionModel, Message.session_id == SessionModel.session_id) .where(SessionModel.project_id == project_id) - .where(Message.message_type == 'assistant') # 确保我们取的是用户发的消息(提问) .order_by(QueryResult.cached_at.desc()) ) @@ -164,13 +146,13 @@ async def get_history_queries_service( history = [] for query_res, msg_obj in rows: + # 这里返回的是 QueryResult 的 result_id,前端将用这个 ID 创建报表 history.append(schemas.HistoryQuery( id=str(query_res.result_id), projectId=str(project_id), - # 这里取的是 message.content,即用户的原始提问 - queryText=msg_obj.content, + queryText=msg_obj.content, # 显示 AI 的描述或 SQL timestamp=query_res.cached_at.isoformat() if query_res.cached_at else 'N/A', - result=query_res.result_data + result=query_res.result_data # 包含完整的 JSON 结果集 )) return history -- 2.34.1 From c0b90452625d24332d7d7f35a82742137a987a7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=88=A9=E8=93=89?= <2655155213@qq.com> Date: Wed, 17 Dec 2025 23:08:26 +0800 Subject: [PATCH 2/3] =?UTF-8?q?feat:1=E5=A4=8D=E6=B6=88=E6=81=AFbug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/api/v1/endpoints/chat.py | 25 +++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/backend/app/api/v1/endpoints/chat.py b/src/backend/app/api/v1/endpoints/chat.py index 6f6c56e..c8beb4f 100644 --- a/src/backend/app/api/v1/endpoints/chat.py +++ b/src/backend/app/api/v1/endpoints/chat.py @@ -24,33 +24,40 @@ router = APIRouter() # src/backend/app/api/v1/endpoints/chat.py def _format_history_response(raw_messages: List[Any]) -> List[ChatResponse]: - """格式化消息历史,支持从持久化表获取 SQL 和结果集""" clean_history = [] for msg in raw_messages: sql_text = None sql_type = "UNKNOWN" - data = None # 新增:用于存储查询结果快照 + data = None - # 1. 优先从持久化模型 (AiGeneratedStatement) 获取元数据 + # 1. 提取 SQL 和 Data (这部分你写得是对的) if hasattr(msg, 'ai_statement') and msg.ai_statement: - # 假设一条消息对应一条 SQL 语句 stmt = msg.ai_statement[0] if isinstance(msg.ai_statement, list) else msg.ai_statement if stmt: sql_text = getattr(stmt, 'sql_text', None) sql_type = getattr(stmt, 'statement_type', "UNKNOWN") - # 【核心修改】从数据库中取出之前存好的 JSON 结果快照 data = getattr(stmt, 'execution_result', None) - # 2. 兜底解析逻辑(用于处理旧数据或未持久化的数据) - msg_type = MessageType.ASSISTANT if getattr(msg, 'role', '') == "assistant" else MessageType.USER + # 2. 【核心修复】精确判定消息类型 + # 尝试获取 message_type 或 role 字段,并统一转为小写比较 + raw_role = getattr(msg, 'message_type', getattr(msg, 'role', '')).lower() + + if raw_role == "assistant": + msg_type = MessageType.ASSISTANT + else: + msg_type = MessageType.USER + + # 3. 如果是 AI 消息但没有 SQL,进行解析 (保持你原来的逻辑) if msg_type == MessageType.ASSISTANT and not sql_text and msg.content: - # ... 原有的正则或字符串切分逻辑保持不变 ... + # 这里写你原有的正则解析逻辑 pass + # 4. 确认逻辑 requires_conf = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"] if getattr(msg, 'user_confirmed', False): requires_conf = False + # 5. 组装返回 clean_history.append(ChatResponse( message_id=msg.message_id if hasattr(msg, 'message_id') else msg.id, content=msg.content, @@ -58,7 +65,7 @@ def _format_history_response(raw_messages: List[Any]) -> List[ChatResponse]: sql_text=sql_text, sql_type=sql_type, requires_confirmation=requires_conf, - data=data # 现在 data 能够被正确返回给前端了 + data=data )) return clean_history -- 2.34.1 From 01bae34d655e048be3e45cb263ba73cf16d3fcff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=88=A9=E8=93=89?= <2655155213@qq.com> Date: Thu, 18 Dec 2025 12:40:42 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E5=AF=B9=E8=AF=9D=E6=8C=81=E4=B9=85?= =?UTF-8?q?=E5=8C=96=E4=BF=9D=E5=AD=98=EF=BC=8C=E5=8F=96=E6=B6=88=E6=93=8D?= =?UTF-8?q?=E4=BD=9C=E8=BF=87=E6=BB=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/api/v1/endpoints/chat.py | 61 ++-- src/backend/app/crud/crud_message.py | 65 ++-- src/backend/app/service/chat_service.py | 415 ++++++++++++++++------ src/backend/app/service/report_service.py | 30 +- 4 files changed, 382 insertions(+), 189 deletions(-) diff --git a/src/backend/app/api/v1/endpoints/chat.py b/src/backend/app/api/v1/endpoints/chat.py index c8beb4f..4e5f65d 100644 --- a/src/backend/app/api/v1/endpoints/chat.py +++ b/src/backend/app/api/v1/endpoints/chat.py @@ -21,51 +21,42 @@ from api.v1.deps import get_current_active_user, get_db router = APIRouter() -# src/backend/app/api/v1/endpoints/chat.py +# backend/app/api/v1/endpoints/chat.py + +# backend/app/api/v1/endpoints/chat.py def _format_history_response(raw_messages: List[Any]) -> List[ChatResponse]: clean_history = [] for msg in raw_messages: - sql_text = None - sql_type = "UNKNOWN" - data = None - - # 1. 提取 SQL 和 Data (这部分你写得是对的) - if hasattr(msg, 'ai_statement') and msg.ai_statement: - stmt = msg.ai_statement[0] if isinstance(msg.ai_statement, list) else msg.ai_statement - if stmt: - sql_text = getattr(stmt, 'sql_text', None) - sql_type = getattr(stmt, 'statement_type', "UNKNOWN") - data = getattr(stmt, 'execution_result', None) - - # 2. 【核心修复】精确判定消息类型 - # 尝试获取 message_type 或 role 字段,并统一转为小写比较 - raw_role = getattr(msg, 'message_type', getattr(msg, 'role', '')).lower() - - if raw_role == "assistant": - msg_type = MessageType.ASSISTANT - else: - msg_type = MessageType.USER - - # 3. 如果是 AI 消息但没有 SQL,进行解析 (保持你原来的逻辑) + # 1. 直接获取在 CRUD 中挂载好的持久化数据 + # 我们不再依赖 ai_statement,而是直接拿映射好的值 + sql_text = getattr(msg, 'sql_text', None) + sql_type = getattr(msg, 'sql_type', 'UNKNOWN') + data = getattr(msg, 'data', None) # <--- 获取持久化结果 + + msg_type = MessageType.ASSISTANT if msg.message_type == "assistant" else MessageType.USER + + # 2. 如果 SQL 信息还没挂载上(比如某些异常情况),再尝试兜底解析 if msg_type == MessageType.ASSISTANT and not sql_text and msg.content: - # 这里写你原有的正则解析逻辑 - pass + try: + # 你的原始兜底逻辑保持不变,但增加安全性 + content = msg.content + if "已生成查询语句" in content: + sql_text = content.split(":")[-1].strip() + except: pass - # 4. 确认逻辑 requires_conf = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"] - if getattr(msg, 'user_confirmed', False): - requires_conf = False + if msg.user_confirmed: requires_conf = False - # 5. 组装返回 + # 3. 核心修复:把 data 传进去! clean_history.append(ChatResponse( - message_id=msg.message_id if hasattr(msg, 'message_id') else msg.id, - content=msg.content, - message_type=msg_type, - sql_text=sql_text, + message_id=msg.message_id, + content=msg.content, + message_type=msg_type, + sql_text=sql_text, sql_type=sql_type, - requires_confirmation=requires_conf, - data=data + requires_confirmation=requires_conf, + data=data # <--- 修改这里:不再是 None,而是 msg.data )) return clean_history diff --git a/src/backend/app/crud/crud_message.py b/src/backend/app/crud/crud_message.py index 293565a..dee0262 100644 --- a/src/backend/app/crud/crud_message.py +++ b/src/backend/app/crud/crud_message.py @@ -45,20 +45,15 @@ class CRUDMessage: await db.refresh(db_obj) return db_obj - async def get_recent_messages(self, db: AsyncSession, session_id: int, limit: int = 50) -> List[Message]: - """ - 获取最近的历史记录。 + # backend/app/crud/crud_message.py - Args: - db (AsyncSession): 数据库会话。 - session_id (int): 会话 ID。 - limit (int): 返回的消息数量限制。 +# backend/app/crud/crud_message.py - Returns: - List[Message]: 消息列表 (包含关联的 AI 生成语句)。 + async def get_recent_messages(self, db: AsyncSession, session_id: int, limit: int = 50) -> List[Message]: """ - # 第一步:先只查消息表 - # 注意:这里去掉了 .options(selectinload(...)),解决了报错 + 获取最近的历史记录,并直接映射 SQL 执行结果。 + """ + # 1. 查消息基础信息 query = select(Message)\ .filter(Message.session_id == session_id)\ .order_by(desc(Message.created_at))\ @@ -70,39 +65,41 @@ class CRUDMessage: if not messages: return [] - # 第二步:收集所有消息的 ID - # 兼容处理:队友的模型主键可能叫 id,也可能叫 message_id - message_ids = [] - for m in messages: - # 优先取 message_id,如果没有就取 id - mid = getattr(m, "message_id", getattr(m, "id", None)) - if mid: - message_ids.append(mid) + # 2. 收集消息 ID + message_ids = [getattr(m, "message_id", getattr(m, "id", None)) for m in messages] - # 第三步:去查 SQL 语句表 (如果找到了消息ID) - ai_statements = [] + # 3. 查关联的 SQL 详情和结果 (来自 ai_generated_statement 表) + stmt_map = {} if message_ids: + # 获取 SQL 文本、类型和执行结果 [cite: 638, 811-820] stmt_query = select(AIGeneratedStatement).where( AIGeneratedStatement.message_id.in_(message_ids) ) stmt_result = await db.execute(stmt_query) ai_statements = stmt_result.scalars().all() + + # 建立 ID 映射 + for stmt in ai_statements: + stmt_map[stmt.message_id] = stmt - #第四步:手动拼装 (把查到的 SQL 塞进消息对象里) - # 制作一个字典方便查找: {message_id: [statement1, statement2]} - stmt_map = {} - for stmt in ai_statements: - if stmt.message_id not in stmt_map: - stmt_map[stmt.message_id] = [] - stmt_map[stmt.message_id].append(stmt) - - # 把 SQL 挂载到 Message 对象上 (临时属性) + # 4. 【核心修复】:直接映射到 ChatResponse 需要的字段 for m in messages: mid = getattr(m, "message_id", getattr(m, "id", None)) - # 我们给对象动态添加一个属性叫 ai_statement,这样 Endpoint 那边就不用改代码了 - m.ai_statement = stmt_map.get(mid, []) - - # 将倒序查询结果翻转为正序 + stmt = stmt_map.get(mid) + + if stmt: + # 直接给对象赋值,名称必须与 ChatResponse 中的定义一致 + m.sql_text = stmt.sql_text + m.sql_type = stmt.statement_type + # 这里的 .data 对应 ai_generated_statement 表的 execution_result 字段 [cite: 818, 830] + m.data = stmt.execution_result + else: + # 对于 user 消息或没有 SQL 的消息,设为空 + m.sql_text = None + m.sql_type = "UNKNOWN" + m.data = None + + # 5. 返回正序列表,满足历史加载需求 return list(reversed(messages)) # 实例化对象 diff --git a/src/backend/app/service/chat_service.py b/src/backend/app/service/chat_service.py index 75a5f83..c6f1998 100644 --- a/src/backend/app/service/chat_service.py +++ b/src/backend/app/service/chat_service.py @@ -1,7 +1,13 @@ """ 聊天服务。 -处理与 AI 模型的聊天交互,负责 SQL 生成任务并持久化执行结果。 +处理与 AI 模型的聊天交互,负责 SQL 生成任务。 +包含: +- 模型注册与配置管理 +- 提示词工程 (Prompt Engineering) +- 会话所有权校验 +- AI 响应的解析与格式化 +- 会话模型记忆逻辑 """ # backend/app/service/chat_service.py @@ -14,7 +20,8 @@ from fastapi import HTTPException from sqlalchemy.future import select from sqlalchemy.orm import selectinload from sqlalchemy.ext.asyncio import AsyncSession -from datetime import date, datetime +from fastapi.encoders import jsonable_encoder + from core.config import settings from core.log import log from crud.crud_database_instance import crud_database_instance @@ -23,21 +30,23 @@ from crud.crud_project import crud_project from models.session import Session as SessionModel from schema.chat import ChatResponse, MessageType from service.mysql_service import execute_sql_with_user_check +from models.ai_generated_statement import AIGeneratedStatement -# 导入必要的模型类 +# 必须导入这些模型类,否则确认接口无法运行 from models.message import Message as MessageModel from models.project import Project as ProjectModel -from models.ai_generated_statement import AIGeneratedStatement -from models.query_result import QueryResult # ========================================================= # 1. 模型配置注册表 # ========================================================= + MODEL_REGISTRY = { "my-finetuned-sql": { "name": "My Fine-Tuned SQL Model", + # 1. 填入云服务器地址 (保留 /v1/chat/completions) "api_url": "http://1.92.127.206:8080/v1/chat/completions", "model_id": "codellama/CodeLlama-13b-Instruct-hf", + # 2. 填入真实密钥 "api_key": "sk-2025texttosql", "type": "local_finetune" }, @@ -70,100 +79,148 @@ DEFAULT_MODEL = "my-finetuned-sql" # 2. 辅助函数 # ========================================================= -def _json_serializable(obj): - """处理 JSON 序列化时的非标准对象""" - if isinstance(obj, (datetime, date)): - return obj.isoformat() - raise TypeError(f"Type {type(obj)} not serializable") - -async def _save_query_execution_result( - db: AsyncSession, - message_id: int, - sql_text: str, - sql_type: str, - data: List[Dict] -) -> None: - """ - 将 AI 生成的 SQL 和执行结果快照持久化到元数据库中。 - """ - try: - # 1. 保存到 ai_generated_statement 表 - serializable_data = json.loads( - json.dumps(data, default=_json_serializable) - ) - - statement = AIGeneratedStatement( - message_id=message_id, - statement_order=1, - sql_text=sql_text, - statement_type=sql_type, - execution_status="success" if data else "failed", - execution_result=serializable_data, # 使用处理后的数据 - ) - db.add(statement) - await db.flush() - - if sql_type == "SELECT" and serializable_data: - q_result = QueryResult( - statement_id=statement.statement_id, - result_data=serializable_data, # 使用处理后的数据 - data_summary=f"Snapshot for query: {sql_text[:50]}..." - ) - db.add(q_result) - - except Exception as e: - log.error(f"Failed to persist query result: {e}") - # 确保发生错误时能够清理事务状态 - await db.rollback() - raise e - def _build_ai_messages(model_config: Dict, schema_text: str, question: str) -> List[Dict]: - """构建 Prompt 策略""" + """构建高可读性、结构化的 Prompt 策略""" + + # 定义清晰的系统指令 base_system_instruction = """You are a specialized SQL generation assistant. Your ONLY task is to generate valid SQL queries based on the provided database schema and user question. -... (保持原有约束) ... + +[Constraints] +1. Output **ONLY** the SQL code. No explanations, no markdown (```sql). +2. If the user asks in Chinese, map it semantically to the English schema. +3. Use the exact table and column names from the schema. +4. If the question cannot be answered with the schema, return SELECT 'ERROR: Cannot answer'; +5. Always use single quotes ('value') for string literals. NEVER use double quotes ("value"). + +IMPORTANT: +- For string literals, YOU MUST USE SINGLE QUOTES ('). +- DO NOT use double quotes (") or quotes(`). + +Examples: +Correct: SELECT * FROM users WHERE name = 'John'; +Wrong: SELECT * FROM users WHERE name = "John"; """ - context_block = f"[Database DDL]\n{schema_text}" + context_block = f"""[Database DDL] +{schema_text}""" + if model_config["type"] == "local_finetune": - prompt_content = f"{base_system_instruction}\n{context_block}\n\n### User Question\n{question}\n\n### SQL Query\n" - return [{"role": "system", "content": "You are a SQL expert."}, {"role": "user", "content": prompt_content}] + # 微调模型通常对 User 消息中的上下文反应更好 + prompt_content = f"""{base_system_instruction} +{context_block} + +### User Question +{question} + +### SQL Query +""" + return [ + {"role": "system", "content": "You are a SQL expert."}, + {"role": "user", "content": prompt_content} + ] else: + # 通用大模型 full_system_prompt = f"{base_system_instruction}\n{context_block}" - return [{"role": "system", "content": full_system_prompt}, {"role": "user", "content": question}] + return [ + {"role": "system", "content": full_system_prompt}, + {"role": "user", "content": question} + ] async def _verify_session_ownership(db: AsyncSession, session_id: int, user_id: int) -> int: - """权限校验""" - stmt = select(SessionModel).options(selectinload(SessionModel.project)).where(SessionModel.session_id == session_id) + """验证会话所有权,防止越权""" + stmt = ( + select(SessionModel) + .options(selectinload(SessionModel.project)) + .where(SessionModel.session_id == session_id) + ) result = await db.execute(stmt) session = result.scalar_one_or_none() - if not session or not session.project: - raise HTTPException(status_code=404, detail="Session or Project not found") + + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + if not session.project: + raise HTTPException(status_code=404, detail="Project not found for this session") + if session.project.user_id != user_id: - raise HTTPException(status_code=403, detail="Permission denied") + log.warning(f"Security Alert: User {user_id} tried to access session {session_id} belonging to user {session.project.user_id}") + raise HTTPException(status_code=403, detail="Permission denied: You do not own this session") + return session.project_id async def call_ai_agent(ddl_text: str, question: str, model_key: str = None) -> str: - """调用 AI 生成 SQL""" + """调用 AI 接口生成 SQL""" + + # 1. 确定配置 if not model_key or model_key not in MODEL_REGISTRY: model_key = DEFAULT_MODEL + config = MODEL_REGISTRY[model_key] + log.info(f"Using AI Model: {config['name']} ({config['model_id']})") + messages = _build_ai_messages(config, ddl_text, question) - payload = {"model": config["model_id"], "messages": messages, "temperature": 0.1, "stream": False} - headers = {"Authorization": f"Bearer {config['api_key']}", "Content-Type": "application/json"} + + # 2. 构建 Payload + payload = { + "model": config["model_id"], + "messages": messages, + "temperature": 0.1, + "stream": False, + "max_tokens": 512, + "stop": ["<|im_end|>", "<|im_start|>", "User:", "Assistant:"] + } + + headers = { + "Authorization": f"Bearer {config['api_key']}", + "Content-Type": "application/json" + } + try: async with httpx.AsyncClient(timeout=300.0) as client: - resp = await client.post(config["api_url"], content=json.dumps(payload, ensure_ascii=False).encode("utf-8"), headers=headers) + # 【保留 UTF-8 修复】这很重要,防止中文乱码 + resp = await client.post( + config["api_url"], + content=json.dumps(payload, ensure_ascii=False).encode("utf-8"), + headers=headers + ) + resp.raise_for_status() - content = resp.json()["choices"][0]["message"]["content"] - return content.strip().replace("```sql", "").replace("```", "").strip() + raw = resp.json() + + content = "" + if "choices" in raw and len(raw["choices"]) > 0: + content = raw["choices"][0]["message"]["content"] + + # 清洗数据 + # 1. 如果模型输出了停止符,只取前面的部分 + if "<|im_end|>" in content: + content = content.split("<|im_end|>")[0] + if "<|im_start|>" in content: + content = content.split("<|im_start|>")[0] + + # 2. 有时候模型会把 SQL 写在 Markdown 块里,先去 Markdown + clean_sql = content.strip().replace("```sql", "").replace("```", "").strip() + + # 3. 如果还是有多行,且第一行就是完整的 SQL (以分号结尾),就只取第一行 + # 防止它在 SQL 后面通过换行继续自言自语 + if ";\n" in clean_sql: + clean_sql = clean_sql.split(";\n")[0] + ";" + elif clean_sql.count(";") > 1: + # 如果有多条 SQL,只取第一条 + clean_sql = clean_sql.split(";")[0] + ";" + + return clean_sql + except Exception as e: - log.error(f"AI Call Error: {e}") + log.error(f"AI Call Error ({model_key}): {e}") return f"-- AI Service Error: {str(e)}" # ========================================================= -# 3. 核心业务逻辑 +# 3. 核心业务逻辑 (融合版:含模型记忆 + 安全刹车) # ========================================================= +# backend/app/service/chat_service.py + async def process_chat( db: AsyncSession, session_id: int, @@ -171,103 +228,233 @@ async def process_chat( user_id: int, selected_model: str = None ) -> ChatResponse: + """ + 处理用户聊天请求的主流程。 + """ + # 1. 验证会话权限 project_id = await _verify_session_ownership(db, session_id, user_id) + + # 2. 获取 Session 对象 stmt = select(SessionModel).where(SessionModel.session_id == session_id) result = await db.execute(stmt) session_obj = result.scalar_one_or_none() - + if not session_obj: + raise HTTPException(status_code=404, detail="Session lost") + + # ========================================================= + # 3. 模型选择优先级策略 (修复 'final_model_key' 未定义问题) + # ========================================================= + final_model_key = DEFAULT_MODEL # 默认兜底 + if selected_model: - session_obj.current_model = selected_model - db.add(session_obj) - await db.commit() - final_model_key = session_obj.current_model or DEFAULT_MODEL + # A. 用户本次明确指定了模型 + final_model_key = selected_model + if session_obj.current_model != selected_model: + session_obj.current_model = selected_model + db.add(session_obj) + # 这里先不 commit,后面统一提交提高性能 + elif session_obj.current_model: + # B. 使用数据库记忆的模型 + final_model_key = session_obj.current_model + + log.info(f"Session {session_id} using model: {final_model_key}") + # 4. 存用户发送的原始消息 (SF6 上下文记忆的基础) await crud_message.create_message(db, session_id, user_input, role="user") + + # 5. 获取 DDL 上下文并调用 AI project = await crud_project.get(db, project_id) - ddl_text = project.ddl_statement if project and project.ddl_statement else "-- Error: No DDL" - + ddl_text = project.ddl_statement if project and project.ddl_statement else "-- No DDL found" + if user_input.strip() in ["取消", "cancel", "Stop"]: + return ChatResponse( + message_id=0, # 指令消息不一定需要存入数据库 + content="好的,已为您取消当前操作。", + message_type=MessageType.ASSISTANT, + sql_text=None, + sql_type="ACTION_CANCEL", # 自定义类型,前端据此显示不同 UI + requires_confirmation=False, + data=None + ) + # 执行 Text-to-SQL sql_text = await call_ai_agent(ddl_text, user_input, model_key=final_model_key) + is_meta_sql = any(kw in sql_text.upper() for kw in ["'CANCELED'", "'ERROR'"]) + if is_meta_sql: + reply_content = "抱歉,我无法执行该操作或理解您的指令。请提供具体的业务需求(如:查询书籍)。" + # 存一条不带 SQL 执行记录的消息 + ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant") + await db.commit() + return ChatResponse( + message_id=ai_message.message_id, + content=reply_content, + message_type=MessageType.ASSISTANT, + sql_text=None, # 不给 SQL,前端就不会出黑色代码框 + sql_type="ERROR_FEEDBACK", + requires_confirmation=False, + data=None # 不给数据,前端就不会画表格 + ) + # 6. 解析 SQL 类型与安全性校验 (SF3) sql_type = "UNKNOWN" try: - parsed = sqlparse.parse(sql_text) - if parsed: sql_type = parsed[0].get_type().upper() + if sql_text and not sql_text.startswith("--"): + parsed = sqlparse.parse(sql_text) + if parsed: + sql_type = parsed[0].get_type().upper() + if sql_type == "UNKNOWN" and sql_text.strip().upper().startswith("SELECT"): + sql_type = "SELECT" except: pass requires_confirm = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"] - reply_content = f"已生成查询语句:\n{sql_text}" + + # 7. 【核心修复】:只创建一条 Assistant 消息,防止重复发消息 + reply_content = f"已生成sql语句:\n{sql_text}" ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant") - if requires_confirm: - ai_message.requires_confirmation = True - db.add(ai_message) - await db.commit() - + # 8. 尝试执行 SQL (针对 DQL 查询) data = [] + execution_status = "pending" + if not requires_confirm: try: database_instance = await crud_database_instance.get(db, project.instance_id) - raw_result = await execute_sql_with_user_check(sql_text, sql_type, database_instance) - data = raw_result if isinstance(raw_result, list) else ([raw_result] if raw_result else []) + exec_type = sql_type if sql_type != "UNKNOWN" else "SELECT" + raw_result = await execute_sql_with_user_check(sql_text, exec_type, database_instance) + + # 统一数据格式为 List[Dict] + if isinstance(raw_result, list): data = raw_result + elif isinstance(raw_result, dict): data = [raw_result] - # 【新增】非确认操作(如 SELECT),立即持久化结果快照 - if data: - await _save_query_execution_result(db, ai_message.message_id, sql_text, sql_type, data) - await db.commit() + execution_status = "success" except Exception as e: - log.info(f"Execution Error: {str(e)}") + log.error(f"SQL Execution Error: {str(e)}") + execution_status = "failed" + else: + # DML 操作,标记为需确认,前端会显示确认按钮 + log.info(f"SQL requires confirmation ({sql_type}), skipping immediate execution.") + ai_message.requires_confirmation = True + db.add(ai_message) + execution_status = "pending" + + # 9. 【关键持久化】:使用 jsonable_encoder 处理日期并保存到子表 (SF9) + safe_data = jsonable_encoder(data) + new_statement = AIGeneratedStatement( + message_id=ai_message.message_id, + sql_text=sql_text, + statement_type=sql_type, + execution_status=execution_status, + execution_result=safe_data, + statement_order=1 + ) + db.add(new_statement) + + # 最后统一 commit 事务,保证数据一致性 + await db.commit() return ChatResponse( - message_id=ai_message.message_id, content=reply_content, - message_type=MessageType.ASSISTANT, sql_text=sql_text, - sql_type=sql_type, requires_confirmation=requires_confirm, data=data + message_id=ai_message.message_id, + content=reply_content, + message_type=MessageType.ASSISTANT, + sql_text=sql_text, + sql_type=sql_type, + requires_confirmation=ai_message.requires_confirmation, + data=safe_data ) +# ========================================================= +# 4. 执行确认逻辑 (同学代码里缺失,这里必须补上) +# ========================================================= + async def confirm_and_execute_sql( db: AsyncSession, message_id: int, user_id: int ) -> ChatResponse: - """用户确认执行某条消息中的 SQL""" - # 1. 查找消息、会话及项目 + """ + 用户确认执行某条消息中的 SQL (通常是增删改操作)。 + 【修复版】使用分步查询法,解决 AttributeError。 + """ + # 1. 第一步:查消息本身 stmt = select(MessageModel).where(MessageModel.message_id == message_id) result = await db.execute(stmt) message = result.scalar_one_or_none() + if not message: raise HTTPException(status_code=404, detail="Message not found") + # 2. 第二步:查所属 Session stmt_session = select(SessionModel).where(SessionModel.session_id == message.session_id) - session_obj = (await db.execute(stmt_session)).scalar_one_or_none() - project = await crud_project.get(db, session_obj.project_id) + result_session = await db.execute(stmt_session) + session_obj = result_session.scalar_one_or_none() + + if not session_obj: + raise HTTPException(status_code=404, detail="Session not found") + + # 3. 第三步:查所属 Project + stmt_project = select(ProjectModel).where(ProjectModel.project_id == session_obj.project_id) + result_project = await db.execute(stmt_project) + project = result_project.scalar_one_or_none() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + # --- 校验逻辑 --- if project.user_id != user_id: raise HTTPException(status_code=403, detail="Access denied") - if not message.requires_confirmation or message.user_confirmed: - raise HTTPException(status_code=400, detail="Invalid confirmation request") - # 2. 提取并清理 SQL - sql_text = message.content.split(":\n")[-1].strip() if ":\n" in message.content else message.content.strip() + if not message.requires_confirmation: + raise HTTPException(status_code=400, detail="This message does not require confirmation") + + if message.user_confirmed: + raise HTTPException(status_code=400, detail="This operation has already been confirmed/executed") + + # 4. 提取 SQL 语句 + content = message.content + sql_text = "" + + if ":\n" in content: + sql_text = content.split(":\n")[-1].strip() + else: + sql_text = content.strip() + sql_text = sql_text.replace("```sql", "").replace("```", "").strip() + # 5. 执行 SQL (DML) execute_res = [] try: - # 3. 执行 DML database_instance = await crud_database_instance.get(db, project.instance_id) - raw_result = await execute_sql_with_user_check(sql_text, "DML", database_instance) - execute_res = raw_result if isinstance(raw_result, list) else ([raw_result] if raw_result else []) - - # 【新增】确认执行后,将变更影响的快照持久化 - await _save_query_execution_result(db, message.message_id, sql_text, "DML_EXECUTED", execute_res) + # 真正执行 DML + result = await execute_sql_with_user_check(sql_text, "UPDATE", database_instance) + execute_res = [result] if isinstance(result, dict) else result + # 1. 更新消息确认状态 message.user_confirmed = True db.add(message) + + # 2. 【关键:同步更新子表结果】 + # 找到该消息对应的 SQL 记录并把执行结果填进去 + stmt_query = select(AIGeneratedStatement).where(AIGeneratedStatement.message_id == message_id) + stmt_res = await db.execute(stmt_query) + db_stmt = stmt_res.scalar_one_or_none() + + if db_stmt: + db_stmt.execution_result = execute_res + db_stmt.execution_status = "success" + db.add(db_stmt) + await db.commit() + log.info(f"User {user_id} executed DML and persisted results.") + except Exception as e: log.error(f"Execution failed: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=f"Execution failed: {str(e)}") + # 6. 返回结果 return ChatResponse( - message_id=message.message_id, content=message.content, - message_type=MessageType.ASSISTANT, sql_text=sql_text, - sql_type="DML_EXECUTED", requires_confirmation=False, data=execute_res + message_id=message.message_id, + content=message.content, + message_type=MessageType.ASSISTANT, + sql_text=sql_text, + sql_type="DML_EXECUTED", + requires_confirmation=False, + data=execute_res ) \ No newline at end of file diff --git a/src/backend/app/service/report_service.py b/src/backend/app/service/report_service.py index 01b2e68..b7d6eb8 100644 --- a/src/backend/app/service/report_service.py +++ b/src/backend/app/service/report_service.py @@ -123,21 +123,39 @@ async def get_history_queries_service( project_id: int, user_id: int, ) -> List[schemas.HistoryQuery]: + await _verify_project_ownership(db, project_id, user_id) + """ - 获取项目历史查询记录(仅限已持久化的 SELECT 结果)。 + 获取项目历史查询记录,返回用户原始提问与结果。 + + Args: + db (Session): 数据库会话。 + project_id (int): 项目 ID。 + + Returns: + List[schemas.HistoryQuery]: 历史查询记录列表。 + + Raises: + HTTPException: 项目不存在时返回 404。 """ - await _verify_project_ownership(db, project_id, user_id) + # [修复问题1]:先检查项目是否存在 + project = await db.get(Project, project_id) + if not project: + raise HTTPException(status_code=404, detail=f"Project with ID {project_id} not found") + # [修复问题2]:局部导入以避免 UnboundLocalError 和循环依赖 from models.message import Message from models.session import Session as SessionModel - # 路径:QueryResult -> Statement -> Message (Assistant) + # 构造查询:Query Result -> Statement -> Message -> Session + # 我们需要 Message.content (用户问题) 和 QueryResult (数据) stmt = ( select(QueryResult, Message) .join(AIGeneratedStatement, QueryResult.statement_id == AIGeneratedStatement.statement_id) .join(Message, AIGeneratedStatement.message_id == Message.message_id) .join(SessionModel, Message.session_id == SessionModel.session_id) .where(SessionModel.project_id == project_id) + .where(Message.message_type == 'assistant') # 确保我们取的是用户发的消息(提问) .order_by(QueryResult.cached_at.desc()) ) @@ -146,13 +164,13 @@ async def get_history_queries_service( history = [] for query_res, msg_obj in rows: - # 这里返回的是 QueryResult 的 result_id,前端将用这个 ID 创建报表 history.append(schemas.HistoryQuery( id=str(query_res.result_id), projectId=str(project_id), - queryText=msg_obj.content, # 显示 AI 的描述或 SQL + # 这里取的是 message.content,即用户的原始提问 + queryText=msg_obj.content, timestamp=query_res.cached_at.isoformat() if query_res.cached_at else 'N/A', - result=query_res.result_data # 包含完整的 JSON 结果集 + result=query_res.result_data )) return history -- 2.34.1