From bdcd7337f2ead9737320df34e44d7b59a404fcc8 Mon Sep 17 00:00:00 2001 From: echo Date: Tue, 30 Dec 2025 13:45:14 +0000 Subject: [PATCH] =?UTF-8?q?=E8=81=8A=E5=A4=A9=E8=AE=B0=E5=BD=95=E5=88=86?= =?UTF-8?q?=E7=94=A8=E6=88=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/routers/ai.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/backend/app/routers/ai.py b/backend/app/routers/ai.py index 224b290..8dcdc5b 100644 --- a/backend/app/routers/ai.py +++ b/backend/app/routers/ai.py @@ -39,6 +39,10 @@ class HistoryReq(BaseModel): def _get_username(u) -> str: return getattr(u, "username", None) or (u.get("username") if isinstance(u, dict) else None) or "system" +def _get_internal_session_id(user, session_id: str) -> str: + uname = _get_username(user) + return f"{uname}:{session_id}" + @router.post("/ai/diagnose-repair") async def diagnose_repair(req: DiagnoseRepairReq, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)): @@ -67,7 +71,8 @@ async def diagnose_repair(req: DiagnoseRepairReq, user=Depends(get_current_user) @router.get("/ai/history") async def get_history(sessionId: str, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)): """获取会话历史""" - stmt = select(ChatMessage).where(ChatMessage.session_id == sessionId).order_by(ChatMessage.created_at.asc()) + internal_id = _get_internal_session_id(user, sessionId) + stmt = select(ChatMessage).where(ChatMessage.session_id == internal_id).order_by(ChatMessage.created_at.asc()) rows = (await db.execute(stmt)).scalars().all() messages = [{"role": r.role, "content": r.content} for r in rows] return {"messages": messages} @@ -75,10 +80,13 @@ async def get_history(sessionId: str, user=Depends(get_current_user), db: AsyncS @router.post("/ai/chat") async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)): try: - session_stmt = select(ChatSession).where(ChatSession.id == req.sessionId) + internal_id = _get_internal_session_id(user, req.sessionId) + user_id = user.get("id") if isinstance(user, dict) else getattr(user, "id", None) + + session_stmt = select(ChatSession).where(ChatSession.id == internal_id) session = (await db.execute(session_stmt)).scalars().first() if not session: - session = ChatSession(id=req.sessionId, user_id=getattr(user, "id", None), title=req.message[:20]) + session = ChatSession(id=internal_id, user_id=user_id, title=req.message[:20]) db.add(session) system_prompt = "You are a helpful Hadoop diagnostic assistant." @@ -88,7 +96,7 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession if req.context.get("node"): system_prompt += f" You are currently analyzing node: {req.context['node']}." - hist_stmt = select(ChatMessage).where(ChatMessage.session_id == req.sessionId).order_by(ChatMessage.created_at.desc()).limit(12) + hist_stmt = select(ChatMessage).where(ChatMessage.session_id == internal_id).order_by(ChatMessage.created_at.desc()).limit(12) hist_rows = (await db.execute(hist_stmt)).scalars().all() hist_rows = hist_rows[::-1] @@ -97,7 +105,7 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession messages.append({"role": r.role, "content": r.content}) messages.append({"role": "user", "content": req.message}) - user_msg = ChatMessage(session_id=req.sessionId, role="user", content=req.message) + user_msg = ChatMessage(session_id=internal_id, role="user", content=req.message) db.add(user_msg) llm = LLMClient() @@ -108,7 +116,7 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession chat_tools = [t for t in tools if t["function"]["name"] == "web_search"] if req.stream and not web_search_enabled: - return await handle_streaming_chat(llm, messages, req.sessionId, db, tools=None) + return await handle_streaming_chat(llm, messages, internal_id, db, tools=None) resp = await llm.chat(messages, tools=chat_tools, stream=False) choices = resp.get("choices") or [] @@ -140,21 +148,21 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession }) if req.stream: - return await handle_streaming_chat(llm, messages, req.sessionId, db, tools=chat_tools) + return await handle_streaming_chat(llm, messages, internal_id, db, tools=chat_tools) else: resp = await llm.chat(messages, tools=chat_tools, stream=False) choices = resp.get("choices") or [] if not choices: - raise HTTPException(status_code=502, detail="llm_unavailable_after_tool") + raise HTTPException(status_code=502, detail="llm_unavailable_after_tool") msg = choices[0].get("message") or {} else: if req.stream: - return await handle_streaming_chat(llm, messages, req.sessionId, db, tools=chat_tools) + return await handle_streaming_chat(llm, messages, internal_id, db, tools=chat_tools) reply = msg.get("content") or "" reasoning = msg.get("reasoning_content") or "" - asst_msg = ChatMessage(session_id=req.sessionId, role="assistant", content=reply) + asst_msg = ChatMessage(session_id=internal_id, role="assistant", content=reply) db.add(asst_msg) await db.commit()