|
|
|
|
@ -1,5 +1,8 @@
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
from fastapi import APIRouter
|
|
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
|
from gpustack.schemas.models import ModelInstance
|
|
|
|
|
|
|
|
|
|
from gpustack.api.exceptions import (
|
|
|
|
|
AlreadyExistsException,
|
|
|
|
|
@ -11,15 +14,16 @@ from gpustack.schemas.workers import (
|
|
|
|
|
WorkerCreate,
|
|
|
|
|
WorkerPublic,
|
|
|
|
|
WorkerUpdate,
|
|
|
|
|
WorkersPublic,
|
|
|
|
|
Worker,
|
|
|
|
|
WorkerWithInstancesPublic,
|
|
|
|
|
WorkerWithInstances,
|
|
|
|
|
)
|
|
|
|
|
from gpustack.server.services import WorkerService
|
|
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("", response_model=WorkersPublic)
|
|
|
|
|
@router.get("", response_model=WorkerWithInstancesPublic)
|
|
|
|
|
async def get_workers(
|
|
|
|
|
engine: EngineDep,
|
|
|
|
|
session: SessionDep,
|
|
|
|
|
@ -44,7 +48,7 @@ async def get_workers(
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return await Worker.paginated_by_query(
|
|
|
|
|
worker_list = await Worker.paginated_by_query(
|
|
|
|
|
session=session,
|
|
|
|
|
fields=fields,
|
|
|
|
|
fuzzy_fields=fuzzy_fields,
|
|
|
|
|
@ -52,6 +56,18 @@ async def get_workers(
|
|
|
|
|
per_page=params.perPage,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
enriched_workers = []
|
|
|
|
|
for worker in worker_list.items:
|
|
|
|
|
model_instances = await get_model_instances_by_worker(session, worker.id)
|
|
|
|
|
enriched_worker = WorkerWithInstances(
|
|
|
|
|
**worker.model_dump(), model_instances=model_instances
|
|
|
|
|
)
|
|
|
|
|
enriched_workers.append(enriched_worker)
|
|
|
|
|
|
|
|
|
|
return WorkerWithInstancesPublic(
|
|
|
|
|
items=enriched_workers, pagination=worker_list.pagination
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/{id}", response_model=WorkerPublic)
|
|
|
|
|
async def get_worker(session: SessionDep, id: int):
|
|
|
|
|
@ -101,3 +117,14 @@ async def delete_worker(session: SessionDep, id: int):
|
|
|
|
|
await WorkerService(session).delete(worker)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise InternalServerErrorException(message=f"Failed to delete worker: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_model_instances_by_worker(session: SessionDep, worker_id: int):
|
|
|
|
|
fields = {
|
|
|
|
|
"worker_id": worker_id,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
model_instances: List[ModelInstance] = await ModelInstance.all_by_fields(
|
|
|
|
|
session, fields=fields
|
|
|
|
|
)
|
|
|
|
|
return model_instances
|
|
|
|
|
|