feat: add benchmark

pull/1666/head
gitlawr 12 months ago committed by Lawrence Li
parent 291ff064f0
commit b4bc99802c

@ -0,0 +1,307 @@
import asyncio
import time
import httpx
import numpy
import logging
import argparse
import json
import random
from openai import AsyncOpenAI
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# Avoid client side connection error: https://github.com/encode/httpx/discussions/3084
http_client = httpx.AsyncClient(
limits=httpx.Limits(
max_connections=10000, max_keepalive_connections=10000, keepalive_expiry=30
)
)
SAMPLE_PROMPTS = [
"Explain how blockchain technology works, and provide a real-world example of its application outside of cryptocurrency.",
"Compare and contrast the philosophies of Nietzsche and Kant, including their views on morality and human nature.",
"Imagine you're a travel blogger. Write a detailed post describing a week-long adventure through rural Japan.",
"Write a fictional letter from Albert Einstein to a modern-day physicist, discussing the current state of quantum mechanics.",
"Provide a comprehensive explanation of how transformers work in machine learning, including attention mechanisms and positional encoding.",
"Draft a business proposal for launching a new AI-powered productivity app, including target audience, key features, and a monetization strategy.",
"Simulate a panel discussion between Elon Musk, Marie Curie, and Sun Tzu on the topic of 'Leadership in Times of Crisis'.",
"Describe the process of photosynthesis in depth, and explain its importance in the global carbon cycle.",
"Analyze the impact of social media on political polarization, citing relevant studies or historical examples.",
"Write a short science fiction story where humans discover a parallel universe that operates under different physical laws.",
"Explain the role of the Federal Reserve in the U.S. economy and how it manages inflation and unemployment.",
"Describe the architecture of a modern web application, from frontend to backend, including databases, APIs, and deployment.",
"Write an essay discussing whether artificial general intelligence (AGI) poses an existential threat to humanity.",
"Summarize the key events and consequences of the Cuban Missile Crisis, and reflect on lessons for modern diplomacy.",
"Create a guide for beginners on how to train a custom LLM using open-source tools and publicly available datasets.",
]
async def process_stream(stream):
first_token_time = None
total_tokens = 0
async for chunk in stream:
if first_token_time is None:
first_token_time = time.time()
if chunk.choices[0].delta.content:
total_tokens += 1
if chunk.choices[0].finish_reason is not None:
break
return first_token_time, total_tokens
async def make_request(
client: AsyncOpenAI, model, max_completion_tokens, request_timeout
):
start_time = time.time()
content = random.choice(SAMPLE_PROMPTS)
try:
stream = await client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": content}],
max_completion_tokens=max_completion_tokens,
stream=True,
)
first_token_time, total_tokens = await asyncio.wait_for(
process_stream(stream), timeout=request_timeout
)
end_time = time.time()
elapsed_time = end_time - start_time
ttft = first_token_time - start_time if first_token_time else None
tokens_per_second = total_tokens / elapsed_time if elapsed_time > 0 else 0
return total_tokens, elapsed_time, tokens_per_second, ttft
except asyncio.TimeoutError:
logging.warning(f"Request timed out after {request_timeout} seconds")
return None
except Exception as e:
logging.error(f"Error during request: {str(e)}")
return None
async def worker(
client,
model,
semaphore,
queue,
results,
max_completion_tokens,
request_timeout,
):
while True:
async with semaphore:
task_id = await queue.get()
if task_id is None:
queue.task_done()
break
logging.info(f"Starting request {task_id}")
result = await make_request(
client, model, max_completion_tokens, request_timeout
)
if result:
results.append(result)
else:
logging.warning(f"Request {task_id} failed")
queue.task_done()
logging.info(f"Finished request {task_id}")
def calculate_percentile(values, percentile, reverse=False):
if not values:
return None
if reverse:
return numpy.percentile(values, 100 - percentile)
return numpy.percentile(values, percentile)
async def preflight_check(client, model) -> bool:
result = await make_request(client, model, 16, 5)
return result is not None
async def main(
model,
num_requests,
concurrency,
request_timeout,
max_completion_tokens,
server_url,
api_key,
):
client = AsyncOpenAI(
base_url=f"{server_url}/v1", api_key=api_key, http_client=http_client
)
if not await preflight_check(client, model):
logging.error(
"Preflight check failed. Please check configuration and the service status."
)
return
semaphore = asyncio.Semaphore(concurrency)
queue = asyncio.Queue()
results = []
# Add tasks to the queue
for i in range(num_requests):
await queue.put(i)
# Add sentinel values to stop workers
for _ in range(concurrency):
await queue.put(None)
# Create worker tasks
workers = [
asyncio.create_task(
worker(
client,
model,
semaphore,
queue,
results,
max_completion_tokens,
request_timeout,
)
)
for _ in range(concurrency)
]
start_time = time.time()
# Wait for all tasks to complete
await queue.join()
await asyncio.gather(*workers)
end_time = time.time()
# Calculate metrics
total_elapsed_time = end_time - start_time
total_tokens = sum(tokens for tokens, _, _, _ in results if tokens is not None)
latencies = [
elapsed_time for _, elapsed_time, _, _ in results if elapsed_time is not None
]
tokens_per_second_list = [tps for _, _, tps, _ in results if tps is not None]
ttft_list = [ttft for _, _, _, ttft in results if ttft is not None]
successful_requests = len(results)
requests_per_second = (
successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0
)
avg_latency = sum(latencies) / len(latencies) if latencies else 0
avg_tokens_per_second = (
sum(tokens_per_second_list) / len(tokens_per_second_list)
if tokens_per_second_list
else 0
)
avg_ttft = sum(ttft_list) / len(ttft_list) if ttft_list else 0
# Calculate percentiles
percentiles = [50, 95, 99]
latency_percentiles = [calculate_percentile(latencies, p) for p in percentiles]
tps_percentiles = [
calculate_percentile(tokens_per_second_list, p, reverse=True)
for p in percentiles
]
ttft_percentiles = [calculate_percentile(ttft_list, p) for p in percentiles]
return {
"model": model,
"total_requests": num_requests,
"successful_requests": successful_requests,
"concurrency": concurrency,
"request_timeout": request_timeout,
"max_completion_tokens": max_completion_tokens,
"total_time": total_elapsed_time,
"requests_per_second": requests_per_second,
"total_completion_tokens": total_tokens,
"latency": {
"average": avg_latency,
"p50": latency_percentiles[0],
"p95": latency_percentiles[1],
"p99": latency_percentiles[2],
},
"tokens_per_second": {
"average": avg_tokens_per_second,
"p50": tps_percentiles[0],
"p95": tps_percentiles[1],
"p99": tps_percentiles[2],
},
"time_to_first_token": {
"average": avg_ttft,
"p50": ttft_percentiles[0],
"p95": ttft_percentiles[1],
"p99": ttft_percentiles[2],
},
}
def output_results(results, result_file=None):
if result_file:
with open(result_file, "w") as f:
json.dump(results, f, indent=2)
logging.info(f"Results saved to {result_file}")
else:
print(json.dumps(results, indent=2))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark Chat Completions API")
parser.add_argument(
"-m", "--model", type=str, required=True, help="Name of the model"
)
parser.add_argument(
"-n",
"--num-requests",
type=int,
default=100,
help="Number of requests to make (default: 100)",
)
parser.add_argument(
"-c",
"--concurrency",
type=int,
default=10,
help="Number of concurrent requests (default: 10)",
)
parser.add_argument(
"--request-timeout",
type=int,
default=300,
help="Timeout for each request in seconds (default: 300)",
)
parser.add_argument(
"--max-completion-tokens",
type=int,
default=128,
help="Maximum number of tokens in the completion (default: 128)",
)
parser.add_argument(
"--server-url",
type=str,
default="http://127.0.0.1",
help="URL of the GPUStack server",
)
parser.add_argument("--api-key", type=str, default="fake", help="GPUStack API key")
parser.add_argument(
"--result-file",
type=str,
help="Result file path to save benchmark json results",
)
args = parser.parse_args()
results = asyncio.run(
main(
args.model,
args.num_requests,
args.concurrency,
args.request_timeout,
args.max_completion_tokens,
args.server_url,
args.api_key,
)
)
output_results(results, args.result_file)

@ -1,7 +1,7 @@
from datetime import datetime, timezone
import logging
from fastapi import Depends, Request
from gpustack.config.config import Config
from gpustack.schemas.api_keys import ApiKey
from gpustack.server.db import get_session
from typing import Annotated, Optional
from fastapi.security import (
@ -12,13 +12,20 @@ from fastapi.security import (
HTTPBearer,
)
from sqlmodel.ext.asyncio.session import AsyncSession
from gpustack.api.exceptions import ForbiddenException, UnauthorizedException
from gpustack.api.exceptions import (
ForbiddenException,
InternalServerErrorException,
UnauthorizedException,
)
from gpustack.schemas.users import User
from gpustack.security import (
API_KEY_PREFIX,
JWTManager,
verify_hashed_secret,
)
from gpustack.server.services import APIKeyService, UserService
logger = logging.getLogger(__name__)
SESSION_COOKIE_NAME = "gpustack_session"
SYSTEM_USER_PREFIX = "system/"
@ -58,8 +65,10 @@ async def get_current_user(
if user is None and request.client.host == "127.0.0.1":
server_config: Config = request.app.state.server_config
if not server_config.force_auth_localhost:
user = await User.first_by_field(session, "is_admin", True)
try:
user = await User.first_by_field(session, "is_admin", True)
except Exception as e:
raise InternalServerErrorException(message=f"Failed to get user: {e}")
if user:
request.state.user = user
return user
@ -124,14 +133,17 @@ async def get_user_from_jwt_token(
payload = jwt_manager.decode_jwt_token(access_token)
username = payload.get("sub")
except Exception:
logger.error("Failed to decode JWT token")
return None
if username is None:
return None
user = await User.one_by_field(session, "username", username)
if not user:
return None
try:
user = await UserService(session).get_by_username(username)
except Exception as e:
raise InternalServerErrorException(message=f"Failed to get user: {e}")
return user
@ -143,7 +155,7 @@ async def get_user_from_bearer_token(
if len(parts) == 3 and parts[0] == API_KEY_PREFIX:
access_key = parts[1]
secret_key = parts[2]
api_key = await ApiKey.one_by_field(session, "access_key", access_key)
api_key = await APIKeyService(session).get_by_access_key(access_key)
if (
api_key is not None
and verify_hashed_secret(api_key.hashed_secret_key, secret_key)
@ -152,11 +164,11 @@ async def get_user_from_bearer_token(
or api_key.expires_at > datetime.now(timezone.utc)
)
):
user = await User.one_by_id(session, api_key.user_id)
user = await UserService(session).get_by_id(api_key.user_id)
if user is not None:
return user
except Exception:
return None
except Exception as e:
raise InternalServerErrorException(message=f"Failed to get user: {e}")
return None
@ -164,7 +176,7 @@ async def get_user_from_bearer_token(
async def authenticate_user(
session: AsyncSession, username: str, password: str
) -> User:
user = await User.one_by_field(session, "username", username)
user = await UserService(session).get_by_username(username)
if not user:
raise UnauthorizedException(message="Incorrect username or password")

@ -27,6 +27,8 @@ from gpustack.api.auth import SESSION_COOKIE_NAME
from gpustack.server.db import get_engine
from sqlmodel.ext.asyncio.session import AsyncSession
from gpustack.server.services import ModelUsageService
logger = logging.getLogger(__name__)
@ -169,14 +171,14 @@ async def record_model_usage(
request_count=1,
)
async with AsyncSession(get_engine()) as session:
current_model_usage = await ModelUsage.one_by_fields(session, fields)
model_usage_service = ModelUsageService(session)
current_model_usage = await model_usage_service.get_by_fields(fields)
if current_model_usage:
current_model_usage.completion_token_count += completion_tokens
current_model_usage.prompt_token_count += prompt_tokens
current_model_usage.request_count += 1
await current_model_usage.update(session)
await model_usage_service.update(
current_model_usage, completion_tokens, prompt_tokens
)
else:
await ModelUsage.create(session, model_usage)
await model_usage_service.create(model_usage)
async def handle_streaming_response(

@ -39,6 +39,7 @@ def setup_logging(debug: bool = False):
"httpcore.proxy",
"httpx",
"asyncio",
"aiocache.base",
"aiosqlite",
"urllib3.connectionpool",
"multipart.multipart",

@ -3,6 +3,7 @@ import aiohttp
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import PlainTextResponse, StreamingResponse
from gpustack.server.services import ModelInstanceService
from gpustack.worker.logs import LogOptionsDep
from gpustack.api.exceptions import (
InternalServerErrorException,
@ -136,7 +137,7 @@ async def update_model_instance(
raise NotFoundException(message="Model instance not found")
try:
await model_instance.update(session, model_instance_in)
await ModelInstanceService(session).update(model_instance, model_instance_in)
except Exception as e:
raise InternalServerErrorException(
message=f"Failed to update model instance: {e}"
@ -152,7 +153,7 @@ async def delete_model_instance(session: SessionDep, id: int):
raise NotFoundException(message="Model instance not found")
try:
await model_instance.delete(session)
await ModelInstanceService(session).delete(model_instance)
except Exception as e:
raise InternalServerErrorException(
message=f"Failed to delete model instance: {e}"

@ -31,6 +31,7 @@ from gpustack.schemas.models import (
ModelPublic,
ModelsPublic,
)
from gpustack.server.services import ModelService
from gpustack.utils.command import find_parameter
from gpustack.utils.convert import safe_int
from gpustack.utils.gpu import parse_gpu_id
@ -281,7 +282,7 @@ async def update_model(session: SessionDep, id: int, model_in: ModelUpdate):
await validate_model_in(session, model_in)
try:
await model.update(session, model_in)
await ModelService(session).update(model, model_in)
except Exception as e:
raise InternalServerErrorException(message=f"Failed to update model: {e}")
@ -295,6 +296,6 @@ async def delete_model(session: SessionDep, id: int):
raise NotFoundException(message="Model not found")
try:
await model.delete(session)
await ModelService(session).delete(model)
except Exception as e:
raise InternalServerErrorException(message=f"Failed to delete model: {e}")

@ -24,11 +24,10 @@ from gpustack.routes.models import build_pg_category_condition
from gpustack.schemas.models import (
CategoryEnum,
Model,
ModelInstance,
ModelInstanceStateEnum,
)
from gpustack.schemas.workers import Worker
from gpustack.server.db import get_session_context
from gpustack.server.deps import SessionDep
from gpustack.server.services import ModelInstanceService, ModelService, WorkerService
logger = logging.getLogger(__name__)
@ -40,38 +39,38 @@ aliasable_router = APIRouter()
@aliasable_router.post("/chat/completions")
async def chat_completions(session: SessionDep, request: Request):
return await proxy_request_by_model(request, session, "chat/completions")
async def chat_completions(request: Request):
return await proxy_request_by_model(request, "chat/completions")
@aliasable_router.post("/completions")
async def completions(session: SessionDep, request: Request):
return await proxy_request_by_model(request, session, "completions")
async def completions(request: Request):
return await proxy_request_by_model(request, "completions")
@aliasable_router.post("/embeddings")
async def embeddings(session: SessionDep, request: Request):
return await proxy_request_by_model(request, session, "embeddings")
async def embeddings(request: Request):
return await proxy_request_by_model(request, "embeddings")
@aliasable_router.post("/images/generations")
async def images_generations(session: SessionDep, request: Request):
return await proxy_request_by_model(request, session, "images/generations")
async def images_generations(request: Request):
return await proxy_request_by_model(request, "images/generations")
@aliasable_router.post("/images/edits")
async def images_edits(session: SessionDep, request: Request):
return await proxy_request_by_model(request, session, "images/edits")
async def images_edits(request: Request):
return await proxy_request_by_model(request, "images/edits")
@aliasable_router.post("/audio/speech")
async def audio_speech(session: SessionDep, request: Request):
return await proxy_request_by_model(request, session, "audio/speech")
async def audio_speech(request: Request):
return await proxy_request_by_model(request, "audio/speech")
@aliasable_router.post("/audio/transcriptions")
async def audio_transcriptions(session: SessionDep, request: Request):
return await proxy_request_by_model(request, session, "audio/transcriptions")
async def audio_transcriptions(request: Request):
return await proxy_request_by_model(request, "audio/transcriptions")
router = APIRouter()
@ -167,31 +166,32 @@ async def list_models(
return result
async def proxy_request_by_model(request: Request, session: SessionDep, endpoint: str):
async def proxy_request_by_model(request: Request, endpoint: str):
"""
Proxy the request to the model instance that is running the model specified in the
request body.
"""
model, stream, body_json, form_data, form_files = await parse_request_body(
request, session
)
if not model:
raise NotFoundException(
message="Model not found",
is_openai_exception=True,
async with get_session_context() as session:
model, stream, body_json, form_data, form_files = await parse_request_body(
request, session
)
request.state.model = model
request.state.stream = stream
if not model:
raise NotFoundException(
message="Model not found",
is_openai_exception=True,
)
instance = await get_running_instance(session, model.id)
worker = await Worker.one_by_id(session, instance.worker_id)
if not worker:
raise InternalServerErrorException(
message=f"Worker with ID {instance.worker_id} not found",
is_openai_exception=True,
)
request.state.model = model
request.state.stream = stream
instance = await get_running_instance(session, model.id)
worker = await WorkerService(session).get_by_id(instance.worker_id)
if not worker:
raise InternalServerErrorException(
message=f"Worker with ID {instance.worker_id} not found",
is_openai_exception=True,
)
url = f"http://{instance.worker_ip}:{worker.port}/proxy/v1/{endpoint}"
token = request.app.state.server_config.token
@ -254,7 +254,7 @@ async def parse_request_body(request: Request, session: SessionDep):
# stream may be set in form data, e.g., image edits.
stream = True
model = await get_model(session, model_name)
model = await ModelService(session).get_by_name(model_name)
return model, stream, body_json, form_data, form_files
@ -294,10 +294,6 @@ async def parse_json_body(request: Request):
)
async def get_model(session: SessionDep, model_name: Optional[str]):
return await Model.one_by_field(session=session, field="name", value=model_name)
async def handle_streaming_request(
request: Request,
url: str,
@ -405,12 +401,9 @@ def filter_headers(headers):
async def get_running_instance(session: AsyncSession, model_id: int):
model_instances = await ModelInstance.all_by_field(
session=session, field="model_id", value=model_id
running_instances = await ModelInstanceService(session).get_running_instances(
model_id
)
running_instances = [
inst for inst in model_instances if inst.state == ModelInstanceStateEnum.RUNNING
]
if not running_instances:
raise ServiceUnavailableException(
message="No running instances available",

@ -14,6 +14,7 @@ from gpustack.schemas.workers import (
WorkersPublic,
Worker,
)
from gpustack.server.services import WorkerService
router = APIRouter()
@ -77,7 +78,7 @@ async def update_worker(session: SessionDep, id: int, worker_in: WorkerUpdate):
try:
worker_in.compute_state()
await worker.update(session, worker_in)
await WorkerService(session).update(worker, worker_in)
except Exception as e:
raise InternalServerErrorException(message=f"Failed to update worker: {e}")
@ -91,6 +92,6 @@ async def delete_worker(session: SessionDep, id: int):
raise NotFoundException(message="worker not found")
try:
await worker.delete(session)
await WorkerService(session).delete(worker)
except Exception as e:
raise InternalServerErrorException(message=f"Failed to delete worker: {e}")

@ -50,6 +50,7 @@ from gpustack.scheduler.calculator import (
GPUOffloadEnum,
calculate_model_resource_claim,
)
from gpustack.server.services import ModelInstanceService, ModelService
from gpustack.utils.gpu import parse_gpu_ids_by_worker
from gpustack.utils.hub import get_pretrained_config
from gpustack.utils.task import run_in_thread
@ -142,7 +143,7 @@ class Scheduler:
if instance.state != ModelInstanceStateEnum.ANALYZING:
instance.state = ModelInstanceStateEnum.ANALYZING
instance.state_message = "Evaluating resource requirements"
await instance.update(session)
await ModelInstanceService(session).update(instance)
if model.source == SourceEnum.LOCAL_PATH and not os.path.exists(
model.local_path
@ -166,14 +167,14 @@ class Scheduler:
should_update_model = await evaluate_pretrained_config(model)
if should_update_model:
await model.update(session)
await ModelService(session).update(model)
await self._queue.put(instance)
except Exception as e:
try:
instance.state = ModelInstanceStateEnum.ERROR
instance.state_message = str(e)
await instance.update(session)
await ModelInstanceService(session).update(instance)
except Exception as ue:
logger.error(
f"Failed to update model instance: {ue}. Original error: {e}"
@ -193,7 +194,7 @@ class Scheduler:
instance.state_message = (
"The model is not distributable to multiple workers."
)
await instance.update(session)
await ModelInstanceService(session).update(instance)
return True
return False
@ -282,7 +283,7 @@ class Scheduler:
if state_message != "":
model_instance.state_message = state_message
await model_instance.update(session, model_instance)
await ModelInstanceService(session).update(model_instance)
logger.debug(
f"No suitable workers for model instance {model_instance.name}, state: {model_instance.state}"
)
@ -302,7 +303,7 @@ class Scheduler:
ray_actors=candidate.ray_actors,
)
await model_instance.update(session, model_instance)
await ModelInstanceService(session).update(model_instance)
logger.debug(
f"Scheduled model instance {model_instance.name} to worker "

@ -27,6 +27,7 @@ from gpustack.schemas.models import (
from gpustack.schemas.workers import Worker, WorkerStateEnum
from gpustack.server.bus import Event, EventType
from gpustack.server.db import get_engine
from gpustack.server.services import ModelInstanceService, ModelService
logger = logging.getLogger(__name__)
@ -131,7 +132,7 @@ async def set_default_worker_selector(session: AsyncSession, model: Model):
):
# vLLM models are only supported on Linux
model.worker_selector = {"os": "linux"}
await model.update(session)
await ModelService(session).update(model)
async def sync_replicas(session: AsyncSession, model: Model, cfg: Config):
@ -162,7 +163,7 @@ async def sync_replicas(session: AsyncSession, model: Model, cfg: Config):
state=ModelInstanceStateEnum.PENDING,
)
await ModelInstance.create(session, instance)
await ModelInstanceService(session).create(instance)
logger.debug(f"Created model instance for model {model.name}")
elif len(instances) > model.replicas:
@ -172,7 +173,7 @@ async def sync_replicas(session: AsyncSession, model: Model, cfg: Config):
if scale_down_count > 0:
for candidate in candidates[:scale_down_count]:
instance = candidate.model_instance
await instance.delete(session)
await ModelInstanceService(session).delete(instance)
logger.debug(f"Deleted model instance {instance.name}")
@ -346,7 +347,7 @@ async def sync_ready_replicas(session: AsyncSession, model: Model):
if model.ready_replicas != ready_replicas:
model.ready_replicas = ready_replicas
await model.update(session)
await ModelService(session).update(model)
async def get_meta_from_running_instance(mi: ModelInstance) -> Dict[str, Any]:
@ -434,7 +435,7 @@ class WorkerController:
):
instance_names = [instance.name for instance in instances]
for instance in instances:
await instance.delete(session)
await ModelInstanceService(session).delete(instance)
if instance_names:
state = (
@ -474,7 +475,7 @@ class WorkerController:
instance.state = new_state
instance.state_message = new_state_message
await instance.update(session)
await ModelInstanceService(session).update(instance)
if instance_names:
logger.debug(
f"Marked instance {', '.join(instance_names)} {new_state} "
@ -544,6 +545,7 @@ async def sync_main_model_file_state(
f"progress: {file.download_progress}, message: {file.state_message}, instance state: {instance.state}"
)
need_update = False
if (
file.state == ModelFileStateEnum.DOWNLOADING
and instance.state == ModelInstanceStateEnum.INITIALIZING
@ -551,7 +553,8 @@ async def sync_main_model_file_state(
# Download started
instance.state = ModelInstanceStateEnum.DOWNLOADING
instance.download_progress = 0
await instance.update(session)
instance.state_message = ""
need_update = True
elif (
file.state == ModelFileStateEnum.DOWNLOADING
and instance.state == ModelInstanceStateEnum.DOWNLOADING
@ -559,7 +562,7 @@ async def sync_main_model_file_state(
):
# Update the download progress
instance.download_progress = file.download_progress
await instance.update(session)
need_update = True
elif file.state == ModelFileStateEnum.READY and (
instance.state == ModelInstanceStateEnum.DOWNLOADING
@ -571,12 +574,15 @@ async def sync_main_model_file_state(
if model_instance_download_completed(instance):
# All files are downloaded
instance.state = ModelInstanceStateEnum.STARTING
await instance.update(session)
need_update = True
elif file.state == ModelFileStateEnum.ERROR:
# Download failed
instance.state = ModelInstanceStateEnum.ERROR
instance.state_message = file.state_message
await instance.update(session)
need_update = True
if need_update:
await ModelInstanceService(session).update(instance)
async def sync_distributed_model_file_state(
@ -623,7 +629,7 @@ async def sync_distributed_model_file_state(
if need_update:
instance.distributed_servers.ray_actors = ray_actors
flag_modified(instance, "distributed_servers")
await instance.update(session)
await ModelInstanceService(session).update(instance)
def model_instance_download_completed(instance: ModelInstance):

@ -1,3 +1,5 @@
from contextlib import asynccontextmanager
import os
import re
from sqlalchemy.ext.asyncio import (
AsyncEngine,
@ -22,6 +24,11 @@ from gpustack.schemas.stmt import (
_engine = None
DB_ECHO = os.getenv("GPUSTACK_DB_ECHO", "false").lower() == "true"
DB_POOL_SIZE = int(os.getenv("GPUSTACK_DB_POOL_SIZE", 5))
DB_MAX_OVERFLOW = int(os.getenv("GPUSTACK_DB_MAX_OVERFLOW", 10))
DB_POOL_TIMEOUT = int(os.getenv("GPUSTACK_DB_POOL_TIMEOUT", 30))
def get_engine():
return _engine
@ -32,6 +39,12 @@ async def get_session():
yield session
@asynccontextmanager
async def get_session_context():
async with AsyncSession(_engine) as session:
yield session
async def init_db(db_url: str):
global _engine, _session_maker
if _engine is None:
@ -45,7 +58,14 @@ async def init_db(db_url: str):
else:
raise Exception(f"Unsupported database URL: {db_url}")
_engine = create_async_engine(db_url, echo=False, connect_args=connect_args)
_engine = create_async_engine(
db_url,
echo=DB_ECHO,
pool_size=DB_POOL_SIZE,
max_overflow=DB_MAX_OVERFLOW,
pool_timeout=DB_POOL_TIMEOUT,
connect_args=connect_args,
)
listen_events(_engine)
await create_db_and_tables(_engine)

@ -23,6 +23,7 @@ from gpustack.scheduler.scheduler import Scheduler
from gpustack.ray.manager import RayManager
from gpustack.server.system_load import SystemLoadCollector
from gpustack.server.update_check import UpdateChecker
from gpustack.server.usage_buffer import flush_usage_to_db
from gpustack.server.worker_syncer import WorkerSyncer
from gpustack.utils.process import add_signal_handlers_in_loop
@ -65,6 +66,7 @@ class Server:
self._start_system_load_collector()
self._start_worker_syncer()
self._start_update_checker()
self._start_model_usage_flusher()
self._start_ray()
port = 80
@ -170,6 +172,11 @@ class Server:
logger.debug("Worker syncer started.")
def _start_model_usage_flusher(self):
self._create_async_task(flush_usage_to_db())
logger.debug("Model usage flusher started.")
def _start_update_checker(self):
if self._config.disable_update_check:
return

@ -0,0 +1,218 @@
import asyncio
import logging
from typing import Any, Callable, Dict, List, Optional, Union
from aiocache import Cache, cached
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
from gpustack.schemas.api_keys import ApiKey
from gpustack.schemas.model_usage import ModelUsage
from gpustack.schemas.models import Model, ModelInstance, ModelInstanceStateEnum
from gpustack.schemas.users import User
from gpustack.schemas.workers import Worker
from gpustack.server.usage_buffer import usage_flush_buffer
logger = logging.getLogger(__name__)
cache = Cache(Cache.MEMORY)
def build_cache_key(func: Callable, *args, **kwargs):
if kwargs is None:
kwargs = {}
ordered_kwargs = sorted(kwargs.items())
return func.__qualname__ + str(args) + str(ordered_kwargs)
async def delete_cache_by_key(func, *args, **kwargs):
key = build_cache_key(func, *args, **kwargs)
logger.trace(f"Deleting cache for key: {key}")
await cache.delete(key)
async def set_cache_by_key(key: str, value: Any):
logger.trace(f"Set cache for key: {key}")
await cache.set(key, value)
_cache_locks: Dict[str, asyncio.Lock] = {}
class locked_cached(cached):
async def decorator(self, f, *args, **kwargs):
# no self arg
key = build_cache_key(f, *args[1:], **kwargs)
value = await self.get_from_cache(key)
if value is not None:
return value
lock = _cache_locks.setdefault(key, asyncio.Lock())
async with lock:
value = await self.get_from_cache(key)
if value is not None:
return value
logger.trace(f"cache miss for key: {key}")
result = await f(*args, **kwargs)
await self.set_in_cache(key, result)
return result
class UserService:
def __init__(self, session: AsyncSession):
self.session = session
@locked_cached(ttl=60)
async def get_by_id(self, user_id: int) -> Optional[User]:
return await User.one_by_id(self.session, user_id)
@locked_cached(ttl=60)
async def get_by_username(self, username: str) -> Optional[User]:
return await User.one_by_field(self.session, "username", username)
async def create(self, user: User):
return await User.create(self.session, user)
async def update(self, user: User, source: Union[dict, SQLModel, None] = None):
result = await user.update(self.session, source)
await delete_cache_by_key(self.get_by_id, user.id)
await delete_cache_by_key(self.get_by_username, user.username)
return result
async def delete(self, user: User):
result = await user.delete(self.session)
await delete_cache_by_key(self.get_by_id, user.id)
await delete_cache_by_key(self.get_by_username, user.username)
return result
class APIKeyService:
def __init__(self, session: AsyncSession):
self.session = session
@locked_cached(ttl=60)
async def get_by_access_key(self, access_key: str) -> Optional[ApiKey]:
return await ApiKey.one_by_field(self.session, "access_key", access_key)
async def delete(self, api_key: ApiKey):
result = await api_key.delete(self.session)
await delete_cache_by_key(self.get_by_access_key, api_key.access_key)
return result
class WorkerService:
def __init__(self, session: AsyncSession):
self.session = session
@locked_cached(ttl=60)
async def get_by_id(self, worker_id: int) -> Optional[Worker]:
return await Worker.one_by_id(self.session, worker_id)
@locked_cached(ttl=60)
async def get_by_name(self, name: str) -> Optional[Worker]:
return await Worker.one_by_field(self.session, "name", name)
async def update(self, worker: Worker, source: Union[dict, SQLModel, None] = None):
result = await worker.update(self.session, source)
await delete_cache_by_key(self.get_by_id, worker.id)
await delete_cache_by_key(self.get_by_name, worker.name)
return result
async def delete(self, worker: Worker):
result = await worker.delete(self.session)
await delete_cache_by_key(self.get_by_id, worker.id)
await delete_cache_by_key(self.get_by_name, worker.name)
return result
class ModelService:
def __init__(self, session: AsyncSession):
self.session = session
@locked_cached(ttl=60)
async def get_by_id(self, model_id: int) -> Optional[Model]:
return await Model.one_by_id(self.session, model_id)
@locked_cached(ttl=60)
async def get_by_name(self, name: str) -> Optional[Model]:
return await Model.one_by_field(self.session, "name", name)
async def update(self, model: Model, source: Union[dict, SQLModel, None] = None):
result = await model.update(self.session, source)
await delete_cache_by_key(self.get_by_id, model.id)
await delete_cache_by_key(self.get_by_name, model.name)
return result
async def delete(self, model: Model):
result = await model.delete(self.session)
await delete_cache_by_key(self.get_by_id, model.id)
await delete_cache_by_key(self.get_by_name, model.name)
return result
class ModelInstanceService:
def __init__(self, session: AsyncSession):
self.session = session
@locked_cached(ttl=60)
async def get_running_instances(self, model_id: int) -> List[ModelInstance]:
return await ModelInstance.all_by_fields(
self.session,
fields={"model_id": model_id, "state": ModelInstanceStateEnum.RUNNING},
)
async def create(self, model_instance):
result = await ModelInstance.create(self.session, model_instance)
await delete_cache_by_key(self.get_running_instances, model_instance.model_id)
return result
async def update(
self, model_instance: ModelInstance, source: Union[dict, SQLModel, None] = None
):
result = await model_instance.update(self.session, source)
await delete_cache_by_key(self.get_running_instances, model_instance.model_id)
return result
async def delete(self, model_instance: ModelInstance):
result = await model_instance.delete(self.session)
await delete_cache_by_key(self.get_running_instances, model_instance.model_id)
return result
class ModelUsageService:
def __init__(self, session: AsyncSession):
self.session = session
@locked_cached(ttl=60)
async def get_by_fields(self, fields: dict) -> ModelUsage:
return await ModelUsage.one_by_fields(
self.session,
fields=fields,
)
async def create(self, model_usage: ModelUsage):
return await ModelUsage.create(self.session, model_usage)
async def update(
self,
model_usage: ModelUsage,
completion_token_count: int,
prompt_token_count: int,
):
model_usage.completion_token_count += completion_token_count
model_usage.prompt_token_count += prompt_token_count
model_usage.request_count += 1
key = build_cache_key(
self.get_by_fields,
model_usage.user_id,
model_usage.model_id,
model_usage.operation,
model_usage.date,
)
await set_cache_by_key(key, model_usage)
usage_flush_buffer[key] = model_usage
return model_usage

@ -0,0 +1,39 @@
import asyncio
import logging
from typing import Dict
from sqlmodel.ext.asyncio.session import AsyncSession
from gpustack.schemas.model_usage import ModelUsage
from gpustack.server.db import get_engine
logger = logging.getLogger(__name__)
usage_flush_buffer: Dict[str, ModelUsage] = {}
async def flush_usage_to_db():
"""
Flush model usage records to the database periodically.
"""
while True:
await asyncio.sleep(5)
if not usage_flush_buffer:
continue
local_buffer = dict(usage_flush_buffer)
usage_flush_buffer.clear()
try:
async with AsyncSession(get_engine()) as session:
for key, usage in local_buffer.items():
to_update = await ModelUsage.one_by_id(session, usage.id)
to_update.prompt_token_count = usage.prompt_token_count
to_update.completion_token_count = usage.completion_token_count
to_update.request_count = usage.request_count
session.add(to_update)
await session.commit()
logger.debug(f"Flushed {len(local_buffer)} usage records to DB")
except Exception as e:
logger.error(f"Error flushing usage to DB: {e}")

@ -3,6 +3,7 @@ import logging
from sqlmodel.ext.asyncio.session import AsyncSession
from gpustack.schemas.workers import Worker, WorkerStateEnum
from gpustack.server.db import get_engine
from gpustack.server.services import WorkerService
from gpustack.utils.network import is_url_reachable
logger = logging.getLogger(__name__)
@ -55,7 +56,7 @@ class WorkerSyncer:
state_to_worker_name[worker.state].append(worker.name)
for worker in should_update_workers:
await worker.update(session, worker)
await WorkerService(session).update(worker)
for state, worker_names in state_to_worker_name.items():
if worker_names:

18
poetry.lock generated

@ -53,6 +53,22 @@ files = [
{file = "addict-2.4.0.tar.gz", hash = "sha256:b3b2210e0e067a281f5646c8c5db92e99b7231ea8b0eb5f74dbdf9e259d4e494"},
]
[[package]]
name = "aiocache"
version = "0.12.3"
description = "multi backend asyncio cache"
optional = false
python-versions = "*"
files = [
{file = "aiocache-0.12.3-py2.py3-none-any.whl", hash = "sha256:889086fc24710f431937b87ad3720a289f7fc31c4fd8b68e9f918b9bacd8270d"},
{file = "aiocache-0.12.3.tar.gz", hash = "sha256:f528b27bf4d436b497a1d0d1a8f59a542c153ab1e37c3621713cb376d44c4713"},
]
[package.extras]
memcached = ["aiomcache (>=0.5.2)"]
msgpack = ["msgpack (>=0.5.5)"]
redis = ["redis (>=4.2.0)"]
[[package]]
name = "aiofiles"
version = "23.2.1"
@ -8867,4 +8883,4 @@ vllm = ["bitsandbytes", "mistral_common", "vllm"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "446f471bce8d67056e9e79b21b04b56e5446804693de22fe9ebdbb5029144436"
content-hash = "c07930551e489277eb3327bb398dac6d0106ed7ed5180f04760a7d2de950eac5"

@ -51,6 +51,7 @@ psycopg2-binary = "^2.9.10"
vox-box = {version = "0.0.11", optional = true}
tenacity = "^9.0.0"
aiocache = "^0.12.3"
aiofiles = "^23.2.1"
aiohttp = "^3.11.2"
bitsandbytes = {version = "^0.45.2", optional = true}

Loading…
Cancel
Save