聊天记录分用户

pull/49/head
echo 2 weeks ago
parent b463bd9b34
commit bdcd7337f2

@ -39,6 +39,10 @@ class HistoryReq(BaseModel):
def _get_username(u) -> str:
return getattr(u, "username", None) or (u.get("username") if isinstance(u, dict) else None) or "system"
def _get_internal_session_id(user, session_id: str) -> str:
uname = _get_username(user)
return f"{uname}:{session_id}"
@router.post("/ai/diagnose-repair")
async def diagnose_repair(req: DiagnoseRepairReq, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
@ -67,7 +71,8 @@ async def diagnose_repair(req: DiagnoseRepairReq, user=Depends(get_current_user)
@router.get("/ai/history")
async def get_history(sessionId: str, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""获取会话历史"""
stmt = select(ChatMessage).where(ChatMessage.session_id == sessionId).order_by(ChatMessage.created_at.asc())
internal_id = _get_internal_session_id(user, sessionId)
stmt = select(ChatMessage).where(ChatMessage.session_id == internal_id).order_by(ChatMessage.created_at.asc())
rows = (await db.execute(stmt)).scalars().all()
messages = [{"role": r.role, "content": r.content} for r in rows]
return {"messages": messages}
@ -75,10 +80,13 @@ async def get_history(sessionId: str, user=Depends(get_current_user), db: AsyncS
@router.post("/ai/chat")
async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
try:
session_stmt = select(ChatSession).where(ChatSession.id == req.sessionId)
internal_id = _get_internal_session_id(user, req.sessionId)
user_id = user.get("id") if isinstance(user, dict) else getattr(user, "id", None)
session_stmt = select(ChatSession).where(ChatSession.id == internal_id)
session = (await db.execute(session_stmt)).scalars().first()
if not session:
session = ChatSession(id=req.sessionId, user_id=getattr(user, "id", None), title=req.message[:20])
session = ChatSession(id=internal_id, user_id=user_id, title=req.message[:20])
db.add(session)
system_prompt = "You are a helpful Hadoop diagnostic assistant."
@ -88,7 +96,7 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
if req.context.get("node"):
system_prompt += f" You are currently analyzing node: {req.context['node']}."
hist_stmt = select(ChatMessage).where(ChatMessage.session_id == req.sessionId).order_by(ChatMessage.created_at.desc()).limit(12)
hist_stmt = select(ChatMessage).where(ChatMessage.session_id == internal_id).order_by(ChatMessage.created_at.desc()).limit(12)
hist_rows = (await db.execute(hist_stmt)).scalars().all()
hist_rows = hist_rows[::-1]
@ -97,7 +105,7 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
messages.append({"role": r.role, "content": r.content})
messages.append({"role": "user", "content": req.message})
user_msg = ChatMessage(session_id=req.sessionId, role="user", content=req.message)
user_msg = ChatMessage(session_id=internal_id, role="user", content=req.message)
db.add(user_msg)
llm = LLMClient()
@ -108,7 +116,7 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
chat_tools = [t for t in tools if t["function"]["name"] == "web_search"]
if req.stream and not web_search_enabled:
return await handle_streaming_chat(llm, messages, req.sessionId, db, tools=None)
return await handle_streaming_chat(llm, messages, internal_id, db, tools=None)
resp = await llm.chat(messages, tools=chat_tools, stream=False)
choices = resp.get("choices") or []
@ -140,21 +148,21 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
})
if req.stream:
return await handle_streaming_chat(llm, messages, req.sessionId, db, tools=chat_tools)
return await handle_streaming_chat(llm, messages, internal_id, db, tools=chat_tools)
else:
resp = await llm.chat(messages, tools=chat_tools, stream=False)
choices = resp.get("choices") or []
if not choices:
raise HTTPException(status_code=502, detail="llm_unavailable_after_tool")
raise HTTPException(status_code=502, detail="llm_unavailable_after_tool")
msg = choices[0].get("message") or {}
else:
if req.stream:
return await handle_streaming_chat(llm, messages, req.sessionId, db, tools=chat_tools)
return await handle_streaming_chat(llm, messages, internal_id, db, tools=chat_tools)
reply = msg.get("content") or ""
reasoning = msg.get("reasoning_content") or ""
asst_msg = ChatMessage(session_id=req.sessionId, role="assistant", content=reply)
asst_msg = ChatMessage(session_id=internal_id, role="assistant", content=reply)
db.add(asst_msg)
await db.commit()

Loading…
Cancel
Save