perf: embedding calc in worker process

main
jialin 11 months ago
parent ec45f23485
commit f288c9a9ad

@ -28,6 +28,9 @@ export default defineConfig({
logLevel: 'info',
defaultSizes: 'parsed' // stat // gzip
},
mfsu: {
exclude: ['lodash', 'ml-pca']
},
base: process.env.npm_config_base || '/',
...(isProduction
? {
@ -52,7 +55,7 @@ export default defineConfig({
]);
config.module
.rule('worker')
.test(/\.worker\.js$/)
.test(/\.worker\.(js|ts)$/)
.use('worker-loader')
.loader('worker-loader');
config.output

@ -78,6 +78,7 @@
"@types/react": "^18.3.1",
"@types/react-dom": "^18.3.0",
"@umijs/case-sensitive-paths-webpack-plugin": "^1.0.1",
"@umijs/plugins": "^4.4.11",
"babel-plugin-named-asset-import": "^0.3.8",
"case-sensitive-paths-webpack-plugin": "^2.4.0",
"compression-webpack-plugin": "^11.1.0",

File diff suppressed because it is too large Load Diff

@ -148,4 +148,4 @@ const RowTextarea: React.FC<SystemMessageProps> = (props) => {
);
};
export default React.memo(RowTextarea);
export default RowTextarea;

@ -16,7 +16,10 @@ interface ErrorResultProps {
function isChunkLoadError(msg?: string): boolean {
if (typeof msg !== 'string') return false;
return msg.includes('Loading chunk') && msg.includes('failed');
const jsChunkFailed = msg.includes('Loading chunk');
const cssChunkFailed = msg.includes('Loading CSS chunk');
return (jsChunkFailed || cssChunkFailed) && msg.includes('failed');
}
const ErrorResult: React.FC<ErrorResultProps> = ({ extra }) => {

@ -21,7 +21,7 @@ export default {
'playground.params.temperature.tips':
'Controls randomness: Lowering results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive.',
'playground.params.maxtokens.tips':
"The maximum number of tokens to generated. The total length of input tokens and generated tokens is limited by the model's context length.",
"The maximum number of tokens to generate. The total length of input tokens and generated tokens is limited by the model's context length.",
'playground.params.topp.tips':
'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.',
'playground.params.seed.tips':

@ -16,7 +16,6 @@ import { useIntl } from '@umijs/max';
import { Button, Checkbox, Form, Segmented, Spin, Tabs, Tooltip } from 'antd';
import classNames from 'classnames';
import _ from 'lodash';
import { PCA } from 'ml-pca';
import 'overlayscrollbars/overlayscrollbars.css';
import { Resizable } from 're-resizable';
import React, {
@ -31,6 +30,7 @@ import React, {
import { EMBEDDING_API, handleEmbedding } from '../apis';
import { extractErrorMessage } from '../config';
import { LLM_METAKEYS } from '../hooks/config';
import useEmbeddingWorker from '../hooks/use-embedding-worker';
import { useInitLLmMeta } from '../hooks/use-init-meta';
import '../style/ground-left.less';
import '../style/rerank.less';
@ -51,6 +51,8 @@ const GroundEmbedding: React.FC<MessageProps> = forwardRef((props, ref) => {
const messageId = useRef<number>(0);
const intl = useIntl();
const { workerRef, createWorker, postMessage, terminateWorker } =
useEmbeddingWorker();
const requestSource = useRequestToken();
const [show, setShow] = useState(false);
const [loading, setLoading] = useState(false);
@ -154,46 +156,6 @@ const GroundEmbedding: React.FC<MessageProps> = forwardRef((props, ref) => {
return list.length < 2;
}, [textList, fileList]);
const generateEmbedding = useCallback(
(embeddings: any[]) => {
try {
const dataList = embeddings.map((item) => {
return item.embedding;
});
const pca = new PCA(dataList);
const pcadata = pca.predict(dataList, { nComponents: 2 }).to2DArray();
const input = [
...textList.map((item) => item.text).filter((item) => item),
...fileList.map((item) => item.text).filter((item) => item)
];
const list = pcadata.map((item: number[], index: number) => {
return {
value: item,
name: index + 1,
text: input[index]
};
});
setScatterData(list);
const embeddingJson = embeddings.map((o, index) => {
const item = _.cloneDeep(o);
item.embedding = item.embedding.slice(0, 5);
item.embedding.push(null);
return item;
});
setEmbeddingData({
code: JSON.stringify(embeddingJson, null, 2).replace(/null/g, '...'),
copyValue: JSON.stringify(embeddings, null, 2)
});
} catch (e) {
console.log('error:', e);
}
},
[textList, fileList]
);
const setMessageId = () => {
messageId.current = messageId.current + 1;
};
@ -245,12 +207,29 @@ const GroundEmbedding: React.FC<MessageProps> = forwardRef((props, ref) => {
}
);
console.log('result=========', result);
setTokenResult(result.usage);
const embeddingsList = result.data || [];
generateEmbedding(embeddingsList);
createWorker();
workerRef.current!.onmessage = (event: MessageEvent) => {
const { scatterData, embeddingData } = event.data;
console.log('worker result:', scatterData, embeddingData);
setScatterData(scatterData);
setEmbeddingData(embeddingData);
};
postMessage({
embeddings: embeddingsList,
textList: textList,
fileList: fileList
});
// generateEmbedding(embeddingsList);
} catch (error: any) {
console.log('result=========error', error);
setTokenResult({
error: true,
errorMessage: extractErrorMessage(error.response)
@ -324,12 +303,11 @@ const GroundEmbedding: React.FC<MessageProps> = forwardRef((props, ref) => {
if (!multiplePasteEnable.current) return;
const text = e.clipboardData.getData('text');
if (text) {
const currentContent = textList[index].text;
const dataLlist = text.split('\n').map((item: string) => {
return {
text: item?.trim(),
uid: inputListRef.current?.setMessageId(),
name: ''
name: '',
uid: setMessageId()
};
});
dataLlist[0].text = `${selectionTextRef.current?.beforeText || ''}${dataLlist[0].text}${selectionTextRef.current?.afterText || ''}`;
@ -337,14 +315,8 @@ const GroundEmbedding: React.FC<MessageProps> = forwardRef((props, ref) => {
...textList.slice(0, index),
...dataLlist,
...textList.slice(index + 1)
]
.filter((item) => item.text)
.map((item, index) => {
return {
...item,
uid: inputListRef.current?.setMessageId()
};
});
].filter((item) => item.text);
setTextList(result);
}
},

@ -111,7 +111,6 @@ const GroundReranker: React.FC<MessageProps> = forwardRef((props, ref) => {
name: ''
}
]);
const [sortIndexMap, setSortIndexMap] = useState<number[]>([]);
const [queryValue, setQueryValue] = useState<string>('');
const selectionTextRef = useRef<any>(null);
@ -237,8 +236,6 @@ const GroundReranker: React.FC<MessageProps> = forwardRef((props, ref) => {
setLoading(true);
setMessageId();
setTokenResult(null);
setSortIndexMap([]);
requestToken.current?.cancel?.();
requestToken.current = requestSource();
@ -278,8 +275,6 @@ const GroundReranker: React.FC<MessageProps> = forwardRef((props, ref) => {
let sortMap: number[] = [];
result.results?.forEach((item: any, sIndex: number) => {
sortMap.push(item.index);
newTextList[item.index] = {
...newTextList[item.index],
uid: setMessageId(),
@ -295,7 +290,6 @@ const GroundReranker: React.FC<MessageProps> = forwardRef((props, ref) => {
});
newTextList = _.sortBy(newTextList, 'rank');
setSortIndexMap(sortMap);
setTextList(newTextList);
} catch (error: any) {
setTokenResult({
@ -360,8 +354,11 @@ const GroundReranker: React.FC<MessageProps> = forwardRef((props, ref) => {
const dataLlist = text.split('\n').map((item: string) => {
return {
text: item?.trim(),
name: '',
uid: setMessageId(),
name: ''
percent: undefined,
score: undefined,
rank: undefined
};
});
dataLlist[0].text = `${selectionTextRef.current?.beforeText || ''}${dataLlist[0].text}${selectionTextRef.current?.afterText || ''}`;
@ -369,36 +366,14 @@ const GroundReranker: React.FC<MessageProps> = forwardRef((props, ref) => {
...textList.slice(0, index),
...dataLlist,
...textList.slice(index + 1)
]
.filter((item) => item.text)
.map((item, index) => {
item.percent = undefined;
item.score = undefined;
item.rank = undefined;
return {
...item,
uid: setMessageId()
};
});
].filter((item) => item.text);
setTextList(result);
}
},
[textList]
);
const handleOnSort = useCallback(
(list: { text: string; uid: number | string; name: string }[]) => {
const newList = list?.map((item) => {
return {
...item,
uid: setMessageId()
};
});
setTextList(newList);
},
[]
);
const renderExtra = useMemo(() => {
if (modelMeta?.n_ctx && modelMeta?.n_slot) {
return (
@ -538,14 +513,11 @@ const GroundReranker: React.FC<MessageProps> = forwardRef((props, ref) => {
<div className="docs-wrapper">
<InputList
key={messageId.current}
sortIndex={sortIndexMap}
ref={inputListRef}
textList={textList}
showLabel={false}
sortable={false}
height={46}
onChange={handleTextListChange}
onSort={handleOnSort}
extra={renderPercent}
onSelect={handleonSelect}
onPaste={handleOnPaste}

@ -2,14 +2,7 @@ import RowTextarea from '@/components/seal-form/row-textarea';
import { DeleteOutlined } from '@ant-design/icons';
import { useIntl } from '@umijs/max';
import { Button, Tooltip } from 'antd';
import _ from 'lodash';
import React, {
forwardRef,
useCallback,
useEffect,
useImperativeHandle,
useRef
} from 'react';
import React, { forwardRef, useImperativeHandle, useRef } from 'react';
import '../style/input-list.less';
interface InputListProps {
@ -17,20 +10,15 @@ interface InputListProps {
height?: number;
extra?: (data: any) => React.ReactNode;
showLabel?: boolean;
sortIndex?: number[];
textList: {
text: string;
uid: number | string;
name: string;
}[];
sortable?: boolean;
onChange?: (
textList: { text: string; uid: number | string; name: string }[]
) => void;
onPaste?: (e: any, index: number) => void;
onSort?: (
textList: { text: string; uid: number | string; name: string }[]
) => void;
onSelect?: (data: {
start: number;
end: number;
@ -42,152 +30,12 @@ interface InputListProps {
const InputList: React.FC<InputListProps> = forwardRef(
(
{
textList,
showLabel = true,
sortIndex = [],
sortable,
height,
onSort,
onChange,
extra,
onPaste,
onSelect
},
{ textList, showLabel = true, height, onChange, extra, onPaste, onSelect },
ref
) => {
const intl = useIntl();
const messageId = useRef(0);
const containerRef = useRef<any>(null);
const childListRef = useRef<any[]>([]);
const getContainerChildList = () => {
childListRef.current = Array.from(containerRef.current?.children || []);
};
const getOffsetUsingBoundingClientRect = useCallback(
(element: HTMLElement, targetElement: HTMLElement) => {
const currentRect = element.getBoundingClientRect();
const targetRect = targetElement.getBoundingClientRect();
return {
x: targetRect.left - currentRect.left,
y: targetRect.top - currentRect.top
};
},
[]
);
// move item from fromIndex to toIndex
const moveItem = useCallback((child: any, toIndex: number) => {
const container = containerRef.current;
if (!container) return;
const children = Array.from(container.children);
if (toIndex >= children.length) {
if (container.firstChild) {
container.insertBefore(child, container.firstChild);
} else {
container.appendChild(child);
}
} else {
container.insertBefore(child, children[toIndex]);
}
}, []);
const moveElement = (arr: any[], fromIndex: number, toIndex: number) => {
if (
fromIndex === toIndex ||
fromIndex < 0 ||
toIndex < 0 ||
fromIndex >= arr.length ||
toIndex > arr.length
) {
return arr;
}
const element = arr.splice(fromIndex, 1)[0];
arr.splice(toIndex > fromIndex ? toIndex - 1 : toIndex, 0, element);
return arr;
};
const sort = useCallback(() => {
if (!sortable) return;
getContainerChildList();
const container = containerRef.current;
if (!container) return;
const newOrder = [...textList];
const offsets = sortIndex.map((fromIndex, toIndex) => {
const currentElement = childListRef.current[fromIndex];
const targetElement = childListRef.current[toIndex];
if (!currentElement || !targetElement) return null;
const offset = getOffsetUsingBoundingClientRect(
currentElement,
targetElement
);
return {
element: currentElement,
offset: {
x: offset.x,
y: offset.y
},
fromIndex,
toIndex
};
});
const moveSequentially = async () => {
for (const [index, data] of offsets.entries()) {
if (!data) continue;
const { element, offset, fromIndex, toIndex } = data;
console.log('sort+++++++', {
textList,
element,
offset,
fromIndex,
toIndex
});
await new Promise((resolve) => {
element.style.opacity = 0.5;
element.style.transform = `translate(${offset.x}px, ${offset.y}px)`;
element.style.transition = 'transform 0.8s,opacity 0.8s';
element.addEventListener(
'transitionend',
() => {
moveItem(element, toIndex);
moveElement(newOrder, fromIndex, toIndex);
getContainerChildList();
element.style.opacity = 1;
element.style.transform = '';
console.log('sort++++++++++end');
resolve(null);
},
{ once: true }
);
});
}
onSort?.(newOrder);
};
moveSequentially();
}, [
sortable,
sortIndex,
textList,
onSort,
getContainerChildList,
getOffsetUsingBoundingClientRect
]);
const setMessageId = () => {
messageId.current = messageId.current + 1;
@ -221,14 +69,6 @@ const InputList: React.FC<InputListProps> = forwardRef(
dataList[index].text = value;
onChange?.(dataList);
};
const debounceSort = _.debounce(sort, 100);
useEffect(() => {
if (sortIndex?.length) {
console.log('sort++++2+++');
debounceSort();
}
}, [sortIndex]);
useImperativeHandle(ref, () => ({
handleAdd,
@ -277,4 +117,4 @@ const InputList: React.FC<InputListProps> = forwardRef(
}
);
export default React.memo(InputList);
export default InputList;

@ -0,0 +1,69 @@
import _ from 'lodash';
import { PCA } from 'ml-pca';
self.onmessage = (
event: MessageEvent<{
embeddings: any[];
fileList: { text: string; name: string; uid: number | string }[];
textList: { text: string; name: string; uid: number | string }[];
}>
) => {
const { embeddings, fileList, textList } = event.data;
try {
const dataList = embeddings.map((item) => {
return item.embedding;
});
console.log('list:', dataList);
const pca = new PCA(dataList);
console.log('list:', pca);
const pcadata = pca.predict(dataList, { nComponents: 2 }).to2DArray();
const input = [
...textList.map((item) => item.text).filter((item) => item),
...fileList.map((item) => item.text).filter((item) => item)
];
const list = pcadata.map((item: number[], index: number) => {
return {
value: item,
name: index + 1,
text: input[index]
};
});
const embeddingJson = embeddings.map((o, index) => {
const item = _.cloneDeep(o);
item.embedding = item.embedding.slice(0, 5);
item.embedding.push(null);
return item;
});
const embeddingData = {
code: JSON.stringify(embeddingJson, null, 2).replace(/null/g, '...'),
copyValue: JSON.stringify(embeddings, null, 2)
};
self.postMessage({
scatterData: list,
embeddingData: embeddingData
});
} catch (e) {
console.log('error:', e);
self.postMessage({
scatterData: [],
embeddingData: {
code: '',
copyValue: ''
}
});
}
};
self.onerror = (e) => {
console.log('error:', e);
};
self.onmessageerror = (e) => {
console.log('message error:', e);
};

@ -0,0 +1,42 @@
import { useEffect, useRef } from 'react';
export default function useEmbeddingWorker() {
const workerRef = useRef<Worker | null>(null);
const createWorker = () => {
if (workerRef.current) {
workerRef.current.terminate();
}
workerRef.current = new Worker(
new URL('../config/embedding-worker.worker.ts', import.meta.url),
{
type: 'module'
}
);
};
const postMessage = (params: {
embeddings: any[];
textList: { text: string; name: string; uid: number | string }[];
fileList: { text: string; name: string; uid: number | string }[];
}) => {
if (workerRef.current) {
workerRef.current.postMessage(params);
}
};
const terminateWorker = () => {
if (workerRef.current) {
workerRef.current.terminate();
workerRef.current = null;
}
};
useEffect(() => {
return () => {
terminateWorker();
};
}, []);
return { workerRef, createWorker, postMessage, terminateWorker };
}

@ -27,5 +27,6 @@
"src/components/logs-viewer/parse-worker.ts",
"src/components/image-editor/invert-worker.ts",
"src/components/image-editor/offscreen-worker.ts"
// "src/pages/playground/config/embedding-worker.ts"
]
}

Loading…
Cancel
Save