|
|
|
|
@ -7,6 +7,7 @@
|
|
|
|
|
- 提示词工程 (Prompt Engineering)
|
|
|
|
|
- 会话所有权校验
|
|
|
|
|
- AI 响应的解析与格式化
|
|
|
|
|
- 会话模型记忆逻辑
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# backend/app/service/chat_service.py
|
|
|
|
|
@ -35,36 +36,37 @@ from service.mysql_service import execute_sql_with_user_check
|
|
|
|
|
# =========================================================
|
|
|
|
|
# 1. 模型配置注册表
|
|
|
|
|
# =========================================================
|
|
|
|
|
# 建议:长期来看,这些配置也可以移入数据库或 YAML,但目前作为常量定义在 Service 层是可以接受的
|
|
|
|
|
# 重点是 API Key 必须从 settings 读取
|
|
|
|
|
|
|
|
|
|
MODEL_REGISTRY = {
|
|
|
|
|
"my-finetuned-sql": {
|
|
|
|
|
"name": "My Fine-Tuned SQL Model",
|
|
|
|
|
"api_url": "http://26.64.77.145:1234/v1/chat/completions", # 这里的IP如果是固定的可以留着,如果是变动的建议放config
|
|
|
|
|
# 1. 填入云服务器地址 (保留 /v1/chat/completions)
|
|
|
|
|
"api_url": "http://1.92.127.206:8080/v1/chat/completions",
|
|
|
|
|
"model_id": "codellama/CodeLlama-13b-Instruct-hf",
|
|
|
|
|
"api_key": "dummy-key", # 本地模型通常不需要 Key
|
|
|
|
|
# 2. 填入真实密钥
|
|
|
|
|
"api_key": "sk-2025texttosql",
|
|
|
|
|
"type": "local_finetune"
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
"xiyan-sql": {
|
|
|
|
|
"name": "XiYan-SQL (QwenCoder-32B)",
|
|
|
|
|
"api_url": "https://api-inference.modelscope.cn/v1/chat/completions",
|
|
|
|
|
"model_id": "XGenerationLab/XiYanSQL-QwenCoder-32B-2504",
|
|
|
|
|
"api_key": settings.ai.modelscope_api_key, # <--- 从配置读取
|
|
|
|
|
"api_key": settings.ai.modelscope_api_key,
|
|
|
|
|
"type": "general_llm"
|
|
|
|
|
},
|
|
|
|
|
"qwen-coder-32b": {
|
|
|
|
|
"name": "Qwen2.5-Coder-32B",
|
|
|
|
|
"api_url": "https://api-inference.modelscope.cn/v1/chat/completions",
|
|
|
|
|
"model_id": "Qwen/Qwen2.5-Coder-32B-Instruct",
|
|
|
|
|
"api_key": settings.ai.modelscope_api_key, # <--- 从配置读取
|
|
|
|
|
"api_key": settings.ai.modelscope_api_key,
|
|
|
|
|
"type": "general_llm"
|
|
|
|
|
},
|
|
|
|
|
"deepseek-v3": {
|
|
|
|
|
"name": "DeepSeek V3.1",
|
|
|
|
|
"api_url": "https://api-inference.modelscope.cn/v1/chat/completions",
|
|
|
|
|
"model_id": "deepseek-ai/DeepSeek-V3.1",
|
|
|
|
|
"api_key": settings.ai.modelscope_api_key, # <--- 从配置读取
|
|
|
|
|
"api_key": settings.ai.modelscope_api_key,
|
|
|
|
|
"type": "general_llm"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@ -72,43 +74,27 @@ MODEL_REGISTRY = {
|
|
|
|
|
DEFAULT_MODEL = "my-finetuned-sql"
|
|
|
|
|
|
|
|
|
|
# =========================================================
|
|
|
|
|
# 2. 辅助函数 (逻辑拆分)
|
|
|
|
|
# 2. 辅助函数
|
|
|
|
|
# =========================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _format_schema_to_text(schema_data: Any) -> str:
|
|
|
|
|
"""
|
|
|
|
|
将 Schema JSON 转换为模型易读的文本格式。
|
|
|
|
|
|
|
|
|
|
兼容 List 和 Dict 两种结构。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
schema_data (Any): 原始 Schema 数据 (List or Dict)。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
str: 格式化后的 Schema 文本描述。
|
|
|
|
|
"""
|
|
|
|
|
"""将 Schema JSON 转换为模型易读的文本格式 (兼容 List/Dict)"""
|
|
|
|
|
if not schema_data:
|
|
|
|
|
return ""
|
|
|
|
|
try:
|
|
|
|
|
# 1. 如果是字符串,先转成对象
|
|
|
|
|
if isinstance(schema_data, str):
|
|
|
|
|
schema_data = json.loads(schema_data)
|
|
|
|
|
|
|
|
|
|
# 2. 如果是字典且包含 'tables' 键,提取出列表
|
|
|
|
|
if isinstance(schema_data, dict) and "tables" in schema_data:
|
|
|
|
|
schema_data = schema_data["tables"]
|
|
|
|
|
|
|
|
|
|
# 3. 现在的 schema_data 应该是一个列表了,开始遍历
|
|
|
|
|
lines = []
|
|
|
|
|
if isinstance(schema_data, list):
|
|
|
|
|
for table in schema_data:
|
|
|
|
|
# 兼容 table 可能是 dict 或者 object 的情况
|
|
|
|
|
if isinstance(table, dict):
|
|
|
|
|
t_name = table.get("table_name", "unknown")
|
|
|
|
|
cols = table.get("columns", [])
|
|
|
|
|
else:
|
|
|
|
|
# 万一数据很怪,做一个容错
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if isinstance(cols, str):
|
|
|
|
|
@ -116,31 +102,16 @@ def _format_schema_to_text(schema_data: Any) -> str:
|
|
|
|
|
col_str = ", ".join(str(c) for c in cols)
|
|
|
|
|
lines.append(f"Table: {t_name}, columns = [{col_str}]")
|
|
|
|
|
else:
|
|
|
|
|
# 如果结构实在太乱,直接转字符串兜底
|
|
|
|
|
return str(schema_data)
|
|
|
|
|
|
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Schema format error: {e}")
|
|
|
|
|
# 出错了也不要崩,把原始数据给 AI,看它能不能看懂
|
|
|
|
|
return str(schema_data)
|
|
|
|
|
|
|
|
|
|
def _build_ai_messages(model_config: Dict, schema_text: str, question: str) -> List[Dict]:
|
|
|
|
|
"""
|
|
|
|
|
根据模型类型构建对应的 Prompt 策略。
|
|
|
|
|
|
|
|
|
|
解决“函数过长”问题,将 Prompt 逻辑抽离。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model_config (Dict): 模型配置字典。
|
|
|
|
|
schema_text (str): Schema 文本描述。
|
|
|
|
|
question (str): 用户问题。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List[Dict]: 构建好的消息列表 (role/content)。
|
|
|
|
|
"""
|
|
|
|
|
"""构建 Prompt 策略"""
|
|
|
|
|
if model_config["type"] == "local_finetune":
|
|
|
|
|
# 微调模型 (严格格式)
|
|
|
|
|
prompt_content = f"""I want you to act as a SQL terminal in front of an database.
|
|
|
|
|
Here is the schema:
|
|
|
|
|
{schema_text}
|
|
|
|
|
@ -155,7 +126,6 @@ I want you to answer the following question.
|
|
|
|
|
{"role": "user", "content": prompt_content}
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
# 通用大模型 (思维链与规则引导)
|
|
|
|
|
system_prompt = f"""You are a generic SQL expert.
|
|
|
|
|
Your task is to generate valid SQL queries based on the provided database schema and user question.
|
|
|
|
|
|
|
|
|
|
@ -175,24 +145,10 @@ Your task is to generate valid SQL queries based on the provided database schema
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
async def _verify_session_ownership(db: AsyncSession, session_id: int, user_id: int) -> int:
|
|
|
|
|
"""
|
|
|
|
|
验证会话所有权,防止越权访问 (IDOR)。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
db (AsyncSession): 数据库会话。
|
|
|
|
|
session_id (int): 会话 ID。
|
|
|
|
|
user_id (int): 用户 ID。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
int: 关联的项目 ID (project_id)。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
HTTPException: 会话不存在、项目不存在或无权限时抛出。
|
|
|
|
|
"""
|
|
|
|
|
# 联表查询:Session -> Project,检查 Project.user_id 是否匹配
|
|
|
|
|
"""验证会话所有权,防止越权"""
|
|
|
|
|
stmt = (
|
|
|
|
|
select(SessionModel)
|
|
|
|
|
.options(selectinload(SessionModel.project)) # 预加载 Project 避免 N+1
|
|
|
|
|
.options(selectinload(SessionModel.project))
|
|
|
|
|
.where(SessionModel.session_id == session_id)
|
|
|
|
|
)
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
@ -204,29 +160,15 @@ async def _verify_session_ownership(db: AsyncSession, session_id: int, user_id:
|
|
|
|
|
if not session.project:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Project not found for this session")
|
|
|
|
|
|
|
|
|
|
# 【关键安全检查】
|
|
|
|
|
if session.project.user_id != user_id:
|
|
|
|
|
log.warning(f"Security Alert: User {user_id} tried to access session {session_id} belonging to user {session.project.user_id}")
|
|
|
|
|
raise HTTPException(status_code=403, detail="Permission denied: You do not own this session")
|
|
|
|
|
|
|
|
|
|
return session.project_id
|
|
|
|
|
|
|
|
|
|
# =========================================================
|
|
|
|
|
# 3. 核心业务逻辑
|
|
|
|
|
# =========================================================
|
|
|
|
|
|
|
|
|
|
async def call_ai_agent(schema_text: str, question: str, model_key: str = None) -> str:
|
|
|
|
|
"""
|
|
|
|
|
调用 AI 接口生成 SQL。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
schema_text (str): 数据库 Schema 描述。
|
|
|
|
|
question (str): 用户问题。
|
|
|
|
|
model_key (str, optional): 模型标识 Key。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
str: 生成的 SQL 语句。
|
|
|
|
|
"""
|
|
|
|
|
"""调用 AI 接口生成 SQL"""
|
|
|
|
|
|
|
|
|
|
# 1. 确定配置
|
|
|
|
|
if not model_key or model_key not in MODEL_REGISTRY:
|
|
|
|
|
model_key = DEFAULT_MODEL
|
|
|
|
|
@ -234,47 +176,64 @@ async def call_ai_agent(schema_text: str, question: str, model_key: str = None)
|
|
|
|
|
config = MODEL_REGISTRY[model_key]
|
|
|
|
|
log.info(f"Using AI Model: {config['name']} ({config['model_id']})")
|
|
|
|
|
|
|
|
|
|
# 2. 构建 Prompt
|
|
|
|
|
messages = _build_ai_messages(config, schema_text, question)
|
|
|
|
|
|
|
|
|
|
# 3. 构建请求 Payload
|
|
|
|
|
# 2. 构建 Payload
|
|
|
|
|
payload = {
|
|
|
|
|
"model": config["model_id"],
|
|
|
|
|
"messages": messages,
|
|
|
|
|
"temperature": 0.1,
|
|
|
|
|
"stream": False
|
|
|
|
|
"stream": False,
|
|
|
|
|
"max_tokens": 512,
|
|
|
|
|
# 【新增】告诉模型看到这些符号就闭嘴
|
|
|
|
|
"stop": ["<|im_end|>", "<|im_start|>", "User:", "Assistant:"]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if config["type"] == "general_llm":
|
|
|
|
|
payload["max_tokens"] = 1024
|
|
|
|
|
|
|
|
|
|
headers = {
|
|
|
|
|
"Authorization": f"Bearer {config['api_key']}",
|
|
|
|
|
"Content-Type": "application/json"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# 4. 执行网络请求
|
|
|
|
|
try:
|
|
|
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
|
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
|
|
|
resp = await client.post(config["api_url"], json=payload, headers=headers)
|
|
|
|
|
|
|
|
|
|
resp.raise_for_status()
|
|
|
|
|
raw = resp.json()
|
|
|
|
|
|
|
|
|
|
# 解析响应
|
|
|
|
|
content = ""
|
|
|
|
|
if "choices" in raw and len(raw["choices"]) > 0:
|
|
|
|
|
content = raw["choices"][0]["message"]["content"]
|
|
|
|
|
|
|
|
|
|
# 清理结果
|
|
|
|
|
# =====================================================
|
|
|
|
|
# 【关键修复】清洗数据,截断废话
|
|
|
|
|
# =====================================================
|
|
|
|
|
# 1. 如果模型输出了停止符,只取前面的部分
|
|
|
|
|
if "<|im_end|>" in content:
|
|
|
|
|
content = content.split("<|im_end|>")[0]
|
|
|
|
|
if "<|im_start|>" in content:
|
|
|
|
|
content = content.split("<|im_start|>")[0]
|
|
|
|
|
|
|
|
|
|
# 2. 有时候模型会把 SQL 写在 Markdown 块里,先去 Markdown
|
|
|
|
|
clean_sql = content.strip().replace("```sql", "").replace("```", "").strip()
|
|
|
|
|
|
|
|
|
|
# 3. 如果还是有多行,且第一行就是完整的 SQL (以分号结尾),就只取第一行
|
|
|
|
|
# 防止它在 SQL 后面通过换行继续自言自语
|
|
|
|
|
if ";\n" in clean_sql:
|
|
|
|
|
clean_sql = clean_sql.split(";\n")[0] + ";"
|
|
|
|
|
elif clean_sql.count(";") > 1:
|
|
|
|
|
# 如果有多条 SQL,只取第一条
|
|
|
|
|
clean_sql = clean_sql.split(";")[0] + ";"
|
|
|
|
|
|
|
|
|
|
return clean_sql
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"AI Call Error ({model_key}): {e}")
|
|
|
|
|
# 这里返回错误字符串是可以的,让用户知道 AI 挂了,而不是整个页面崩溃
|
|
|
|
|
return f"-- AI Service Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
# =========================================================
|
|
|
|
|
# 3. 核心业务逻辑 (含模型记忆)
|
|
|
|
|
# =========================================================
|
|
|
|
|
|
|
|
|
|
async def process_chat(
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
@ -285,34 +244,59 @@ async def process_chat(
|
|
|
|
|
) -> ChatResponse:
|
|
|
|
|
"""
|
|
|
|
|
处理用户聊天请求的主流程。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
db (AsyncSession): 数据库会话。
|
|
|
|
|
session_id (int): 会话 ID。
|
|
|
|
|
user_input (str): 用户输入。
|
|
|
|
|
user_id (int): 用户 ID。
|
|
|
|
|
selected_model (str, optional): 选择的模型。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
ChatResponse: 聊天响应对象。
|
|
|
|
|
"""
|
|
|
|
|
# 1. 安全检查:确认会话属于当前用户,并获取 project_id
|
|
|
|
|
# 1. 验证会话权限
|
|
|
|
|
project_id = await _verify_session_ownership(db, session_id, user_id)
|
|
|
|
|
|
|
|
|
|
# 2. 存用户消息 (先存库,保证有记录)
|
|
|
|
|
# 2. 获取 Session 对象以处理模型记忆逻辑
|
|
|
|
|
stmt = select(SessionModel).where(SessionModel.session_id == session_id)
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
session_obj = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if not session_obj:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Session lost")
|
|
|
|
|
|
|
|
|
|
# =========================================================
|
|
|
|
|
# 【核心逻辑优化】模型选择优先级策略
|
|
|
|
|
# =========================================================
|
|
|
|
|
final_model_key = DEFAULT_MODEL # 兜底
|
|
|
|
|
|
|
|
|
|
if selected_model:
|
|
|
|
|
# A. 如果用户本次明确指定了模型 -> 使用它,并更新到数据库(记忆)
|
|
|
|
|
final_model_key = selected_model
|
|
|
|
|
if session_obj.current_model != selected_model:
|
|
|
|
|
session_obj.current_model = selected_model
|
|
|
|
|
db.add(session_obj)
|
|
|
|
|
await db.commit() # 保存记忆
|
|
|
|
|
log.info(f"Session {session_id} model switched to: {selected_model}")
|
|
|
|
|
|
|
|
|
|
elif session_obj.current_model:
|
|
|
|
|
# B. 如果用户没指定,但数据库里有记忆 -> 使用记忆的模型
|
|
|
|
|
final_model_key = session_obj.current_model
|
|
|
|
|
log.info(f"Session {session_id} using stored model: {final_model_key}")
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
# C. 既没指定也没记忆 -> 使用系统默认,并保存到数据库作为初始记忆
|
|
|
|
|
final_model_key = DEFAULT_MODEL
|
|
|
|
|
session_obj.current_model = DEFAULT_MODEL
|
|
|
|
|
db.add(session_obj)
|
|
|
|
|
await db.commit()
|
|
|
|
|
# =========================================================
|
|
|
|
|
|
|
|
|
|
# 3. 存用户消息
|
|
|
|
|
await crud_message.create_message(db, session_id, user_input, role="user")
|
|
|
|
|
|
|
|
|
|
# 3. 获取 Schema 上下文
|
|
|
|
|
# 4. 获取 Schema 上下文
|
|
|
|
|
project = await crud_project.get(db, project_id)
|
|
|
|
|
if not project or not project.schema_definition:
|
|
|
|
|
schema_text = "No schema defined."
|
|
|
|
|
else:
|
|
|
|
|
schema_text = _format_schema_to_text(project.schema_definition)
|
|
|
|
|
|
|
|
|
|
# 4. 调用 AI 生成 SQL
|
|
|
|
|
sql_text = await call_ai_agent(schema_text, user_input, model_key=selected_model)
|
|
|
|
|
# 5. 调用 AI 生成 SQL (使用记忆或指定的模型)
|
|
|
|
|
sql_text = await call_ai_agent(schema_text, user_input, model_key=final_model_key)
|
|
|
|
|
|
|
|
|
|
# 5. 简单解析 SQL 类型
|
|
|
|
|
# 6. 简单解析 SQL 类型
|
|
|
|
|
sql_type = "UNKNOWN"
|
|
|
|
|
try:
|
|
|
|
|
if sql_text and not sql_text.startswith("--"):
|
|
|
|
|
@ -322,20 +306,20 @@ async def process_chat(
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# 6. 存 AI 回复消息
|
|
|
|
|
# 7. 存 AI 回复消息
|
|
|
|
|
reply_content = f"已生成查询语句:\n{sql_text}"
|
|
|
|
|
ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant")
|
|
|
|
|
|
|
|
|
|
# 8. 尝试执行 SQL
|
|
|
|
|
data = []
|
|
|
|
|
# 7. 尝试执行 SQL
|
|
|
|
|
try:
|
|
|
|
|
database_instance = await crud_database_instance.get(db, project.instance_id)
|
|
|
|
|
|
|
|
|
|
# 注意:这里调用的是 execute_sql_with_user_check,它会检查 SQL 是否安全
|
|
|
|
|
data = await execute_sql_with_user_check(sql_text, sql_type, database_instance)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.info(f"SQL Execution Error: {str(e)}", exc_info=True)
|
|
|
|
|
log.info(f"SQL Execution Error (Safe to ignore if SQL is invalid): {str(e)}")
|
|
|
|
|
|
|
|
|
|
# 8. 构造响应
|
|
|
|
|
# 9. 构造响应
|
|
|
|
|
return ChatResponse(
|
|
|
|
|
message_id=ai_message.message_id,
|
|
|
|
|
content=reply_content,
|
|
|
|
|
|