|
|
|
|
@ -1,35 +1,21 @@
|
|
|
|
|
import asyncio
|
|
|
|
|
import logging
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from enum import Enum
|
|
|
|
|
import hashlib
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import (
|
|
|
|
|
TYPE_CHECKING,
|
|
|
|
|
Annotated,
|
|
|
|
|
Any,
|
|
|
|
|
Dict,
|
|
|
|
|
List,
|
|
|
|
|
Optional,
|
|
|
|
|
Union,
|
|
|
|
|
Callable,
|
|
|
|
|
AsyncGenerator,
|
|
|
|
|
)
|
|
|
|
|
from typing import TYPE_CHECKING, Annotated, Any, Dict, List, Optional, Union
|
|
|
|
|
from pydantic import BaseModel, ConfigDict, model_validator, Field as PydanticField
|
|
|
|
|
from sqlalchemy import JSON, Column
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
|
|
|
from sqlmodel import Field, Relationship, SQLModel, Text
|
|
|
|
|
|
|
|
|
|
from gpustack.schemas.common import PaginatedList, UTCDateTime, pydantic_column_type
|
|
|
|
|
from gpustack.mixins import BaseModelMixin
|
|
|
|
|
from gpustack.schemas.links import ModelInstanceModelFileLink
|
|
|
|
|
from gpustack.server.bus import EventType, Event
|
|
|
|
|
from gpustack.utils.command import find_parameter
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from gpustack.schemas.model_files import ModelFile
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
# Models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -394,82 +380,6 @@ class ModelInstance(ModelInstanceBase, BaseModelMixin, table=True):
|
|
|
|
|
def __hash__(self):
|
|
|
|
|
return self.id
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def _enrich_instance_with_gpu_names(
|
|
|
|
|
cls, event: Event, gpu_list: List[Dict]
|
|
|
|
|
) -> None:
|
|
|
|
|
instance = event.data
|
|
|
|
|
worker_ip = getattr(instance, "worker_ip", None)
|
|
|
|
|
worker_name = getattr(instance, "worker_name", None)
|
|
|
|
|
gpu_indexes = getattr(instance, "gpu_indexes", [])
|
|
|
|
|
try:
|
|
|
|
|
gpu_names = get_gpu_names_by_gpus_worker_name(
|
|
|
|
|
gpu_list, worker_name, gpu_indexes
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 如果存在分布式子工作节点,获取其GPU名称
|
|
|
|
|
subordinate_gpu_names = []
|
|
|
|
|
if (
|
|
|
|
|
instance.distributed_servers
|
|
|
|
|
and instance.distributed_servers.subordinate_workers
|
|
|
|
|
):
|
|
|
|
|
for (
|
|
|
|
|
subordinate_worker
|
|
|
|
|
) in instance.distributed_servers.subordinate_workers:
|
|
|
|
|
subordinate_gpu_names += get_gpu_names_by_gpus_worker_name(
|
|
|
|
|
gpu_list,
|
|
|
|
|
subordinate_worker.worker_name,
|
|
|
|
|
subordinate_worker.gpu_indexes,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 合并主节点和子节点的GPU名称
|
|
|
|
|
gpu_names + subordinate_gpu_names
|
|
|
|
|
|
|
|
|
|
public_instance = ModelInstancePublicWithExtra.model_validate(instance)
|
|
|
|
|
public_instance.gpu_names = gpu_names
|
|
|
|
|
event.data = public_instance
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Failed to inject gpu for ModelInstance worker_ip={worker_ip}, gpu_indexes={gpu_indexes}: {e}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def streaming(
|
|
|
|
|
cls,
|
|
|
|
|
engine: AsyncEngine,
|
|
|
|
|
fields: Optional[dict] = None,
|
|
|
|
|
fuzzy_fields: Optional[dict] = None,
|
|
|
|
|
filter_func: Optional[Callable[[Any], bool]] = None,
|
|
|
|
|
gpu_list: Optional[List[Dict]] = None,
|
|
|
|
|
) -> AsyncGenerator[str, None]:
|
|
|
|
|
"""
|
|
|
|
|
在原有 streaming 基础上,为每个 instance 事件注入 gpu_names。
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# 使用原有 subscribe 机制,保证 topic、过滤、转换逻辑不变
|
|
|
|
|
async for event in cls.subscribe(engine):
|
|
|
|
|
if event.type == EventType.HEARTBEAT:
|
|
|
|
|
yield "\n\n"
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if not cls._match_fields(event, fields):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if not cls._match_fuzzy_fields(event, fuzzy_fields):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if filter_func and not filter_func(event.data):
|
|
|
|
|
continue
|
|
|
|
|
if gpu_list:
|
|
|
|
|
await cls._enrich_instance_with_gpu_names(event, gpu_list)
|
|
|
|
|
# 格式化输出
|
|
|
|
|
yield cls._format_event(event)
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
pass
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in streaming {cls.__name__}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelInstanceCreate(ModelInstanceBase):
|
|
|
|
|
pass
|
|
|
|
|
@ -487,12 +397,6 @@ class ModelInstancePublic(
|
|
|
|
|
updated_at: datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelInstancePublicWithExtra(ModelInstancePublic):
|
|
|
|
|
gpu_names: List[str] = Field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ModelInstancesPublicWithExtra = PaginatedList[ModelInstancePublicWithExtra]
|
|
|
|
|
|
|
|
|
|
ModelInstancesPublic = PaginatedList[ModelInstancePublic]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -594,13 +498,3 @@ def get_mmproj_filename(model: Union[Model, ModelSource]) -> Optional[str]:
|
|
|
|
|
return mmproj
|
|
|
|
|
|
|
|
|
|
return "*mmproj*.gguf"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gpu_names_by_gpus_worker_name(
|
|
|
|
|
gpus: List[Dict], worker_name: str, gpu_indexes: List[int]
|
|
|
|
|
) -> List[str]:
|
|
|
|
|
return [
|
|
|
|
|
gpu.get("name")
|
|
|
|
|
for gpu in gpus
|
|
|
|
|
if gpu.get("worker_name") == worker_name and gpu.get("index") in gpu_indexes
|
|
|
|
|
]
|
|
|
|
|
|