From 281d8bddc198bf1f9881e2ff2cccd2c6ef99c0a6 Mon Sep 17 00:00:00 2001 From: jialin Date: Tue, 10 Jun 2025 19:23:16 +0800 Subject: [PATCH] feat: add pagination for searching models --- .env | 1 + src/components/echarts/mix-line-bar.tsx | 2 +- src/pages/llmodels/apis/index.ts | 8 +- .../llmodels/components/deploy-modal.tsx | 25 ++- .../llmodels/components/hf-model-file.tsx | 8 +- .../llmodels/components/search-model.tsx | 200 +++++++++++++----- src/pages/llmodels/hooks/index.ts | 3 +- 7 files changed, 183 insertions(+), 64 deletions(-) diff --git a/.env b/.env index a5c95d33..de8c0657 100644 --- a/.env +++ b/.env @@ -1,3 +1,4 @@ PORT=9000 UMI_DEV_SERVER_COMPRESS=none +DID_YOU_KNOW=none diff --git a/src/components/echarts/mix-line-bar.tsx b/src/components/echarts/mix-line-bar.tsx index 4d776bec..afb2752d 100644 --- a/src/components/echarts/mix-line-bar.tsx +++ b/src/components/echarts/mix-line-bar.tsx @@ -46,7 +46,7 @@ const MixLineBarChart: React.FC< grid: { ...grid, top: 20, - bottom: 20 + bottom: 10 }, tooltip: { ...tooltip, diff --git a/src/pages/llmodels/apis/index.ts b/src/pages/llmodels/apis/index.ts index b254553f..3d9fc1d8 100644 --- a/src/pages/llmodels/apis/index.ts +++ b/src/pages/llmodels/apis/index.ts @@ -191,11 +191,11 @@ export async function queryModelScopeModels( 'Content-Type': 'application/json' }, body: JSON.stringify({ + PageSize: 10, + PageNumber: 1, ...params, ...Criterion, - Name: `${params.Name}`, - PageSize: 10, - PageNumber: 1 + Name: `${params.Name}` }) }); if (!res.ok) { @@ -256,7 +256,7 @@ export async function queryHuggingfaceModels( for await (const model of listModels({ ...params, ...options, - limit: 10, + limit: 500, additionalFields: ['sha', 'tags'], fetch(_url: string, config: any) { const url = params.search.sort diff --git a/src/pages/llmodels/components/deploy-modal.tsx b/src/pages/llmodels/components/deploy-modal.tsx index 6ffe8012..99e78f0d 100644 --- a/src/pages/llmodels/components/deploy-modal.tsx +++ b/src/pages/llmodels/components/deploy-modal.tsx @@ -129,6 +129,15 @@ const AddModal: FC = (props) => { const evaluateStateRef = useRef<{ state: EvaluateProccessType }>({ state: 'form' }); + const requestModelIdRef = useRef(0); + + /** + * Update the request model id to distinguish + * the evaluate request. + */ + const updateRequestModelId = () => { + requestModelIdRef.current += 1; + }; /** * @@ -194,12 +203,19 @@ const AddModal: FC = (props) => { }); if (item.fakeName) { + const currentModelId = requestModelIdRef.current; setEvaluteState(EvaluateProccess.file); const evaluateRes = await handleEvaluateOnChange?.({ changedValues: {}, allValues: form.current?.form?.getFieldsValue?.(), source: props.source }); + + if (currentModelId !== requestModelIdRef.current) { + // if the request model id has changed, do not update the form + return; + } + const defaultSpec = getDefaultSpec({ evaluateResult: evaluateRes }); @@ -224,7 +240,11 @@ const AddModal: FC = (props) => { }; const handleOnSelectModel = (item: any, evaluate?: boolean) => { - // when select a model not from the evaluate result, + /** + * evaluate: false means select a new model + * evaluate: true means select a model file from the evaluate result + */ + updateRequestModelId(); if (!evaluate) { setEvaluteState(EvaluateProccess.model); setSelectedModel(item); @@ -241,7 +261,7 @@ const AddModal: FC = (props) => { const modelInfo = onSelectModel(item, props.source); if ( !isHolderRef.current.model && - evaluateStateRef.current.state === 'model' + evaluateStateRef.current.state === EvaluateProccess.model ) { handleShowCompatibleAlert(item.evaluateResult); form.current?.setFieldsValue?.({ @@ -268,7 +288,6 @@ const AddModal: FC = (props) => { }; const handleSetIsGGUF = async (flag: boolean) => { - console.log('handleSetIsGGUF', flag); setIsGGUF(flag); await new Promise((resolve) => { setTimeout(() => { diff --git a/src/pages/llmodels/components/hf-model-file.tsx b/src/pages/llmodels/components/hf-model-file.tsx index 0fe8634f..624a2661 100644 --- a/src/pages/llmodels/components/hf-model-file.tsx +++ b/src/pages/llmodels/components/hf-model-file.tsx @@ -54,13 +54,7 @@ const includeReg = /\.(safetensors|gguf)$/i; const filterRegGGUF = /\.(gguf)$/i; const HFModelFile: React.FC = forwardRef((props, ref) => { - const { - collapsed, - modelSource, - isDownload, - gpuOptions, - displayEvaluateStatus - } = props; + const { collapsed, modelSource, isDownload, displayEvaluateStatus } = props; const intl = useIntl(); const [isEvaluating, setIsEvaluating] = useState(false); const [dataSource, setDataSource] = useState({ diff --git a/src/pages/llmodels/components/search-model.tsx b/src/pages/llmodels/components/search-model.tsx index e2e7e03f..0e1d8868 100644 --- a/src/pages/llmodels/components/search-model.tsx +++ b/src/pages/llmodels/components/search-model.tsx @@ -1,15 +1,9 @@ import { createAxiosToken } from '@/hooks/use-chunk-request'; import { QuestionCircleOutlined } from '@ant-design/icons'; import { useIntl } from '@umijs/max'; -import { Checkbox, Select, Tooltip } from 'antd'; +import { Checkbox, Pagination, Select, Tooltip } from 'antd'; import _ from 'lodash'; -import React, { - useCallback, - useEffect, - useMemo, - useRef, - useState -} from 'react'; +import React, { useEffect, useMemo, useRef, useState } from 'react'; import styled from 'styled-components'; import { evaluationsModelSpec, @@ -86,6 +80,12 @@ const SearchModel: React.FC = (props) => { const filterGGUFRef = useRef(!hasLinuxWorker); const filterTaskRef = useRef(''); const timer = useRef(null); + const requestIdRef = useRef(0); + const [query, setQuery] = useState({ + page: 1, + perPage: 10, + total: 100 + }); const modelFilesSortOptions = useRef([ { label: intl.formatMessage({ id: 'models.sort.trending' }), @@ -105,6 +105,11 @@ const SearchModel: React.FC = (props) => { } ]); + const updateRequestId = () => { + requestIdRef.current += 1; + return requestIdRef.current; + }; + const checkIsGGUF = (item: any) => { const isGGUF = _.some(item.tags, (tag: string) => { return tag.toLowerCase() === 'gguf'; @@ -122,7 +127,7 @@ const SearchModel: React.FC = (props) => { }; // huggeface - const getModelsFromHuggingface = useCallback(async (sort: string) => { + const getModelsFromHuggingface = async (sort: string) => { try { const task: any = searchInputRef.current ? '' : 'text-generation'; const params = { @@ -149,10 +154,14 @@ const SearchModel: React.FC = (props) => { } catch (error) { return []; } - }, []); + }; - // modelscope - const getModelsFromModelscope = useCallback(async (sort: string) => { + // modelscope, only modelscope has page and perPage + const getModelsFromModelscope = async (queryParams: { + sortType: string; + page: number; + perPage?: number; + }) => { try { const params = { Name: `${searchInputRef.current}`, @@ -160,7 +169,9 @@ const SearchModel: React.FC = (props) => { tasks: filterTaskRef.current ? ([ModelscopeTaskMap[filterTaskRef.current]] as string[]) : [], - SortBy: ModelScopeSortType[sort] + SortBy: ModelScopeSortType[queryParams.sortType], + PageNumber: queryParams.page, + PageSize: queryParams.perPage }; const data = await queryModelScopeModels(params, { signal: axiosTokenRef.current.signal @@ -179,18 +190,28 @@ const SearchModel: React.FC = (props) => { task: item.Tasks?.map((sItem: any) => sItem.Name).join(','), tags: item.Tags, libraries: item.Libraries, - isGGUF: checkIsGGUF({ tags: item.Tags, libraries: item.Libraries }), + isGGUF: checkIsGGUF({ + tags: item.Tags, + libraries: item.Libraries + }), source: modelSource }; }); + setQuery((prev) => { + return { + ...prev, + page: queryParams.page, + total: _.get(data, 'Data.Model.TotalCount', 0) + }; + }); return list; } catch (error) { return []; } - }, []); + }; - const getEvaluateResults = useCallback(async (repoList: any[]) => { + const getEvaluateResults = async (repoList: any[]) => { try { checkTokenRef.current?.cancel?.(); checkTokenRef.current = createAxiosToken(); @@ -206,12 +227,13 @@ const SearchModel: React.FC = (props) => { } catch (error) { return []; } - }, []); + }; const handleEvaluate = async (list: any[]) => { if (isDownload) { return; } + const currentRequestId = updateRequestId(); try { const repoList = list.map((item) => { const res = handleRecognizeAudioModel(item, modelSource); @@ -241,8 +263,12 @@ const SearchModel: React.FC = (props) => { }) }; }); + setIsEvaluating(true); const evaluations = await getEvaluateResults(repoList); + if (requestIdRef.current !== currentRequestId) { + return; + } const resultList = list.map((item, index) => { return { ...item, @@ -262,7 +288,10 @@ const SearchModel: React.FC = (props) => { (item) => item.id === currentRef.current ); - // if item is GGUF, the evaluating would be do after selecting the model file. + /** + * if item is GGUF, the evaluating would be do after selecting the model file, Or the evaluate status of model would be overrided the + * file evaluate status. + */ if (currentItem && !currentItem.isGGUF) { displayEvaluateStatus?.({ show: false, @@ -276,11 +305,23 @@ const SearchModel: React.FC = (props) => { handleOnSelectModel(currentItem, true); } } catch (error) { - setIsEvaluating(false); + if (requestIdRef.current === currentRequestId) { + setIsEvaluating(false); + } } }; - const handleOnSearchRepo = async (sortType?: string) => { + const getCurrentPage = (page: number) => { + const start = (page - 1) * query.perPage; + const end = start + query.perPage; + return cacheRepoOptions.current.slice(start, end); + }; + + const handleOnSearchRepo = async (params: { + sortType: string; + page: number; + perPage: number; + }) => { if (!SUPPORTEDSOURCE.includes(modelSource)) { return; } @@ -290,7 +331,7 @@ const SearchModel: React.FC = (props) => { if (timer.current) { clearTimeout(timer.current); } - const sort = sortType ?? dataSource.sortType; + const sort = params.sortType; try { setDataSource((pre) => { pre.loading = true; @@ -300,11 +341,22 @@ const SearchModel: React.FC = (props) => { cacheRepoOptions.current = []; let list: any[] = []; if (modelSource === modelSourceMap.huggingface_value) { - list = await getModelsFromHuggingface(sort); + const resultList = await getModelsFromHuggingface(sort); + cacheRepoOptions.current = resultList; + + // hf has no page and perPage, so we need to slice the resultList + list = getCurrentPage(params.page); + setQuery((prev) => { + return { + ...prev, + page: params.page, + total: resultList.length + }; + }); } else if (modelSource === modelSourceMap.modelscope_value) { - list = await getModelsFromModelscope(sort); + list = await getModelsFromModelscope(params); + cacheRepoOptions.current = list; } - cacheRepoOptions.current = list; setDataSource({ repoOptions: list, @@ -340,11 +392,18 @@ const SearchModel: React.FC = (props) => { cacheRepoOptions.current = []; } }; - const handleSearchInputChange = useCallback((e: any) => { + const handleSearchInputChange = (e: any) => { searchInputRef.current = e.target.value; - }, []); - - const handlerSearchModels = _.debounce(() => handleOnSearchRepo(), 100); + }; + const handlerSearchModels = _.debounce( + () => + handleOnSearchRepo({ + sortType: dataSource.sortType, + page: 1, + perPage: query.perPage + }), + 100 + ); const handleOnOpen = () => { if ( @@ -352,17 +411,62 @@ const SearchModel: React.FC = (props) => { !cacheRepoOptions.current.length && SUPPORTEDSOURCE.includes(modelSource) ) { - handleOnSearchRepo(); + handleOnSearchRepo({ + sortType: dataSource.sortType, + page: 1, + perPage: query.perPage + }); } }; const handleSortChange = (value: string) => { - handleOnSearchRepo(value || ''); + handleOnSearchRepo({ + sortType: value, + page: 1, + perPage: query.perPage + }); }; const handleFilterGGUFChange = (e: any) => { filterGGUFRef.current = e.target.checked; - handleOnSearchRepo(); + handleOnSearchRepo({ + sortType: dataSource.sortType, + page: 1, + perPage: query.perPage + }); + }; + + const handleOnPageChange = (page: number) => { + if (modelSource === modelSourceMap.huggingface_value) { + const currentList = getCurrentPage(page); + setQuery((prev) => { + return { + ...prev, + page: page + }; + }); + setDataSource((pre) => { + return { + ...pre, + repoOptions: currentList + }; + }); + displayEvaluateStatus?.({ + show: true, + flag: { + model: true + } + }); + console.log('isEvaluating:', isEvaluating); + handleOnSelectModel(currentList[0]); + handleEvaluate(currentList); + } else if (modelSource === modelSourceMap.modelscope_value) { + handleOnSearchRepo({ + sortType: dataSource.sortType, + page: page, + perPage: query.perPage + }); + } }; const renderGGUFTips = useMemo(() => { @@ -402,23 +506,6 @@ const SearchModel: React.FC = (props) => {
- - {intl.formatMessage( - { id: 'models.search.result' }, - { count: dataSource.repoOptions.length } - )} - - - - - - {renderGGUFTips} - - + + {renderGGUFTips} + +
); @@ -457,6 +560,7 @@ const SearchModel: React.FC = (props) => { return (
{renderHFSearch()}
+ { if (currentRequestId === requestIdRef.current) { handleShowCompatibleAlert?.(evalutionData); + return evalutionData; } - return evalutionData; + return null; }; const checkRequiredValue = (allValues: any) => {