|
|
|
|
@ -39,6 +39,10 @@ class HistoryReq(BaseModel):
|
|
|
|
|
def _get_username(u) -> str:
|
|
|
|
|
return getattr(u, "username", None) or (u.get("username") if isinstance(u, dict) else None) or "system"
|
|
|
|
|
|
|
|
|
|
def _get_internal_session_id(user, session_id: str) -> str:
|
|
|
|
|
uname = _get_username(user)
|
|
|
|
|
return f"{uname}:{session_id}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/ai/diagnose-repair")
|
|
|
|
|
async def diagnose_repair(req: DiagnoseRepairReq, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
|
|
|
|
|
@ -67,7 +71,8 @@ async def diagnose_repair(req: DiagnoseRepairReq, user=Depends(get_current_user)
|
|
|
|
|
@router.get("/ai/history")
|
|
|
|
|
async def get_history(sessionId: str, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
|
|
|
|
|
"""获取会话历史"""
|
|
|
|
|
stmt = select(ChatMessage).where(ChatMessage.session_id == sessionId).order_by(ChatMessage.created_at.asc())
|
|
|
|
|
internal_id = _get_internal_session_id(user, sessionId)
|
|
|
|
|
stmt = select(ChatMessage).where(ChatMessage.session_id == internal_id).order_by(ChatMessage.created_at.asc())
|
|
|
|
|
rows = (await db.execute(stmt)).scalars().all()
|
|
|
|
|
messages = [{"role": r.role, "content": r.content} for r in rows]
|
|
|
|
|
return {"messages": messages}
|
|
|
|
|
@ -75,10 +80,13 @@ async def get_history(sessionId: str, user=Depends(get_current_user), db: AsyncS
|
|
|
|
|
@router.post("/ai/chat")
|
|
|
|
|
async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
|
|
|
|
|
try:
|
|
|
|
|
session_stmt = select(ChatSession).where(ChatSession.id == req.sessionId)
|
|
|
|
|
internal_id = _get_internal_session_id(user, req.sessionId)
|
|
|
|
|
user_id = user.get("id") if isinstance(user, dict) else getattr(user, "id", None)
|
|
|
|
|
|
|
|
|
|
session_stmt = select(ChatSession).where(ChatSession.id == internal_id)
|
|
|
|
|
session = (await db.execute(session_stmt)).scalars().first()
|
|
|
|
|
if not session:
|
|
|
|
|
session = ChatSession(id=req.sessionId, user_id=getattr(user, "id", None), title=req.message[:20])
|
|
|
|
|
session = ChatSession(id=internal_id, user_id=user_id, title=req.message[:20])
|
|
|
|
|
db.add(session)
|
|
|
|
|
|
|
|
|
|
system_prompt = "You are a helpful Hadoop diagnostic assistant."
|
|
|
|
|
@ -88,7 +96,7 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
|
|
|
|
|
if req.context.get("node"):
|
|
|
|
|
system_prompt += f" You are currently analyzing node: {req.context['node']}."
|
|
|
|
|
|
|
|
|
|
hist_stmt = select(ChatMessage).where(ChatMessage.session_id == req.sessionId).order_by(ChatMessage.created_at.desc()).limit(12)
|
|
|
|
|
hist_stmt = select(ChatMessage).where(ChatMessage.session_id == internal_id).order_by(ChatMessage.created_at.desc()).limit(12)
|
|
|
|
|
hist_rows = (await db.execute(hist_stmt)).scalars().all()
|
|
|
|
|
hist_rows = hist_rows[::-1]
|
|
|
|
|
|
|
|
|
|
@ -97,7 +105,7 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
|
|
|
|
|
messages.append({"role": r.role, "content": r.content})
|
|
|
|
|
messages.append({"role": "user", "content": req.message})
|
|
|
|
|
|
|
|
|
|
user_msg = ChatMessage(session_id=req.sessionId, role="user", content=req.message)
|
|
|
|
|
user_msg = ChatMessage(session_id=internal_id, role="user", content=req.message)
|
|
|
|
|
db.add(user_msg)
|
|
|
|
|
|
|
|
|
|
llm = LLMClient()
|
|
|
|
|
@ -108,7 +116,7 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
|
|
|
|
|
chat_tools = [t for t in tools if t["function"]["name"] == "web_search"]
|
|
|
|
|
|
|
|
|
|
if req.stream and not web_search_enabled:
|
|
|
|
|
return await handle_streaming_chat(llm, messages, req.sessionId, db, tools=None)
|
|
|
|
|
return await handle_streaming_chat(llm, messages, internal_id, db, tools=None)
|
|
|
|
|
|
|
|
|
|
resp = await llm.chat(messages, tools=chat_tools, stream=False)
|
|
|
|
|
choices = resp.get("choices") or []
|
|
|
|
|
@ -140,21 +148,21 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if req.stream:
|
|
|
|
|
return await handle_streaming_chat(llm, messages, req.sessionId, db, tools=chat_tools)
|
|
|
|
|
return await handle_streaming_chat(llm, messages, internal_id, db, tools=chat_tools)
|
|
|
|
|
else:
|
|
|
|
|
resp = await llm.chat(messages, tools=chat_tools, stream=False)
|
|
|
|
|
choices = resp.get("choices") or []
|
|
|
|
|
if not choices:
|
|
|
|
|
raise HTTPException(status_code=502, detail="llm_unavailable_after_tool")
|
|
|
|
|
raise HTTPException(status_code=502, detail="llm_unavailable_after_tool")
|
|
|
|
|
msg = choices[0].get("message") or {}
|
|
|
|
|
else:
|
|
|
|
|
if req.stream:
|
|
|
|
|
return await handle_streaming_chat(llm, messages, req.sessionId, db, tools=chat_tools)
|
|
|
|
|
return await handle_streaming_chat(llm, messages, internal_id, db, tools=chat_tools)
|
|
|
|
|
|
|
|
|
|
reply = msg.get("content") or ""
|
|
|
|
|
reasoning = msg.get("reasoning_content") or ""
|
|
|
|
|
|
|
|
|
|
asst_msg = ChatMessage(session_id=req.sessionId, role="assistant", content=reply)
|
|
|
|
|
asst_msg = ChatMessage(session_id=internal_id, role="assistant", content=reply)
|
|
|
|
|
db.add(asst_msg)
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
|