diff --git a/src/pages/playground/components/image-edit.tsx b/src/pages/playground/components/image-edit.tsx new file mode 100644 index 00000000..33993d03 --- /dev/null +++ b/src/pages/playground/components/image-edit.tsx @@ -0,0 +1,619 @@ +import AlertInfo from '@/components/alert-info'; +import FieldComponent from '@/components/seal-form/field-component'; +import SealSelect from '@/components/seal-form/seal-select'; +import useOverlayScroller from '@/hooks/use-overlay-scroller'; +import ThumbImg from '@/pages/playground/components/thumb-img'; +import { generateRandomNumber } from '@/utils'; +import { + fetchChunkedData, + readLargeStreamData as readStreamData +} from '@/utils/fetch-chunk-data'; +import { FileImageOutlined, SwapOutlined } from '@ant-design/icons'; +import { useIntl, useSearchParams } from '@umijs/max'; +import { Button, Form, Tooltip } from 'antd'; +import classNames from 'classnames'; +import _ from 'lodash'; +import 'overlayscrollbars/overlayscrollbars.css'; +import React, { + forwardRef, + memo, + useCallback, + useEffect, + useImperativeHandle, + useMemo, + useRef, + useState +} from 'react'; +import { CREAT_IMAGE_API } from '../apis'; +import { promptList } from '../config'; +import { + ImageAdvancedParamsConfig, + ImageCustomSizeConfig, + ImageconstExtraConfig, + ImageEidtParamsConfig as paramsConfig +} from '../config/params-config'; +import { MessageItem, ParamsSchema } from '../config/types'; +import '../style/ground-left.less'; +import '../style/system-message-wrap.less'; +import { generateImageCode, generateOpenaiImageCode } from '../view-code/image'; +import DynamicParams from './dynamic-params'; +import MessageInput from './message-input'; +import ViewCommonCode from './view-common-code'; + +interface MessageProps { + modelList: Global.BaseOption[]; + loaded?: boolean; + ref?: any; +} +const advancedFieldsDefaultValus = { + seed: null, + sampler: 'euler_a', + cfg_scale: 4.5, + sample_steps: 10, + negative_prompt: null, + schedule: 'discrete' +}; + +const openaiCompatibleFieldsDefaultValus = { + quality: 'standard', + style: null +}; + +const initialValues = { + n: 1, + size: '512x512', + ...advancedFieldsDefaultValus +}; + +const GroundImages: React.FC = forwardRef((props, ref) => { + const { modelList } = props; + const messageId = useRef(0); + const [isOpenaiCompatible, setIsOpenaiCompatible] = useState(false); + const [imageList, setImageList] = useState< + { + dataUrl: string; + height: number | string; + width: string | number; + maxHeight: string | number; + maxWidth: string | number; + uid: number; + span?: number; + loading?: boolean; + progress?: number; + }[] + >([]); + + const intl = useIntl(); + const [searchParams] = useSearchParams(); + const selectModel = searchParams.get('model') || ''; + const [parameters, setParams] = useState({}); + const [show, setShow] = useState(false); + const [loading, setLoading] = useState(false); + const [tokenResult, setTokenResult] = useState(null); + const [collapse, setCollapse] = useState(false); + const scroller = useRef(null); + const paramsRef = useRef(null); + const messageListLengthCache = useRef(0); + const requestToken = useRef(null); + const [currentPrompt, setCurrentPrompt] = useState(''); + const form = useRef(null); + const inputRef = useRef(null); + + const size = Form.useWatch('size', form.current?.form); + + const { initialize, updateScrollerPosition } = useOverlayScroller(); + const { initialize: innitializeParams } = useOverlayScroller(); + + useImperativeHandle(ref, () => { + return { + viewCode() { + setShow(true); + }, + setCollapse() { + setCollapse(!collapse); + }, + collapse: collapse + }; + }); + + const generateNumber = (min: number, max: number) => { + return Math.floor(Math.random() * (max - min + 1) + min); + }; + + const handleRandomPrompt = useCallback(() => { + const randomIndex = generateNumber(0, promptList.length - 1); + const randomPrompt = promptList[randomIndex]; + inputRef.current?.handleInputChange({ + target: { + value: randomPrompt + } + }); + }, []); + + const setImageSize = useCallback(() => { + let size: Record = { + span: 12 + }; + if (parameters.n === 1) { + size.span = 24; + } + if (parameters.n === 2) { + size.span = 12; + } + if (parameters.n === 3) { + size.span = 12; + } + if (parameters.n === 4) { + size.span = 12; + } + return size; + }, [parameters.n]); + + const finalParameters = useMemo(() => { + if (parameters.size === 'custom') { + return { + ..._.omit(parameters, ['width', 'height']), + size: + parameters.width && parameters.height + ? `${parameters.width}x${parameters.height}` + : '' + }; + } + return { + ..._.omit(parameters, ['width', 'height', 'random_seed']) + }; + }, [parameters]); + + const viewCodeContent = useMemo(() => { + if (isOpenaiCompatible) { + return generateOpenaiImageCode({ + api: '/v1-openai/images/generations', + parameters: { + ...finalParameters, + prompt: currentPrompt + } + }); + } + return generateImageCode({ + api: '/v1-openai/images/generations', + parameters: { + ...finalParameters, + prompt: currentPrompt + } + }); + }, [finalParameters, currentPrompt, parameters.size]); + + const setMessageId = () => { + messageId.current = messageId.current + 1; + return messageId.current; + }; + + const handleStopConversation = () => { + requestToken.current?.abort?.(); + setLoading(false); + }; + + const submitMessage = async (current?: { content: string }) => { + try { + await form.current?.form?.validateFields(); + if (!parameters.model) return; + const size: any = setImageSize(); + setLoading(true); + setMessageId(); + setTokenResult(null); + setCurrentPrompt(current?.content || ''); + const imgSize = _.split(finalParameters.size, 'x'); + + let newImageList = Array(parameters.n) + .fill({}) + .map((item, index: number) => { + return { + dataUrl: 'data:image/png;base64,', + ...size, + progress: 0, + height: imgSize[1], + width: imgSize[0], + loading: true, + uid: setMessageId() + }; + }); + setImageList(newImageList); + + requestToken.current?.abort?.(); + requestToken.current = new AbortController(); + + const params = { + ..._.omitBy(finalParameters, (value: string) => !value), + seed: parameters.random_seed ? generateRandomNumber() : parameters.seed, + stream: true, + stream_options: { + chunk_size: 16 * 1024, + chunk_results: true + }, + prompt: current?.content || currentPrompt || '' + }; + setParams({ + ...parameters, + seed: params.seed + }); + form.current?.form?.setFieldValue('seed', params.seed); + + const result: any = await fetchChunkedData({ + data: params, + url: `${CREAT_IMAGE_API}?t=${Date.now()}`, + signal: requestToken.current.signal + }); + if (result.error) { + setTokenResult({ + error: true, + errorMessage: + result?.data?.error?.message || result?.data?.error || '' + }); + setImageList([]); + return; + } + + const { reader, decoder } = result; + + await readStreamData(reader, decoder, (chunk: any) => { + if (chunk?.error) { + setTokenResult({ + error: true, + errorMessage: chunk?.error?.message || chunk?.message || '' + }); + return; + } + chunk?.data?.forEach((item: any) => { + const imgItem = newImageList[item.index]; + if (item.b64_json) { + imgItem.dataUrl += item.b64_json; + } + const progress = _.round(item.progress, 0); + newImageList[item.index] = { + dataUrl: imgItem.dataUrl, + height: imgSize[1], + width: imgSize[0], + maxHeight: `${imgSize[1]}px`, + maxWidth: `${imgSize[0]}px`, + uid: imgItem.uid, + span: imgItem.span, + loading: progress < 100, + progress: progress + }; + }); + setImageList([...newImageList]); + }); + } catch (error) { + console.log('error:', error); + requestToken.current?.abort?.(); + setImageList([]); + } finally { + setLoading(false); + } + }; + const handleClear = () => { + setMessageId(); + setImageList([]); + setTokenResult(null); + }; + + const handleInputChange = (e: any) => { + setCurrentPrompt(e.target.value); + }; + + const handleSendMessage = (message: Omit) => { + const currentMessage = message.content ? message : undefined; + submitMessage(currentMessage); + }; + + const handleCloseViewCode = () => { + setShow(false); + }; + + const handleToggleParamsStyle = () => { + if (isOpenaiCompatible) { + form.current?.form?.setFieldsValue({ + ...advancedFieldsDefaultValus + }); + setParams((pre: object) => { + return { + ..._.omit(pre, _.keys(openaiCompatibleFieldsDefaultValus)), + ...advancedFieldsDefaultValus + }; + }); + } else { + form.current?.form?.setFieldsValue({ + ...openaiCompatibleFieldsDefaultValus + }); + setParams((pre: object) => { + return { + ...openaiCompatibleFieldsDefaultValus, + ..._.omit(pre, _.keys(advancedFieldsDefaultValus)) + }; + }); + } + setIsOpenaiCompatible(!isOpenaiCompatible); + }; + + const renderExtra = useMemo(() => { + if (!isOpenaiCompatible) { + return []; + } + return ImageconstExtraConfig.map((item: ParamsSchema) => { + return ( + + + + ); + }); + }, [ImageconstExtraConfig, isOpenaiCompatible, intl]); + + const handleFieldChange = (e: any) => { + if (e.target.id.indexOf('random_seed') > -1) { + form.current?.form?.setFieldValue('random_seed', e.target.checked); + setParams((pre: object) => { + return { + ...pre, + random_seed: e.target.checked + }; + }); + } + }; + const renderAdvanced = useMemo(() => { + if (isOpenaiCompatible) { + return []; + } + const formValues = form.current?.form?.getFieldsValue(); + return ImageAdvancedParamsConfig.map((item: ParamsSchema) => { + return ( + + + + ); + }); + }, [ImageAdvancedParamsConfig, isOpenaiCompatible, intl, form.current]); + + const renderCustomSize = useMemo(() => { + if (size === 'custom') { + return ImageCustomSizeConfig.map((item: ParamsSchema) => { + return ( + + + + ); + }); + } + return null; + }, [size, intl]); + + useEffect(() => { + return () => { + requestToken.current?.abort?.(); + }; + }, []); + + useEffect(() => { + if (size === 'custom') { + form.current?.form?.setFieldsValue({ + width: 512, + height: 512 + }); + setParams((pre: object) => { + return { + ...pre, + width: 512, + height: 512 + }; + }); + } + }, [size]); + + useEffect(() => { + if (scroller.current) { + initialize(scroller.current); + } + }, [scroller.current, initialize]); + + useEffect(() => { + if (paramsRef.current) { + innitializeParams(paramsRef.current); + } + }, [paramsRef.current, innitializeParams]); + + useEffect(() => { + if (loading) { + updateScrollerPosition(); + } + }, [imageList, loading]); + + useEffect(() => { + if (imageList.length > messageListLengthCache.current) { + updateScrollerPosition(); + } + messageListLengthCache.current = imageList.length; + }, [imageList.length]); + + return ( +
+
+
+ <> +
+ + {!imageList.length && ( +
+ + + + + {intl.formatMessage({ id: 'playground.params.empty.tips' })} + +
+ )} +
+ +
+ {tokenResult && ( +
+ +
+ )} +
+ + {intl.formatMessage({ id: 'playground.image.prompt' })} + + } + loading={loading} + disabled={!parameters.model} + isEmpty={!imageList.length} + handleSubmit={handleSendMessage} + handleAbortFetch={handleStopConversation} + onInputChange={handleInputChange} + shouldResetMessage={false} + clearAll={handleClear} + /> +
+
+
+
+ + + {intl.formatMessage({ id: 'playground.parameters' })} + + + + +
+ } + setParams={setParams} + paramsConfig={paramsConfig} + initialValues={initialValues} + params={parameters} + selectedModel={selectModel} + modelList={modelList} + extra={[renderCustomSize, ...renderExtra, ...renderAdvanced]} + /> +
+
+ + + ); +}); + +export default memo(GroundImages); diff --git a/src/pages/playground/config/params-config.ts b/src/pages/playground/config/params-config.ts index 8a73790a..9efc43c0 100644 --- a/src/pages/playground/config/params-config.ts +++ b/src/pages/playground/config/params-config.ts @@ -128,6 +128,67 @@ export const ImageParamsConfig: ParamsSchema[] = [ } ]; +export const ImageEidtParamsConfig: ParamsSchema[] = [ + // { + // type: 'Slider', + // name: 'brush_size', + // label: { + // text: 'Brush Size', + // isLocalized: false + // }, + // attrs: { + // min: 25, + // max: 80 + // }, + // rules: [ + // { + // required: false + // } + // ] + // }, + { + type: 'InputNumber', + name: 'n', + label: { + text: 'playground.params.counts', + isLocalized: true + }, + attrs: { + min: 1, + max: 4 + }, + rules: [ + { + required: false + } + ] + }, + { + type: 'Select', + name: 'size', + options: [ + { label: 'playground.params.custom', value: 'custom', locale: true }, + { label: '512x512', value: '512x512' }, + { label: '768x1024', value: '768x1024' }, + { label: '1024x1024', value: '1024x1024' } + ], + description: { + text: 'playground.params.size.description', + html: true, + isLocalized: true + }, + label: { + text: 'playground.params.size', + isLocalized: true + }, + rules: [ + { + required: false + } + ] + } +]; + export const ImageconstExtraConfig: ParamsSchema[] = [ { type: 'Select', diff --git a/src/pages/playground/images.tsx b/src/pages/playground/images.tsx index e26499cf..1df28119 100644 --- a/src/pages/playground/images.tsx +++ b/src/pages/playground/images.tsx @@ -2,31 +2,76 @@ import IconFont from '@/components/icon-font'; import breakpoints from '@/config/breakpoints'; import HotKeys from '@/config/hotkeys'; import useWindowResize from '@/hooks/use-window-resize'; +import { DiffOutlined, HighlightOutlined } from '@ant-design/icons'; import { PageContainer } from '@ant-design/pro-components'; import { useIntl } from '@umijs/max'; -import { Button, Space } from 'antd'; +import { Button, Segmented, Space, Tabs, TabsProps } from 'antd'; import classNames from 'classnames'; import _ from 'lodash'; import { useCallback, useEffect, useRef, useState } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { queryModelsList } from './apis'; import GroundImages from './components/ground-images'; +import ImageEdit from './components/image-edit'; import './style/play-ground.less'; +const TabsValueMap = { + Tab1: 'generate', + Tab2: 'edit' +}; + const TextToImages: React.FC = () => { const intl = useIntl(); const { size } = useWindowResize(); + const [activeKey, setActiveKey] = useState(TabsValueMap.Tab1); const groundTabRef1 = useRef(null); + const groundTabRef2 = useRef(null); const [modelList, setModelList] = useState[]>([]); const [loaded, setLoaded] = useState(false); + const optionsList = [ + { + label: 'Generate', + value: TabsValueMap.Tab1, + icon: + }, + { + label: 'Edit', + value: TabsValueMap.Tab2, + icon: + } + ]; + const handleViewCode = useCallback(() => { - groundTabRef1.current?.viewCode?.(); - }, []); + if (activeKey === TabsValueMap.Tab1) { + groundTabRef1.current?.viewCode?.(); + } else if (activeKey === TabsValueMap.Tab2) { + groundTabRef2.current?.viewCode?.(); + } + }, [activeKey]); const handleToggleCollapse = useCallback(() => { - groundTabRef1.current?.setCollapse?.(); - }, []); + if (activeKey === TabsValueMap.Tab1) { + groundTabRef1.current?.setCollapse?.(); + return; + } + groundTabRef2.current?.setCollapse?.(); + }, [activeKey]); + + const items: TabsProps['items'] = [ + { + key: TabsValueMap.Tab1, + label: 'Generate', + children: ( + + ) + }, + { + key: TabsValueMap.Tab2, + label: 'Edit', + children: + } + ]; useEffect(() => { if (size.width < breakpoints.lg) { @@ -105,7 +150,21 @@ const TextToImages: React.FC = () => { + + {intl.formatMessage({ id: 'menu.playground.text2images' })} + + { + setActiveKey(key)} + > + } + + ), breadcrumb: {} }} extra={renderExtra()} @@ -113,10 +172,7 @@ const TextToImages: React.FC = () => { >
- +