You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ErrorDetecting/backend/app/services/llm.py

121 lines
4.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import json
from typing import Any, Dict, Iterable, List, Optional
from dotenv import load_dotenv
try:
import httpx
except Exception: # pragma: no cover
httpx = None
load_dotenv()
_shared_async_client: Any = None
def _get_async_client() -> Any:
global _shared_async_client
if httpx is None:
return None
if _shared_async_client is None:
_shared_async_client = httpx.AsyncClient(
headers={},
limits=httpx.Limits(max_keepalive_connections=20, max_connections=50),
http2=True,
)
return _shared_async_client
_DEFAULT_ENDPOINTS: Dict[str, str] = {
"openai": "https://api.openai.com/v1/chat/completions",
"siliconflow": "https://api.siliconflow.cn/v1/chat/completions",
"deepseek": "https://api.deepseek.com/v1/chat/completions",
}
_DEFAULT_MODELS: Dict[str, str] = {
"openai": "gpt-4o-mini",
"siliconflow": "deepseek-ai/DeepSeek-V3",
"deepseek": "deepseek-v3",
"r1": "Pro/deepseek-ai/DeepSeek-R1",
}
def _clean_str(s: str) -> str:
if s is None:
return ""
s = s.strip()
if (s.startswith("`") and s.endswith("`")) or (s.startswith('"') and s.endswith('"')) or (s.startswith("'") and s.endswith("'")):
s = s[1:-1].strip()
return s
def _normalize_endpoint(ep: str) -> str:
if not ep:
return ep
s = _clean_str(ep).rstrip("/")
if s.endswith("/v1"):
return s + "/chat/completions"
if s.endswith("/chat/completions"):
return s
return s
class LLMClient:
def __init__(self):
self.provider = os.getenv("LLM_PROVIDER", "openai").strip().lower()
raw_endpoint = os.getenv("LLM_ENDPOINT", "") or _DEFAULT_ENDPOINTS.get(self.provider, _DEFAULT_ENDPOINTS["openai"])
self.endpoint = _normalize_endpoint(raw_endpoint)
self.model = _clean_str(os.getenv("LLM_MODEL", _DEFAULT_MODELS.get(self.provider, "gpt-4o-mini")))
api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY") or os.getenv("DEEPSEEK_API_KEY") or os.getenv("SILICONFLOW_API_KEY") or ""
self.api_key = api_key
self.simulate = os.getenv("LLM_SIMULATE", "false").lower() == "true"
self.timeout = int(os.getenv("LLM_TIMEOUT", "300"))
def _headers(self) -> Dict[str, str]:
return {
"Authorization": f"Bearer {self.api_key}" if self.api_key else "",
"Content-Type": "application/json",
}
async def chat(self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None, stream: bool = False, model: Optional[str] = None) -> Any:
if self.simulate or httpx is None:
if stream:
async def _sim_stream():
yield {"choices": [{"delta": {"content": "模拟流式输出检测到错误日志建议重启或kill相关进程"}, "index": 0}]}
return _sim_stream()
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "模拟输出检测到错误日志建议重启或kill相关进程",
"tool_calls": [],
}
}
]
}
target_model = model or self.model
payload: Dict[str, Any] = {"model": target_model, "messages": messages, "stream": stream}
if tools:
payload["tools"] = tools
payload["tool_choice"] = "auto"
if stream:
async def _stream_gen():
client = _get_async_client()
async with client.stream("POST", self.endpoint, headers=self._headers(), json=payload, timeout=self.timeout) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if not line or not line.startswith("data: "):
continue
data_str = line[6:].strip()
if data_str == "[DONE]":
break
try:
yield json.loads(data_str)
except:
continue
return _stream_gen()
client = _get_async_client()
resp = await client.post(self.endpoint, headers=self._headers(), json=payload, timeout=self.timeout)
resp.raise_for_status()
return resp.json()