diff --git a/benchmarks/benchmark_llm.py b/benchmarks/benchmark_llm.py new file mode 100644 index 0000000..165869a --- /dev/null +++ b/benchmarks/benchmark_llm.py @@ -0,0 +1,307 @@ +import asyncio +import time +import httpx +import numpy +import logging +import argparse +import json +import random +from openai import AsyncOpenAI + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +# Avoid client side connection error: https://github.com/encode/httpx/discussions/3084 +http_client = httpx.AsyncClient( + limits=httpx.Limits( + max_connections=10000, max_keepalive_connections=10000, keepalive_expiry=30 + ) +) + +SAMPLE_PROMPTS = [ + "Explain how blockchain technology works, and provide a real-world example of its application outside of cryptocurrency.", + "Compare and contrast the philosophies of Nietzsche and Kant, including their views on morality and human nature.", + "Imagine you're a travel blogger. Write a detailed post describing a week-long adventure through rural Japan.", + "Write a fictional letter from Albert Einstein to a modern-day physicist, discussing the current state of quantum mechanics.", + "Provide a comprehensive explanation of how transformers work in machine learning, including attention mechanisms and positional encoding.", + "Draft a business proposal for launching a new AI-powered productivity app, including target audience, key features, and a monetization strategy.", + "Simulate a panel discussion between Elon Musk, Marie Curie, and Sun Tzu on the topic of 'Leadership in Times of Crisis'.", + "Describe the process of photosynthesis in depth, and explain its importance in the global carbon cycle.", + "Analyze the impact of social media on political polarization, citing relevant studies or historical examples.", + "Write a short science fiction story where humans discover a parallel universe that operates under different physical laws.", + "Explain the role of the Federal Reserve in the U.S. economy and how it manages inflation and unemployment.", + "Describe the architecture of a modern web application, from frontend to backend, including databases, APIs, and deployment.", + "Write an essay discussing whether artificial general intelligence (AGI) poses an existential threat to humanity.", + "Summarize the key events and consequences of the Cuban Missile Crisis, and reflect on lessons for modern diplomacy.", + "Create a guide for beginners on how to train a custom LLM using open-source tools and publicly available datasets.", +] + + +async def process_stream(stream): + first_token_time = None + total_tokens = 0 + async for chunk in stream: + if first_token_time is None: + first_token_time = time.time() + if chunk.choices[0].delta.content: + total_tokens += 1 + if chunk.choices[0].finish_reason is not None: + break + return first_token_time, total_tokens + + +async def make_request( + client: AsyncOpenAI, model, max_completion_tokens, request_timeout +): + start_time = time.time() + content = random.choice(SAMPLE_PROMPTS) + + try: + stream = await client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": content}], + max_completion_tokens=max_completion_tokens, + stream=True, + ) + first_token_time, total_tokens = await asyncio.wait_for( + process_stream(stream), timeout=request_timeout + ) + + end_time = time.time() + elapsed_time = end_time - start_time + ttft = first_token_time - start_time if first_token_time else None + tokens_per_second = total_tokens / elapsed_time if elapsed_time > 0 else 0 + return total_tokens, elapsed_time, tokens_per_second, ttft + + except asyncio.TimeoutError: + logging.warning(f"Request timed out after {request_timeout} seconds") + return None + except Exception as e: + logging.error(f"Error during request: {str(e)}") + return None + + +async def worker( + client, + model, + semaphore, + queue, + results, + max_completion_tokens, + request_timeout, +): + while True: + async with semaphore: + task_id = await queue.get() + if task_id is None: + queue.task_done() + break + logging.info(f"Starting request {task_id}") + result = await make_request( + client, model, max_completion_tokens, request_timeout + ) + if result: + results.append(result) + else: + logging.warning(f"Request {task_id} failed") + queue.task_done() + logging.info(f"Finished request {task_id}") + + +def calculate_percentile(values, percentile, reverse=False): + if not values: + return None + if reverse: + return numpy.percentile(values, 100 - percentile) + return numpy.percentile(values, percentile) + + +async def preflight_check(client, model) -> bool: + result = await make_request(client, model, 16, 5) + return result is not None + + +async def main( + model, + num_requests, + concurrency, + request_timeout, + max_completion_tokens, + server_url, + api_key, +): + client = AsyncOpenAI( + base_url=f"{server_url}/v1", api_key=api_key, http_client=http_client + ) + + if not await preflight_check(client, model): + logging.error( + "Preflight check failed. Please check configuration and the service status." + ) + return + + semaphore = asyncio.Semaphore(concurrency) + queue = asyncio.Queue() + results = [] + + # Add tasks to the queue + for i in range(num_requests): + await queue.put(i) + + # Add sentinel values to stop workers + for _ in range(concurrency): + await queue.put(None) + + # Create worker tasks + workers = [ + asyncio.create_task( + worker( + client, + model, + semaphore, + queue, + results, + max_completion_tokens, + request_timeout, + ) + ) + for _ in range(concurrency) + ] + + start_time = time.time() + + # Wait for all tasks to complete + await queue.join() + await asyncio.gather(*workers) + + end_time = time.time() + + # Calculate metrics + total_elapsed_time = end_time - start_time + total_tokens = sum(tokens for tokens, _, _, _ in results if tokens is not None) + latencies = [ + elapsed_time for _, elapsed_time, _, _ in results if elapsed_time is not None + ] + tokens_per_second_list = [tps for _, _, tps, _ in results if tps is not None] + ttft_list = [ttft for _, _, _, ttft in results if ttft is not None] + + successful_requests = len(results) + requests_per_second = ( + successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0 + ) + avg_latency = sum(latencies) / len(latencies) if latencies else 0 + avg_tokens_per_second = ( + sum(tokens_per_second_list) / len(tokens_per_second_list) + if tokens_per_second_list + else 0 + ) + avg_ttft = sum(ttft_list) / len(ttft_list) if ttft_list else 0 + + # Calculate percentiles + percentiles = [50, 95, 99] + latency_percentiles = [calculate_percentile(latencies, p) for p in percentiles] + tps_percentiles = [ + calculate_percentile(tokens_per_second_list, p, reverse=True) + for p in percentiles + ] + ttft_percentiles = [calculate_percentile(ttft_list, p) for p in percentiles] + + return { + "model": model, + "total_requests": num_requests, + "successful_requests": successful_requests, + "concurrency": concurrency, + "request_timeout": request_timeout, + "max_completion_tokens": max_completion_tokens, + "total_time": total_elapsed_time, + "requests_per_second": requests_per_second, + "total_completion_tokens": total_tokens, + "latency": { + "average": avg_latency, + "p50": latency_percentiles[0], + "p95": latency_percentiles[1], + "p99": latency_percentiles[2], + }, + "tokens_per_second": { + "average": avg_tokens_per_second, + "p50": tps_percentiles[0], + "p95": tps_percentiles[1], + "p99": tps_percentiles[2], + }, + "time_to_first_token": { + "average": avg_ttft, + "p50": ttft_percentiles[0], + "p95": ttft_percentiles[1], + "p99": ttft_percentiles[2], + }, + } + + +def output_results(results, result_file=None): + if result_file: + with open(result_file, "w") as f: + json.dump(results, f, indent=2) + logging.info(f"Results saved to {result_file}") + else: + print(json.dumps(results, indent=2)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark Chat Completions API") + parser.add_argument( + "-m", "--model", type=str, required=True, help="Name of the model" + ) + parser.add_argument( + "-n", + "--num-requests", + type=int, + default=100, + help="Number of requests to make (default: 100)", + ) + parser.add_argument( + "-c", + "--concurrency", + type=int, + default=10, + help="Number of concurrent requests (default: 10)", + ) + parser.add_argument( + "--request-timeout", + type=int, + default=300, + help="Timeout for each request in seconds (default: 300)", + ) + parser.add_argument( + "--max-completion-tokens", + type=int, + default=128, + help="Maximum number of tokens in the completion (default: 128)", + ) + parser.add_argument( + "--server-url", + type=str, + default="http://127.0.0.1", + help="URL of the GPUStack server", + ) + parser.add_argument("--api-key", type=str, default="fake", help="GPUStack API key") + parser.add_argument( + "--result-file", + type=str, + help="Result file path to save benchmark json results", + ) + args = parser.parse_args() + + results = asyncio.run( + main( + args.model, + args.num_requests, + args.concurrency, + args.request_timeout, + args.max_completion_tokens, + args.server_url, + args.api_key, + ) + ) + output_results(results, args.result_file) diff --git a/gpustack/api/auth.py b/gpustack/api/auth.py index 5e55226..9ec7759 100644 --- a/gpustack/api/auth.py +++ b/gpustack/api/auth.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone +import logging from fastapi import Depends, Request from gpustack.config.config import Config -from gpustack.schemas.api_keys import ApiKey from gpustack.server.db import get_session from typing import Annotated, Optional from fastapi.security import ( @@ -12,13 +12,20 @@ from fastapi.security import ( HTTPBearer, ) from sqlmodel.ext.asyncio.session import AsyncSession -from gpustack.api.exceptions import ForbiddenException, UnauthorizedException +from gpustack.api.exceptions import ( + ForbiddenException, + InternalServerErrorException, + UnauthorizedException, +) from gpustack.schemas.users import User from gpustack.security import ( API_KEY_PREFIX, JWTManager, verify_hashed_secret, ) +from gpustack.server.services import APIKeyService, UserService + +logger = logging.getLogger(__name__) SESSION_COOKIE_NAME = "gpustack_session" SYSTEM_USER_PREFIX = "system/" @@ -58,8 +65,10 @@ async def get_current_user( if user is None and request.client.host == "127.0.0.1": server_config: Config = request.app.state.server_config if not server_config.force_auth_localhost: - user = await User.first_by_field(session, "is_admin", True) - + try: + user = await User.first_by_field(session, "is_admin", True) + except Exception as e: + raise InternalServerErrorException(message=f"Failed to get user: {e}") if user: request.state.user = user return user @@ -124,14 +133,17 @@ async def get_user_from_jwt_token( payload = jwt_manager.decode_jwt_token(access_token) username = payload.get("sub") except Exception: + logger.error("Failed to decode JWT token") return None if username is None: return None - user = await User.one_by_field(session, "username", username) - if not user: - return None + try: + user = await UserService(session).get_by_username(username) + except Exception as e: + raise InternalServerErrorException(message=f"Failed to get user: {e}") + return user @@ -143,7 +155,7 @@ async def get_user_from_bearer_token( if len(parts) == 3 and parts[0] == API_KEY_PREFIX: access_key = parts[1] secret_key = parts[2] - api_key = await ApiKey.one_by_field(session, "access_key", access_key) + api_key = await APIKeyService(session).get_by_access_key(access_key) if ( api_key is not None and verify_hashed_secret(api_key.hashed_secret_key, secret_key) @@ -152,11 +164,11 @@ async def get_user_from_bearer_token( or api_key.expires_at > datetime.now(timezone.utc) ) ): - user = await User.one_by_id(session, api_key.user_id) + user = await UserService(session).get_by_id(api_key.user_id) if user is not None: return user - except Exception: - return None + except Exception as e: + raise InternalServerErrorException(message=f"Failed to get user: {e}") return None @@ -164,7 +176,7 @@ async def get_user_from_bearer_token( async def authenticate_user( session: AsyncSession, username: str, password: str ) -> User: - user = await User.one_by_field(session, "username", username) + user = await UserService(session).get_by_username(username) if not user: raise UnauthorizedException(message="Incorrect username or password") diff --git a/gpustack/api/middlewares.py b/gpustack/api/middlewares.py index 3ae17ee..f5aadc8 100644 --- a/gpustack/api/middlewares.py +++ b/gpustack/api/middlewares.py @@ -27,6 +27,8 @@ from gpustack.api.auth import SESSION_COOKIE_NAME from gpustack.server.db import get_engine from sqlmodel.ext.asyncio.session import AsyncSession +from gpustack.server.services import ModelUsageService + logger = logging.getLogger(__name__) @@ -169,14 +171,14 @@ async def record_model_usage( request_count=1, ) async with AsyncSession(get_engine()) as session: - current_model_usage = await ModelUsage.one_by_fields(session, fields) + model_usage_service = ModelUsageService(session) + current_model_usage = await model_usage_service.get_by_fields(fields) if current_model_usage: - current_model_usage.completion_token_count += completion_tokens - current_model_usage.prompt_token_count += prompt_tokens - current_model_usage.request_count += 1 - await current_model_usage.update(session) + await model_usage_service.update( + current_model_usage, completion_tokens, prompt_tokens + ) else: - await ModelUsage.create(session, model_usage) + await model_usage_service.create(model_usage) async def handle_streaming_response( diff --git a/gpustack/logging.py b/gpustack/logging.py index bb26d4c..7466dc2 100644 --- a/gpustack/logging.py +++ b/gpustack/logging.py @@ -39,6 +39,7 @@ def setup_logging(debug: bool = False): "httpcore.proxy", "httpx", "asyncio", + "aiocache.base", "aiosqlite", "urllib3.connectionpool", "multipart.multipart", diff --git a/gpustack/routes/model_instances.py b/gpustack/routes/model_instances.py index f158863..7e6bd8b 100644 --- a/gpustack/routes/model_instances.py +++ b/gpustack/routes/model_instances.py @@ -3,6 +3,7 @@ import aiohttp from fastapi import APIRouter, HTTPException, Request from fastapi.responses import PlainTextResponse, StreamingResponse +from gpustack.server.services import ModelInstanceService from gpustack.worker.logs import LogOptionsDep from gpustack.api.exceptions import ( InternalServerErrorException, @@ -136,7 +137,7 @@ async def update_model_instance( raise NotFoundException(message="Model instance not found") try: - await model_instance.update(session, model_instance_in) + await ModelInstanceService(session).update(model_instance, model_instance_in) except Exception as e: raise InternalServerErrorException( message=f"Failed to update model instance: {e}" @@ -152,7 +153,7 @@ async def delete_model_instance(session: SessionDep, id: int): raise NotFoundException(message="Model instance not found") try: - await model_instance.delete(session) + await ModelInstanceService(session).delete(model_instance) except Exception as e: raise InternalServerErrorException( message=f"Failed to delete model instance: {e}" diff --git a/gpustack/routes/models.py b/gpustack/routes/models.py index 19ab1b0..7171680 100644 --- a/gpustack/routes/models.py +++ b/gpustack/routes/models.py @@ -31,6 +31,7 @@ from gpustack.schemas.models import ( ModelPublic, ModelsPublic, ) +from gpustack.server.services import ModelService from gpustack.utils.command import find_parameter from gpustack.utils.convert import safe_int from gpustack.utils.gpu import parse_gpu_id @@ -281,7 +282,7 @@ async def update_model(session: SessionDep, id: int, model_in: ModelUpdate): await validate_model_in(session, model_in) try: - await model.update(session, model_in) + await ModelService(session).update(model, model_in) except Exception as e: raise InternalServerErrorException(message=f"Failed to update model: {e}") @@ -295,6 +296,6 @@ async def delete_model(session: SessionDep, id: int): raise NotFoundException(message="Model not found") try: - await model.delete(session) + await ModelService(session).delete(model) except Exception as e: raise InternalServerErrorException(message=f"Failed to delete model: {e}") diff --git a/gpustack/routes/openai.py b/gpustack/routes/openai.py index e31cb69..3302cd3 100644 --- a/gpustack/routes/openai.py +++ b/gpustack/routes/openai.py @@ -24,11 +24,10 @@ from gpustack.routes.models import build_pg_category_condition from gpustack.schemas.models import ( CategoryEnum, Model, - ModelInstance, - ModelInstanceStateEnum, ) -from gpustack.schemas.workers import Worker +from gpustack.server.db import get_session_context from gpustack.server.deps import SessionDep +from gpustack.server.services import ModelInstanceService, ModelService, WorkerService logger = logging.getLogger(__name__) @@ -40,38 +39,38 @@ aliasable_router = APIRouter() @aliasable_router.post("/chat/completions") -async def chat_completions(session: SessionDep, request: Request): - return await proxy_request_by_model(request, session, "chat/completions") +async def chat_completions(request: Request): + return await proxy_request_by_model(request, "chat/completions") @aliasable_router.post("/completions") -async def completions(session: SessionDep, request: Request): - return await proxy_request_by_model(request, session, "completions") +async def completions(request: Request): + return await proxy_request_by_model(request, "completions") @aliasable_router.post("/embeddings") -async def embeddings(session: SessionDep, request: Request): - return await proxy_request_by_model(request, session, "embeddings") +async def embeddings(request: Request): + return await proxy_request_by_model(request, "embeddings") @aliasable_router.post("/images/generations") -async def images_generations(session: SessionDep, request: Request): - return await proxy_request_by_model(request, session, "images/generations") +async def images_generations(request: Request): + return await proxy_request_by_model(request, "images/generations") @aliasable_router.post("/images/edits") -async def images_edits(session: SessionDep, request: Request): - return await proxy_request_by_model(request, session, "images/edits") +async def images_edits(request: Request): + return await proxy_request_by_model(request, "images/edits") @aliasable_router.post("/audio/speech") -async def audio_speech(session: SessionDep, request: Request): - return await proxy_request_by_model(request, session, "audio/speech") +async def audio_speech(request: Request): + return await proxy_request_by_model(request, "audio/speech") @aliasable_router.post("/audio/transcriptions") -async def audio_transcriptions(session: SessionDep, request: Request): - return await proxy_request_by_model(request, session, "audio/transcriptions") +async def audio_transcriptions(request: Request): + return await proxy_request_by_model(request, "audio/transcriptions") router = APIRouter() @@ -167,31 +166,32 @@ async def list_models( return result -async def proxy_request_by_model(request: Request, session: SessionDep, endpoint: str): +async def proxy_request_by_model(request: Request, endpoint: str): """ Proxy the request to the model instance that is running the model specified in the request body. """ - model, stream, body_json, form_data, form_files = await parse_request_body( - request, session - ) - - if not model: - raise NotFoundException( - message="Model not found", - is_openai_exception=True, + async with get_session_context() as session: + model, stream, body_json, form_data, form_files = await parse_request_body( + request, session ) - request.state.model = model - request.state.stream = stream + if not model: + raise NotFoundException( + message="Model not found", + is_openai_exception=True, + ) - instance = await get_running_instance(session, model.id) - worker = await Worker.one_by_id(session, instance.worker_id) - if not worker: - raise InternalServerErrorException( - message=f"Worker with ID {instance.worker_id} not found", - is_openai_exception=True, - ) + request.state.model = model + request.state.stream = stream + + instance = await get_running_instance(session, model.id) + worker = await WorkerService(session).get_by_id(instance.worker_id) + if not worker: + raise InternalServerErrorException( + message=f"Worker with ID {instance.worker_id} not found", + is_openai_exception=True, + ) url = f"http://{instance.worker_ip}:{worker.port}/proxy/v1/{endpoint}" token = request.app.state.server_config.token @@ -254,7 +254,7 @@ async def parse_request_body(request: Request, session: SessionDep): # stream may be set in form data, e.g., image edits. stream = True - model = await get_model(session, model_name) + model = await ModelService(session).get_by_name(model_name) return model, stream, body_json, form_data, form_files @@ -294,10 +294,6 @@ async def parse_json_body(request: Request): ) -async def get_model(session: SessionDep, model_name: Optional[str]): - return await Model.one_by_field(session=session, field="name", value=model_name) - - async def handle_streaming_request( request: Request, url: str, @@ -405,12 +401,9 @@ def filter_headers(headers): async def get_running_instance(session: AsyncSession, model_id: int): - model_instances = await ModelInstance.all_by_field( - session=session, field="model_id", value=model_id + running_instances = await ModelInstanceService(session).get_running_instances( + model_id ) - running_instances = [ - inst for inst in model_instances if inst.state == ModelInstanceStateEnum.RUNNING - ] if not running_instances: raise ServiceUnavailableException( message="No running instances available", diff --git a/gpustack/routes/workers.py b/gpustack/routes/workers.py index 1b4aaa5..f441522 100644 --- a/gpustack/routes/workers.py +++ b/gpustack/routes/workers.py @@ -14,6 +14,7 @@ from gpustack.schemas.workers import ( WorkersPublic, Worker, ) +from gpustack.server.services import WorkerService router = APIRouter() @@ -77,7 +78,7 @@ async def update_worker(session: SessionDep, id: int, worker_in: WorkerUpdate): try: worker_in.compute_state() - await worker.update(session, worker_in) + await WorkerService(session).update(worker, worker_in) except Exception as e: raise InternalServerErrorException(message=f"Failed to update worker: {e}") @@ -91,6 +92,6 @@ async def delete_worker(session: SessionDep, id: int): raise NotFoundException(message="worker not found") try: - await worker.delete(session) + await WorkerService(session).delete(worker) except Exception as e: raise InternalServerErrorException(message=f"Failed to delete worker: {e}") diff --git a/gpustack/scheduler/scheduler.py b/gpustack/scheduler/scheduler.py index 6d08517..3537934 100644 --- a/gpustack/scheduler/scheduler.py +++ b/gpustack/scheduler/scheduler.py @@ -50,6 +50,7 @@ from gpustack.scheduler.calculator import ( GPUOffloadEnum, calculate_model_resource_claim, ) +from gpustack.server.services import ModelInstanceService, ModelService from gpustack.utils.gpu import parse_gpu_ids_by_worker from gpustack.utils.hub import get_pretrained_config from gpustack.utils.task import run_in_thread @@ -142,7 +143,7 @@ class Scheduler: if instance.state != ModelInstanceStateEnum.ANALYZING: instance.state = ModelInstanceStateEnum.ANALYZING instance.state_message = "Evaluating resource requirements" - await instance.update(session) + await ModelInstanceService(session).update(instance) if model.source == SourceEnum.LOCAL_PATH and not os.path.exists( model.local_path @@ -166,14 +167,14 @@ class Scheduler: should_update_model = await evaluate_pretrained_config(model) if should_update_model: - await model.update(session) + await ModelService(session).update(model) await self._queue.put(instance) except Exception as e: try: instance.state = ModelInstanceStateEnum.ERROR instance.state_message = str(e) - await instance.update(session) + await ModelInstanceService(session).update(instance) except Exception as ue: logger.error( f"Failed to update model instance: {ue}. Original error: {e}" @@ -193,7 +194,7 @@ class Scheduler: instance.state_message = ( "The model is not distributable to multiple workers." ) - await instance.update(session) + await ModelInstanceService(session).update(instance) return True return False @@ -282,7 +283,7 @@ class Scheduler: if state_message != "": model_instance.state_message = state_message - await model_instance.update(session, model_instance) + await ModelInstanceService(session).update(model_instance) logger.debug( f"No suitable workers for model instance {model_instance.name}, state: {model_instance.state}" ) @@ -302,7 +303,7 @@ class Scheduler: ray_actors=candidate.ray_actors, ) - await model_instance.update(session, model_instance) + await ModelInstanceService(session).update(model_instance) logger.debug( f"Scheduled model instance {model_instance.name} to worker " diff --git a/gpustack/server/controllers.py b/gpustack/server/controllers.py index ecf2920..d4b600f 100644 --- a/gpustack/server/controllers.py +++ b/gpustack/server/controllers.py @@ -27,6 +27,7 @@ from gpustack.schemas.models import ( from gpustack.schemas.workers import Worker, WorkerStateEnum from gpustack.server.bus import Event, EventType from gpustack.server.db import get_engine +from gpustack.server.services import ModelInstanceService, ModelService logger = logging.getLogger(__name__) @@ -131,7 +132,7 @@ async def set_default_worker_selector(session: AsyncSession, model: Model): ): # vLLM models are only supported on Linux model.worker_selector = {"os": "linux"} - await model.update(session) + await ModelService(session).update(model) async def sync_replicas(session: AsyncSession, model: Model, cfg: Config): @@ -162,7 +163,7 @@ async def sync_replicas(session: AsyncSession, model: Model, cfg: Config): state=ModelInstanceStateEnum.PENDING, ) - await ModelInstance.create(session, instance) + await ModelInstanceService(session).create(instance) logger.debug(f"Created model instance for model {model.name}") elif len(instances) > model.replicas: @@ -172,7 +173,7 @@ async def sync_replicas(session: AsyncSession, model: Model, cfg: Config): if scale_down_count > 0: for candidate in candidates[:scale_down_count]: instance = candidate.model_instance - await instance.delete(session) + await ModelInstanceService(session).delete(instance) logger.debug(f"Deleted model instance {instance.name}") @@ -346,7 +347,7 @@ async def sync_ready_replicas(session: AsyncSession, model: Model): if model.ready_replicas != ready_replicas: model.ready_replicas = ready_replicas - await model.update(session) + await ModelService(session).update(model) async def get_meta_from_running_instance(mi: ModelInstance) -> Dict[str, Any]: @@ -434,7 +435,7 @@ class WorkerController: ): instance_names = [instance.name for instance in instances] for instance in instances: - await instance.delete(session) + await ModelInstanceService(session).delete(instance) if instance_names: state = ( @@ -474,7 +475,7 @@ class WorkerController: instance.state = new_state instance.state_message = new_state_message - await instance.update(session) + await ModelInstanceService(session).update(instance) if instance_names: logger.debug( f"Marked instance {', '.join(instance_names)} {new_state} " @@ -544,6 +545,7 @@ async def sync_main_model_file_state( f"progress: {file.download_progress}, message: {file.state_message}, instance state: {instance.state}" ) + need_update = False if ( file.state == ModelFileStateEnum.DOWNLOADING and instance.state == ModelInstanceStateEnum.INITIALIZING @@ -551,7 +553,8 @@ async def sync_main_model_file_state( # Download started instance.state = ModelInstanceStateEnum.DOWNLOADING instance.download_progress = 0 - await instance.update(session) + instance.state_message = "" + need_update = True elif ( file.state == ModelFileStateEnum.DOWNLOADING and instance.state == ModelInstanceStateEnum.DOWNLOADING @@ -559,7 +562,7 @@ async def sync_main_model_file_state( ): # Update the download progress instance.download_progress = file.download_progress - await instance.update(session) + need_update = True elif file.state == ModelFileStateEnum.READY and ( instance.state == ModelInstanceStateEnum.DOWNLOADING @@ -571,12 +574,15 @@ async def sync_main_model_file_state( if model_instance_download_completed(instance): # All files are downloaded instance.state = ModelInstanceStateEnum.STARTING - await instance.update(session) + need_update = True elif file.state == ModelFileStateEnum.ERROR: # Download failed instance.state = ModelInstanceStateEnum.ERROR instance.state_message = file.state_message - await instance.update(session) + need_update = True + + if need_update: + await ModelInstanceService(session).update(instance) async def sync_distributed_model_file_state( @@ -623,7 +629,7 @@ async def sync_distributed_model_file_state( if need_update: instance.distributed_servers.ray_actors = ray_actors flag_modified(instance, "distributed_servers") - await instance.update(session) + await ModelInstanceService(session).update(instance) def model_instance_download_completed(instance: ModelInstance): diff --git a/gpustack/server/db.py b/gpustack/server/db.py index 1b2a253..256d797 100644 --- a/gpustack/server/db.py +++ b/gpustack/server/db.py @@ -1,3 +1,5 @@ +from contextlib import asynccontextmanager +import os import re from sqlalchemy.ext.asyncio import ( AsyncEngine, @@ -22,6 +24,11 @@ from gpustack.schemas.stmt import ( _engine = None +DB_ECHO = os.getenv("GPUSTACK_DB_ECHO", "false").lower() == "true" +DB_POOL_SIZE = int(os.getenv("GPUSTACK_DB_POOL_SIZE", 5)) +DB_MAX_OVERFLOW = int(os.getenv("GPUSTACK_DB_MAX_OVERFLOW", 10)) +DB_POOL_TIMEOUT = int(os.getenv("GPUSTACK_DB_POOL_TIMEOUT", 30)) + def get_engine(): return _engine @@ -32,6 +39,12 @@ async def get_session(): yield session +@asynccontextmanager +async def get_session_context(): + async with AsyncSession(_engine) as session: + yield session + + async def init_db(db_url: str): global _engine, _session_maker if _engine is None: @@ -45,7 +58,14 @@ async def init_db(db_url: str): else: raise Exception(f"Unsupported database URL: {db_url}") - _engine = create_async_engine(db_url, echo=False, connect_args=connect_args) + _engine = create_async_engine( + db_url, + echo=DB_ECHO, + pool_size=DB_POOL_SIZE, + max_overflow=DB_MAX_OVERFLOW, + pool_timeout=DB_POOL_TIMEOUT, + connect_args=connect_args, + ) listen_events(_engine) await create_db_and_tables(_engine) diff --git a/gpustack/server/server.py b/gpustack/server/server.py index 48a11b5..f979d52 100644 --- a/gpustack/server/server.py +++ b/gpustack/server/server.py @@ -23,6 +23,7 @@ from gpustack.scheduler.scheduler import Scheduler from gpustack.ray.manager import RayManager from gpustack.server.system_load import SystemLoadCollector from gpustack.server.update_check import UpdateChecker +from gpustack.server.usage_buffer import flush_usage_to_db from gpustack.server.worker_syncer import WorkerSyncer from gpustack.utils.process import add_signal_handlers_in_loop @@ -65,6 +66,7 @@ class Server: self._start_system_load_collector() self._start_worker_syncer() self._start_update_checker() + self._start_model_usage_flusher() self._start_ray() port = 80 @@ -170,6 +172,11 @@ class Server: logger.debug("Worker syncer started.") + def _start_model_usage_flusher(self): + self._create_async_task(flush_usage_to_db()) + + logger.debug("Model usage flusher started.") + def _start_update_checker(self): if self._config.disable_update_check: return diff --git a/gpustack/server/services.py b/gpustack/server/services.py new file mode 100644 index 0000000..e98fec2 --- /dev/null +++ b/gpustack/server/services.py @@ -0,0 +1,218 @@ +import asyncio +import logging +from typing import Any, Callable, Dict, List, Optional, Union +from aiocache import Cache, cached +from sqlmodel import SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession + +from gpustack.schemas.api_keys import ApiKey +from gpustack.schemas.model_usage import ModelUsage +from gpustack.schemas.models import Model, ModelInstance, ModelInstanceStateEnum +from gpustack.schemas.users import User +from gpustack.schemas.workers import Worker +from gpustack.server.usage_buffer import usage_flush_buffer + +logger = logging.getLogger(__name__) +cache = Cache(Cache.MEMORY) + + +def build_cache_key(func: Callable, *args, **kwargs): + if kwargs is None: + kwargs = {} + ordered_kwargs = sorted(kwargs.items()) + return func.__qualname__ + str(args) + str(ordered_kwargs) + + +async def delete_cache_by_key(func, *args, **kwargs): + key = build_cache_key(func, *args, **kwargs) + logger.trace(f"Deleting cache for key: {key}") + await cache.delete(key) + + +async def set_cache_by_key(key: str, value: Any): + logger.trace(f"Set cache for key: {key}") + await cache.set(key, value) + + +_cache_locks: Dict[str, asyncio.Lock] = {} + + +class locked_cached(cached): + async def decorator(self, f, *args, **kwargs): + # no self arg + key = build_cache_key(f, *args[1:], **kwargs) + value = await self.get_from_cache(key) + if value is not None: + return value + + lock = _cache_locks.setdefault(key, asyncio.Lock()) + + async with lock: + value = await self.get_from_cache(key) + if value is not None: + return value + + logger.trace(f"cache miss for key: {key}") + result = await f(*args, **kwargs) + + await self.set_in_cache(key, result) + + return result + + +class UserService: + + def __init__(self, session: AsyncSession): + self.session = session + + @locked_cached(ttl=60) + async def get_by_id(self, user_id: int) -> Optional[User]: + return await User.one_by_id(self.session, user_id) + + @locked_cached(ttl=60) + async def get_by_username(self, username: str) -> Optional[User]: + return await User.one_by_field(self.session, "username", username) + + async def create(self, user: User): + return await User.create(self.session, user) + + async def update(self, user: User, source: Union[dict, SQLModel, None] = None): + result = await user.update(self.session, source) + await delete_cache_by_key(self.get_by_id, user.id) + await delete_cache_by_key(self.get_by_username, user.username) + return result + + async def delete(self, user: User): + result = await user.delete(self.session) + await delete_cache_by_key(self.get_by_id, user.id) + await delete_cache_by_key(self.get_by_username, user.username) + return result + + +class APIKeyService: + def __init__(self, session: AsyncSession): + self.session = session + + @locked_cached(ttl=60) + async def get_by_access_key(self, access_key: str) -> Optional[ApiKey]: + return await ApiKey.one_by_field(self.session, "access_key", access_key) + + async def delete(self, api_key: ApiKey): + result = await api_key.delete(self.session) + await delete_cache_by_key(self.get_by_access_key, api_key.access_key) + return result + + +class WorkerService: + def __init__(self, session: AsyncSession): + self.session = session + + @locked_cached(ttl=60) + async def get_by_id(self, worker_id: int) -> Optional[Worker]: + return await Worker.one_by_id(self.session, worker_id) + + @locked_cached(ttl=60) + async def get_by_name(self, name: str) -> Optional[Worker]: + return await Worker.one_by_field(self.session, "name", name) + + async def update(self, worker: Worker, source: Union[dict, SQLModel, None] = None): + result = await worker.update(self.session, source) + await delete_cache_by_key(self.get_by_id, worker.id) + await delete_cache_by_key(self.get_by_name, worker.name) + return result + + async def delete(self, worker: Worker): + result = await worker.delete(self.session) + await delete_cache_by_key(self.get_by_id, worker.id) + await delete_cache_by_key(self.get_by_name, worker.name) + return result + + +class ModelService: + def __init__(self, session: AsyncSession): + self.session = session + + @locked_cached(ttl=60) + async def get_by_id(self, model_id: int) -> Optional[Model]: + return await Model.one_by_id(self.session, model_id) + + @locked_cached(ttl=60) + async def get_by_name(self, name: str) -> Optional[Model]: + return await Model.one_by_field(self.session, "name", name) + + async def update(self, model: Model, source: Union[dict, SQLModel, None] = None): + result = await model.update(self.session, source) + await delete_cache_by_key(self.get_by_id, model.id) + await delete_cache_by_key(self.get_by_name, model.name) + return result + + async def delete(self, model: Model): + result = await model.delete(self.session) + await delete_cache_by_key(self.get_by_id, model.id) + await delete_cache_by_key(self.get_by_name, model.name) + return result + + +class ModelInstanceService: + def __init__(self, session: AsyncSession): + self.session = session + + @locked_cached(ttl=60) + async def get_running_instances(self, model_id: int) -> List[ModelInstance]: + return await ModelInstance.all_by_fields( + self.session, + fields={"model_id": model_id, "state": ModelInstanceStateEnum.RUNNING}, + ) + + async def create(self, model_instance): + result = await ModelInstance.create(self.session, model_instance) + await delete_cache_by_key(self.get_running_instances, model_instance.model_id) + return result + + async def update( + self, model_instance: ModelInstance, source: Union[dict, SQLModel, None] = None + ): + result = await model_instance.update(self.session, source) + await delete_cache_by_key(self.get_running_instances, model_instance.model_id) + return result + + async def delete(self, model_instance: ModelInstance): + result = await model_instance.delete(self.session) + await delete_cache_by_key(self.get_running_instances, model_instance.model_id) + return result + + +class ModelUsageService: + def __init__(self, session: AsyncSession): + self.session = session + + @locked_cached(ttl=60) + async def get_by_fields(self, fields: dict) -> ModelUsage: + return await ModelUsage.one_by_fields( + self.session, + fields=fields, + ) + + async def create(self, model_usage: ModelUsage): + return await ModelUsage.create(self.session, model_usage) + + async def update( + self, + model_usage: ModelUsage, + completion_token_count: int, + prompt_token_count: int, + ): + model_usage.completion_token_count += completion_token_count + model_usage.prompt_token_count += prompt_token_count + model_usage.request_count += 1 + + key = build_cache_key( + self.get_by_fields, + model_usage.user_id, + model_usage.model_id, + model_usage.operation, + model_usage.date, + ) + await set_cache_by_key(key, model_usage) + usage_flush_buffer[key] = model_usage + return model_usage diff --git a/gpustack/server/usage_buffer.py b/gpustack/server/usage_buffer.py new file mode 100644 index 0000000..624b1b5 --- /dev/null +++ b/gpustack/server/usage_buffer.py @@ -0,0 +1,39 @@ +import asyncio +import logging +from typing import Dict +from sqlmodel.ext.asyncio.session import AsyncSession + +from gpustack.schemas.model_usage import ModelUsage +from gpustack.server.db import get_engine + +logger = logging.getLogger(__name__) + +usage_flush_buffer: Dict[str, ModelUsage] = {} + + +async def flush_usage_to_db(): + """ + Flush model usage records to the database periodically. + """ + while True: + await asyncio.sleep(5) + + if not usage_flush_buffer: + continue + + local_buffer = dict(usage_flush_buffer) + usage_flush_buffer.clear() + + try: + async with AsyncSession(get_engine()) as session: + for key, usage in local_buffer.items(): + to_update = await ModelUsage.one_by_id(session, usage.id) + to_update.prompt_token_count = usage.prompt_token_count + to_update.completion_token_count = usage.completion_token_count + to_update.request_count = usage.request_count + session.add(to_update) + + await session.commit() + logger.debug(f"Flushed {len(local_buffer)} usage records to DB") + except Exception as e: + logger.error(f"Error flushing usage to DB: {e}") diff --git a/gpustack/server/worker_syncer.py b/gpustack/server/worker_syncer.py index b4f99c9..d6c4fed 100644 --- a/gpustack/server/worker_syncer.py +++ b/gpustack/server/worker_syncer.py @@ -3,6 +3,7 @@ import logging from sqlmodel.ext.asyncio.session import AsyncSession from gpustack.schemas.workers import Worker, WorkerStateEnum from gpustack.server.db import get_engine +from gpustack.server.services import WorkerService from gpustack.utils.network import is_url_reachable logger = logging.getLogger(__name__) @@ -55,7 +56,7 @@ class WorkerSyncer: state_to_worker_name[worker.state].append(worker.name) for worker in should_update_workers: - await worker.update(session, worker) + await WorkerService(session).update(worker) for state, worker_names in state_to_worker_name.items(): if worker_names: diff --git a/poetry.lock b/poetry.lock index 0945fa4..9ca794c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -53,6 +53,22 @@ files = [ {file = "addict-2.4.0.tar.gz", hash = "sha256:b3b2210e0e067a281f5646c8c5db92e99b7231ea8b0eb5f74dbdf9e259d4e494"}, ] +[[package]] +name = "aiocache" +version = "0.12.3" +description = "multi backend asyncio cache" +optional = false +python-versions = "*" +files = [ + {file = "aiocache-0.12.3-py2.py3-none-any.whl", hash = "sha256:889086fc24710f431937b87ad3720a289f7fc31c4fd8b68e9f918b9bacd8270d"}, + {file = "aiocache-0.12.3.tar.gz", hash = "sha256:f528b27bf4d436b497a1d0d1a8f59a542c153ab1e37c3621713cb376d44c4713"}, +] + +[package.extras] +memcached = ["aiomcache (>=0.5.2)"] +msgpack = ["msgpack (>=0.5.5)"] +redis = ["redis (>=4.2.0)"] + [[package]] name = "aiofiles" version = "23.2.1" @@ -8867,4 +8883,4 @@ vllm = ["bitsandbytes", "mistral_common", "vllm"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "446f471bce8d67056e9e79b21b04b56e5446804693de22fe9ebdbb5029144436" +content-hash = "c07930551e489277eb3327bb398dac6d0106ed7ed5180f04760a7d2de950eac5" diff --git a/pyproject.toml b/pyproject.toml index 1c53f71..61b574d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ psycopg2-binary = "^2.9.10" vox-box = {version = "0.0.11", optional = true} tenacity = "^9.0.0" +aiocache = "^0.12.3" aiofiles = "^23.2.1" aiohttp = "^3.11.2" bitsandbytes = {version = "^0.45.2", optional = true}