流式输出优化

pull/48/head
echo 4 months ago
parent a570f82e97
commit 742b1d06e3

@ -75,21 +75,12 @@ async def get_history(sessionId: str, user=Depends(get_current_user), db: AsyncS
@router.post("/ai/chat")
async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
try:
# 1. Ensure Session Exists
session_stmt = select(ChatSession).where(ChatSession.id == req.sessionId)
session = (await db.execute(session_stmt)).scalars().first()
if not session:
session = ChatSession(id=req.sessionId, user_id=getattr(user, "id", None), title=req.message[:20])
db.add(session)
await db.commit()
await db.refresh(session)
# 2. Save User Message
user_msg = ChatMessage(session_id=req.sessionId, role="user", content=req.message)
db.add(user_msg)
await db.commit()
# 3. Build Context & History for LLM
system_prompt = "You are a helpful Hadoop diagnostic assistant."
if req.context:
if req.context.get("agent"):
@ -97,29 +88,35 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
if req.context.get("node"):
system_prompt += f" You are currently analyzing node: {req.context['node']}."
hist_stmt = select(ChatMessage).where(ChatMessage.session_id == req.sessionId).order_by(ChatMessage.created_at.desc()).limit(20)
hist_stmt = select(ChatMessage).where(ChatMessage.session_id == req.sessionId).order_by(ChatMessage.created_at.desc()).limit(12)
hist_rows = (await db.execute(hist_stmt)).scalars().all()
hist_rows = hist_rows[::-1]
messages = [{"role": "system", "content": system_prompt}]
for r in hist_rows:
messages.append({"role": r.role, "content": r.content})
messages.append({"role": "user", "content": req.message})
# 4. Call LLM
user_msg = ChatMessage(session_id=req.sessionId, role="user", content=req.message)
db.add(user_msg)
llm = LLMClient()
tools = openai_tools_schema()
chat_tools = [t for t in tools if t["function"]["name"] == "web_search"]
web_search_enabled = bool(req.context and req.context.get("webSearch"))
chat_tools = None
if web_search_enabled:
tools = openai_tools_schema()
chat_tools = [t for t in tools if t["function"]["name"] == "web_search"]
# We always do the first call without streaming to handle tool calls easily
if req.stream and not web_search_enabled:
return await handle_streaming_chat(llm, messages, req.sessionId, db, tools=None)
resp = await llm.chat(messages, tools=chat_tools, stream=False)
choices = resp.get("choices") or []
if not choices:
raise HTTPException(status_code=502, detail="llm_unavailable")
msg = choices[0].get("message") or {}
tool_calls = msg.get("tool_calls") or []
# Tool Loop
if tool_calls:
messages.append(msg)
for tc in tool_calls:
@ -142,9 +139,8 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
"content": json.dumps(tool_result, ensure_ascii=False)
})
# After tool calls, we decide whether to stream the final response
if req.stream:
return await handle_streaming_chat(llm, messages, req.sessionId, db)
return await handle_streaming_chat(llm, messages, req.sessionId, db, tools=chat_tools)
else:
resp = await llm.chat(messages, tools=chat_tools, stream=False)
choices = resp.get("choices") or []
@ -152,15 +148,9 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
raise HTTPException(status_code=502, detail="llm_unavailable_after_tool")
msg = choices[0].get("message") or {}
else:
# No tool calls initially
if req.stream:
# If we want to stream the first response, we need to call it again with stream=True
# because we already called it with stream=False to check for tool calls.
# Alternatively, we could have streamed the first call and checked for tool calls in the stream.
# But for simplicity, we just re-call it.
return await handle_streaming_chat(llm, messages, req.sessionId, db)
return await handle_streaming_chat(llm, messages, req.sessionId, db, tools=chat_tools)
# Normal (non-streaming) response
reply = msg.get("content") or ""
reasoning = msg.get("reasoning_content") or ""
@ -173,43 +163,35 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
except HTTPException:
raise
except Exception as e:
print(f"Chat Error: {e}")
raise HTTPException(status_code=500, detail="server_error")
async def handle_streaming_chat(llm: LLMClient, messages: list, session_id: str, db: AsyncSession):
async def handle_streaming_chat(llm: LLMClient, messages: list, session_id: str, db: AsyncSession, tools=None):
async def event_generator():
full_reply = ""
full_reasoning = ""
# Start streaming from LLM
stream_gen = await llm.chat(messages, stream=True)
async for chunk in stream_gen:
choices = chunk.get("choices") or []
if not choices:
continue
delta = choices[0].get("delta") or {}
content = delta.get("content") or ""
reasoning = delta.get("reasoning_content") or ""
if content:
full_reply += content
if reasoning:
full_reasoning += reasoning
try:
stream_gen = await llm.chat(messages, tools=tools, stream=True)
async for chunk in stream_gen:
choices = chunk.get("choices") or []
if not choices:
continue
yield f"data: {json.dumps({'content': content, 'reasoning': reasoning}, ensure_ascii=False)}\n\n"
# After stream ends, save to DB
if full_reply:
# Note: We need a new session or be careful with the current one
# as this runs in a generator which might outlive the request scope
# but FastAPI handles this correctly if we use the db from depends.
# However, committed changes in a generator might be tricky.
# Let's use a fresh session if possible or just commit here.
delta = choices[0].get("delta") or {}
content = delta.get("content") or ""
reasoning = delta.get("reasoning_content") or ""
if content:
full_reply += content
if reasoning:
full_reasoning += reasoning
yield f"data: {json.dumps({'content': content, 'reasoning': reasoning}, ensure_ascii=False)}\n\n"
finally:
try:
asst_msg = ChatMessage(session_id=session_id, role="assistant", content=full_reply)
db.add(asst_msg)
if full_reply:
asst_msg = ChatMessage(session_id=session_id, role="assistant", content=full_reply)
db.add(asst_msg)
await db.commit()
except Exception as e:
print(f"Error saving stream to DB: {e}")

@ -10,6 +10,20 @@ except Exception: # pragma: no cover
load_dotenv()
_shared_async_client: Any = None
def _get_async_client() -> Any:
global _shared_async_client
if httpx is None:
return None
if _shared_async_client is None:
_shared_async_client = httpx.AsyncClient(
headers={},
limits=httpx.Limits(max_keepalive_connections=20, max_connections=50),
http2=True,
)
return _shared_async_client
_DEFAULT_ENDPOINTS: Dict[str, str] = {
"openai": "https://api.openai.com/v1/chat/completions",
"siliconflow": "https://api.siliconflow.cn/v1/chat/completions",
@ -82,8 +96,8 @@ class LLMClient:
if stream:
async def _stream_gen():
async with httpx.AsyncClient(timeout=self.timeout) as client:
async with client.stream("POST", self.endpoint, headers=self._headers(), json=payload) as resp:
client = _get_async_client()
async with client.stream("POST", self.endpoint, headers=self._headers(), json=payload, timeout=self.timeout) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if not line or not line.startswith("data: "):
@ -97,7 +111,7 @@ class LLMClient:
continue
return _stream_gen()
async with httpx.AsyncClient(timeout=self.timeout) as client:
resp = await client.post(self.endpoint, headers=self._headers(), json=payload)
resp.raise_for_status()
return resp.json()
client = _get_async_client()
resp = await client.post(self.endpoint, headers=self._headers(), json=payload, timeout=self.timeout)
resp.raise_for_status()
return resp.json()

@ -21,7 +21,7 @@ async def main():
print(f"Endpoint: {llm.endpoint}")
print(f"Model: {llm.model}")
messages = [{"role": "user", "content": "你好,请简单介绍一下你自己。"}]
messages = [{"role": "user", "content": ""}]
print("Sending streaming request...")
stream_gen = await llm.chat(messages, stream=True)

Loading…
Cancel
Save