You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
RGproject/AI-Writing-main/llm-api/utils.py

163 lines
4.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""LLM API工具函数模块
提供与原有代码兼容的便捷函数
"""
from typing import Any, Optional, Type, Callable, Dict, Tuple
from pydantic import BaseModel
try:
from .client import LLMClient
from .models import get_model_info, LLMModel
from .exceptions import LLMAPIError
except ImportError:
from client import LLMClient
from models import get_model_info, LLMModel
from exceptions import LLMAPIError
# 全局客户端实例
_global_client = None
def get_client() -> LLMClient:
"""获取全局LLM客户端实例"""
global _global_client
if _global_client is None:
_global_client = LLMClient()
return _global_client
def call_llm(
prompt: Any,
pydantic_model: Type[BaseModel],
agent_name: Optional[str] = None,
state: Optional[Any] = None,
max_retries: int = 3,
default_factory: Optional[Callable] = None,
model_name: Optional[str] = None,
provider: Optional[str] = None,
) -> BaseModel:
"""
调用LLM并返回结构化输出兼容原有接口
Args:
prompt: 提示词
pydantic_model: Pydantic模型类
agent_name: 代理名称用于从state中提取模型配置
state: 状态对象
max_retries: 最大重试次数
default_factory: 默认值工厂函数
model_name: 模型名称
provider: 模型提供商
Returns:
结构化输出实例
"""
client = get_client()
# 从state中提取模型配置如果提供
if state and agent_name and not model_name:
extracted_model_name, extracted_provider = get_agent_model_config(state, agent_name)
model_name = model_name or extracted_model_name
provider = provider or extracted_provider
# 使用默认值
model_name = model_name or "gpt-4o"
provider = provider or "OpenAI"
try:
return client.chat_with_structured_output(
message=str(prompt),
pydantic_model=pydantic_model,
model=model_name,
provider=provider,
max_retries=max_retries
)
except Exception as e:
print(f"LLM调用出错: {e}")
if default_factory:
return default_factory()
return client._create_default_response(pydantic_model)
def get_agent_model_config(state: Any, agent_name: str) -> Tuple[str, str]:
"""
从状态中获取代理的模型配置(兼容原有接口)
Args:
state: 状态对象
agent_name: 代理名称
Returns:
(model_name, provider) 元组
"""
try:
request = state.get("metadata", {}).get("request")
if agent_name == 'portfolio_manager':
# 从state metadata中获取模型和提供商
model_name = state.get("metadata", {}).get("model_name", "gpt-4o")
provider = state.get("metadata", {}).get("model_provider", "OpenAI")
return model_name, provider
if request and hasattr(request, 'get_agent_model_config'):
# 获取代理特定的模型配置
model_name, provider = request.get_agent_model_config(agent_name)
return model_name, provider.value if hasattr(provider, 'value') else str(provider)
# 回退到全局配置
model_name = state.get("metadata", {}).get("model_name", "gpt-4o")
provider = state.get("metadata", {}).get("model_provider", "OpenAI")
# 转换枚举为字符串
if hasattr(provider, 'value'):
provider = provider.value
return model_name, provider
except Exception:
# 如果出错,返回默认值
return "gpt-4o", "OpenAI"
def create_default_response(model_class: Type[BaseModel]) -> BaseModel:
"""创建默认响应(兼容原有接口)"""
client = get_client()
return client._create_default_response(model_class)
def extract_json_from_response(content: str) -> Optional[Dict[str, Any]]:
"""从响应中提取JSON兼容原有接口"""
client = get_client()
return client._extract_json_from_response(content)
# 便捷函数
def chat(
message: str,
model: Optional[str] = None,
provider: Optional[str] = None,
system_message: Optional[str] = None,
**kwargs
) -> str:
"""简单的聊天接口"""
client = get_client()
return client.chat(
message=message,
model=model,
provider=provider,
system_message=system_message,
**kwargs
)
def get_model(model_name: str, provider: str):
"""获取模型实例(兼容原有接口)"""
client = get_client()
return client.get_model(model_name, provider)
def list_models():
"""列出所有可用模型"""
client = get_client()
return client.list_available_models()