wanglirong_branch #35

Merged
hnu202326010328 merged 4 commits from wanglirong_branch into develop 2 months ago

@ -21,50 +21,42 @@ from api.v1.deps import get_current_active_user, get_db
router = APIRouter()
def _format_history_response(raw_messages: List[Any]) -> List[ChatResponse]:
"""
格式化消息历史提取 SQL 信息
# backend/app/api/v1/endpoints/chat.py
Args:
raw_messages (List[Any]): 原始消息列表
# backend/app/api/v1/endpoints/chat.py
Returns:
List[ChatResponse]: 格式化后的响应列表
"""
def _format_history_response(raw_messages: List[Any]) -> List[ChatResponse]:
clean_history = []
for msg in raw_messages:
sql_text = None
sql_type = "UNKNOWN"
# 1. 尝试从数据库元数据获取
if msg.ai_statement:
# 兼容列表或单对象,取第一个核心 SQL
stmt = msg.ai_statement[0] if isinstance(msg.ai_statement, list) and len(msg.ai_statement) > 0 else msg.ai_statement
if stmt and hasattr(stmt, 'sql_text'):
sql_text = stmt.sql_text
sql_type = stmt.statement_type
# 2. 兜底解析
msg_type = MessageType.ASSISTANT if hasattr(msg, 'message_type') and msg.message_type == "assistant" else MessageType.USER
# 1. 直接获取在 CRUD 中挂载好的持久化数据
# 我们不再依赖 ai_statement而是直接拿映射好的值
sql_text = getattr(msg, 'sql_text', None)
sql_type = getattr(msg, 'sql_type', 'UNKNOWN')
data = getattr(msg, 'data', None) # <--- 获取持久化结果
msg_type = MessageType.ASSISTANT if msg.message_type == "assistant" else MessageType.USER
# 2. 如果 SQL 信息还没挂载上(比如某些异常情况),再尝试兜底解析
if msg_type == MessageType.ASSISTANT and not sql_text and msg.content:
content = msg.content
if "已生成查询语句" in content:
content = content.split("")[-1].strip()
clean = content.replace("```sql", "").replace("```", "").strip()
try:
parsed = sqlparse.parse(clean)
if parsed and parsed[0].get_type() != "UNKNOWN":
sql_text = clean
sql_type = parsed[0].get_type().upper()
# 你的原始兜底逻辑保持不变,但增加安全性
content = msg.content
if "已生成查询语句" in content:
sql_text = content.split("")[-1].strip()
except: pass
requires_conf = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"]
if msg.user_confirmed: requires_conf = False
# 3. 核心修复:把 data 传进去!
clean_history.append(ChatResponse(
message_id=msg.message_id if hasattr(msg, 'message_id') else msg.id,
content=msg.content, message_type=msg_type, sql_text=sql_text, sql_type=sql_type,
requires_confirmation=requires_conf, data=None
message_id=msg.message_id,
content=msg.content,
message_type=msg_type,
sql_text=sql_text,
sql_type=sql_type,
requires_confirmation=requires_conf,
data=data # <--- 修改这里:不再是 None而是 msg.data
))
return clean_history

@ -45,20 +45,15 @@ class CRUDMessage:
await db.refresh(db_obj)
return db_obj
async def get_recent_messages(self, db: AsyncSession, session_id: int, limit: int = 50) -> List[Message]:
"""
获取最近的历史记录
# backend/app/crud/crud_message.py
Args:
db (AsyncSession): 数据库会话
session_id (int): 会话 ID
limit (int): 返回的消息数量限制
# backend/app/crud/crud_message.py
Returns:
List[Message]: 消息列表 (包含关联的 AI 生成语句)
async def get_recent_messages(self, db: AsyncSession, session_id: int, limit: int = 50) -> List[Message]:
"""
# 第一步:先只查消息表
# 注意:这里去掉了 .options(selectinload(...)),解决了报错
获取最近的历史记录并直接映射 SQL 执行结果
"""
# 1. 查消息基础信息
query = select(Message)\
.filter(Message.session_id == session_id)\
.order_by(desc(Message.created_at))\
@ -70,39 +65,41 @@ class CRUDMessage:
if not messages:
return []
# 第二步:收集所有消息的 ID
# 兼容处理:队友的模型主键可能叫 id也可能叫 message_id
message_ids = []
for m in messages:
# 优先取 message_id如果没有就取 id
mid = getattr(m, "message_id", getattr(m, "id", None))
if mid:
message_ids.append(mid)
# 2. 收集消息 ID
message_ids = [getattr(m, "message_id", getattr(m, "id", None)) for m in messages]
# 第三步:去查 SQL 语句表 (如果找到了消息ID)
ai_statements = []
# 3. 查关联的 SQL 详情和结果 (来自 ai_generated_statement 表)
stmt_map = {}
if message_ids:
# 获取 SQL 文本、类型和执行结果 [cite: 638, 811-820]
stmt_query = select(AIGeneratedStatement).where(
AIGeneratedStatement.message_id.in_(message_ids)
)
stmt_result = await db.execute(stmt_query)
ai_statements = stmt_result.scalars().all()
# 建立 ID 映射
for stmt in ai_statements:
stmt_map[stmt.message_id] = stmt
#第四步:手动拼装 (把查到的 SQL 塞进消息对象里)
# 制作一个字典方便查找: {message_id: [statement1, statement2]}
stmt_map = {}
for stmt in ai_statements:
if stmt.message_id not in stmt_map:
stmt_map[stmt.message_id] = []
stmt_map[stmt.message_id].append(stmt)
# 把 SQL 挂载到 Message 对象上 (临时属性)
# 4. 【核心修复】:直接映射到 ChatResponse 需要的字段
for m in messages:
mid = getattr(m, "message_id", getattr(m, "id", None))
# 我们给对象动态添加一个属性叫 ai_statement这样 Endpoint 那边就不用改代码了
m.ai_statement = stmt_map.get(mid, [])
# 将倒序查询结果翻转为正序
stmt = stmt_map.get(mid)
if stmt:
# 直接给对象赋值,名称必须与 ChatResponse 中的定义一致
m.sql_text = stmt.sql_text
m.sql_type = stmt.statement_type
# 这里的 .data 对应 ai_generated_statement 表的 execution_result 字段 [cite: 818, 830]
m.data = stmt.execution_result
else:
# 对于 user 消息或没有 SQL 的消息,设为空
m.sql_text = None
m.sql_type = "UNKNOWN"
m.data = None
# 5. 返回正序列表,满足历史加载需求
return list(reversed(messages))
# 实例化对象

@ -20,6 +20,7 @@ from fastapi import HTTPException
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi.encoders import jsonable_encoder
from core.config import settings
from core.log import log
@ -29,6 +30,7 @@ from crud.crud_project import crud_project
from models.session import Session as SessionModel
from schema.chat import ChatResponse, MessageType
from service.mysql_service import execute_sql_with_user_check
from models.ai_generated_statement import AIGeneratedStatement
# 必须导入这些模型类,否则确认接口无法运行
from models.message import Message as MessageModel
@ -217,6 +219,8 @@ async def call_ai_agent(ddl_text: str, question: str, model_key: str = None) ->
# 3. 核心业务逻辑 (融合版:含模型记忆 + 安全刹车)
# =========================================================
# backend/app/service/chat_service.py
async def process_chat(
db: AsyncSession,
session_id: int,
@ -230,114 +234,130 @@ async def process_chat(
# 1. 验证会话权限
project_id = await _verify_session_ownership(db, session_id, user_id)
# 2. 获取 Session 对象以处理模型记忆逻辑
# 2. 获取 Session 对象
stmt = select(SessionModel).where(SessionModel.session_id == session_id)
result = await db.execute(stmt)
session_obj = result.scalar_one_or_none()
if not session_obj:
raise HTTPException(status_code=404, detail="Session lost")
# =========================================================
# 模型选择优先级策略
# 3. 模型选择优先级策略 (修复 'final_model_key' 未定义问题)
# =========================================================
final_model_key = DEFAULT_MODEL # 兜底
final_model_key = DEFAULT_MODEL # 默认兜底
if selected_model:
# A. 如果用户本次明确指定了模型 -> 使用它,并更新到数据库(记忆)
# A. 用户本次明确指定了模型
final_model_key = selected_model
if session_obj.current_model != selected_model:
session_obj.current_model = selected_model
db.add(session_obj)
await db.commit() # 保存记忆
log.info(f"Session {session_id} model switched to: {selected_model}")
# 这里先不 commit后面统一提交提高性能
elif session_obj.current_model:
# B. 如果用户没指定,但数据库里有记忆 -> 使用记忆的模型
# B. 使用数据库记忆的模型
final_model_key = session_obj.current_model
log.info(f"Session {session_id} using stored model: {final_model_key}")
else:
session_obj.current_model = DEFAULT_MODEL
db.add(session_obj)
await db.commit()
log.info(f"Session {session_id} using model: {final_model_key}")
# 3. 存用户消息
# 4. 存用户发送的原始消息 (SF6 上下文记忆的基础)
await crud_message.create_message(db, session_id, user_input, role="user")
# 4. 获取 DDL 上下文
# 直接读取 project.ddl_statement不再使用 schema_definition 进行转换
# 5. 获取 DDL 上下文并调用 AI
project = await crud_project.get(db, project_id)
ddl_text = ""
if project and project.ddl_statement:
ddl_text = project.ddl_statement
log.info(f"Using DDL for project {project_id} (Length: {len(ddl_text)})")
else:
# 虽然假设 ddl_statement 一定不为空,但为了稳健性保留一个 Warning
log.warning(f"Project {project_id} has empty ddl_statement!")
ddl_text = "-- Error: No DDL found for this project."
# 5. 调用 AI 生成 SQL
ddl_text = project.ddl_statement if project and project.ddl_statement else "-- No DDL found"
if user_input.strip() in ["取消", "cancel", "Stop"]:
return ChatResponse(
message_id=0, # 指令消息不一定需要存入数据库
content="好的,已为您取消当前操作。",
message_type=MessageType.ASSISTANT,
sql_text=None,
sql_type="ACTION_CANCEL", # 自定义类型,前端据此显示不同 UI
requires_confirmation=False,
data=None
)
# 执行 Text-to-SQL
sql_text = await call_ai_agent(ddl_text, user_input, model_key=final_model_key)
# 6. 解析 SQL 类型
is_meta_sql = any(kw in sql_text.upper() for kw in ["'CANCELED'", "'ERROR'"])
if is_meta_sql:
reply_content = "抱歉,我无法执行该操作或理解您的指令。请提供具体的业务需求(如:查询书籍)。"
# 存一条不带 SQL 执行记录的消息
ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant")
await db.commit()
return ChatResponse(
message_id=ai_message.message_id,
content=reply_content,
message_type=MessageType.ASSISTANT,
sql_text=None, # 不给 SQL前端就不会出黑色代码框
sql_type="ERROR_FEEDBACK",
requires_confirmation=False,
data=None # 不给数据,前端就不会画表格
)
# 6. 解析 SQL 类型与安全性校验 (SF3)
sql_type = "UNKNOWN"
try:
if sql_text and not sql_text.startswith("--"):
parsed = sqlparse.parse(sql_text)
if parsed:
sql_type = parsed[0].get_type().upper()
except Exception:
pass
if sql_type == "UNKNOWN" and sql_text.strip().upper().startswith("SELECT"):
sql_type = "SELECT"
except: pass
# 【关键融合 2】判断是否需要确认
requires_confirm = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"]
# 7. 存 AI 回复消息
reply_content = f"已生成查询语句:\n{sql_text}"
# 7. 【核心修复】:只创建一条 Assistant 消息,防止重复发消息
reply_content = f"已生成sql语句:\n{sql_text}"
ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant")
# 【关键融合 3】如果需要确认必须把状态写回数据库
# 同学的代码里漏了这一步,导致确认时会报 400
if requires_confirm:
ai_message.requires_confirmation = True
db.add(ai_message)
await db.commit()
await db.refresh(ai_message)
# 8. 尝试执行 SQL (带刹车)
# 8. 尝试执行 SQL (针对 DQL 查询)
data = []
execution_status = "pending"
# 【关键融合 4】只有不需要确认的操作才立即执行
# 同学的代码里直接执行了,这很危险
if not requires_confirm:
try:
database_instance = await crud_database_instance.get(db, project.instance_id)
# result 可能是一个列表(SELECT) 或 一个字典(INSERT/UPDATE)
raw_result = await execute_sql_with_user_check(sql_text, sql_type, database_instance)
exec_type = sql_type if sql_type != "UNKNOWN" else "SELECT"
raw_result = await execute_sql_with_user_check(sql_text, exec_type, database_instance)
# 【同学的优化】统一数据格式为 List[Dict]
if isinstance(raw_result, dict):
# 如果是 DML 返回的字典,包裹成列表
data = [raw_result]
elif isinstance(raw_result, list):
# 如果是 DQL 返回的列表,直接使用
data = raw_result
else:
data = []
# 统一数据格式为 List[Dict]
if isinstance(raw_result, list): data = raw_result
elif isinstance(raw_result, dict): data = [raw_result]
execution_status = "success"
except Exception as e:
log.info(f"SQL Execution Error (Safe to ignore if SQL is invalid): {str(e)}")
log.error(f"SQL Execution Error: {str(e)}")
execution_status = "failed"
else:
# 如果需要确认,跳过执行
# DML 操作,标记为需确认,前端会显示确认按钮
log.info(f"SQL requires confirmation ({sql_type}), skipping immediate execution.")
ai_message.requires_confirmation = True
db.add(ai_message)
execution_status = "pending"
# 9. 【关键持久化】:使用 jsonable_encoder 处理日期并保存到子表 (SF9)
safe_data = jsonable_encoder(data)
new_statement = AIGeneratedStatement(
message_id=ai_message.message_id,
sql_text=sql_text,
statement_type=sql_type,
execution_status=execution_status,
execution_result=safe_data,
statement_order=1
)
db.add(new_statement)
# 最后统一 commit 事务,保证数据一致性
await db.commit()
# 9. 构造响应
return ChatResponse(
message_id=ai_message.message_id,
content=reply_content,
message_type=MessageType.ASSISTANT,
sql_text=sql_text,
sql_type=sql_type,
requires_confirmation=requires_confirm,
data=data
requires_confirmation=ai_message.requires_confirmation,
data=safe_data
)
# =========================================================
@ -402,24 +422,27 @@ async def confirm_and_execute_sql(
execute_res = []
try:
database_instance = await crud_database_instance.get(db, project.instance_id)
# 真正执行
# 真正执行 DML
result = await execute_sql_with_user_check(sql_text, "UPDATE", database_instance)
execute_res = [result] if isinstance(result, dict) else result
# 📦 统一包装成列表
if isinstance(result, dict):
execute_res = [result]
elif isinstance(result, list):
execute_res = result
else:
execute_res = []
# 更新消息状态
# 1. 更新消息确认状态
message.user_confirmed = True
db.add(message)
await db.commit()
log.info(f"User {user_id} confirmed execution of message {message_id}")
# 2. 【关键:同步更新子表结果】
# 找到该消息对应的 SQL 记录并把执行结果填进去
stmt_query = select(AIGeneratedStatement).where(AIGeneratedStatement.message_id == message_id)
stmt_res = await db.execute(stmt_query)
db_stmt = stmt_res.scalar_one_or_none()
if db_stmt:
db_stmt.execution_result = execute_res
db_stmt.execution_status = "success"
db.add(db_stmt)
await db.commit()
log.info(f"User {user_id} executed DML and persisted results.")
except Exception as e:
log.error(f"Execution failed: {e}")

Loading…
Cancel
Save