|
|
|
|
@ -20,6 +20,7 @@ from fastapi import HTTPException
|
|
|
|
|
from sqlalchemy.future import select
|
|
|
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
from fastapi.encoders import jsonable_encoder
|
|
|
|
|
|
|
|
|
|
from core.config import settings
|
|
|
|
|
from core.log import log
|
|
|
|
|
@ -29,6 +30,7 @@ from crud.crud_project import crud_project
|
|
|
|
|
from models.session import Session as SessionModel
|
|
|
|
|
from schema.chat import ChatResponse, MessageType
|
|
|
|
|
from service.mysql_service import execute_sql_with_user_check
|
|
|
|
|
from models.ai_generated_statement import AIGeneratedStatement
|
|
|
|
|
|
|
|
|
|
# 必须导入这些模型类,否则确认接口无法运行
|
|
|
|
|
from models.message import Message as MessageModel
|
|
|
|
|
@ -217,6 +219,8 @@ async def call_ai_agent(ddl_text: str, question: str, model_key: str = None) ->
|
|
|
|
|
# 3. 核心业务逻辑 (融合版:含模型记忆 + 安全刹车)
|
|
|
|
|
# =========================================================
|
|
|
|
|
|
|
|
|
|
# backend/app/service/chat_service.py
|
|
|
|
|
|
|
|
|
|
async def process_chat(
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
session_id: int,
|
|
|
|
|
@ -230,114 +234,130 @@ async def process_chat(
|
|
|
|
|
# 1. 验证会话权限
|
|
|
|
|
project_id = await _verify_session_ownership(db, session_id, user_id)
|
|
|
|
|
|
|
|
|
|
# 2. 获取 Session 对象以处理模型记忆逻辑
|
|
|
|
|
# 2. 获取 Session 对象
|
|
|
|
|
stmt = select(SessionModel).where(SessionModel.session_id == session_id)
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
session_obj = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if not session_obj:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Session lost")
|
|
|
|
|
|
|
|
|
|
# =========================================================
|
|
|
|
|
# 模型选择优先级策略
|
|
|
|
|
# 3. 模型选择优先级策略 (修复 'final_model_key' 未定义问题)
|
|
|
|
|
# =========================================================
|
|
|
|
|
final_model_key = DEFAULT_MODEL # 兜底
|
|
|
|
|
final_model_key = DEFAULT_MODEL # 默认兜底
|
|
|
|
|
|
|
|
|
|
if selected_model:
|
|
|
|
|
# A. 如果用户本次明确指定了模型 -> 使用它,并更新到数据库(记忆)
|
|
|
|
|
# A. 用户本次明确指定了模型
|
|
|
|
|
final_model_key = selected_model
|
|
|
|
|
if session_obj.current_model != selected_model:
|
|
|
|
|
session_obj.current_model = selected_model
|
|
|
|
|
db.add(session_obj)
|
|
|
|
|
await db.commit() # 保存记忆
|
|
|
|
|
log.info(f"Session {session_id} model switched to: {selected_model}")
|
|
|
|
|
# 这里先不 commit,后面统一提交提高性能
|
|
|
|
|
elif session_obj.current_model:
|
|
|
|
|
# B. 如果用户没指定,但数据库里有记忆 -> 使用记忆的模型
|
|
|
|
|
# B. 使用数据库记忆的模型
|
|
|
|
|
final_model_key = session_obj.current_model
|
|
|
|
|
log.info(f"Session {session_id} using stored model: {final_model_key}")
|
|
|
|
|
else:
|
|
|
|
|
session_obj.current_model = DEFAULT_MODEL
|
|
|
|
|
db.add(session_obj)
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
log.info(f"Session {session_id} using model: {final_model_key}")
|
|
|
|
|
|
|
|
|
|
# 3. 存用户消息
|
|
|
|
|
# 4. 存用户发送的原始消息 (SF6 上下文记忆的基础)
|
|
|
|
|
await crud_message.create_message(db, session_id, user_input, role="user")
|
|
|
|
|
|
|
|
|
|
# 4. 获取 DDL 上下文
|
|
|
|
|
# 直接读取 project.ddl_statement,不再使用 schema_definition 进行转换
|
|
|
|
|
# 5. 获取 DDL 上下文并调用 AI
|
|
|
|
|
project = await crud_project.get(db, project_id)
|
|
|
|
|
ddl_text = ""
|
|
|
|
|
if project and project.ddl_statement:
|
|
|
|
|
ddl_text = project.ddl_statement
|
|
|
|
|
log.info(f"Using DDL for project {project_id} (Length: {len(ddl_text)})")
|
|
|
|
|
else:
|
|
|
|
|
# 虽然假设 ddl_statement 一定不为空,但为了稳健性保留一个 Warning
|
|
|
|
|
log.warning(f"Project {project_id} has empty ddl_statement!")
|
|
|
|
|
ddl_text = "-- Error: No DDL found for this project."
|
|
|
|
|
|
|
|
|
|
# 5. 调用 AI 生成 SQL
|
|
|
|
|
ddl_text = project.ddl_statement if project and project.ddl_statement else "-- No DDL found"
|
|
|
|
|
if user_input.strip() in ["取消", "cancel", "Stop"]:
|
|
|
|
|
return ChatResponse(
|
|
|
|
|
message_id=0, # 指令消息不一定需要存入数据库
|
|
|
|
|
content="好的,已为您取消当前操作。",
|
|
|
|
|
message_type=MessageType.ASSISTANT,
|
|
|
|
|
sql_text=None,
|
|
|
|
|
sql_type="ACTION_CANCEL", # 自定义类型,前端据此显示不同 UI
|
|
|
|
|
requires_confirmation=False,
|
|
|
|
|
data=None
|
|
|
|
|
)
|
|
|
|
|
# 执行 Text-to-SQL
|
|
|
|
|
sql_text = await call_ai_agent(ddl_text, user_input, model_key=final_model_key)
|
|
|
|
|
|
|
|
|
|
# 6. 解析 SQL 类型
|
|
|
|
|
is_meta_sql = any(kw in sql_text.upper() for kw in ["'CANCELED'", "'ERROR'"])
|
|
|
|
|
|
|
|
|
|
if is_meta_sql:
|
|
|
|
|
reply_content = "抱歉,我无法执行该操作或理解您的指令。请提供具体的业务需求(如:查询书籍)。"
|
|
|
|
|
# 存一条不带 SQL 执行记录的消息
|
|
|
|
|
ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant")
|
|
|
|
|
await db.commit()
|
|
|
|
|
return ChatResponse(
|
|
|
|
|
message_id=ai_message.message_id,
|
|
|
|
|
content=reply_content,
|
|
|
|
|
message_type=MessageType.ASSISTANT,
|
|
|
|
|
sql_text=None, # 不给 SQL,前端就不会出黑色代码框
|
|
|
|
|
sql_type="ERROR_FEEDBACK",
|
|
|
|
|
requires_confirmation=False,
|
|
|
|
|
data=None # 不给数据,前端就不会画表格
|
|
|
|
|
)
|
|
|
|
|
# 6. 解析 SQL 类型与安全性校验 (SF3)
|
|
|
|
|
sql_type = "UNKNOWN"
|
|
|
|
|
try:
|
|
|
|
|
if sql_text and not sql_text.startswith("--"):
|
|
|
|
|
parsed = sqlparse.parse(sql_text)
|
|
|
|
|
if parsed:
|
|
|
|
|
sql_type = parsed[0].get_type().upper()
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
if sql_type == "UNKNOWN" and sql_text.strip().upper().startswith("SELECT"):
|
|
|
|
|
sql_type = "SELECT"
|
|
|
|
|
except: pass
|
|
|
|
|
|
|
|
|
|
# 【关键融合 2】判断是否需要确认
|
|
|
|
|
requires_confirm = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"]
|
|
|
|
|
|
|
|
|
|
# 7. 存 AI 回复消息
|
|
|
|
|
reply_content = f"已生成查询语句:\n{sql_text}"
|
|
|
|
|
# 7. 【核心修复】:只创建一条 Assistant 消息,防止重复发消息
|
|
|
|
|
reply_content = f"已生成sql语句:\n{sql_text}"
|
|
|
|
|
ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant")
|
|
|
|
|
|
|
|
|
|
# 【关键融合 3】如果需要确认,必须把状态写回数据库!
|
|
|
|
|
# 同学的代码里漏了这一步,导致确认时会报 400
|
|
|
|
|
if requires_confirm:
|
|
|
|
|
ai_message.requires_confirmation = True
|
|
|
|
|
db.add(ai_message)
|
|
|
|
|
await db.commit()
|
|
|
|
|
await db.refresh(ai_message)
|
|
|
|
|
|
|
|
|
|
# 8. 尝试执行 SQL (带刹车)
|
|
|
|
|
# 8. 尝试执行 SQL (针对 DQL 查询)
|
|
|
|
|
data = []
|
|
|
|
|
execution_status = "pending"
|
|
|
|
|
|
|
|
|
|
# 【关键融合 4】只有不需要确认的操作,才立即执行
|
|
|
|
|
# 同学的代码里直接执行了,这很危险
|
|
|
|
|
if not requires_confirm:
|
|
|
|
|
try:
|
|
|
|
|
database_instance = await crud_database_instance.get(db, project.instance_id)
|
|
|
|
|
# result 可能是一个列表(SELECT) 或 一个字典(INSERT/UPDATE)
|
|
|
|
|
raw_result = await execute_sql_with_user_check(sql_text, sql_type, database_instance)
|
|
|
|
|
exec_type = sql_type if sql_type != "UNKNOWN" else "SELECT"
|
|
|
|
|
raw_result = await execute_sql_with_user_check(sql_text, exec_type, database_instance)
|
|
|
|
|
|
|
|
|
|
# 【同学的优化】统一数据格式为 List[Dict]
|
|
|
|
|
if isinstance(raw_result, dict):
|
|
|
|
|
# 如果是 DML 返回的字典,包裹成列表
|
|
|
|
|
data = [raw_result]
|
|
|
|
|
elif isinstance(raw_result, list):
|
|
|
|
|
# 如果是 DQL 返回的列表,直接使用
|
|
|
|
|
data = raw_result
|
|
|
|
|
else:
|
|
|
|
|
data = []
|
|
|
|
|
# 统一数据格式为 List[Dict]
|
|
|
|
|
if isinstance(raw_result, list): data = raw_result
|
|
|
|
|
elif isinstance(raw_result, dict): data = [raw_result]
|
|
|
|
|
|
|
|
|
|
execution_status = "success"
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.info(f"SQL Execution Error (Safe to ignore if SQL is invalid): {str(e)}")
|
|
|
|
|
log.error(f"SQL Execution Error: {str(e)}")
|
|
|
|
|
execution_status = "failed"
|
|
|
|
|
else:
|
|
|
|
|
# 如果需要确认,跳过执行
|
|
|
|
|
# DML 操作,标记为需确认,前端会显示确认按钮
|
|
|
|
|
log.info(f"SQL requires confirmation ({sql_type}), skipping immediate execution.")
|
|
|
|
|
ai_message.requires_confirmation = True
|
|
|
|
|
db.add(ai_message)
|
|
|
|
|
execution_status = "pending"
|
|
|
|
|
|
|
|
|
|
# 9. 【关键持久化】:使用 jsonable_encoder 处理日期并保存到子表 (SF9)
|
|
|
|
|
safe_data = jsonable_encoder(data)
|
|
|
|
|
new_statement = AIGeneratedStatement(
|
|
|
|
|
message_id=ai_message.message_id,
|
|
|
|
|
sql_text=sql_text,
|
|
|
|
|
statement_type=sql_type,
|
|
|
|
|
execution_status=execution_status,
|
|
|
|
|
execution_result=safe_data,
|
|
|
|
|
statement_order=1
|
|
|
|
|
)
|
|
|
|
|
db.add(new_statement)
|
|
|
|
|
|
|
|
|
|
# 最后统一 commit 事务,保证数据一致性
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
# 9. 构造响应
|
|
|
|
|
return ChatResponse(
|
|
|
|
|
message_id=ai_message.message_id,
|
|
|
|
|
content=reply_content,
|
|
|
|
|
message_type=MessageType.ASSISTANT,
|
|
|
|
|
sql_text=sql_text,
|
|
|
|
|
sql_type=sql_type,
|
|
|
|
|
requires_confirmation=requires_confirm,
|
|
|
|
|
data=data
|
|
|
|
|
requires_confirmation=ai_message.requires_confirmation,
|
|
|
|
|
data=safe_data
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# =========================================================
|
|
|
|
|
@ -402,24 +422,27 @@ async def confirm_and_execute_sql(
|
|
|
|
|
execute_res = []
|
|
|
|
|
try:
|
|
|
|
|
database_instance = await crud_database_instance.get(db, project.instance_id)
|
|
|
|
|
|
|
|
|
|
# 真正执行
|
|
|
|
|
# 真正执行 DML
|
|
|
|
|
result = await execute_sql_with_user_check(sql_text, "UPDATE", database_instance)
|
|
|
|
|
execute_res = [result] if isinstance(result, dict) else result
|
|
|
|
|
|
|
|
|
|
# 📦 统一包装成列表
|
|
|
|
|
if isinstance(result, dict):
|
|
|
|
|
execute_res = [result]
|
|
|
|
|
elif isinstance(result, list):
|
|
|
|
|
execute_res = result
|
|
|
|
|
else:
|
|
|
|
|
execute_res = []
|
|
|
|
|
|
|
|
|
|
# 更新消息状态
|
|
|
|
|
# 1. 更新消息确认状态
|
|
|
|
|
message.user_confirmed = True
|
|
|
|
|
db.add(message)
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
log.info(f"User {user_id} confirmed execution of message {message_id}")
|
|
|
|
|
# 2. 【关键:同步更新子表结果】
|
|
|
|
|
# 找到该消息对应的 SQL 记录并把执行结果填进去
|
|
|
|
|
stmt_query = select(AIGeneratedStatement).where(AIGeneratedStatement.message_id == message_id)
|
|
|
|
|
stmt_res = await db.execute(stmt_query)
|
|
|
|
|
db_stmt = stmt_res.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if db_stmt:
|
|
|
|
|
db_stmt.execution_result = execute_res
|
|
|
|
|
db_stmt.execution_status = "success"
|
|
|
|
|
db.add(db_stmt)
|
|
|
|
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
log.info(f"User {user_id} executed DML and persisted results.")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Execution failed: {e}")
|
|
|
|
|
|