From d329fd77973f4cdd6c8f14c39c0550e76c2574c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=88=A9=E8=93=89?= <2655155213@qq.com> Date: Tue, 16 Dec 2025 15:03:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0AI=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E9=80=89=E6=8B=A9=E5=8A=9F=E8=83=BD=E4=B8=8E=E4=BC=9A=E8=AF=9D?= =?UTF-8?q?=E8=AE=B0=E5=BF=86=EF=BC=8C=E5=AF=B9=E6=8E=A5=E4=BA=91=E7=AB=AF?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/models/session.py | 19 ++- src/backend/app/service/chat_service.py | 196 +++++++++++------------- src/backend/config.yaml | 2 +- 3 files changed, 108 insertions(+), 109 deletions(-) diff --git a/src/backend/app/models/session.py b/src/backend/app/models/session.py index 89c2810..70a6258 100644 --- a/src/backend/app/models/session.py +++ b/src/backend/app/models/session.py @@ -10,6 +10,7 @@ from sqlalchemy import Column, String, CheckConstraint, Index, Integer, ForeignK from sqlalchemy.sql import func from core.database import Base from sqlalchemy.orm import relationship + class Session(Base): """ 会话表 ORM 模型。 @@ -20,6 +21,7 @@ class Session(Base): session_id (int): 会话ID。 project_id (int): 关联的项目ID。 session_name (str): 会话名称。 + current_model (str): 当前会话偏好的AI模型ID (新增)。 created_at (datetime): 创建时间。 last_activity (datetime): 最后活动时间。 """ @@ -43,14 +45,20 @@ class Session(Base): nullable=False, comment='关联的项目ID' ) - # 反向关联:让 Session 知道它属于哪个 Project - project = relationship("Project", back_populates="sessions") session_name = Column( String(100), nullable=False, default='New Session', comment='会话名称' ) + + # 【新增字段】记录当前会话使用的模型 + current_model = Column( + String(50), + nullable=True, + comment="当前会话偏好的AI模型ID" + ) + created_at = Column( DateTime(timezone=True), server_default=func.now(), @@ -63,5 +71,12 @@ class Session(Base): nullable=True, comment='最后活动时间' ) + + # 关联关系 + # 反向关联:让 Session 知道它属于哪个 Project + project = relationship("Project", back_populates="sessions") + + # messages = relationship("Message", back_populates="session", cascade="all, delete-orphan") # 如果你有 Message 模型的话 + class Config: from_attributes = True \ No newline at end of file diff --git a/src/backend/app/service/chat_service.py b/src/backend/app/service/chat_service.py index e43130f..e6cbc66 100644 --- a/src/backend/app/service/chat_service.py +++ b/src/backend/app/service/chat_service.py @@ -7,6 +7,7 @@ - 提示词工程 (Prompt Engineering) - 会话所有权校验 - AI 响应的解析与格式化 +- 会话模型记忆逻辑 """ # backend/app/service/chat_service.py @@ -35,36 +36,37 @@ from service.mysql_service import execute_sql_with_user_check # ========================================================= # 1. 模型配置注册表 # ========================================================= -# 建议:长期来看,这些配置也可以移入数据库或 YAML,但目前作为常量定义在 Service 层是可以接受的 -# 重点是 API Key 必须从 settings 读取 MODEL_REGISTRY = { "my-finetuned-sql": { "name": "My Fine-Tuned SQL Model", - "api_url": "http://26.64.77.145:1234/v1/chat/completions", # 这里的IP如果是固定的可以留着,如果是变动的建议放config + # 1. 填入云服务器地址 (保留 /v1/chat/completions) + "api_url": "http://1.92.127.206:8080/v1/chat/completions", "model_id": "codellama/CodeLlama-13b-Instruct-hf", - "api_key": "dummy-key", # 本地模型通常不需要 Key + # 2. 填入真实密钥 + "api_key": "sk-2025texttosql", "type": "local_finetune" }, + "xiyan-sql": { "name": "XiYan-SQL (QwenCoder-32B)", "api_url": "https://api-inference.modelscope.cn/v1/chat/completions", "model_id": "XGenerationLab/XiYanSQL-QwenCoder-32B-2504", - "api_key": settings.ai.modelscope_api_key, # <--- 从配置读取 + "api_key": settings.ai.modelscope_api_key, "type": "general_llm" }, "qwen-coder-32b": { "name": "Qwen2.5-Coder-32B", "api_url": "https://api-inference.modelscope.cn/v1/chat/completions", "model_id": "Qwen/Qwen2.5-Coder-32B-Instruct", - "api_key": settings.ai.modelscope_api_key, # <--- 从配置读取 + "api_key": settings.ai.modelscope_api_key, "type": "general_llm" }, "deepseek-v3": { "name": "DeepSeek V3.1", "api_url": "https://api-inference.modelscope.cn/v1/chat/completions", "model_id": "deepseek-ai/DeepSeek-V3.1", - "api_key": settings.ai.modelscope_api_key, # <--- 从配置读取 + "api_key": settings.ai.modelscope_api_key, "type": "general_llm" } } @@ -72,43 +74,27 @@ MODEL_REGISTRY = { DEFAULT_MODEL = "my-finetuned-sql" # ========================================================= -# 2. 辅助函数 (逻辑拆分) +# 2. 辅助函数 # ========================================================= - def _format_schema_to_text(schema_data: Any) -> str: - """ - 将 Schema JSON 转换为模型易读的文本格式。 - - 兼容 List 和 Dict 两种结构。 - - Args: - schema_data (Any): 原始 Schema 数据 (List or Dict)。 - - Returns: - str: 格式化后的 Schema 文本描述。 - """ + """将 Schema JSON 转换为模型易读的文本格式 (兼容 List/Dict)""" if not schema_data: return "" try: - # 1. 如果是字符串,先转成对象 if isinstance(schema_data, str): schema_data = json.loads(schema_data) - # 2. 如果是字典且包含 'tables' 键,提取出列表 if isinstance(schema_data, dict) and "tables" in schema_data: schema_data = schema_data["tables"] - # 3. 现在的 schema_data 应该是一个列表了,开始遍历 lines = [] if isinstance(schema_data, list): for table in schema_data: - # 兼容 table 可能是 dict 或者 object 的情况 if isinstance(table, dict): t_name = table.get("table_name", "unknown") cols = table.get("columns", []) else: - # 万一数据很怪,做一个容错 continue if isinstance(cols, str): @@ -116,31 +102,16 @@ def _format_schema_to_text(schema_data: Any) -> str: col_str = ", ".join(str(c) for c in cols) lines.append(f"Table: {t_name}, columns = [{col_str}]") else: - # 如果结构实在太乱,直接转字符串兜底 return str(schema_data) return "\n".join(lines) except Exception as e: log.error(f"Schema format error: {e}") - # 出错了也不要崩,把原始数据给 AI,看它能不能看懂 return str(schema_data) def _build_ai_messages(model_config: Dict, schema_text: str, question: str) -> List[Dict]: - """ - 根据模型类型构建对应的 Prompt 策略。 - - 解决“函数过长”问题,将 Prompt 逻辑抽离。 - - Args: - model_config (Dict): 模型配置字典。 - schema_text (str): Schema 文本描述。 - question (str): 用户问题。 - - Returns: - List[Dict]: 构建好的消息列表 (role/content)。 - """ + """构建 Prompt 策略""" if model_config["type"] == "local_finetune": - # 微调模型 (严格格式) prompt_content = f"""I want you to act as a SQL terminal in front of an database. Here is the schema: {schema_text} @@ -155,7 +126,6 @@ I want you to answer the following question. {"role": "user", "content": prompt_content} ] else: - # 通用大模型 (思维链与规则引导) system_prompt = f"""You are a generic SQL expert. Your task is to generate valid SQL queries based on the provided database schema and user question. @@ -175,24 +145,10 @@ Your task is to generate valid SQL queries based on the provided database schema ] async def _verify_session_ownership(db: AsyncSession, session_id: int, user_id: int) -> int: - """ - 验证会话所有权,防止越权访问 (IDOR)。 - - Args: - db (AsyncSession): 数据库会话。 - session_id (int): 会话 ID。 - user_id (int): 用户 ID。 - - Returns: - int: 关联的项目 ID (project_id)。 - - Raises: - HTTPException: 会话不存在、项目不存在或无权限时抛出。 - """ - # 联表查询:Session -> Project,检查 Project.user_id 是否匹配 + """验证会话所有权,防止越权""" stmt = ( select(SessionModel) - .options(selectinload(SessionModel.project)) # 预加载 Project 避免 N+1 + .options(selectinload(SessionModel.project)) .where(SessionModel.session_id == session_id) ) result = await db.execute(stmt) @@ -204,29 +160,15 @@ async def _verify_session_ownership(db: AsyncSession, session_id: int, user_id: if not session.project: raise HTTPException(status_code=404, detail="Project not found for this session") - # 【关键安全检查】 if session.project.user_id != user_id: log.warning(f"Security Alert: User {user_id} tried to access session {session_id} belonging to user {session.project.user_id}") raise HTTPException(status_code=403, detail="Permission denied: You do not own this session") return session.project_id -# ========================================================= -# 3. 核心业务逻辑 -# ========================================================= - async def call_ai_agent(schema_text: str, question: str, model_key: str = None) -> str: - """ - 调用 AI 接口生成 SQL。 - - Args: - schema_text (str): 数据库 Schema 描述。 - question (str): 用户问题。 - model_key (str, optional): 模型标识 Key。 - - Returns: - str: 生成的 SQL 语句。 - """ + """调用 AI 接口生成 SQL""" + # 1. 确定配置 if not model_key or model_key not in MODEL_REGISTRY: model_key = DEFAULT_MODEL @@ -234,47 +176,64 @@ async def call_ai_agent(schema_text: str, question: str, model_key: str = None) config = MODEL_REGISTRY[model_key] log.info(f"Using AI Model: {config['name']} ({config['model_id']})") - # 2. 构建 Prompt messages = _build_ai_messages(config, schema_text, question) - # 3. 构建请求 Payload + # 2. 构建 Payload payload = { "model": config["model_id"], "messages": messages, "temperature": 0.1, - "stream": False + "stream": False, + "max_tokens": 512, + # 【新增】告诉模型看到这些符号就闭嘴 + "stop": ["<|im_end|>", "<|im_start|>", "User:", "Assistant:"] } - - if config["type"] == "general_llm": - payload["max_tokens"] = 1024 headers = { "Authorization": f"Bearer {config['api_key']}", "Content-Type": "application/json" } - # 4. 执行网络请求 try: - async with httpx.AsyncClient(timeout=60) as client: + async with httpx.AsyncClient(timeout=300.0) as client: resp = await client.post(config["api_url"], json=payload, headers=headers) resp.raise_for_status() raw = resp.json() - # 解析响应 content = "" if "choices" in raw and len(raw["choices"]) > 0: content = raw["choices"][0]["message"]["content"] - # 清理结果 + # ===================================================== + # 【关键修复】清洗数据,截断废话 + # ===================================================== + # 1. 如果模型输出了停止符,只取前面的部分 + if "<|im_end|>" in content: + content = content.split("<|im_end|>")[0] + if "<|im_start|>" in content: + content = content.split("<|im_start|>")[0] + + # 2. 有时候模型会把 SQL 写在 Markdown 块里,先去 Markdown clean_sql = content.strip().replace("```sql", "").replace("```", "").strip() + + # 3. 如果还是有多行,且第一行就是完整的 SQL (以分号结尾),就只取第一行 + # 防止它在 SQL 后面通过换行继续自言自语 + if ";\n" in clean_sql: + clean_sql = clean_sql.split(";\n")[0] + ";" + elif clean_sql.count(";") > 1: + # 如果有多条 SQL,只取第一条 + clean_sql = clean_sql.split(";")[0] + ";" + return clean_sql except Exception as e: log.error(f"AI Call Error ({model_key}): {e}") - # 这里返回错误字符串是可以的,让用户知道 AI 挂了,而不是整个页面崩溃 return f"-- AI Service Error: {str(e)}" +# ========================================================= +# 3. 核心业务逻辑 (含模型记忆) +# ========================================================= async def process_chat( db: AsyncSession, @@ -285,34 +244,59 @@ async def process_chat( ) -> ChatResponse: """ 处理用户聊天请求的主流程。 - - Args: - db (AsyncSession): 数据库会话。 - session_id (int): 会话 ID。 - user_input (str): 用户输入。 - user_id (int): 用户 ID。 - selected_model (str, optional): 选择的模型。 - - Returns: - ChatResponse: 聊天响应对象。 """ - # 1. 安全检查:确认会话属于当前用户,并获取 project_id + # 1. 验证会话权限 project_id = await _verify_session_ownership(db, session_id, user_id) - # 2. 存用户消息 (先存库,保证有记录) + # 2. 获取 Session 对象以处理模型记忆逻辑 + stmt = select(SessionModel).where(SessionModel.session_id == session_id) + result = await db.execute(stmt) + session_obj = result.scalar_one_or_none() + + if not session_obj: + raise HTTPException(status_code=404, detail="Session lost") + + # ========================================================= + # 【核心逻辑优化】模型选择优先级策略 + # ========================================================= + final_model_key = DEFAULT_MODEL # 兜底 + + if selected_model: + # A. 如果用户本次明确指定了模型 -> 使用它,并更新到数据库(记忆) + final_model_key = selected_model + if session_obj.current_model != selected_model: + session_obj.current_model = selected_model + db.add(session_obj) + await db.commit() # 保存记忆 + log.info(f"Session {session_id} model switched to: {selected_model}") + + elif session_obj.current_model: + # B. 如果用户没指定,但数据库里有记忆 -> 使用记忆的模型 + final_model_key = session_obj.current_model + log.info(f"Session {session_id} using stored model: {final_model_key}") + + else: + # C. 既没指定也没记忆 -> 使用系统默认,并保存到数据库作为初始记忆 + final_model_key = DEFAULT_MODEL + session_obj.current_model = DEFAULT_MODEL + db.add(session_obj) + await db.commit() + # ========================================================= + + # 3. 存用户消息 await crud_message.create_message(db, session_id, user_input, role="user") - # 3. 获取 Schema 上下文 + # 4. 获取 Schema 上下文 project = await crud_project.get(db, project_id) if not project or not project.schema_definition: schema_text = "No schema defined." else: schema_text = _format_schema_to_text(project.schema_definition) - # 4. 调用 AI 生成 SQL - sql_text = await call_ai_agent(schema_text, user_input, model_key=selected_model) + # 5. 调用 AI 生成 SQL (使用记忆或指定的模型) + sql_text = await call_ai_agent(schema_text, user_input, model_key=final_model_key) - # 5. 简单解析 SQL 类型 + # 6. 简单解析 SQL 类型 sql_type = "UNKNOWN" try: if sql_text and not sql_text.startswith("--"): @@ -322,20 +306,20 @@ async def process_chat( except Exception: pass - # 6. 存 AI 回复消息 + # 7. 存 AI 回复消息 reply_content = f"已生成查询语句:\n{sql_text}" ai_message = await crud_message.create_message(db, session_id, reply_content, role="assistant") + # 8. 尝试执行 SQL data = [] - # 7. 尝试执行 SQL try: database_instance = await crud_database_instance.get(db, project.instance_id) - + # 注意:这里调用的是 execute_sql_with_user_check,它会检查 SQL 是否安全 data = await execute_sql_with_user_check(sql_text, sql_type, database_instance) except Exception as e: - log.info(f"SQL Execution Error: {str(e)}", exc_info=True) + log.info(f"SQL Execution Error (Safe to ignore if SQL is invalid): {str(e)}") - # 8. 构造响应 + # 9. 构造响应 return ChatResponse( message_id=ai_message.message_id, content=reply_content, diff --git a/src/backend/config.yaml b/src/backend/config.yaml index c57b955..1442e3c 100644 --- a/src/backend/config.yaml +++ b/src/backend/config.yaml @@ -104,7 +104,7 @@ dev: # SSL ssl: false # plugin - plugin: sha256_password + plugin: mysql_native_password # 数据库连接驱动 driver: mysql+aiomysql # sqlalchemy连接池配置 -- 2.34.1