diff --git a/src/backend/app/api/v1/endpoints/chat.py b/src/backend/app/api/v1/endpoints/chat.py index 58cb1d8..4e5f65d 100644 --- a/src/backend/app/api/v1/endpoints/chat.py +++ b/src/backend/app/api/v1/endpoints/chat.py @@ -21,50 +21,42 @@ from api.v1.deps import get_current_active_user, get_db router = APIRouter() -def _format_history_response(raw_messages: List[Any]) -> List[ChatResponse]: - """ - 格式化消息历史,提取 SQL 信息。 +# backend/app/api/v1/endpoints/chat.py - Args: - raw_messages (List[Any]): 原始消息列表。 +# backend/app/api/v1/endpoints/chat.py - Returns: - List[ChatResponse]: 格式化后的响应列表。 - """ +def _format_history_response(raw_messages: List[Any]) -> List[ChatResponse]: clean_history = [] for msg in raw_messages: - sql_text = None - sql_type = "UNKNOWN" - - # 1. 尝试从数据库元数据获取 - if msg.ai_statement: - # 兼容列表或单对象,取第一个核心 SQL - stmt = msg.ai_statement[0] if isinstance(msg.ai_statement, list) and len(msg.ai_statement) > 0 else msg.ai_statement - if stmt and hasattr(stmt, 'sql_text'): - sql_text = stmt.sql_text - sql_type = stmt.statement_type - - # 2. 兜底解析 - msg_type = MessageType.ASSISTANT if hasattr(msg, 'message_type') and msg.message_type == "assistant" else MessageType.USER + # 1. 直接获取在 CRUD 中挂载好的持久化数据 + # 我们不再依赖 ai_statement,而是直接拿映射好的值 + sql_text = getattr(msg, 'sql_text', None) + sql_type = getattr(msg, 'sql_type', 'UNKNOWN') + data = getattr(msg, 'data', None) # <--- 获取持久化结果 + + msg_type = MessageType.ASSISTANT if msg.message_type == "assistant" else MessageType.USER + + # 2. 如果 SQL 信息还没挂载上(比如某些异常情况),再尝试兜底解析 if msg_type == MessageType.ASSISTANT and not sql_text and msg.content: - content = msg.content - if "已生成查询语句" in content: - content = content.split(":")[-1].strip() - clean = content.replace("```sql", "").replace("```", "").strip() try: - parsed = sqlparse.parse(clean) - if parsed and parsed[0].get_type() != "UNKNOWN": - sql_text = clean - sql_type = parsed[0].get_type().upper() + # 你的原始兜底逻辑保持不变,但增加安全性 + content = msg.content + if "已生成查询语句" in content: + sql_text = content.split(":")[-1].strip() except: pass requires_conf = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"] if msg.user_confirmed: requires_conf = False + # 3. 核心修复:把 data 传进去! clean_history.append(ChatResponse( - message_id=msg.message_id if hasattr(msg, 'message_id') else msg.id, - content=msg.content, message_type=msg_type, sql_text=sql_text, sql_type=sql_type, - requires_confirmation=requires_conf, data=None + message_id=msg.message_id, + content=msg.content, + message_type=msg_type, + sql_text=sql_text, + sql_type=sql_type, + requires_confirmation=requires_conf, + data=data # <--- 修改这里:不再是 None,而是 msg.data )) return clean_history diff --git a/src/backend/app/crud/crud_message.py b/src/backend/app/crud/crud_message.py index 293565a..dee0262 100644 --- a/src/backend/app/crud/crud_message.py +++ b/src/backend/app/crud/crud_message.py @@ -45,20 +45,15 @@ class CRUDMessage: await db.refresh(db_obj) return db_obj - async def get_recent_messages(self, db: AsyncSession, session_id: int, limit: int = 50) -> List[Message]: - """ - 获取最近的历史记录。 + # backend/app/crud/crud_message.py - Args: - db (AsyncSession): 数据库会话。 - session_id (int): 会话 ID。 - limit (int): 返回的消息数量限制。 +# backend/app/crud/crud_message.py - Returns: - List[Message]: 消息列表 (包含关联的 AI 生成语句)。 + async def get_recent_messages(self, db: AsyncSession, session_id: int, limit: int = 50) -> List[Message]: """ - # 第一步:先只查消息表 - # 注意:这里去掉了 .options(selectinload(...)),解决了报错 + 获取最近的历史记录,并直接映射 SQL 执行结果。 + """ + # 1. 查消息基础信息 query = select(Message)\ .filter(Message.session_id == session_id)\ .order_by(desc(Message.created_at))\ @@ -70,39 +65,41 @@ class CRUDMessage: if not messages: return [] - # 第二步:收集所有消息的 ID - # 兼容处理:队友的模型主键可能叫 id,也可能叫 message_id - message_ids = [] - for m in messages: - # 优先取 message_id,如果没有就取 id - mid = getattr(m, "message_id", getattr(m, "id", None)) - if mid: - message_ids.append(mid) + # 2. 收集消息 ID + message_ids = [getattr(m, "message_id", getattr(m, "id", None)) for m in messages] - # 第三步:去查 SQL 语句表 (如果找到了消息ID) - ai_statements = [] + # 3. 查关联的 SQL 详情和结果 (来自 ai_generated_statement 表) + stmt_map = {} if message_ids: + # 获取 SQL 文本、类型和执行结果 [cite: 638, 811-820] stmt_query = select(AIGeneratedStatement).where( AIGeneratedStatement.message_id.in_(message_ids) ) stmt_result = await db.execute(stmt_query) ai_statements = stmt_result.scalars().all() + + # 建立 ID 映射 + for stmt in ai_statements: + stmt_map[stmt.message_id] = stmt - #第四步:手动拼装 (把查到的 SQL 塞进消息对象里) - # 制作一个字典方便查找: {message_id: [statement1, statement2]} - stmt_map = {} - for stmt in ai_statements: - if stmt.message_id not in stmt_map: - stmt_map[stmt.message_id] = [] - stmt_map[stmt.message_id].append(stmt) - - # 把 SQL 挂载到 Message 对象上 (临时属性) + # 4. 【核心修复】:直接映射到 ChatResponse 需要的字段 for m in messages: mid = getattr(m, "message_id", getattr(m, "id", None)) - # 我们给对象动态添加一个属性叫 ai_statement,这样 Endpoint 那边就不用改代码了 - m.ai_statement = stmt_map.get(mid, []) - - # 将倒序查询结果翻转为正序 + stmt = stmt_map.get(mid) + + if stmt: + # 直接给对象赋值,名称必须与 ChatResponse 中的定义一致 + m.sql_text = stmt.sql_text + m.sql_type = stmt.statement_type + # 这里的 .data 对应 ai_generated_statement 表的 execution_result 字段 [cite: 818, 830] + m.data = stmt.execution_result + else: + # 对于 user 消息或没有 SQL 的消息,设为空 + m.sql_text = None + m.sql_type = "UNKNOWN" + m.data = None + + # 5. 返回正序列表,满足历史加载需求 return list(reversed(messages)) # 实例化对象 diff --git a/src/backend/app/service/chat_service.py b/src/backend/app/service/chat_service.py index aa455c1..c6f1998 100644 --- a/src/backend/app/service/chat_service.py +++ b/src/backend/app/service/chat_service.py @@ -20,6 +20,7 @@ from fastapi import HTTPException from sqlalchemy.future import select from sqlalchemy.orm import selectinload from sqlalchemy.ext.asyncio import AsyncSession +from fastapi.encoders import jsonable_encoder from core.config import settings from core.log import log @@ -29,6 +30,7 @@ from crud.crud_project import crud_project from models.session import Session as SessionModel from schema.chat import ChatResponse, MessageType from service.mysql_service import execute_sql_with_user_check +from models.ai_generated_statement import AIGeneratedStatement # 必须导入这些模型类,否则确认接口无法运行 from models.message import Message as MessageModel @@ -217,6 +219,8 @@ async def call_ai_agent(ddl_text: str, question: str, model_key: str = None) -> # 3. 核心业务逻辑 (融合版:含模型记忆 + 安全刹车) # ========================================================= +# backend/app/service/chat_service.py + async def process_chat( db: AsyncSession, session_id: int, @@ -230,114 +234,130 @@ async def process_chat( # 1. 验证会话权限 project_id = await _verify_session_ownership(db, session_id, user_id) - # 2. 获取 Session 对象以处理模型记忆逻辑 + # 2. 获取 Session 对象 stmt = select(SessionModel).where(SessionModel.session_id == session_id) result = await db.execute(stmt) session_obj = result.scalar_one_or_none() - if not session_obj: raise HTTPException(status_code=404, detail="Session lost") # ========================================================= - # 模型选择优先级策略 + # 3. 模型选择优先级策略 (修复 'final_model_key' 未定义问题) # ========================================================= - final_model_key = DEFAULT_MODEL # 兜底 + final_model_key = DEFAULT_MODEL # 默认兜底 if selected_model: - # A. 如果用户本次明确指定了模型 -> 使用它,并更新到数据库(记忆) + # A. 用户本次明确指定了模型 final_model_key = selected_model if session_obj.current_model != selected_model: session_obj.current_model = selected_model db.add(session_obj) - await db.commit() # 保存记忆 - log.info(f"Session {session_id} model switched to: {selected_model}") + # 这里先不 commit,后面统一提交提高性能 elif session_obj.current_model: - # B. 如果用户没指定,但数据库里有记忆 -> 使用记忆的模型 + # B. 使用数据库记忆的模型 final_model_key = session_obj.current_model - log.info(f"Session {session_id} using stored model: {final_model_key}") - else: - session_obj.current_model = DEFAULT_MODEL - db.add(session_obj) - await db.commit() + + log.info(f"Session {session_id} using model: {final_model_key}") - # 3. 存用户消息 + # 4. 存用户发送的原始消息 (SF6 上下文记忆的基础) await crud_message.create_message(db, session_id, user_input, role="user") - # 4. 获取 DDL 上下文 - # 直接读取 project.ddl_statement,不再使用 schema_definition 进行转换 + # 5. 获取 DDL 上下文并调用 AI project = await crud_project.get(db, project_id) - ddl_text = "" - if project and project.ddl_statement: - ddl_text = project.ddl_statement - log.info(f"Using DDL for project {project_id} (Length: {len(ddl_text)})") - else: - # 虽然假设 ddl_statement 一定不为空,但为了稳健性保留一个 Warning - log.warning(f"Project {project_id} has empty ddl_statement!") - ddl_text = "-- Error: No DDL found for this project." - - # 5. 调用 AI 生成 SQL + ddl_text = project.ddl_statement if project and project.ddl_statement else "-- No DDL found" + if user_input.strip() in ["取消", "cancel", "Stop"]: + return ChatResponse( + message_id=0, # 指令消息不一定需要存入数据库 + content="好的,已为您取消当前操作。", + message_type=MessageType.ASSISTANT, + sql_text=None, + sql_type="ACTION_CANCEL", # 自定义类型,前端据此显示不同 UI + requires_confirmation=False, + data=None + ) + # 执行 Text-to-SQL sql_text = await call_ai_agent(ddl_text, user_input, model_key=final_model_key) - - # 6. 解析 SQL 类型 + is_meta_sql = any(kw in sql_text.upper() for kw in ["'CANCELED'", "'ERROR'"]) + + if is_meta_sql: + reply_content = "抱歉,我无法执行该操作或理解您的指令。请提供具体的业务需求(如:查询书籍)。" + # 存一条不带 SQL 执行记录的消息 + ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant") + await db.commit() + return ChatResponse( + message_id=ai_message.message_id, + content=reply_content, + message_type=MessageType.ASSISTANT, + sql_text=None, # 不给 SQL,前端就不会出黑色代码框 + sql_type="ERROR_FEEDBACK", + requires_confirmation=False, + data=None # 不给数据,前端就不会画表格 + ) + # 6. 解析 SQL 类型与安全性校验 (SF3) sql_type = "UNKNOWN" try: if sql_text and not sql_text.startswith("--"): parsed = sqlparse.parse(sql_text) if parsed: sql_type = parsed[0].get_type().upper() - except Exception: - pass + if sql_type == "UNKNOWN" and sql_text.strip().upper().startswith("SELECT"): + sql_type = "SELECT" + except: pass - # 【关键融合 2】判断是否需要确认 requires_confirm = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"] - # 7. 存 AI 回复消息 - reply_content = f"已生成查询语句:\n{sql_text}" + # 7. 【核心修复】:只创建一条 Assistant 消息,防止重复发消息 + reply_content = f"已生成sql语句:\n{sql_text}" ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant") - # 【关键融合 3】如果需要确认,必须把状态写回数据库! - # 同学的代码里漏了这一步,导致确认时会报 400 - if requires_confirm: - ai_message.requires_confirmation = True - db.add(ai_message) - await db.commit() - await db.refresh(ai_message) - - # 8. 尝试执行 SQL (带刹车) + # 8. 尝试执行 SQL (针对 DQL 查询) data = [] + execution_status = "pending" - # 【关键融合 4】只有不需要确认的操作,才立即执行 - # 同学的代码里直接执行了,这很危险 if not requires_confirm: try: database_instance = await crud_database_instance.get(db, project.instance_id) - # result 可能是一个列表(SELECT) 或 一个字典(INSERT/UPDATE) - raw_result = await execute_sql_with_user_check(sql_text, sql_type, database_instance) + exec_type = sql_type if sql_type != "UNKNOWN" else "SELECT" + raw_result = await execute_sql_with_user_check(sql_text, exec_type, database_instance) - # 【同学的优化】统一数据格式为 List[Dict] - if isinstance(raw_result, dict): - # 如果是 DML 返回的字典,包裹成列表 - data = [raw_result] - elif isinstance(raw_result, list): - # 如果是 DQL 返回的列表,直接使用 - data = raw_result - else: - data = [] + # 统一数据格式为 List[Dict] + if isinstance(raw_result, list): data = raw_result + elif isinstance(raw_result, dict): data = [raw_result] + + execution_status = "success" except Exception as e: - log.info(f"SQL Execution Error (Safe to ignore if SQL is invalid): {str(e)}") + log.error(f"SQL Execution Error: {str(e)}") + execution_status = "failed" else: - # 如果需要确认,跳过执行 + # DML 操作,标记为需确认,前端会显示确认按钮 log.info(f"SQL requires confirmation ({sql_type}), skipping immediate execution.") + ai_message.requires_confirmation = True + db.add(ai_message) + execution_status = "pending" + + # 9. 【关键持久化】:使用 jsonable_encoder 处理日期并保存到子表 (SF9) + safe_data = jsonable_encoder(data) + new_statement = AIGeneratedStatement( + message_id=ai_message.message_id, + sql_text=sql_text, + statement_type=sql_type, + execution_status=execution_status, + execution_result=safe_data, + statement_order=1 + ) + db.add(new_statement) + + # 最后统一 commit 事务,保证数据一致性 + await db.commit() - # 9. 构造响应 return ChatResponse( message_id=ai_message.message_id, content=reply_content, message_type=MessageType.ASSISTANT, sql_text=sql_text, sql_type=sql_type, - requires_confirmation=requires_confirm, - data=data + requires_confirmation=ai_message.requires_confirmation, + data=safe_data ) # ========================================================= @@ -402,24 +422,27 @@ async def confirm_and_execute_sql( execute_res = [] try: database_instance = await crud_database_instance.get(db, project.instance_id) - - # 真正执行 + # 真正执行 DML result = await execute_sql_with_user_check(sql_text, "UPDATE", database_instance) + execute_res = [result] if isinstance(result, dict) else result - # 📦 统一包装成列表 - if isinstance(result, dict): - execute_res = [result] - elif isinstance(result, list): - execute_res = result - else: - execute_res = [] - - # 更新消息状态 + # 1. 更新消息确认状态 message.user_confirmed = True db.add(message) - await db.commit() - log.info(f"User {user_id} confirmed execution of message {message_id}") + # 2. 【关键:同步更新子表结果】 + # 找到该消息对应的 SQL 记录并把执行结果填进去 + stmt_query = select(AIGeneratedStatement).where(AIGeneratedStatement.message_id == message_id) + stmt_res = await db.execute(stmt_query) + db_stmt = stmt_res.scalar_one_or_none() + + if db_stmt: + db_stmt.execution_result = execute_res + db_stmt.execution_status = "success" + db.add(db_stmt) + + await db.commit() + log.info(f"User {user_id} executed DML and persisted results.") except Exception as e: log.error(f"Execution failed: {e}")