You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
124 lines
5.6 KiB
124 lines
5.6 KiB
import asyncio
|
|
import os
|
|
import sys
|
|
|
|
# Add backend directory to sys.path to import app modules
|
|
# Current file: backend/tests/test_llm.py
|
|
# Parent: backend/tests
|
|
# Grandparent: backend
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from app.services.llm import LLMClient
|
|
from app.services.ops_tools import openai_tools_schema, tool_web_search, tool_start_cluster, tool_stop_cluster
|
|
from app.db import SessionLocal
|
|
from dotenv import load_dotenv
|
|
import json
|
|
|
|
async def main():
|
|
# Load .env from backend directory
|
|
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env")
|
|
load_dotenv(env_path)
|
|
|
|
print("Testing LLMClient with REAL Tools...")
|
|
try:
|
|
llm = LLMClient()
|
|
print(f"Provider: {llm.provider}")
|
|
print(f"Endpoint: {llm.endpoint}")
|
|
print(f"Model: {llm.model}")
|
|
print(f"Timeout: {llm.timeout}")
|
|
|
|
messages = [{"role": "user", "content": "停止集群 5c43a9c7-e2a9-4756-b75d-6813ac55d3ba"}]
|
|
|
|
# 1. Get tools definition
|
|
chat_tools = openai_tools_schema()
|
|
|
|
print(f"Tools loaded: {[t['function']['name'] for t in chat_tools]}")
|
|
|
|
print("Sending initial request...")
|
|
resp = await llm.chat(messages, tools=chat_tools)
|
|
|
|
if "choices" in resp and resp["choices"]:
|
|
msg = resp["choices"][0].get("message", {})
|
|
tool_calls = msg.get("tool_calls")
|
|
|
|
if tool_calls:
|
|
print(f"Tool calls triggered: {len(tool_calls)}")
|
|
# Append assistant message with tool_calls
|
|
messages.append(msg)
|
|
|
|
async with SessionLocal() as db:
|
|
for tc in tool_calls:
|
|
fn = tc.get("function", {})
|
|
name = fn.get("name")
|
|
args_str = fn.get("arguments", "{}")
|
|
print(f"Executing REAL tool: {name} with args: {args_str}")
|
|
|
|
if name == "web_search":
|
|
try:
|
|
args = json.loads(args_str)
|
|
tool_result = await tool_web_search(args.get("query"), args.get("max_results", 5))
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tc.get("id"),
|
|
"name": name,
|
|
"content": json.dumps(tool_result, ensure_ascii=False)
|
|
})
|
|
print("Tool execution completed.")
|
|
except Exception as e:
|
|
print(f"Tool execution failed: {e}")
|
|
elif name == "start_cluster":
|
|
try:
|
|
args = json.loads(args_str)
|
|
cluster_uuid = args.get("cluster_uuid")
|
|
# Execute REAL tool
|
|
tool_result = await tool_start_cluster(db, "admin", cluster_uuid)
|
|
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tc.get("id"),
|
|
"name": name,
|
|
"content": json.dumps(tool_result, ensure_ascii=False)
|
|
})
|
|
print(f"REAL tool start_cluster execution completed: {tool_result.get('status')}")
|
|
except Exception as e:
|
|
print(f"REAL tool execution failed: {e}")
|
|
elif name == "stop_cluster":
|
|
try:
|
|
args = json.loads(args_str)
|
|
cluster_uuid = args.get("cluster_uuid")
|
|
# Execute REAL tool
|
|
tool_result = await tool_stop_cluster(db, "admin", cluster_uuid)
|
|
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tc.get("id"),
|
|
"name": name,
|
|
"content": json.dumps(tool_result, ensure_ascii=False)
|
|
})
|
|
print(f"REAL tool stop_cluster execution completed: {tool_result.get('status')}")
|
|
except Exception as e:
|
|
print(f"REAL tool execution failed: {e}")
|
|
|
|
# 2. Send follow-up request with tool results
|
|
print("Sending follow-up request...")
|
|
resp = await llm.chat(messages, tools=chat_tools)
|
|
if "choices" in resp and resp["choices"]:
|
|
final_msg = resp["choices"][0].get("message", {})
|
|
print("\nFinal Reply:")
|
|
print(final_msg.get('content'))
|
|
if "reasoning_content" in final_msg:
|
|
print(f"\nReasoning:\n{final_msg.get('reasoning_content')}")
|
|
else:
|
|
print("No tool calls triggered.")
|
|
print(f"Reply: {msg.get('content')}")
|
|
else:
|
|
print(resp)
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
traceback.print_exc()
|
|
print(f"Error: {repr(e)}")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|