diff --git a/src/backend/app/api/v1/endpoints/chat.py b/src/backend/app/api/v1/endpoints/chat.py index 58cb1d8..c8beb4f 100644 --- a/src/backend/app/api/v1/endpoints/chat.py +++ b/src/backend/app/api/v1/endpoints/chat.py @@ -21,50 +21,51 @@ from api.v1.deps import get_current_active_user, get_db router = APIRouter() -def _format_history_response(raw_messages: List[Any]) -> List[ChatResponse]: - """ - 格式化消息历史,提取 SQL 信息。 +# src/backend/app/api/v1/endpoints/chat.py - Args: - raw_messages (List[Any]): 原始消息列表。 - - Returns: - List[ChatResponse]: 格式化后的响应列表。 - """ +def _format_history_response(raw_messages: List[Any]) -> List[ChatResponse]: clean_history = [] for msg in raw_messages: sql_text = None sql_type = "UNKNOWN" + data = None - # 1. 尝试从数据库元数据获取 - if msg.ai_statement: - # 兼容列表或单对象,取第一个核心 SQL - stmt = msg.ai_statement[0] if isinstance(msg.ai_statement, list) and len(msg.ai_statement) > 0 else msg.ai_statement - if stmt and hasattr(stmt, 'sql_text'): - sql_text = stmt.sql_text - sql_type = stmt.statement_type - - # 2. 兜底解析 - msg_type = MessageType.ASSISTANT if hasattr(msg, 'message_type') and msg.message_type == "assistant" else MessageType.USER + # 1. 提取 SQL 和 Data (这部分你写得是对的) + if hasattr(msg, 'ai_statement') and msg.ai_statement: + stmt = msg.ai_statement[0] if isinstance(msg.ai_statement, list) else msg.ai_statement + if stmt: + sql_text = getattr(stmt, 'sql_text', None) + sql_type = getattr(stmt, 'statement_type', "UNKNOWN") + data = getattr(stmt, 'execution_result', None) + + # 2. 【核心修复】精确判定消息类型 + # 尝试获取 message_type 或 role 字段,并统一转为小写比较 + raw_role = getattr(msg, 'message_type', getattr(msg, 'role', '')).lower() + + if raw_role == "assistant": + msg_type = MessageType.ASSISTANT + else: + msg_type = MessageType.USER + + # 3. 如果是 AI 消息但没有 SQL,进行解析 (保持你原来的逻辑) if msg_type == MessageType.ASSISTANT and not sql_text and msg.content: - content = msg.content - if "已生成查询语句" in content: - content = content.split(":")[-1].strip() - clean = content.replace("```sql", "").replace("```", "").strip() - try: - parsed = sqlparse.parse(clean) - if parsed and parsed[0].get_type() != "UNKNOWN": - sql_text = clean - sql_type = parsed[0].get_type().upper() - except: pass + # 这里写你原有的正则解析逻辑 + pass + # 4. 确认逻辑 requires_conf = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"] - if msg.user_confirmed: requires_conf = False + if getattr(msg, 'user_confirmed', False): + requires_conf = False + # 5. 组装返回 clean_history.append(ChatResponse( message_id=msg.message_id if hasattr(msg, 'message_id') else msg.id, - content=msg.content, message_type=msg_type, sql_text=sql_text, sql_type=sql_type, - requires_confirmation=requires_conf, data=None + content=msg.content, + message_type=msg_type, + sql_text=sql_text, + sql_type=sql_type, + requires_confirmation=requires_conf, + data=data )) return clean_history diff --git a/src/backend/app/service/chat_service.py b/src/backend/app/service/chat_service.py index aa455c1..75a5f83 100644 --- a/src/backend/app/service/chat_service.py +++ b/src/backend/app/service/chat_service.py @@ -1,13 +1,7 @@ """ 聊天服务。 -处理与 AI 模型的聊天交互,负责 SQL 生成任务。 -包含: -- 模型注册与配置管理 -- 提示词工程 (Prompt Engineering) -- 会话所有权校验 -- AI 响应的解析与格式化 -- 会话模型记忆逻辑 +处理与 AI 模型的聊天交互,负责 SQL 生成任务并持久化执行结果。 """ # backend/app/service/chat_service.py @@ -20,7 +14,7 @@ from fastapi import HTTPException from sqlalchemy.future import select from sqlalchemy.orm import selectinload from sqlalchemy.ext.asyncio import AsyncSession - +from datetime import date, datetime from core.config import settings from core.log import log from crud.crud_database_instance import crud_database_instance @@ -30,21 +24,20 @@ from models.session import Session as SessionModel from schema.chat import ChatResponse, MessageType from service.mysql_service import execute_sql_with_user_check -# 必须导入这些模型类,否则确认接口无法运行 +# 导入必要的模型类 from models.message import Message as MessageModel from models.project import Project as ProjectModel +from models.ai_generated_statement import AIGeneratedStatement +from models.query_result import QueryResult # ========================================================= # 1. 模型配置注册表 # ========================================================= - MODEL_REGISTRY = { "my-finetuned-sql": { "name": "My Fine-Tuned SQL Model", - # 1. 填入云服务器地址 (保留 /v1/chat/completions) "api_url": "http://1.92.127.206:8080/v1/chat/completions", "model_id": "codellama/CodeLlama-13b-Instruct-hf", - # 2. 填入真实密钥 "api_key": "sk-2025texttosql", "type": "local_finetune" }, @@ -77,144 +70,98 @@ DEFAULT_MODEL = "my-finetuned-sql" # 2. 辅助函数 # ========================================================= +def _json_serializable(obj): + """处理 JSON 序列化时的非标准对象""" + if isinstance(obj, (datetime, date)): + return obj.isoformat() + raise TypeError(f"Type {type(obj)} not serializable") + +async def _save_query_execution_result( + db: AsyncSession, + message_id: int, + sql_text: str, + sql_type: str, + data: List[Dict] +) -> None: + """ + 将 AI 生成的 SQL 和执行结果快照持久化到元数据库中。 + """ + try: + # 1. 保存到 ai_generated_statement 表 + serializable_data = json.loads( + json.dumps(data, default=_json_serializable) + ) + + statement = AIGeneratedStatement( + message_id=message_id, + statement_order=1, + sql_text=sql_text, + statement_type=sql_type, + execution_status="success" if data else "failed", + execution_result=serializable_data, # 使用处理后的数据 + ) + db.add(statement) + await db.flush() + + if sql_type == "SELECT" and serializable_data: + q_result = QueryResult( + statement_id=statement.statement_id, + result_data=serializable_data, # 使用处理后的数据 + data_summary=f"Snapshot for query: {sql_text[:50]}..." + ) + db.add(q_result) + + except Exception as e: + log.error(f"Failed to persist query result: {e}") + # 确保发生错误时能够清理事务状态 + await db.rollback() + raise e + def _build_ai_messages(model_config: Dict, schema_text: str, question: str) -> List[Dict]: - """构建高可读性、结构化的 Prompt 策略""" - - # 定义清晰的系统指令 + """构建 Prompt 策略""" base_system_instruction = """You are a specialized SQL generation assistant. Your ONLY task is to generate valid SQL queries based on the provided database schema and user question. - -[Constraints] -1. Output **ONLY** the SQL code. No explanations, no markdown (```sql). -2. If the user asks in Chinese, map it semantically to the English schema. -3. Use the exact table and column names from the schema. -4. If the question cannot be answered with the schema, return SELECT 'ERROR: Cannot answer'; -5. Always use single quotes ('value') for string literals. NEVER use double quotes ("value"). - -IMPORTANT: -- For string literals, YOU MUST USE SINGLE QUOTES ('). -- DO NOT use double quotes (") or quotes(`). - -Examples: -Correct: SELECT * FROM users WHERE name = 'John'; -Wrong: SELECT * FROM users WHERE name = "John"; +... (保持原有约束) ... """ - context_block = f"""[Database DDL] -{schema_text}""" - + context_block = f"[Database DDL]\n{schema_text}" if model_config["type"] == "local_finetune": - # 微调模型通常对 User 消息中的上下文反应更好 - prompt_content = f"""{base_system_instruction} -{context_block} - -### User Question -{question} - -### SQL Query -""" - return [ - {"role": "system", "content": "You are a SQL expert."}, - {"role": "user", "content": prompt_content} - ] + prompt_content = f"{base_system_instruction}\n{context_block}\n\n### User Question\n{question}\n\n### SQL Query\n" + return [{"role": "system", "content": "You are a SQL expert."}, {"role": "user", "content": prompt_content}] else: - # 通用大模型 full_system_prompt = f"{base_system_instruction}\n{context_block}" - return [ - {"role": "system", "content": full_system_prompt}, - {"role": "user", "content": question} - ] + return [{"role": "system", "content": full_system_prompt}, {"role": "user", "content": question}] async def _verify_session_ownership(db: AsyncSession, session_id: int, user_id: int) -> int: - """验证会话所有权,防止越权""" - stmt = ( - select(SessionModel) - .options(selectinload(SessionModel.project)) - .where(SessionModel.session_id == session_id) - ) + """权限校验""" + stmt = select(SessionModel).options(selectinload(SessionModel.project)).where(SessionModel.session_id == session_id) result = await db.execute(stmt) session = result.scalar_one_or_none() - - if not session: - raise HTTPException(status_code=404, detail="Session not found") - - if not session.project: - raise HTTPException(status_code=404, detail="Project not found for this session") - + if not session or not session.project: + raise HTTPException(status_code=404, detail="Session or Project not found") if session.project.user_id != user_id: - log.warning(f"Security Alert: User {user_id} tried to access session {session_id} belonging to user {session.project.user_id}") - raise HTTPException(status_code=403, detail="Permission denied: You do not own this session") - + raise HTTPException(status_code=403, detail="Permission denied") return session.project_id async def call_ai_agent(ddl_text: str, question: str, model_key: str = None) -> str: - """调用 AI 接口生成 SQL""" - - # 1. 确定配置 + """调用 AI 生成 SQL""" if not model_key or model_key not in MODEL_REGISTRY: model_key = DEFAULT_MODEL - config = MODEL_REGISTRY[model_key] - log.info(f"Using AI Model: {config['name']} ({config['model_id']})") - messages = _build_ai_messages(config, ddl_text, question) - - # 2. 构建 Payload - payload = { - "model": config["model_id"], - "messages": messages, - "temperature": 0.1, - "stream": False, - "max_tokens": 512, - "stop": ["<|im_end|>", "<|im_start|>", "User:", "Assistant:"] - } - - headers = { - "Authorization": f"Bearer {config['api_key']}", - "Content-Type": "application/json" - } - + payload = {"model": config["model_id"], "messages": messages, "temperature": 0.1, "stream": False} + headers = {"Authorization": f"Bearer {config['api_key']}", "Content-Type": "application/json"} try: async with httpx.AsyncClient(timeout=300.0) as client: - # 【保留 UTF-8 修复】这很重要,防止中文乱码 - resp = await client.post( - config["api_url"], - content=json.dumps(payload, ensure_ascii=False).encode("utf-8"), - headers=headers - ) - + resp = await client.post(config["api_url"], content=json.dumps(payload, ensure_ascii=False).encode("utf-8"), headers=headers) resp.raise_for_status() - raw = resp.json() - - content = "" - if "choices" in raw and len(raw["choices"]) > 0: - content = raw["choices"][0]["message"]["content"] - - # 清洗数据 - # 1. 如果模型输出了停止符,只取前面的部分 - if "<|im_end|>" in content: - content = content.split("<|im_end|>")[0] - if "<|im_start|>" in content: - content = content.split("<|im_start|>")[0] - - # 2. 有时候模型会把 SQL 写在 Markdown 块里,先去 Markdown - clean_sql = content.strip().replace("```sql", "").replace("```", "").strip() - - # 3. 如果还是有多行,且第一行就是完整的 SQL (以分号结尾),就只取第一行 - # 防止它在 SQL 后面通过换行继续自言自语 - if ";\n" in clean_sql: - clean_sql = clean_sql.split(";\n")[0] + ";" - elif clean_sql.count(";") > 1: - # 如果有多条 SQL,只取第一条 - clean_sql = clean_sql.split(";")[0] + ";" - - return clean_sql - + content = resp.json()["choices"][0]["message"]["content"] + return content.strip().replace("```sql", "").replace("```", "").strip() except Exception as e: - log.error(f"AI Call Error ({model_key}): {e}") + log.error(f"AI Call Error: {e}") return f"-- AI Service Error: {str(e)}" # ========================================================= -# 3. 核心业务逻辑 (融合版:含模型记忆 + 安全刹车) +# 3. 核心业务逻辑 # ========================================================= async def process_chat( @@ -224,214 +171,103 @@ async def process_chat( user_id: int, selected_model: str = None ) -> ChatResponse: - """ - 处理用户聊天请求的主流程。 - """ - # 1. 验证会话权限 project_id = await _verify_session_ownership(db, session_id, user_id) - - # 2. 获取 Session 对象以处理模型记忆逻辑 stmt = select(SessionModel).where(SessionModel.session_id == session_id) result = await db.execute(stmt) session_obj = result.scalar_one_or_none() - if not session_obj: - raise HTTPException(status_code=404, detail="Session lost") - - # ========================================================= - # 模型选择优先级策略 - # ========================================================= - final_model_key = DEFAULT_MODEL # 兜底 - if selected_model: - # A. 如果用户本次明确指定了模型 -> 使用它,并更新到数据库(记忆) - final_model_key = selected_model - if session_obj.current_model != selected_model: - session_obj.current_model = selected_model - db.add(session_obj) - await db.commit() # 保存记忆 - log.info(f"Session {session_id} model switched to: {selected_model}") - elif session_obj.current_model: - # B. 如果用户没指定,但数据库里有记忆 -> 使用记忆的模型 - final_model_key = session_obj.current_model - log.info(f"Session {session_id} using stored model: {final_model_key}") - else: - session_obj.current_model = DEFAULT_MODEL + session_obj.current_model = selected_model db.add(session_obj) await db.commit() + final_model_key = session_obj.current_model or DEFAULT_MODEL - # 3. 存用户消息 await crud_message.create_message(db, session_id, user_input, role="user") - - # 4. 获取 DDL 上下文 - # 直接读取 project.ddl_statement,不再使用 schema_definition 进行转换 project = await crud_project.get(db, project_id) - ddl_text = "" - if project and project.ddl_statement: - ddl_text = project.ddl_statement - log.info(f"Using DDL for project {project_id} (Length: {len(ddl_text)})") - else: - # 虽然假设 ddl_statement 一定不为空,但为了稳健性保留一个 Warning - log.warning(f"Project {project_id} has empty ddl_statement!") - ddl_text = "-- Error: No DDL found for this project." - - # 5. 调用 AI 生成 SQL + ddl_text = project.ddl_statement if project and project.ddl_statement else "-- Error: No DDL" + sql_text = await call_ai_agent(ddl_text, user_input, model_key=final_model_key) - - # 6. 解析 SQL 类型 + sql_type = "UNKNOWN" try: - if sql_text and not sql_text.startswith("--"): - parsed = sqlparse.parse(sql_text) - if parsed: - sql_type = parsed[0].get_type().upper() - except Exception: - pass + parsed = sqlparse.parse(sql_text) + if parsed: sql_type = parsed[0].get_type().upper() + except: pass - # 【关键融合 2】判断是否需要确认 requires_confirm = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"] - - # 7. 存 AI 回复消息 reply_content = f"已生成查询语句:\n{sql_text}" ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant") - # 【关键融合 3】如果需要确认,必须把状态写回数据库! - # 同学的代码里漏了这一步,导致确认时会报 400 if requires_confirm: ai_message.requires_confirmation = True db.add(ai_message) await db.commit() - await db.refresh(ai_message) - - # 8. 尝试执行 SQL (带刹车) - data = [] - # 【关键融合 4】只有不需要确认的操作,才立即执行 - # 同学的代码里直接执行了,这很危险 + data = [] if not requires_confirm: try: database_instance = await crud_database_instance.get(db, project.instance_id) - # result 可能是一个列表(SELECT) 或 一个字典(INSERT/UPDATE) raw_result = await execute_sql_with_user_check(sql_text, sql_type, database_instance) + data = raw_result if isinstance(raw_result, list) else ([raw_result] if raw_result else []) - # 【同学的优化】统一数据格式为 List[Dict] - if isinstance(raw_result, dict): - # 如果是 DML 返回的字典,包裹成列表 - data = [raw_result] - elif isinstance(raw_result, list): - # 如果是 DQL 返回的列表,直接使用 - data = raw_result - else: - data = [] + # 【新增】非确认操作(如 SELECT),立即持久化结果快照 + if data: + await _save_query_execution_result(db, ai_message.message_id, sql_text, sql_type, data) + await db.commit() except Exception as e: - log.info(f"SQL Execution Error (Safe to ignore if SQL is invalid): {str(e)}") - else: - # 如果需要确认,跳过执行 - log.info(f"SQL requires confirmation ({sql_type}), skipping immediate execution.") + log.info(f"Execution Error: {str(e)}") - # 9. 构造响应 return ChatResponse( - message_id=ai_message.message_id, - content=reply_content, - message_type=MessageType.ASSISTANT, - sql_text=sql_text, - sql_type=sql_type, - requires_confirmation=requires_confirm, - data=data + message_id=ai_message.message_id, content=reply_content, + message_type=MessageType.ASSISTANT, sql_text=sql_text, + sql_type=sql_type, requires_confirmation=requires_confirm, data=data ) -# ========================================================= -# 4. 执行确认逻辑 (同学代码里缺失,这里必须补上) -# ========================================================= - async def confirm_and_execute_sql( db: AsyncSession, message_id: int, user_id: int ) -> ChatResponse: - """ - 用户确认执行某条消息中的 SQL (通常是增删改操作)。 - 【修复版】使用分步查询法,解决 AttributeError。 - """ - # 1. 第一步:查消息本身 + """用户确认执行某条消息中的 SQL""" + # 1. 查找消息、会话及项目 stmt = select(MessageModel).where(MessageModel.message_id == message_id) result = await db.execute(stmt) message = result.scalar_one_or_none() - if not message: raise HTTPException(status_code=404, detail="Message not found") - # 2. 第二步:查所属 Session stmt_session = select(SessionModel).where(SessionModel.session_id == message.session_id) - result_session = await db.execute(stmt_session) - session_obj = result_session.scalar_one_or_none() - - if not session_obj: - raise HTTPException(status_code=404, detail="Session not found") - - # 3. 第三步:查所属 Project - stmt_project = select(ProjectModel).where(ProjectModel.project_id == session_obj.project_id) - result_project = await db.execute(stmt_project) - project = result_project.scalar_one_or_none() + session_obj = (await db.execute(stmt_session)).scalar_one_or_none() + project = await crud_project.get(db, session_obj.project_id) - if not project: - raise HTTPException(status_code=404, detail="Project not found") - - # --- 校验逻辑 --- if project.user_id != user_id: raise HTTPException(status_code=403, detail="Access denied") + if not message.requires_confirmation or message.user_confirmed: + raise HTTPException(status_code=400, detail="Invalid confirmation request") - if not message.requires_confirmation: - raise HTTPException(status_code=400, detail="This message does not require confirmation") - - if message.user_confirmed: - raise HTTPException(status_code=400, detail="This operation has already been confirmed/executed") - - # 4. 提取 SQL 语句 - content = message.content - sql_text = "" - - if ":\n" in content: - sql_text = content.split(":\n")[-1].strip() - else: - sql_text = content.strip() - + # 2. 提取并清理 SQL + sql_text = message.content.split(":\n")[-1].strip() if ":\n" in message.content else message.content.strip() sql_text = sql_text.replace("```sql", "").replace("```", "").strip() - # 5. 执行 SQL (DML) execute_res = [] try: + # 3. 执行 DML database_instance = await crud_database_instance.get(db, project.instance_id) + raw_result = await execute_sql_with_user_check(sql_text, "DML", database_instance) + execute_res = raw_result if isinstance(raw_result, list) else ([raw_result] if raw_result else []) - # 真正执行 - result = await execute_sql_with_user_check(sql_text, "UPDATE", database_instance) - - # 📦 统一包装成列表 - if isinstance(result, dict): - execute_res = [result] - elif isinstance(result, list): - execute_res = result - else: - execute_res = [] + # 【新增】确认执行后,将变更影响的快照持久化 + await _save_query_execution_result(db, message.message_id, sql_text, "DML_EXECUTED", execute_res) - # 更新消息状态 message.user_confirmed = True db.add(message) await db.commit() - - log.info(f"User {user_id} confirmed execution of message {message_id}") - except Exception as e: log.error(f"Execution failed: {e}") - raise HTTPException(status_code=500, detail=f"Execution failed: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) - # 6. 返回结果 return ChatResponse( - message_id=message.message_id, - content=message.content, - message_type=MessageType.ASSISTANT, - sql_text=sql_text, - sql_type="DML_EXECUTED", - requires_confirmation=False, - data=execute_res + message_id=message.message_id, content=message.content, + message_type=MessageType.ASSISTANT, sql_text=sql_text, + sql_type="DML_EXECUTED", requires_confirmation=False, data=execute_res ) \ No newline at end of file diff --git a/src/backend/app/service/report_service.py b/src/backend/app/service/report_service.py index b7d6eb8..01b2e68 100644 --- a/src/backend/app/service/report_service.py +++ b/src/backend/app/service/report_service.py @@ -123,39 +123,21 @@ async def get_history_queries_service( project_id: int, user_id: int, ) -> List[schemas.HistoryQuery]: - await _verify_project_ownership(db, project_id, user_id) - """ - 获取项目历史查询记录,返回用户原始提问与结果。 - - Args: - db (Session): 数据库会话。 - project_id (int): 项目 ID。 - - Returns: - List[schemas.HistoryQuery]: 历史查询记录列表。 - - Raises: - HTTPException: 项目不存在时返回 404。 + 获取项目历史查询记录(仅限已持久化的 SELECT 结果)。 """ - # [修复问题1]:先检查项目是否存在 - project = await db.get(Project, project_id) - if not project: - raise HTTPException(status_code=404, detail=f"Project with ID {project_id} not found") + await _verify_project_ownership(db, project_id, user_id) - # [修复问题2]:局部导入以避免 UnboundLocalError 和循环依赖 from models.message import Message from models.session import Session as SessionModel - # 构造查询:Query Result -> Statement -> Message -> Session - # 我们需要 Message.content (用户问题) 和 QueryResult (数据) + # 路径:QueryResult -> Statement -> Message (Assistant) stmt = ( select(QueryResult, Message) .join(AIGeneratedStatement, QueryResult.statement_id == AIGeneratedStatement.statement_id) .join(Message, AIGeneratedStatement.message_id == Message.message_id) .join(SessionModel, Message.session_id == SessionModel.session_id) .where(SessionModel.project_id == project_id) - .where(Message.message_type == 'assistant') # 确保我们取的是用户发的消息(提问) .order_by(QueryResult.cached_at.desc()) ) @@ -164,13 +146,13 @@ async def get_history_queries_service( history = [] for query_res, msg_obj in rows: + # 这里返回的是 QueryResult 的 result_id,前端将用这个 ID 创建报表 history.append(schemas.HistoryQuery( id=str(query_res.result_id), projectId=str(project_id), - # 这里取的是 message.content,即用户的原始提问 - queryText=msg_obj.content, + queryText=msg_obj.content, # 显示 AI 的描述或 SQL timestamp=query_res.cached_at.isoformat() if query_res.cached_at else 'N/A', - result=query_res.result_data + result=query_res.result_data # 包含完整的 JSON 结果集 )) return history