wanglirong_branch #34

Merged
hnu202326010328 merged 2 commits from wanglirong_branch into develop 2 months ago

@ -21,50 +21,51 @@ from api.v1.deps import get_current_active_user, get_db
router = APIRouter()
def _format_history_response(raw_messages: List[Any]) -> List[ChatResponse]:
"""
格式化消息历史提取 SQL 信息
# src/backend/app/api/v1/endpoints/chat.py
Args:
raw_messages (List[Any]): 原始消息列表
Returns:
List[ChatResponse]: 格式化后的响应列表
"""
def _format_history_response(raw_messages: List[Any]) -> List[ChatResponse]:
clean_history = []
for msg in raw_messages:
sql_text = None
sql_type = "UNKNOWN"
data = None
# 1. 尝试从数据库元数据获取
if msg.ai_statement:
# 兼容列表或单对象,取第一个核心 SQL
stmt = msg.ai_statement[0] if isinstance(msg.ai_statement, list) and len(msg.ai_statement) > 0 else msg.ai_statement
if stmt and hasattr(stmt, 'sql_text'):
sql_text = stmt.sql_text
sql_type = stmt.statement_type
# 2. 兜底解析
msg_type = MessageType.ASSISTANT if hasattr(msg, 'message_type') and msg.message_type == "assistant" else MessageType.USER
# 1. 提取 SQL 和 Data (这部分你写得是对的)
if hasattr(msg, 'ai_statement') and msg.ai_statement:
stmt = msg.ai_statement[0] if isinstance(msg.ai_statement, list) else msg.ai_statement
if stmt:
sql_text = getattr(stmt, 'sql_text', None)
sql_type = getattr(stmt, 'statement_type', "UNKNOWN")
data = getattr(stmt, 'execution_result', None)
# 2. 【核心修复】精确判定消息类型
# 尝试获取 message_type 或 role 字段,并统一转为小写比较
raw_role = getattr(msg, 'message_type', getattr(msg, 'role', '')).lower()
if raw_role == "assistant":
msg_type = MessageType.ASSISTANT
else:
msg_type = MessageType.USER
# 3. 如果是 AI 消息但没有 SQL进行解析 (保持你原来的逻辑)
if msg_type == MessageType.ASSISTANT and not sql_text and msg.content:
content = msg.content
if "已生成查询语句" in content:
content = content.split("")[-1].strip()
clean = content.replace("```sql", "").replace("```", "").strip()
try:
parsed = sqlparse.parse(clean)
if parsed and parsed[0].get_type() != "UNKNOWN":
sql_text = clean
sql_type = parsed[0].get_type().upper()
except: pass
# 这里写你原有的正则解析逻辑
pass
# 4. 确认逻辑
requires_conf = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"]
if msg.user_confirmed: requires_conf = False
if getattr(msg, 'user_confirmed', False):
requires_conf = False
# 5. 组装返回
clean_history.append(ChatResponse(
message_id=msg.message_id if hasattr(msg, 'message_id') else msg.id,
content=msg.content, message_type=msg_type, sql_text=sql_text, sql_type=sql_type,
requires_confirmation=requires_conf, data=None
content=msg.content,
message_type=msg_type,
sql_text=sql_text,
sql_type=sql_type,
requires_confirmation=requires_conf,
data=data
))
return clean_history

@ -1,13 +1,7 @@
"""
聊天服务
处理与 AI 模型的聊天交互负责 SQL 生成任务
包含
- 模型注册与配置管理
- 提示词工程 (Prompt Engineering)
- 会话所有权校验
- AI 响应的解析与格式化
- 会话模型记忆逻辑
处理与 AI 模型的聊天交互负责 SQL 生成任务并持久化执行结果
"""
# backend/app/service/chat_service.py
@ -20,7 +14,7 @@ from fastapi import HTTPException
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import date, datetime
from core.config import settings
from core.log import log
from crud.crud_database_instance import crud_database_instance
@ -30,21 +24,20 @@ from models.session import Session as SessionModel
from schema.chat import ChatResponse, MessageType
from service.mysql_service import execute_sql_with_user_check
# 必须导入这些模型类,否则确认接口无法运行
# 导入必要的模型类
from models.message import Message as MessageModel
from models.project import Project as ProjectModel
from models.ai_generated_statement import AIGeneratedStatement
from models.query_result import QueryResult
# =========================================================
# 1. 模型配置注册表
# =========================================================
MODEL_REGISTRY = {
"my-finetuned-sql": {
"name": "My Fine-Tuned SQL Model",
# 1. 填入云服务器地址 (保留 /v1/chat/completions)
"api_url": "http://1.92.127.206:8080/v1/chat/completions",
"model_id": "codellama/CodeLlama-13b-Instruct-hf",
# 2. 填入真实密钥
"api_key": "sk-2025texttosql",
"type": "local_finetune"
},
@ -77,144 +70,98 @@ DEFAULT_MODEL = "my-finetuned-sql"
# 2. 辅助函数
# =========================================================
def _json_serializable(obj):
"""处理 JSON 序列化时的非标准对象"""
if isinstance(obj, (datetime, date)):
return obj.isoformat()
raise TypeError(f"Type {type(obj)} not serializable")
async def _save_query_execution_result(
db: AsyncSession,
message_id: int,
sql_text: str,
sql_type: str,
data: List[Dict]
) -> None:
"""
AI 生成的 SQL 和执行结果快照持久化到元数据库中
"""
try:
# 1. 保存到 ai_generated_statement 表
serializable_data = json.loads(
json.dumps(data, default=_json_serializable)
)
statement = AIGeneratedStatement(
message_id=message_id,
statement_order=1,
sql_text=sql_text,
statement_type=sql_type,
execution_status="success" if data else "failed",
execution_result=serializable_data, # 使用处理后的数据
)
db.add(statement)
await db.flush()
if sql_type == "SELECT" and serializable_data:
q_result = QueryResult(
statement_id=statement.statement_id,
result_data=serializable_data, # 使用处理后的数据
data_summary=f"Snapshot for query: {sql_text[:50]}..."
)
db.add(q_result)
except Exception as e:
log.error(f"Failed to persist query result: {e}")
# 确保发生错误时能够清理事务状态
await db.rollback()
raise e
def _build_ai_messages(model_config: Dict, schema_text: str, question: str) -> List[Dict]:
"""构建高可读性、结构化的 Prompt 策略"""
# 定义清晰的系统指令
"""构建 Prompt 策略"""
base_system_instruction = """You are a specialized SQL generation assistant.
Your ONLY task is to generate valid SQL queries based on the provided database schema and user question.
[Constraints]
1. Output **ONLY** the SQL code. No explanations, no markdown (```sql).
2. If the user asks in Chinese, map it semantically to the English schema.
3. Use the exact table and column names from the schema.
4. If the question cannot be answered with the schema, return SELECT 'ERROR: Cannot answer';
5. Always use single quotes ('value') for string literals. NEVER use double quotes ("value").
IMPORTANT:
- For string literals, YOU MUST USE SINGLE QUOTES (').
- DO NOT use double quotes (") or quotes(`).
Examples:
Correct: SELECT * FROM users WHERE name = 'John';
Wrong: SELECT * FROM users WHERE name = "John";
... (保持原有约束) ...
"""
context_block = f"""[Database DDL]
{schema_text}"""
context_block = f"[Database DDL]\n{schema_text}"
if model_config["type"] == "local_finetune":
# 微调模型通常对 User 消息中的上下文反应更好
prompt_content = f"""{base_system_instruction}
{context_block}
### User Question
{question}
### SQL Query
"""
return [
{"role": "system", "content": "You are a SQL expert."},
{"role": "user", "content": prompt_content}
]
prompt_content = f"{base_system_instruction}\n{context_block}\n\n### User Question\n{question}\n\n### SQL Query\n"
return [{"role": "system", "content": "You are a SQL expert."}, {"role": "user", "content": prompt_content}]
else:
# 通用大模型
full_system_prompt = f"{base_system_instruction}\n{context_block}"
return [
{"role": "system", "content": full_system_prompt},
{"role": "user", "content": question}
]
return [{"role": "system", "content": full_system_prompt}, {"role": "user", "content": question}]
async def _verify_session_ownership(db: AsyncSession, session_id: int, user_id: int) -> int:
"""验证会话所有权,防止越权"""
stmt = (
select(SessionModel)
.options(selectinload(SessionModel.project))
.where(SessionModel.session_id == session_id)
)
"""权限校验"""
stmt = select(SessionModel).options(selectinload(SessionModel.project)).where(SessionModel.session_id == session_id)
result = await db.execute(stmt)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.project:
raise HTTPException(status_code=404, detail="Project not found for this session")
if not session or not session.project:
raise HTTPException(status_code=404, detail="Session or Project not found")
if session.project.user_id != user_id:
log.warning(f"Security Alert: User {user_id} tried to access session {session_id} belonging to user {session.project.user_id}")
raise HTTPException(status_code=403, detail="Permission denied: You do not own this session")
raise HTTPException(status_code=403, detail="Permission denied")
return session.project_id
async def call_ai_agent(ddl_text: str, question: str, model_key: str = None) -> str:
"""调用 AI 接口生成 SQL"""
# 1. 确定配置
"""调用 AI 生成 SQL"""
if not model_key or model_key not in MODEL_REGISTRY:
model_key = DEFAULT_MODEL
config = MODEL_REGISTRY[model_key]
log.info(f"Using AI Model: {config['name']} ({config['model_id']})")
messages = _build_ai_messages(config, ddl_text, question)
# 2. 构建 Payload
payload = {
"model": config["model_id"],
"messages": messages,
"temperature": 0.1,
"stream": False,
"max_tokens": 512,
"stop": ["<|im_end|>", "<|im_start|>", "User:", "Assistant:"]
}
headers = {
"Authorization": f"Bearer {config['api_key']}",
"Content-Type": "application/json"
}
payload = {"model": config["model_id"], "messages": messages, "temperature": 0.1, "stream": False}
headers = {"Authorization": f"Bearer {config['api_key']}", "Content-Type": "application/json"}
try:
async with httpx.AsyncClient(timeout=300.0) as client:
# 【保留 UTF-8 修复】这很重要,防止中文乱码
resp = await client.post(
config["api_url"],
content=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
headers=headers
)
resp = await client.post(config["api_url"], content=json.dumps(payload, ensure_ascii=False).encode("utf-8"), headers=headers)
resp.raise_for_status()
raw = resp.json()
content = ""
if "choices" in raw and len(raw["choices"]) > 0:
content = raw["choices"][0]["message"]["content"]
# 清洗数据
# 1. 如果模型输出了停止符,只取前面的部分
if "<|im_end|>" in content:
content = content.split("<|im_end|>")[0]
if "<|im_start|>" in content:
content = content.split("<|im_start|>")[0]
# 2. 有时候模型会把 SQL 写在 Markdown 块里,先去 Markdown
clean_sql = content.strip().replace("```sql", "").replace("```", "").strip()
# 3. 如果还是有多行,且第一行就是完整的 SQL (以分号结尾),就只取第一行
# 防止它在 SQL 后面通过换行继续自言自语
if ";\n" in clean_sql:
clean_sql = clean_sql.split(";\n")[0] + ";"
elif clean_sql.count(";") > 1:
# 如果有多条 SQL只取第一条
clean_sql = clean_sql.split(";")[0] + ";"
return clean_sql
content = resp.json()["choices"][0]["message"]["content"]
return content.strip().replace("```sql", "").replace("```", "").strip()
except Exception as e:
log.error(f"AI Call Error ({model_key}): {e}")
log.error(f"AI Call Error: {e}")
return f"-- AI Service Error: {str(e)}"
# =========================================================
# 3. 核心业务逻辑 (融合版:含模型记忆 + 安全刹车)
# 3. 核心业务逻辑
# =========================================================
async def process_chat(
@ -224,214 +171,103 @@ async def process_chat(
user_id: int,
selected_model: str = None
) -> ChatResponse:
"""
处理用户聊天请求的主流程
"""
# 1. 验证会话权限
project_id = await _verify_session_ownership(db, session_id, user_id)
# 2. 获取 Session 对象以处理模型记忆逻辑
stmt = select(SessionModel).where(SessionModel.session_id == session_id)
result = await db.execute(stmt)
session_obj = result.scalar_one_or_none()
if not session_obj:
raise HTTPException(status_code=404, detail="Session lost")
# =========================================================
# 模型选择优先级策略
# =========================================================
final_model_key = DEFAULT_MODEL # 兜底
if selected_model:
# A. 如果用户本次明确指定了模型 -> 使用它,并更新到数据库(记忆)
final_model_key = selected_model
if session_obj.current_model != selected_model:
session_obj.current_model = selected_model
db.add(session_obj)
await db.commit() # 保存记忆
log.info(f"Session {session_id} model switched to: {selected_model}")
elif session_obj.current_model:
# B. 如果用户没指定,但数据库里有记忆 -> 使用记忆的模型
final_model_key = session_obj.current_model
log.info(f"Session {session_id} using stored model: {final_model_key}")
else:
session_obj.current_model = DEFAULT_MODEL
session_obj.current_model = selected_model
db.add(session_obj)
await db.commit()
final_model_key = session_obj.current_model or DEFAULT_MODEL
# 3. 存用户消息
await crud_message.create_message(db, session_id, user_input, role="user")
# 4. 获取 DDL 上下文
# 直接读取 project.ddl_statement不再使用 schema_definition 进行转换
project = await crud_project.get(db, project_id)
ddl_text = ""
if project and project.ddl_statement:
ddl_text = project.ddl_statement
log.info(f"Using DDL for project {project_id} (Length: {len(ddl_text)})")
else:
# 虽然假设 ddl_statement 一定不为空,但为了稳健性保留一个 Warning
log.warning(f"Project {project_id} has empty ddl_statement!")
ddl_text = "-- Error: No DDL found for this project."
# 5. 调用 AI 生成 SQL
ddl_text = project.ddl_statement if project and project.ddl_statement else "-- Error: No DDL"
sql_text = await call_ai_agent(ddl_text, user_input, model_key=final_model_key)
# 6. 解析 SQL 类型
sql_type = "UNKNOWN"
try:
if sql_text and not sql_text.startswith("--"):
parsed = sqlparse.parse(sql_text)
if parsed:
sql_type = parsed[0].get_type().upper()
except Exception:
pass
parsed = sqlparse.parse(sql_text)
if parsed: sql_type = parsed[0].get_type().upper()
except: pass
# 【关键融合 2】判断是否需要确认
requires_confirm = sql_type in ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE"]
# 7. 存 AI 回复消息
reply_content = f"已生成查询语句:\n{sql_text}"
ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant")
# 【关键融合 3】如果需要确认必须把状态写回数据库
# 同学的代码里漏了这一步,导致确认时会报 400
if requires_confirm:
ai_message.requires_confirmation = True
db.add(ai_message)
await db.commit()
await db.refresh(ai_message)
# 8. 尝试执行 SQL (带刹车)
data = []
# 【关键融合 4】只有不需要确认的操作才立即执行
# 同学的代码里直接执行了,这很危险
data = []
if not requires_confirm:
try:
database_instance = await crud_database_instance.get(db, project.instance_id)
# result 可能是一个列表(SELECT) 或 一个字典(INSERT/UPDATE)
raw_result = await execute_sql_with_user_check(sql_text, sql_type, database_instance)
data = raw_result if isinstance(raw_result, list) else ([raw_result] if raw_result else [])
# 【同学的优化】统一数据格式为 List[Dict]
if isinstance(raw_result, dict):
# 如果是 DML 返回的字典,包裹成列表
data = [raw_result]
elif isinstance(raw_result, list):
# 如果是 DQL 返回的列表,直接使用
data = raw_result
else:
data = []
# 【新增】非确认操作(如 SELECT立即持久化结果快照
if data:
await _save_query_execution_result(db, ai_message.message_id, sql_text, sql_type, data)
await db.commit()
except Exception as e:
log.info(f"SQL Execution Error (Safe to ignore if SQL is invalid): {str(e)}")
else:
# 如果需要确认,跳过执行
log.info(f"SQL requires confirmation ({sql_type}), skipping immediate execution.")
log.info(f"Execution Error: {str(e)}")
# 9. 构造响应
return ChatResponse(
message_id=ai_message.message_id,
content=reply_content,
message_type=MessageType.ASSISTANT,
sql_text=sql_text,
sql_type=sql_type,
requires_confirmation=requires_confirm,
data=data
message_id=ai_message.message_id, content=reply_content,
message_type=MessageType.ASSISTANT, sql_text=sql_text,
sql_type=sql_type, requires_confirmation=requires_confirm, data=data
)
# =========================================================
# 4. 执行确认逻辑 (同学代码里缺失,这里必须补上)
# =========================================================
async def confirm_and_execute_sql(
db: AsyncSession,
message_id: int,
user_id: int
) -> ChatResponse:
"""
用户确认执行某条消息中的 SQL (通常是增删改操作)
修复版使用分步查询法解决 AttributeError
"""
# 1. 第一步:查消息本身
"""用户确认执行某条消息中的 SQL"""
# 1. 查找消息、会话及项目
stmt = select(MessageModel).where(MessageModel.message_id == message_id)
result = await db.execute(stmt)
message = result.scalar_one_or_none()
if not message:
raise HTTPException(status_code=404, detail="Message not found")
# 2. 第二步:查所属 Session
stmt_session = select(SessionModel).where(SessionModel.session_id == message.session_id)
result_session = await db.execute(stmt_session)
session_obj = result_session.scalar_one_or_none()
if not session_obj:
raise HTTPException(status_code=404, detail="Session not found")
# 3. 第三步:查所属 Project
stmt_project = select(ProjectModel).where(ProjectModel.project_id == session_obj.project_id)
result_project = await db.execute(stmt_project)
project = result_project.scalar_one_or_none()
session_obj = (await db.execute(stmt_session)).scalar_one_or_none()
project = await crud_project.get(db, session_obj.project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# --- 校验逻辑 ---
if project.user_id != user_id:
raise HTTPException(status_code=403, detail="Access denied")
if not message.requires_confirmation or message.user_confirmed:
raise HTTPException(status_code=400, detail="Invalid confirmation request")
if not message.requires_confirmation:
raise HTTPException(status_code=400, detail="This message does not require confirmation")
if message.user_confirmed:
raise HTTPException(status_code=400, detail="This operation has already been confirmed/executed")
# 4. 提取 SQL 语句
content = message.content
sql_text = ""
if "\n" in content:
sql_text = content.split("\n")[-1].strip()
else:
sql_text = content.strip()
# 2. 提取并清理 SQL
sql_text = message.content.split("\n")[-1].strip() if "\n" in message.content else message.content.strip()
sql_text = sql_text.replace("```sql", "").replace("```", "").strip()
# 5. 执行 SQL (DML)
execute_res = []
try:
# 3. 执行 DML
database_instance = await crud_database_instance.get(db, project.instance_id)
raw_result = await execute_sql_with_user_check(sql_text, "DML", database_instance)
execute_res = raw_result if isinstance(raw_result, list) else ([raw_result] if raw_result else [])
# 真正执行
result = await execute_sql_with_user_check(sql_text, "UPDATE", database_instance)
# 📦 统一包装成列表
if isinstance(result, dict):
execute_res = [result]
elif isinstance(result, list):
execute_res = result
else:
execute_res = []
# 【新增】确认执行后,将变更影响的快照持久化
await _save_query_execution_result(db, message.message_id, sql_text, "DML_EXECUTED", execute_res)
# 更新消息状态
message.user_confirmed = True
db.add(message)
await db.commit()
log.info(f"User {user_id} confirmed execution of message {message_id}")
except Exception as e:
log.error(f"Execution failed: {e}")
raise HTTPException(status_code=500, detail=f"Execution failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# 6. 返回结果
return ChatResponse(
message_id=message.message_id,
content=message.content,
message_type=MessageType.ASSISTANT,
sql_text=sql_text,
sql_type="DML_EXECUTED",
requires_confirmation=False,
data=execute_res
message_id=message.message_id, content=message.content,
message_type=MessageType.ASSISTANT, sql_text=sql_text,
sql_type="DML_EXECUTED", requires_confirmation=False, data=execute_res
)

@ -123,39 +123,21 @@ async def get_history_queries_service(
project_id: int,
user_id: int,
) -> List[schemas.HistoryQuery]:
await _verify_project_ownership(db, project_id, user_id)
"""
获取项目历史查询记录返回用户原始提问与结果
Args:
db (Session): 数据库会话
project_id (int): 项目 ID
Returns:
List[schemas.HistoryQuery]: 历史查询记录列表
Raises:
HTTPException: 项目不存在时返回 404
获取项目历史查询记录仅限已持久化的 SELECT 结果
"""
# [修复问题1]:先检查项目是否存在
project = await db.get(Project, project_id)
if not project:
raise HTTPException(status_code=404, detail=f"Project with ID {project_id} not found")
await _verify_project_ownership(db, project_id, user_id)
# [修复问题2]:局部导入以避免 UnboundLocalError 和循环依赖
from models.message import Message
from models.session import Session as SessionModel
# 构造查询Query Result -> Statement -> Message -> Session
# 我们需要 Message.content (用户问题) 和 QueryResult (数据)
# 路径QueryResult -> Statement -> Message (Assistant)
stmt = (
select(QueryResult, Message)
.join(AIGeneratedStatement, QueryResult.statement_id == AIGeneratedStatement.statement_id)
.join(Message, AIGeneratedStatement.message_id == Message.message_id)
.join(SessionModel, Message.session_id == SessionModel.session_id)
.where(SessionModel.project_id == project_id)
.where(Message.message_type == 'assistant') # 确保我们取的是用户发的消息(提问)
.order_by(QueryResult.cached_at.desc())
)
@ -164,13 +146,13 @@ async def get_history_queries_service(
history = []
for query_res, msg_obj in rows:
# 这里返回的是 QueryResult 的 result_id前端将用这个 ID 创建报表
history.append(schemas.HistoryQuery(
id=str(query_res.result_id),
projectId=str(project_id),
# 这里取的是 message.content即用户的原始提问
queryText=msg_obj.content,
queryText=msg_obj.content, # 显示 AI 的描述或 SQL
timestamp=query_res.cached_at.isoformat() if query_res.cached_at else 'N/A',
result=query_res.result_data
result=query_res.result_data # 包含完整的 JSON 结果集
))
return history

Loading…
Cancel
Save