feat: 增加AI模型选择功能与会话记忆,对接云端模型 #31

Merged
hnu202326010328 merged 1 commits from wanglirong_branch into develop 2 months ago

@ -10,6 +10,7 @@ from sqlalchemy import Column, String, CheckConstraint, Index, Integer, ForeignK
from sqlalchemy.sql import func
from core.database import Base
from sqlalchemy.orm import relationship
class Session(Base):
"""
会话表 ORM 模型
@ -20,6 +21,7 @@ class Session(Base):
session_id (int): 会话ID
project_id (int): 关联的项目ID
session_name (str): 会话名称
current_model (str): 当前会话偏好的AI模型ID (新增)
created_at (datetime): 创建时间
last_activity (datetime): 最后活动时间
"""
@ -43,14 +45,20 @@ class Session(Base):
nullable=False,
comment='关联的项目ID'
)
# 反向关联:让 Session 知道它属于哪个 Project
project = relationship("Project", back_populates="sessions")
session_name = Column(
String(100),
nullable=False,
default='New Session',
comment='会话名称'
)
# 【新增字段】记录当前会话使用的模型
current_model = Column(
String(50),
nullable=True,
comment="当前会话偏好的AI模型ID"
)
created_at = Column(
DateTime(timezone=True),
server_default=func.now(),
@ -63,5 +71,12 @@ class Session(Base):
nullable=True,
comment='最后活动时间'
)
# 关联关系
# 反向关联:让 Session 知道它属于哪个 Project
project = relationship("Project", back_populates="sessions")
# messages = relationship("Message", back_populates="session", cascade="all, delete-orphan") # 如果你有 Message 模型的话
class Config:
from_attributes = True

@ -7,6 +7,7 @@
- 提示词工程 (Prompt Engineering)
- 会话所有权校验
- AI 响应的解析与格式化
- 会话模型记忆逻辑
"""
# backend/app/service/chat_service.py
@ -35,36 +36,37 @@ from service.mysql_service import execute_sql_with_user_check
# =========================================================
# 1. 模型配置注册表
# =========================================================
# 建议:长期来看,这些配置也可以移入数据库或 YAML但目前作为常量定义在 Service 层是可以接受的
# 重点是 API Key 必须从 settings 读取
MODEL_REGISTRY = {
"my-finetuned-sql": {
"name": "My Fine-Tuned SQL Model",
"api_url": "http://26.64.77.145:1234/v1/chat/completions", # 这里的IP如果是固定的可以留着如果是变动的建议放config
# 1. 填入云服务器地址 (保留 /v1/chat/completions)
"api_url": "http://1.92.127.206:8080/v1/chat/completions",
"model_id": "codellama/CodeLlama-13b-Instruct-hf",
"api_key": "dummy-key", # 本地模型通常不需要 Key
# 2. 填入真实密钥
"api_key": "sk-2025texttosql",
"type": "local_finetune"
},
"xiyan-sql": {
"name": "XiYan-SQL (QwenCoder-32B)",
"api_url": "https://api-inference.modelscope.cn/v1/chat/completions",
"model_id": "XGenerationLab/XiYanSQL-QwenCoder-32B-2504",
"api_key": settings.ai.modelscope_api_key, # <--- 从配置读取
"api_key": settings.ai.modelscope_api_key,
"type": "general_llm"
},
"qwen-coder-32b": {
"name": "Qwen2.5-Coder-32B",
"api_url": "https://api-inference.modelscope.cn/v1/chat/completions",
"model_id": "Qwen/Qwen2.5-Coder-32B-Instruct",
"api_key": settings.ai.modelscope_api_key, # <--- 从配置读取
"api_key": settings.ai.modelscope_api_key,
"type": "general_llm"
},
"deepseek-v3": {
"name": "DeepSeek V3.1",
"api_url": "https://api-inference.modelscope.cn/v1/chat/completions",
"model_id": "deepseek-ai/DeepSeek-V3.1",
"api_key": settings.ai.modelscope_api_key, # <--- 从配置读取
"api_key": settings.ai.modelscope_api_key,
"type": "general_llm"
}
}
@ -72,43 +74,27 @@ MODEL_REGISTRY = {
DEFAULT_MODEL = "my-finetuned-sql"
# =========================================================
# 2. 辅助函数 (逻辑拆分)
# 2. 辅助函数
# =========================================================
def _format_schema_to_text(schema_data: Any) -> str:
"""
Schema JSON 转换为模型易读的文本格式
兼容 List Dict 两种结构
Args:
schema_data (Any): 原始 Schema 数据 (List or Dict)
Returns:
str: 格式化后的 Schema 文本描述
"""
"""将 Schema JSON 转换为模型易读的文本格式 (兼容 List/Dict)"""
if not schema_data:
return ""
try:
# 1. 如果是字符串,先转成对象
if isinstance(schema_data, str):
schema_data = json.loads(schema_data)
# 2. 如果是字典且包含 'tables' 键,提取出列表
if isinstance(schema_data, dict) and "tables" in schema_data:
schema_data = schema_data["tables"]
# 3. 现在的 schema_data 应该是一个列表了,开始遍历
lines = []
if isinstance(schema_data, list):
for table in schema_data:
# 兼容 table 可能是 dict 或者 object 的情况
if isinstance(table, dict):
t_name = table.get("table_name", "unknown")
cols = table.get("columns", [])
else:
# 万一数据很怪,做一个容错
continue
if isinstance(cols, str):
@ -116,31 +102,16 @@ def _format_schema_to_text(schema_data: Any) -> str:
col_str = ", ".join(str(c) for c in cols)
lines.append(f"Table: {t_name}, columns = [{col_str}]")
else:
# 如果结构实在太乱,直接转字符串兜底
return str(schema_data)
return "\n".join(lines)
except Exception as e:
log.error(f"Schema format error: {e}")
# 出错了也不要崩,把原始数据给 AI看它能不能看懂
return str(schema_data)
def _build_ai_messages(model_config: Dict, schema_text: str, question: str) -> List[Dict]:
"""
根据模型类型构建对应的 Prompt 策略
解决函数过长问题 Prompt 逻辑抽离
Args:
model_config (Dict): 模型配置字典
schema_text (str): Schema 文本描述
question (str): 用户问题
Returns:
List[Dict]: 构建好的消息列表 (role/content)
"""
"""构建 Prompt 策略"""
if model_config["type"] == "local_finetune":
# 微调模型 (严格格式)
prompt_content = f"""I want you to act as a SQL terminal in front of an database.
Here is the schema:
{schema_text}
@ -155,7 +126,6 @@ I want you to answer the following question.
{"role": "user", "content": prompt_content}
]
else:
# 通用大模型 (思维链与规则引导)
system_prompt = f"""You are a generic SQL expert.
Your task is to generate valid SQL queries based on the provided database schema and user question.
@ -175,24 +145,10 @@ Your task is to generate valid SQL queries based on the provided database schema
]
async def _verify_session_ownership(db: AsyncSession, session_id: int, user_id: int) -> int:
"""
验证会话所有权防止越权访问 (IDOR)
Args:
db (AsyncSession): 数据库会话
session_id (int): 会话 ID
user_id (int): 用户 ID
Returns:
int: 关联的项目 ID (project_id)
Raises:
HTTPException: 会话不存在项目不存在或无权限时抛出
"""
# 联表查询Session -> Project检查 Project.user_id 是否匹配
"""验证会话所有权,防止越权"""
stmt = (
select(SessionModel)
.options(selectinload(SessionModel.project)) # 预加载 Project 避免 N+1
.options(selectinload(SessionModel.project))
.where(SessionModel.session_id == session_id)
)
result = await db.execute(stmt)
@ -204,29 +160,15 @@ async def _verify_session_ownership(db: AsyncSession, session_id: int, user_id:
if not session.project:
raise HTTPException(status_code=404, detail="Project not found for this session")
# 【关键安全检查】
if session.project.user_id != user_id:
log.warning(f"Security Alert: User {user_id} tried to access session {session_id} belonging to user {session.project.user_id}")
raise HTTPException(status_code=403, detail="Permission denied: You do not own this session")
return session.project_id
# =========================================================
# 3. 核心业务逻辑
# =========================================================
async def call_ai_agent(schema_text: str, question: str, model_key: str = None) -> str:
"""
调用 AI 接口生成 SQL
Args:
schema_text (str): 数据库 Schema 描述
question (str): 用户问题
model_key (str, optional): 模型标识 Key
Returns:
str: 生成的 SQL 语句
"""
"""调用 AI 接口生成 SQL"""
# 1. 确定配置
if not model_key or model_key not in MODEL_REGISTRY:
model_key = DEFAULT_MODEL
@ -234,47 +176,64 @@ async def call_ai_agent(schema_text: str, question: str, model_key: str = None)
config = MODEL_REGISTRY[model_key]
log.info(f"Using AI Model: {config['name']} ({config['model_id']})")
# 2. 构建 Prompt
messages = _build_ai_messages(config, schema_text, question)
# 3. 构建请求 Payload
# 2. 构建 Payload
payload = {
"model": config["model_id"],
"messages": messages,
"temperature": 0.1,
"stream": False
"stream": False,
"max_tokens": 512,
# 【新增】告诉模型看到这些符号就闭嘴
"stop": ["<|im_end|>", "<|im_start|>", "User:", "Assistant:"]
}
if config["type"] == "general_llm":
payload["max_tokens"] = 1024
headers = {
"Authorization": f"Bearer {config['api_key']}",
"Content-Type": "application/json"
}
# 4. 执行网络请求
try:
async with httpx.AsyncClient(timeout=60) as client:
async with httpx.AsyncClient(timeout=300.0) as client:
resp = await client.post(config["api_url"], json=payload, headers=headers)
resp.raise_for_status()
raw = resp.json()
# 解析响应
content = ""
if "choices" in raw and len(raw["choices"]) > 0:
content = raw["choices"][0]["message"]["content"]
# 清理结果
# =====================================================
# 【关键修复】清洗数据,截断废话
# =====================================================
# 1. 如果模型输出了停止符,只取前面的部分
if "<|im_end|>" in content:
content = content.split("<|im_end|>")[0]
if "<|im_start|>" in content:
content = content.split("<|im_start|>")[0]
# 2. 有时候模型会把 SQL 写在 Markdown 块里,先去 Markdown
clean_sql = content.strip().replace("```sql", "").replace("```", "").strip()
# 3. 如果还是有多行,且第一行就是完整的 SQL (以分号结尾),就只取第一行
# 防止它在 SQL 后面通过换行继续自言自语
if ";\n" in clean_sql:
clean_sql = clean_sql.split(";\n")[0] + ";"
elif clean_sql.count(";") > 1:
# 如果有多条 SQL只取第一条
clean_sql = clean_sql.split(";")[0] + ";"
return clean_sql
except Exception as e:
log.error(f"AI Call Error ({model_key}): {e}")
# 这里返回错误字符串是可以的,让用户知道 AI 挂了,而不是整个页面崩溃
return f"-- AI Service Error: {str(e)}"
# =========================================================
# 3. 核心业务逻辑 (含模型记忆)
# =========================================================
async def process_chat(
db: AsyncSession,
@ -285,34 +244,59 @@ async def process_chat(
) -> ChatResponse:
"""
处理用户聊天请求的主流程
Args:
db (AsyncSession): 数据库会话
session_id (int): 会话 ID
user_input (str): 用户输入
user_id (int): 用户 ID
selected_model (str, optional): 选择的模型
Returns:
ChatResponse: 聊天响应对象
"""
# 1. 安全检查:确认会话属于当前用户,并获取 project_id
# 1. 验证会话权限
project_id = await _verify_session_ownership(db, session_id, user_id)
# 2. 存用户消息 (先存库,保证有记录)
# 2. 获取 Session 对象以处理模型记忆逻辑
stmt = select(SessionModel).where(SessionModel.session_id == session_id)
result = await db.execute(stmt)
session_obj = result.scalar_one_or_none()
if not session_obj:
raise HTTPException(status_code=404, detail="Session lost")
# =========================================================
# 【核心逻辑优化】模型选择优先级策略
# =========================================================
final_model_key = DEFAULT_MODEL # 兜底
if selected_model:
# A. 如果用户本次明确指定了模型 -> 使用它,并更新到数据库(记忆)
final_model_key = selected_model
if session_obj.current_model != selected_model:
session_obj.current_model = selected_model
db.add(session_obj)
await db.commit() # 保存记忆
log.info(f"Session {session_id} model switched to: {selected_model}")
elif session_obj.current_model:
# B. 如果用户没指定,但数据库里有记忆 -> 使用记忆的模型
final_model_key = session_obj.current_model
log.info(f"Session {session_id} using stored model: {final_model_key}")
else:
# C. 既没指定也没记忆 -> 使用系统默认,并保存到数据库作为初始记忆
final_model_key = DEFAULT_MODEL
session_obj.current_model = DEFAULT_MODEL
db.add(session_obj)
await db.commit()
# =========================================================
# 3. 存用户消息
await crud_message.create_message(db, session_id, user_input, role="user")
# 3. 获取 Schema 上下文
# 4. 获取 Schema 上下文
project = await crud_project.get(db, project_id)
if not project or not project.schema_definition:
schema_text = "No schema defined."
else:
schema_text = _format_schema_to_text(project.schema_definition)
# 4. 调用 AI 生成 SQL
sql_text = await call_ai_agent(schema_text, user_input, model_key=selected_model)
# 5. 调用 AI 生成 SQL (使用记忆或指定的模型)
sql_text = await call_ai_agent(schema_text, user_input, model_key=final_model_key)
# 5. 简单解析 SQL 类型
# 6. 简单解析 SQL 类型
sql_type = "UNKNOWN"
try:
if sql_text and not sql_text.startswith("--"):
@ -322,20 +306,20 @@ async def process_chat(
except Exception:
pass
# 6. 存 AI 回复消息
# 7. 存 AI 回复消息
reply_content = f"已生成查询语句:\n{sql_text}"
ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant")
# 8. 尝试执行 SQL
data = []
# 7. 尝试执行 SQL
try:
database_instance = await crud_database_instance.get(db, project.instance_id)
# 注意:这里调用的是 execute_sql_with_user_check它会检查 SQL 是否安全
data = await execute_sql_with_user_check(sql_text, sql_type, database_instance)
except Exception as e:
log.info(f"SQL Execution Error: {str(e)}", exc_info=True)
log.info(f"SQL Execution Error (Safe to ignore if SQL is invalid): {str(e)}")
# 8. 构造响应
# 9. 构造响应
return ChatResponse(
message_id=ai_message.message_id,
content=reply_content,

@ -104,7 +104,7 @@ dev:
# SSL
ssl: false
# plugin
plugin: sha256_password
plugin: mysql_native_password
# 数据库连接驱动
driver: mysql+aiomysql
# sqlalchemy连接池配置

Loading…
Cancel
Save