|
|
|
|
@ -1,40 +1,63 @@
|
|
|
|
|
import os
|
|
|
|
|
from typing import Any, Dict, Iterable, List, Optional
|
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
import httpx
|
|
|
|
|
except Exception: # pragma: no cover
|
|
|
|
|
httpx = None
|
|
|
|
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
|
|
|
|
class LLMClient:
|
|
|
|
|
"""供应商大模型客户端,封装聊天与函数调用。
|
|
|
|
|
_DEFAULT_ENDPOINTS: Dict[str, str] = {
|
|
|
|
|
"openai": "https://api.openai.com/v1/chat/completions",
|
|
|
|
|
"siliconflow": "https://api.siliconflow.cn/v1/chat/completions",
|
|
|
|
|
"deepseek": "https://api.deepseek.com/v1/chat/completions",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_DEFAULT_MODELS: Dict[str, str] = {
|
|
|
|
|
"openai": "gpt-4o-mini",
|
|
|
|
|
"siliconflow": "deepseek-ai/DeepSeek-R1",
|
|
|
|
|
"deepseek": "deepseek-ai/DeepSeek-R1",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _clean_str(s: str) -> str:
|
|
|
|
|
if s is None:
|
|
|
|
|
return ""
|
|
|
|
|
s = s.strip()
|
|
|
|
|
if (s.startswith("`") and s.endswith("`")) or (s.startswith('"') and s.endswith('"')) or (s.startswith("'") and s.endswith("'")):
|
|
|
|
|
s = s[1:-1].strip()
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
def _normalize_endpoint(ep: str) -> str:
|
|
|
|
|
if not ep:
|
|
|
|
|
return ep
|
|
|
|
|
s = _clean_str(ep).rstrip("/")
|
|
|
|
|
if s.endswith("/v1"):
|
|
|
|
|
return s + "/chat/completions"
|
|
|
|
|
if s.endswith("/chat/completions"):
|
|
|
|
|
return s
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
- 通过环境变量配置:LLM_PROVIDER/LLM_ENDPOINT/LLM_MODEL/LLM_API_KEY
|
|
|
|
|
- 提供 chat(messages, tools, stream) 接口,返回供应商原始响应字典
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
class LLMClient:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.provider = os.getenv("LLM_PROVIDER", "openai")
|
|
|
|
|
self.endpoint = os.getenv("LLM_ENDPOINT", "https://api.openai.com/v1/chat/completions")
|
|
|
|
|
self.model = os.getenv("LLM_MODEL", "gpt-4o-mini")
|
|
|
|
|
self.api_key = os.getenv("LLM_API_KEY", "")
|
|
|
|
|
self.provider = os.getenv("LLM_PROVIDER", "openai").strip().lower()
|
|
|
|
|
raw_endpoint = os.getenv("LLM_ENDPOINT", "") or _DEFAULT_ENDPOINTS.get(self.provider, _DEFAULT_ENDPOINTS["openai"])
|
|
|
|
|
self.endpoint = _normalize_endpoint(raw_endpoint)
|
|
|
|
|
self.model = _clean_str(os.getenv("LLM_MODEL", _DEFAULT_MODELS.get(self.provider, "gpt-4o-mini")))
|
|
|
|
|
api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY") or os.getenv("DEEPSEEK_API_KEY") or os.getenv("SILICONFLOW_API_KEY") or ""
|
|
|
|
|
self.api_key = api_key
|
|
|
|
|
self.simulate = os.getenv("LLM_SIMULATE", "false").lower() == "true"
|
|
|
|
|
self.timeout = int(os.getenv("LLM_TIMEOUT", "30"))
|
|
|
|
|
|
|
|
|
|
def _headers(self) -> Dict[str, str]:
|
|
|
|
|
"""构造 HTTP 请求头。"""
|
|
|
|
|
return {
|
|
|
|
|
"Authorization": f"Bearer {self.api_key}" if self.api_key else "",
|
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def chat(self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None, stream: bool = False) -> Dict[str, Any]:
|
|
|
|
|
"""调用供应商聊天接口,支持函数调用工具描述。
|
|
|
|
|
|
|
|
|
|
- messages:OpenAI 兼容的消息列表
|
|
|
|
|
- tools:OpenAI 兼容的函数调用工具定义(JSON Schema)
|
|
|
|
|
- stream:是否流式;此处返回一次性结果,SSE/WebSocket 由路由层实现
|
|
|
|
|
"""
|
|
|
|
|
if self.simulate or httpx is None:
|
|
|
|
|
return {
|
|
|
|
|
"choices": [
|
|
|
|
|
@ -51,8 +74,7 @@ class LLMClient:
|
|
|
|
|
if tools:
|
|
|
|
|
payload["tools"] = tools
|
|
|
|
|
payload["tool_choice"] = "auto"
|
|
|
|
|
with httpx.Client(timeout=30) as client:
|
|
|
|
|
with httpx.Client(timeout=self.timeout) as client:
|
|
|
|
|
resp = client.post(self.endpoint, headers=self._headers(), json=payload)
|
|
|
|
|
resp.raise_for_status()
|
|
|
|
|
return resp.json()
|
|
|
|
|
|
|
|
|
|
|