|
|
|
|
@ -1,6 +1,14 @@
|
|
|
|
|
import asyncio
|
|
|
|
|
import logging
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from typing import Dict, Optional
|
|
|
|
|
from typing import Dict, Optional, AsyncGenerator, Callable, Any
|
|
|
|
|
|
|
|
|
|
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
|
|
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
|
|
|
|
|
|
from gpustack.schemas import ModelInstancePublic
|
|
|
|
|
from gpustack.schemas.models import ModelInstance
|
|
|
|
|
from pydantic import ConfigDict, BaseModel
|
|
|
|
|
from sqlmodel import Field, SQLModel, JSON, Column, Text
|
|
|
|
|
|
|
|
|
|
@ -9,7 +17,10 @@ from gpustack.schemas.common import PaginatedList, UTCDateTime, pydantic_column_
|
|
|
|
|
from typing import List
|
|
|
|
|
from sqlalchemy.orm import declarative_base
|
|
|
|
|
|
|
|
|
|
from gpustack.server.bus import EventType, Event
|
|
|
|
|
|
|
|
|
|
Base = declarative_base()
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UtilizationInfo(BaseModel):
|
|
|
|
|
@ -209,6 +220,129 @@ class Worker(WorkerBase, BaseModelMixin, table=True):
|
|
|
|
|
__tablename__ = 'workers'
|
|
|
|
|
id: Optional[int] = Field(default=None, primary_key=True)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def _enrich_worker_with_model_instances(
|
|
|
|
|
cls, event: Event, engine: AsyncEngine
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
查询 model_instances 并注入到 worker 对象中。
|
|
|
|
|
"""
|
|
|
|
|
worker = event.data
|
|
|
|
|
worker_id = getattr(worker, "id", None)
|
|
|
|
|
|
|
|
|
|
if not worker_id:
|
|
|
|
|
if isinstance(worker, dict):
|
|
|
|
|
worker_id = worker.get("id")
|
|
|
|
|
if not worker_id:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
async with AsyncSession(engine) as session:
|
|
|
|
|
try:
|
|
|
|
|
all_model_instances = await ModelInstance.paginated_by_query(
|
|
|
|
|
session=session
|
|
|
|
|
)
|
|
|
|
|
instances_with_subordinate = []
|
|
|
|
|
for model_instance in all_model_instances.items:
|
|
|
|
|
main_instance = ModelInstancePublic.model_validate(model_instance)
|
|
|
|
|
instance_with_subordinate = {
|
|
|
|
|
"main_instance": main_instance,
|
|
|
|
|
"subordinate_instances": [],
|
|
|
|
|
}
|
|
|
|
|
if (
|
|
|
|
|
model_instance.distributed_servers
|
|
|
|
|
and model_instance.distributed_servers.subordinate_workers
|
|
|
|
|
):
|
|
|
|
|
for (
|
|
|
|
|
subordinate_worker
|
|
|
|
|
) in model_instance.distributed_servers.subordinate_workers:
|
|
|
|
|
subordinate_instance = ModelInstancePublic.model_validate(
|
|
|
|
|
model_instance
|
|
|
|
|
)
|
|
|
|
|
subordinate_instance.name = (
|
|
|
|
|
model_instance.name + "(subordinate)"
|
|
|
|
|
)
|
|
|
|
|
subordinate_instance.computed_resource_claim = (
|
|
|
|
|
subordinate_worker.computed_resource_claim
|
|
|
|
|
)
|
|
|
|
|
subordinate_instance.worker_name = (
|
|
|
|
|
subordinate_worker.worker_name
|
|
|
|
|
)
|
|
|
|
|
subordinate_instance.worker_id = (
|
|
|
|
|
subordinate_worker.worker_id
|
|
|
|
|
)
|
|
|
|
|
subordinate_instance.worker_ip = (
|
|
|
|
|
subordinate_worker.worker_ip
|
|
|
|
|
)
|
|
|
|
|
subordinate_instance.gpu_indexes = (
|
|
|
|
|
subordinate_worker.gpu_indexes
|
|
|
|
|
)
|
|
|
|
|
subordinate_instance.state = subordinate_worker.state
|
|
|
|
|
instance_with_subordinate["subordinate_instances"].append(
|
|
|
|
|
subordinate_instance
|
|
|
|
|
)
|
|
|
|
|
instances_with_subordinate.append(instance_with_subordinate)
|
|
|
|
|
|
|
|
|
|
model_instances_public = []
|
|
|
|
|
for instance_with_subordinate in instances_with_subordinate:
|
|
|
|
|
if (
|
|
|
|
|
instance_with_subordinate.get("main_instance").worker_id
|
|
|
|
|
== worker_id
|
|
|
|
|
):
|
|
|
|
|
model_instances_public.append(
|
|
|
|
|
instance_with_subordinate.get("main_instance")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for subordinate_instance in instance_with_subordinate.get(
|
|
|
|
|
"subordinate_instances"
|
|
|
|
|
):
|
|
|
|
|
if subordinate_instance.worker_id == worker_id:
|
|
|
|
|
model_instances_public.append(subordinate_instance)
|
|
|
|
|
|
|
|
|
|
public_worker = WorkerPublicWithInstances.model_validate(worker)
|
|
|
|
|
public_worker.model_instances = model_instances_public
|
|
|
|
|
event.data = public_worker
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Failed to inject model_instances for worker {worker_id}: {e}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod # noqa: C901
|
|
|
|
|
async def streaming(
|
|
|
|
|
cls,
|
|
|
|
|
engine: AsyncEngine,
|
|
|
|
|
fields: Optional[dict] = None,
|
|
|
|
|
fuzzy_fields: Optional[dict] = None,
|
|
|
|
|
filter_func: Optional[Callable[[Any], bool]] = None,
|
|
|
|
|
) -> AsyncGenerator[str, None]:
|
|
|
|
|
"""
|
|
|
|
|
在原有 streaming 基础上,为每个 Worker 事件注入 model_instances。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# 使用原有 subscribe 机制,保证 topic、过滤、转换逻辑不变
|
|
|
|
|
async for event in cls.subscribe(engine):
|
|
|
|
|
if event.type == EventType.HEARTBEAT:
|
|
|
|
|
yield "\n\n"
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if not cls._match_fields(event, fields):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if not cls._match_fuzzy_fields(event, fuzzy_fields):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if filter_func and not filter_func(event.data):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
await cls._enrich_worker_with_model_instances(event, engine)
|
|
|
|
|
# 格式化输出
|
|
|
|
|
yield cls._format_event(event)
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
pass
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in streaming {cls.__name__}: {e}")
|
|
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
|
return hash(self.id)
|
|
|
|
|
|
|
|
|
|
@ -234,4 +368,10 @@ class WorkerPublic(
|
|
|
|
|
updated_at: datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WorkerPublicWithInstances(WorkerPublic):
|
|
|
|
|
model_instances: List[ModelInstancePublic] = Field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WorkersPublicWithInstances = PaginatedList[WorkerPublicWithInstances]
|
|
|
|
|
|
|
|
|
|
WorkersPublic = PaginatedList[WorkerPublic]
|
|
|
|
|
|