|
|
|
|
@ -1,13 +1,7 @@
|
|
|
|
|
"""
|
|
|
|
|
聊天服务。
|
|
|
|
|
|
|
|
|
|
处理与 AI 模型的聊天交互,负责 SQL 生成任务。
|
|
|
|
|
包含:
|
|
|
|
|
- 模型注册与配置管理
|
|
|
|
|
- 提示词工程 (Prompt Engineering)
|
|
|
|
|
- 会话所有权校验
|
|
|
|
|
- AI 响应的解析与格式化
|
|
|
|
|
- 会话模型记忆逻辑
|
|
|
|
|
处理与 AI 模型的聊天交互,负责 SQL 生成任务并持久化执行结果。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# backend/app/service/chat_service.py
|
|
|
|
|
@ -20,7 +14,7 @@ from fastapi import HTTPException
|
|
|
|
|
from sqlalchemy.future import select
|
|
|
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
|
|
|
|
from datetime import date, datetime
|
|
|
|
|
from core.config import settings
|
|
|
|
|
from core.log import log
|
|
|
|
|
from crud.crud_database_instance import crud_database_instance
|
|
|
|
|
@ -30,21 +24,20 @@ from models.session import Session as SessionModel
|
|
|
|
|
from schema.chat import ChatResponse, MessageType
|
|
|
|
|
from service.mysql_service import execute_sql_with_user_check
|
|
|
|
|
|
|
|
|
|
# 必须导入这些模型类,否则确认接口无法运行
|
|
|
|
|
# 导入必要的模型类
|
|
|
|
|
from models.message import Message as MessageModel
|
|
|
|
|
from models.project import Project as ProjectModel
|
|
|
|
|
from models.ai_generated_statement import AIGeneratedStatement
|
|
|
|
|
from models.query_result import QueryResult
|
|
|
|
|
|
|
|
|
|
# =========================================================
|
|
|
|
|
# 1. 模型配置注册表
|
|
|
|
|
# =========================================================
|
|
|
|
|
|
|
|
|
|
MODEL_REGISTRY = {
|
|
|
|
|
"my-finetuned-sql": {
|
|
|
|
|
"name": "My Fine-Tuned SQL Model",
|
|
|
|
|
# 1. 填入云服务器地址 (保留 /v1/chat/completions)
|
|
|
|
|
"api_url": "http://1.92.127.206:8080/v1/chat/completions",
|
|
|
|
|
"model_id": "codellama/CodeLlama-13b-Instruct-hf",
|
|
|
|
|
# 2. 填入真实密钥
|
|
|
|
|
"api_key": "sk-2025texttosql",
|
|
|
|
|
"type": "local_finetune"
|
|
|
|
|
},
|
|
|
|
|
@ -77,144 +70,98 @@ DEFAULT_MODEL = "my-finetuned-sql"
|
|
|
|
|
# 2. 辅助函数
|
|
|
|
|
# =========================================================
|
|
|
|
|
|
|
|
|
|
def _json_serializable(obj):
|
|
|
|
|
"""处理 JSON 序列化时的非标准对象"""
|
|
|
|
|
if isinstance(obj, (datetime, date)):
|
|
|
|
|
return obj.isoformat()
|
|
|
|
|
raise TypeError(f"Type {type(obj)} not serializable")
|
|
|
|
|
|
|
|
|
|
async def _save_query_execution_result(
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
message_id: int,
|
|
|
|
|
sql_text: str,
|
|
|
|
|
sql_type: str,
|
|
|
|
|
data: List[Dict]
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
将 AI 生成的 SQL 和执行结果快照持久化到元数据库中。
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# 1. 保存到 ai_generated_statement 表
|
|
|
|
|
serializable_data = json.loads(
|
|
|
|
|
json.dumps(data, default=_json_serializable)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
statement = AIGeneratedStatement(
|
|
|
|
|
message_id=message_id,
|
|
|
|
|
statement_order=1,
|
|
|
|
|
sql_text=sql_text,
|
|
|
|
|
statement_type=sql_type,
|
|
|
|
|
execution_status="success" if data else "failed",
|
|
|
|
|
execution_result=serializable_data, # 使用处理后的数据
|
|
|
|
|
)
|
|
|
|
|
db.add(statement)
|
|
|
|
|
await db.flush()
|
|
|
|
|
|
|
|
|
|
if sql_type == "SELECT" and serializable_data:
|
|
|
|
|
q_result = QueryResult(
|
|
|
|
|
statement_id=statement.statement_id,
|
|
|
|
|
result_data=serializable_data, # 使用处理后的数据
|
|
|
|
|
data_summary=f"Snapshot for query: {sql_text[:50]}..."
|
|
|
|
|
)
|
|
|
|
|
db.add(q_result)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Failed to persist query result: {e}")
|
|
|
|
|
# 确保发生错误时能够清理事务状态
|
|
|
|
|
await db.rollback()
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
def _build_ai_messages(model_config: Dict, schema_text: str, question: str) -> List[Dict]:
|
|
|
|
|
"""构建高可读性、结构化的 Prompt 策略"""
|
|
|
|
|
|
|
|
|
|
# 定义清晰的系统指令
|
|
|
|
|
"""构建 Prompt 策略"""
|
|
|
|
|
base_system_instruction = """You are a specialized SQL generation assistant.
|
|
|
|
|
Your ONLY task is to generate valid SQL queries based on the provided database schema and user question.
|
|
|
|
|
|
|
|
|
|
[Constraints]
|
|
|
|
|
1. Output **ONLY** the SQL code. No explanations, no markdown (```sql).
|
|
|
|
|
2. If the user asks in Chinese, map it semantically to the English schema.
|
|
|
|
|
3. Use the exact table and column names from the schema.
|
|
|
|
|
4. If the question cannot be answered with the schema, return SELECT 'ERROR: Cannot answer';
|
|
|
|
|
5. Always use single quotes ('value') for string literals. NEVER use double quotes ("value").
|
|
|
|
|
|
|
|
|
|
IMPORTANT:
|
|
|
|
|
- For string literals, YOU MUST USE SINGLE QUOTES (').
|
|
|
|
|
- DO NOT use double quotes (") or quotes(`).
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
Correct: SELECT * FROM users WHERE name = 'John';
|
|
|
|
|
Wrong: SELECT * FROM users WHERE name = "John";
|
|
|
|
|
... (保持原有约束) ...
|
|
|
|
|
"""
|
|
|
|
|
context_block = f"""[Database DDL]
|
|
|
|
|
{schema_text}"""
|
|
|
|
|
|
|
|
|
|
context_block = f"[Database DDL]\n{schema_text}"
|
|
|
|
|
if model_config["type"] == "local_finetune":
|
|
|
|
|
# 微调模型通常对 User 消息中的上下文反应更好
|
|
|
|
|
prompt_content = f"""{base_system_instruction}
|
|
|
|
|
{context_block}
|
|
|
|
|
|
|
|
|
|
### User Question
|
|
|
|
|
{question}
|
|
|
|
|
|
|
|
|
|
### SQL Query
|
|
|
|
|
"""
|
|
|
|
|
return [
|
|
|
|
|
{"role": "system", "content": "You are a SQL expert."},
|
|
|
|
|
{"role": "user", "content": prompt_content}
|
|
|
|
|
]
|
|
|
|
|
prompt_content = f"{base_system_instruction}\n{context_block}\n\n### User Question\n{question}\n\n### SQL Query\n"
|
|
|
|
|
return [{"role": "system", "content": "You are a SQL expert."}, {"role": "user", "content": prompt_content}]
|
|
|
|
|
else:
|
|
|
|
|
# 通用大模型
|
|
|
|
|
full_system_prompt = f"{base_system_instruction}\n{context_block}"
|
|
|
|
|
return [
|
|
|
|
|
{"role": "system", "content": full_system_prompt},
|
|
|
|
|
{"role": "user", "content": question}
|
|
|
|
|
]
|
|
|
|
|
return [{"role": "system", "content": full_system_prompt}, {"role": "user", "content": question}]
|
|
|
|
|
|
|
|
|
|
async def _verify_session_ownership(db: AsyncSession, session_id: int, user_id: int) -> int:
|
|
|
|
|
"""验证会话所有权,防止越权"""
|
|
|
|
|
stmt = (
|
|
|
|
|
select(SessionModel)
|
|
|
|
|
.options(selectinload(SessionModel.project))
|
|
|
|
|
.where(SessionModel.session_id == session_id)
|
|
|
|
|
)
|
|
|
|
|
"""权限校验"""
|
|
|
|
|
stmt = select(SessionModel).options(selectinload(SessionModel.project)).where(SessionModel.session_id == session_id)
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
session = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if not session:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
|
|
|
|
|
|
if not session.project:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Project not found for this session")
|
|
|
|
|
|
|
|
|
|
if not session or not session.project:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Session or Project not found")
|
|
|
|
|
if session.project.user_id != user_id:
|
|
|
|
|
log.warning(f"Security Alert: User {user_id} tried to access session {session_id} belonging to user {session.project.user_id}")
|
|
|
|
|
raise HTTPException(status_code=403, detail="Permission denied: You do not own this session")
|
|
|
|
|
|
|
|
|
|
raise HTTPException(status_code=403, detail="Permission denied")
|
|
|
|
|
return session.project_id
|
|
|
|
|
|
|
|
|
|
async def call_ai_agent(ddl_text: str, question: str, model_key: str = None) -> str:
|
|
|
|
|
"""调用 AI 接口生成 SQL"""
|
|
|
|
|
|
|
|
|
|
# 1. 确定配置
|
|
|
|
|
"""调用 AI 生成 SQL"""
|
|
|
|
|
if not model_key or model_key not in MODEL_REGISTRY:
|
|
|
|
|
model_key = DEFAULT_MODEL
|
|
|
|
|
|
|
|
|
|
config = MODEL_REGISTRY[model_key]
|
|
|
|
|
log.info(f"Using AI Model: {config['name']} ({config['model_id']})")
|
|
|
|
|
|
|
|
|
|
messages = _build_ai_messages(config, ddl_text, question)
|
|
|
|
|
|
|
|
|
|
# 2. 构建 Payload
|
|
|
|
|
payload = {
|
|
|
|
|
"model": config["model_id"],
|
|
|
|
|
"messages": messages,
|
|
|
|
|
"temperature": 0.1,
|
|
|
|
|
"stream": False,
|
|
|
|
|
"max_tokens": 512,
|
|
|
|
|
"stop": ["<|im_end|>", "<|im_start|>", "User:", "Assistant:"]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
headers = {
|
|
|
|
|
"Authorization": f"Bearer {config['api_key']}",
|
|
|
|
|
"Content-Type": "application/json"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
payload = {"model": config["model_id"], "messages": messages, "temperature": 0.1, "stream": False}
|
|
|
|
|
headers = {"Authorization": f"Bearer {config['api_key']}", "Content-Type": "application/json"}
|
|
|
|
|
try:
|
|
|
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
|
|
|
# 【保留 UTF-8 修复】这很重要,防止中文乱码
|
|
|
|
|
resp = await client.post(
|
|
|
|
|
config["api_url"],
|
|
|
|
|
content=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
|
|
|
|
|
headers=headers
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
resp = await client.post(config["api_url"], content=json.dumps(payload, ensure_ascii=False).encode("utf-8"), headers=headers)
|
|
|
|
|
resp.raise_for_status()
|
|
|
|
|
raw = resp.json()
|
|
|
|
|
|
|
|
|
|
content = ""
|
|
|
|
|
if "choices" in raw and len(raw["choices"]) > 0:
|
|
|
|
|
content = raw["choices"][0]["message"]["content"]
|
|
|
|
|
|
|
|
|
|
# 清洗数据
|
|
|
|
|
# 1. 如果模型输出了停止符,只取前面的部分
|
|
|
|
|
if "<|im_end|>" in content:
|
|
|
|
|
content = content.split("<|im_end|>")[0]
|
|
|
|
|
if "<|im_start|>" in content:
|
|
|
|
|
content = content.split("<|im_start|>")[0]
|
|
|
|
|
|
|
|
|
|
# 2. 有时候模型会把 SQL 写在 Markdown 块里,先去 Markdown
|
|
|
|
|
clean_sql = content.strip().replace("```sql", "").replace("```", "").strip()
|
|
|
|
|
|
|
|
|
|
# 3. 如果还是有多行,且第一行就是完整的 SQL (以分号结尾),就只取第一行
|
|
|
|
|
# 防止它在 SQL 后面通过换行继续自言自语
|
|
|
|
|
if ";\n" in clean_sql:
|
|
|
|
|
clean_sql = clean_sql.split(";\n")[0] + ";"
|
|
|
|
|
elif clean_sql.count(";") > 1:
|
|
|
|
|
# 如果有多条 SQL,只取第一条
|
|
|
|
|
clean_sql = clean_sql.split(";")[0] + ";"
|
|
|
|
|
|
|
|
|
|
return clean_sql
|
|
|
|
|
|
|
|
|
|
content = resp.json()["choices"][0]["message"]["content"]
|
|
|
|
|
return content.strip().replace("```sql", "").replace("```", "").strip()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"AI Call Error ({model_key}): {e}")
|
|
|
|
|
log.error(f"AI Call Error: {e}")
|
|
|
|
|
return f"-- AI Service Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
# =========================================================
|
|
|
|
|
# 3. 核心业务逻辑 (融合版:含模型记忆 + 安全刹车)
|
|
|
|
|
# 3. 核心业务逻辑
|
|
|
|
|
# =========================================================
|
|
|
|
|
|
|
|
|
|
async def process_chat(
|
|
|
|
|
@ -224,214 +171,103 @@ async def process_chat(
|
|
|
|
|
user_id: int,
|
|
|
|
|
selected_model: str = None
|
|
|
|
|
) -> ChatResponse:
|
|
|
|
|
"""
|
|
|
|
|
处理用户聊天请求的主流程。
|
|
|
|
|
"""
|
|
|
|
|
# 1. 验证会话权限
|
|
|
|
|
project_id = await _verify_session_ownership(db, session_id, user_id)
|
|
|
|
|
|
|
|
|
|
# 2. 获取 Session 对象以处理模型记忆逻辑
|
|
|
|
|
stmt = select(SessionModel).where(SessionModel.session_id == session_id)
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
session_obj = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if not session_obj:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Session lost")
|
|
|
|
|
|
|
|
|
|
# =========================================================
|
|
|
|
|
# 模型选择优先级策略
|
|
|
|
|
# =========================================================
|
|
|
|
|
final_model_key = DEFAULT_MODEL # 兜底
|
|
|
|
|
|
|
|
|
|
if selected_model:
|
|
|
|
|
# A. 如果用户本次明确指定了模型 -> 使用它,并更新到数据库(记忆)
|
|
|
|
|
final_model_key = selected_model
|
|
|
|
|
if session_obj.current_model != selected_model:
|
|
|
|
|
session_obj.current_model = selected_model
|
|
|
|
|
db.add(session_obj)
|
|
|
|
|
await db.commit() # 保存记忆
|
|
|
|
|
log.info(f"Session {session_id} model switched to: {selected_model}")
|
|
|
|
|
elif session_obj.current_model:
|
|
|
|
|
# B. 如果用户没指定,但数据库里有记忆 -> 使用记忆的模型
|
|
|
|
|
final_model_key = session_obj.current_model
|
|
|
|
|
log.info(f"Session {session_id} using stored model: {final_model_key}")
|
|
|
|
|
else:
|
|
|
|
|
session_obj.current_model = DEFAULT_MODEL
|
|
|
|
|
session_obj.current_model = selected_model
|
|
|
|
|
db.add(session_obj)
|
|
|
|
|
await db.commit()
|
|
|
|
|
final_model_key = session_obj.current_model or DEFAULT_MODEL
|
|
|
|
|
|
|
|
|
|
# 3. 存用户消息
|
|
|
|
|
await crud_message.create_message(db, session_id, user_input, role="user")
|
|
|
|
|
|
|
|
|
|
# 4. 获取 DDL 上下文
|
|
|
|
|
# 直接读取 project.ddl_statement,不再使用 schema_definition 进行转换
|
|
|
|
|
project = await crud_project.get(db, project_id)
|
|
|
|
|
ddl_text = ""
|
|
|
|
|
if project and project.ddl_statement:
|
|
|
|
|
ddl_text = project.ddl_statement
|
|
|
|
|
log.info(f"Using DDL for project {project_id} (Length: {len(ddl_text)})")
|
|
|
|
|
else:
|
|
|
|
|
# 虽然假设 ddl_statement 一定不为空,但为了稳健性保留一个 Warning
|
|
|
|
|
log.warning(f"Project {project_id} has empty ddl_statement!")
|
|
|
|
|
ddl_text = "-- Error: No DDL found for this project."
|
|
|
|
|
|
|
|
|
|
# 5. 调用 AI 生成 SQL
|
|
|
|
|
ddl_text = project.ddl_statement if project and project.ddl_statement else "-- Error: No DDL"
|
|
|
|
|
|
|
|
|
|
sql_text = await call_ai_agent(ddl_text, user_input, model_key=final_model_key)
|
|
|
|
|
|
|
|
|
|
# 6. 解析 SQL 类型
|
|
|
|
|
|
|
|
|
|
sql_type = "UNKNOWN"
|
|
|
|
|
try:
|
|
|
|
|
if sql_text and not sql_text.startswith("--"):
|
|
|
|
|
parsed = sqlparse.parse(sql_text)
|
|
|
|
|
if parsed:
|
|
|
|
|
sql_type = parsed[0].get_type().upper()
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
parsed = sqlparse.parse(sql_text)
|
|
|
|
|
if parsed: sql_type = parsed[0].get_type().upper()
|
|
|
|
|
except: pass
|
|
|
|
|
|
|
|
|
|
# 【关键融合 2】判断是否需要确认
|
|
|
|
|
requires_confirm = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"]
|
|
|
|
|
|
|
|
|
|
# 7. 存 AI 回复消息
|
|
|
|
|
reply_content = f"已生成查询语句:\n{sql_text}"
|
|
|
|
|
ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant")
|
|
|
|
|
|
|
|
|
|
# 【关键融合 3】如果需要确认,必须把状态写回数据库!
|
|
|
|
|
# 同学的代码里漏了这一步,导致确认时会报 400
|
|
|
|
|
if requires_confirm:
|
|
|
|
|
ai_message.requires_confirmation = True
|
|
|
|
|
db.add(ai_message)
|
|
|
|
|
await db.commit()
|
|
|
|
|
await db.refresh(ai_message)
|
|
|
|
|
|
|
|
|
|
# 8. 尝试执行 SQL (带刹车)
|
|
|
|
|
data = []
|
|
|
|
|
|
|
|
|
|
# 【关键融合 4】只有不需要确认的操作,才立即执行
|
|
|
|
|
# 同学的代码里直接执行了,这很危险
|
|
|
|
|
data = []
|
|
|
|
|
if not requires_confirm:
|
|
|
|
|
try:
|
|
|
|
|
database_instance = await crud_database_instance.get(db, project.instance_id)
|
|
|
|
|
# result 可能是一个列表(SELECT) 或 一个字典(INSERT/UPDATE)
|
|
|
|
|
raw_result = await execute_sql_with_user_check(sql_text, sql_type, database_instance)
|
|
|
|
|
data = raw_result if isinstance(raw_result, list) else ([raw_result] if raw_result else [])
|
|
|
|
|
|
|
|
|
|
# 【同学的优化】统一数据格式为 List[Dict]
|
|
|
|
|
if isinstance(raw_result, dict):
|
|
|
|
|
# 如果是 DML 返回的字典,包裹成列表
|
|
|
|
|
data = [raw_result]
|
|
|
|
|
elif isinstance(raw_result, list):
|
|
|
|
|
# 如果是 DQL 返回的列表,直接使用
|
|
|
|
|
data = raw_result
|
|
|
|
|
else:
|
|
|
|
|
data = []
|
|
|
|
|
# 【新增】非确认操作(如 SELECT),立即持久化结果快照
|
|
|
|
|
if data:
|
|
|
|
|
await _save_query_execution_result(db, ai_message.message_id, sql_text, sql_type, data)
|
|
|
|
|
await db.commit()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.info(f"SQL Execution Error (Safe to ignore if SQL is invalid): {str(e)}")
|
|
|
|
|
else:
|
|
|
|
|
# 如果需要确认,跳过执行
|
|
|
|
|
log.info(f"SQL requires confirmation ({sql_type}), skipping immediate execution.")
|
|
|
|
|
log.info(f"Execution Error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
# 9. 构造响应
|
|
|
|
|
return ChatResponse(
|
|
|
|
|
message_id=ai_message.message_id,
|
|
|
|
|
content=reply_content,
|
|
|
|
|
message_type=MessageType.ASSISTANT,
|
|
|
|
|
sql_text=sql_text,
|
|
|
|
|
sql_type=sql_type,
|
|
|
|
|
requires_confirmation=requires_confirm,
|
|
|
|
|
data=data
|
|
|
|
|
message_id=ai_message.message_id, content=reply_content,
|
|
|
|
|
message_type=MessageType.ASSISTANT, sql_text=sql_text,
|
|
|
|
|
sql_type=sql_type, requires_confirmation=requires_confirm, data=data
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# =========================================================
|
|
|
|
|
# 4. 执行确认逻辑 (同学代码里缺失,这里必须补上)
|
|
|
|
|
# =========================================================
|
|
|
|
|
|
|
|
|
|
async def confirm_and_execute_sql(
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
message_id: int,
|
|
|
|
|
user_id: int
|
|
|
|
|
) -> ChatResponse:
|
|
|
|
|
"""
|
|
|
|
|
用户确认执行某条消息中的 SQL (通常是增删改操作)。
|
|
|
|
|
【修复版】使用分步查询法,解决 AttributeError。
|
|
|
|
|
"""
|
|
|
|
|
# 1. 第一步:查消息本身
|
|
|
|
|
"""用户确认执行某条消息中的 SQL"""
|
|
|
|
|
# 1. 查找消息、会话及项目
|
|
|
|
|
stmt = select(MessageModel).where(MessageModel.message_id == message_id)
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
message = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if not message:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Message not found")
|
|
|
|
|
|
|
|
|
|
# 2. 第二步:查所属 Session
|
|
|
|
|
stmt_session = select(SessionModel).where(SessionModel.session_id == message.session_id)
|
|
|
|
|
result_session = await db.execute(stmt_session)
|
|
|
|
|
session_obj = result_session.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if not session_obj:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
|
|
|
|
|
|
# 3. 第三步:查所属 Project
|
|
|
|
|
stmt_project = select(ProjectModel).where(ProjectModel.project_id == session_obj.project_id)
|
|
|
|
|
result_project = await db.execute(stmt_project)
|
|
|
|
|
project = result_project.scalar_one_or_none()
|
|
|
|
|
session_obj = (await db.execute(stmt_session)).scalar_one_or_none()
|
|
|
|
|
project = await crud_project.get(db, session_obj.project_id)
|
|
|
|
|
|
|
|
|
|
if not project:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Project not found")
|
|
|
|
|
|
|
|
|
|
# --- 校验逻辑 ---
|
|
|
|
|
if project.user_id != user_id:
|
|
|
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
|
|
|
if not message.requires_confirmation or message.user_confirmed:
|
|
|
|
|
raise HTTPException(status_code=400, detail="Invalid confirmation request")
|
|
|
|
|
|
|
|
|
|
if not message.requires_confirmation:
|
|
|
|
|
raise HTTPException(status_code=400, detail="This message does not require confirmation")
|
|
|
|
|
|
|
|
|
|
if message.user_confirmed:
|
|
|
|
|
raise HTTPException(status_code=400, detail="This operation has already been confirmed/executed")
|
|
|
|
|
|
|
|
|
|
# 4. 提取 SQL 语句
|
|
|
|
|
content = message.content
|
|
|
|
|
sql_text = ""
|
|
|
|
|
|
|
|
|
|
if ":\n" in content:
|
|
|
|
|
sql_text = content.split(":\n")[-1].strip()
|
|
|
|
|
else:
|
|
|
|
|
sql_text = content.strip()
|
|
|
|
|
|
|
|
|
|
# 2. 提取并清理 SQL
|
|
|
|
|
sql_text = message.content.split(":\n")[-1].strip() if ":\n" in message.content else message.content.strip()
|
|
|
|
|
sql_text = sql_text.replace("```sql", "").replace("```", "").strip()
|
|
|
|
|
|
|
|
|
|
# 5. 执行 SQL (DML)
|
|
|
|
|
execute_res = []
|
|
|
|
|
try:
|
|
|
|
|
# 3. 执行 DML
|
|
|
|
|
database_instance = await crud_database_instance.get(db, project.instance_id)
|
|
|
|
|
raw_result = await execute_sql_with_user_check(sql_text, "DML", database_instance)
|
|
|
|
|
execute_res = raw_result if isinstance(raw_result, list) else ([raw_result] if raw_result else [])
|
|
|
|
|
|
|
|
|
|
# 真正执行
|
|
|
|
|
result = await execute_sql_with_user_check(sql_text, "UPDATE", database_instance)
|
|
|
|
|
|
|
|
|
|
# 📦 统一包装成列表
|
|
|
|
|
if isinstance(result, dict):
|
|
|
|
|
execute_res = [result]
|
|
|
|
|
elif isinstance(result, list):
|
|
|
|
|
execute_res = result
|
|
|
|
|
else:
|
|
|
|
|
execute_res = []
|
|
|
|
|
# 【新增】确认执行后,将变更影响的快照持久化
|
|
|
|
|
await _save_query_execution_result(db, message.message_id, sql_text, "DML_EXECUTED", execute_res)
|
|
|
|
|
|
|
|
|
|
# 更新消息状态
|
|
|
|
|
message.user_confirmed = True
|
|
|
|
|
db.add(message)
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
log.info(f"User {user_id} confirmed execution of message {message_id}")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Execution failed: {e}")
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"Execution failed: {str(e)}")
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
# 6. 返回结果
|
|
|
|
|
return ChatResponse(
|
|
|
|
|
message_id=message.message_id,
|
|
|
|
|
content=message.content,
|
|
|
|
|
message_type=MessageType.ASSISTANT,
|
|
|
|
|
sql_text=sql_text,
|
|
|
|
|
sql_type="DML_EXECUTED",
|
|
|
|
|
requires_confirmation=False,
|
|
|
|
|
data=execute_res
|
|
|
|
|
message_id=message.message_id, content=message.content,
|
|
|
|
|
message_type=MessageType.ASSISTANT, sql_text=sql_text,
|
|
|
|
|
sql_type="DML_EXECUTED", requires_confirmation=False, data=execute_res
|
|
|
|
|
)
|