|
|
|
|
@ -1,14 +1,6 @@
|
|
|
|
|
import asyncio
|
|
|
|
|
import logging
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from typing import Dict, Optional, AsyncGenerator, Callable, Any
|
|
|
|
|
|
|
|
|
|
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
|
|
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
|
|
|
|
|
|
from gpustack.schemas import ModelInstancePublic
|
|
|
|
|
from gpustack.schemas.models import ModelInstance
|
|
|
|
|
from typing import Dict, Optional
|
|
|
|
|
from pydantic import ConfigDict, BaseModel
|
|
|
|
|
from sqlmodel import Field, SQLModel, JSON, Column, Text
|
|
|
|
|
|
|
|
|
|
@ -17,10 +9,7 @@ from gpustack.schemas.common import PaginatedList, UTCDateTime, pydantic_column_
|
|
|
|
|
from typing import List
|
|
|
|
|
from sqlalchemy.orm import declarative_base
|
|
|
|
|
|
|
|
|
|
from gpustack.server.bus import EventType, Event
|
|
|
|
|
|
|
|
|
|
Base = declarative_base()
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UtilizationInfo(BaseModel):
|
|
|
|
|
@ -220,79 +209,6 @@ class Worker(WorkerBase, BaseModelMixin, table=True):
|
|
|
|
|
__tablename__ = 'workers'
|
|
|
|
|
id: Optional[int] = Field(default=None, primary_key=True)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def _enrich_worker_with_model_instances(
|
|
|
|
|
cls, event: Event, engine: AsyncEngine
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
查询 model_instances 并注入到 worker 对象中。
|
|
|
|
|
"""
|
|
|
|
|
worker = event.data
|
|
|
|
|
worker_id = getattr(worker, "id", None)
|
|
|
|
|
|
|
|
|
|
if not worker_id:
|
|
|
|
|
if isinstance(worker, dict):
|
|
|
|
|
worker_id = worker.get("id")
|
|
|
|
|
if not worker_id:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
async with AsyncSession(engine) as session:
|
|
|
|
|
try:
|
|
|
|
|
model_instances = await ModelInstance.all_by_fields(
|
|
|
|
|
session, fields={"worker_id": worker_id}
|
|
|
|
|
)
|
|
|
|
|
model_instances_public = []
|
|
|
|
|
for model_instance in model_instances:
|
|
|
|
|
model_instance_public = ModelInstancePublic.model_validate(
|
|
|
|
|
model_instance
|
|
|
|
|
)
|
|
|
|
|
model_instances_public.append(model_instance_public)
|
|
|
|
|
|
|
|
|
|
public_worker = WorkerPublicWithInstances.model_validate(worker)
|
|
|
|
|
public_worker.model_instances = model_instances_public
|
|
|
|
|
event.data = public_worker
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Failed to inject model_instances for worker {worker_id}: {e}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod # noqa: C901
|
|
|
|
|
async def streaming(
|
|
|
|
|
cls,
|
|
|
|
|
engine: AsyncEngine,
|
|
|
|
|
fields: Optional[dict] = None,
|
|
|
|
|
fuzzy_fields: Optional[dict] = None,
|
|
|
|
|
filter_func: Optional[Callable[[Any], bool]] = None,
|
|
|
|
|
) -> AsyncGenerator[str, None]:
|
|
|
|
|
"""
|
|
|
|
|
在原有 streaming 基础上,为每个 Worker 事件注入 model_instances。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# 使用原有 subscribe 机制,保证 topic、过滤、转换逻辑不变
|
|
|
|
|
async for event in cls.subscribe(engine):
|
|
|
|
|
if event.type == EventType.HEARTBEAT:
|
|
|
|
|
yield "\n\n"
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if not cls._match_fields(event, fields):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if not cls._match_fuzzy_fields(event, fuzzy_fields):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if filter_func and not filter_func(event.data):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
await cls._enrich_worker_with_model_instances(event, engine)
|
|
|
|
|
# 格式化输出
|
|
|
|
|
yield cls._format_event(event)
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
pass
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in streaming {cls.__name__}: {e}")
|
|
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
|
return hash(self.id)
|
|
|
|
|
|
|
|
|
|
@ -318,10 +234,4 @@ class WorkerPublic(
|
|
|
|
|
updated_at: datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WorkerPublicWithInstances(WorkerPublic):
|
|
|
|
|
model_instances: List[ModelInstancePublic] = Field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WorkersPublicWithInstances = PaginatedList[WorkerPublicWithInstances]
|
|
|
|
|
|
|
|
|
|
WorkersPublic = PaginatedList[WorkerPublic]
|
|
|
|
|
|