You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gpustack/tests/server/test_catalog.py

101 lines
3.7 KiB

import os
import time
import pytest
from tenacity import retry, stop_after_attempt, wait_fixed
from gpustack.scheduler.scheduler import SourceEnum
from gpustack.server.catalog import get_model_set_specs, init_model_catalog
from gpustack.utils.hub import match_hugging_face_files, match_model_scope_file_paths
from gpustack.utils.compat_importlib import pkg_resources
from huggingface_hub import HfApi
from modelscope.hub.api import HubApi
@pytest.mark.skipif(
os.getenv("HF_TOKEN") is None,
reason="Skipped by default unless HF_TOKEN is set. Unauthed requests are rate limited.",
)
def test_model_catalog():
init_model_catalog()
model_set_specs = get_model_set_specs()
Hfapi = HfApi()
model_name_filter = os.getenv("TEST_CATALOG_MODEL_NAME_FILTER")
for model_set_id, model_specs in model_set_specs.items():
assert model_set_id is not None
assert len(model_specs) > 0
for model_spec in model_specs:
if model_spec.source != SourceEnum.HUGGING_FACE:
continue
if (
model_name_filter is not None
and model_name_filter not in model_spec.huggingface_repo_id
):
continue
time.sleep(0.01) # mitigate rate limit
print(model_spec.huggingface_repo_id, model_spec.huggingface_filename)
if model_spec.huggingface_filename is None:
model_info = Hfapi.model_info(model_spec.huggingface_repo_id)
assert model_info is not None
else:
match_files = match_hugging_face_files(
model_spec.huggingface_repo_id, model_spec.huggingface_filename
)
assert (
len(match_files) > 0
), f"Failed to find model files: {model_spec.huggingface_repo_id}, {model_spec.huggingface_filename}"
@pytest.mark.skipif(
os.getenv("HF_TOKEN") is None,
reason="Skipped by default unless HF_TOKEN is set. Unauthed requests are rate limited.",
)
def test_model_catalog_modelscope():
modelscope_catalog_file = pkg_resources.files("gpustack.assets").joinpath(
"model-catalog-modelscope.yaml"
)
init_model_catalog(str(modelscope_catalog_file))
model_set_specs = get_model_set_specs()
Msapi = HubApi()
model_name_filter = os.getenv("TEST_CATALOG_MODEL_NAME_FILTER")
for model_set_id, model_specs in model_set_specs.items():
assert model_set_id is not None
assert len(model_specs) > 0
for model_spec in model_specs:
if model_spec.source != SourceEnum.MODEL_SCOPE:
continue
if (
model_name_filter is not None
and model_name_filter not in model_spec.model_scope_model_id
):
continue
print(model_spec.model_scope_model_id, model_spec.model_scope_file_path)
if model_spec.model_scope_file_path is None:
model_info = Msapi.get_model(model_spec.model_scope_model_id)
assert model_info is not None
else:
match_files = match_model_scope_file_paths_with_retry(
model_spec.model_scope_model_id,
model_spec.model_scope_file_path,
)
assert (
len(match_files) > 0
), f"Failed to find model files: {model_spec.model_scope_model_id}, {model_spec.model_scope_file_path}"
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
def match_model_scope_file_paths_with_retry(
model_scope_model_id, model_scope_file_path
):
return match_model_scope_file_paths(model_scope_model_id, model_scope_file_path)