feat: image edit

main
jialin 1 year ago
parent 972ee69fac
commit 91c5e7e06c

@ -0,0 +1,619 @@
import AlertInfo from '@/components/alert-info';
import FieldComponent from '@/components/seal-form/field-component';
import SealSelect from '@/components/seal-form/seal-select';
import useOverlayScroller from '@/hooks/use-overlay-scroller';
import ThumbImg from '@/pages/playground/components/thumb-img';
import { generateRandomNumber } from '@/utils';
import {
fetchChunkedData,
readLargeStreamData as readStreamData
} from '@/utils/fetch-chunk-data';
import { FileImageOutlined, SwapOutlined } from '@ant-design/icons';
import { useIntl, useSearchParams } from '@umijs/max';
import { Button, Form, Tooltip } from 'antd';
import classNames from 'classnames';
import _ from 'lodash';
import 'overlayscrollbars/overlayscrollbars.css';
import React, {
forwardRef,
memo,
useCallback,
useEffect,
useImperativeHandle,
useMemo,
useRef,
useState
} from 'react';
import { CREAT_IMAGE_API } from '../apis';
import { promptList } from '../config';
import {
ImageAdvancedParamsConfig,
ImageCustomSizeConfig,
ImageconstExtraConfig,
ImageEidtParamsConfig as paramsConfig
} from '../config/params-config';
import { MessageItem, ParamsSchema } from '../config/types';
import '../style/ground-left.less';
import '../style/system-message-wrap.less';
import { generateImageCode, generateOpenaiImageCode } from '../view-code/image';
import DynamicParams from './dynamic-params';
import MessageInput from './message-input';
import ViewCommonCode from './view-common-code';
interface MessageProps {
modelList: Global.BaseOption<string>[];
loaded?: boolean;
ref?: any;
}
const advancedFieldsDefaultValus = {
seed: null,
sampler: 'euler_a',
cfg_scale: 4.5,
sample_steps: 10,
negative_prompt: null,
schedule: 'discrete'
};
const openaiCompatibleFieldsDefaultValus = {
quality: 'standard',
style: null
};
const initialValues = {
n: 1,
size: '512x512',
...advancedFieldsDefaultValus
};
const GroundImages: React.FC<MessageProps> = forwardRef((props, ref) => {
const { modelList } = props;
const messageId = useRef<number>(0);
const [isOpenaiCompatible, setIsOpenaiCompatible] = useState<boolean>(false);
const [imageList, setImageList] = useState<
{
dataUrl: string;
height: number | string;
width: string | number;
maxHeight: string | number;
maxWidth: string | number;
uid: number;
span?: number;
loading?: boolean;
progress?: number;
}[]
>([]);
const intl = useIntl();
const [searchParams] = useSearchParams();
const selectModel = searchParams.get('model') || '';
const [parameters, setParams] = useState<any>({});
const [show, setShow] = useState(false);
const [loading, setLoading] = useState(false);
const [tokenResult, setTokenResult] = useState<any>(null);
const [collapse, setCollapse] = useState(false);
const scroller = useRef<any>(null);
const paramsRef = useRef<any>(null);
const messageListLengthCache = useRef<number>(0);
const requestToken = useRef<any>(null);
const [currentPrompt, setCurrentPrompt] = useState<string>('');
const form = useRef<any>(null);
const inputRef = useRef<any>(null);
const size = Form.useWatch('size', form.current?.form);
const { initialize, updateScrollerPosition } = useOverlayScroller();
const { initialize: innitializeParams } = useOverlayScroller();
useImperativeHandle(ref, () => {
return {
viewCode() {
setShow(true);
},
setCollapse() {
setCollapse(!collapse);
},
collapse: collapse
};
});
const generateNumber = (min: number, max: number) => {
return Math.floor(Math.random() * (max - min + 1) + min);
};
const handleRandomPrompt = useCallback(() => {
const randomIndex = generateNumber(0, promptList.length - 1);
const randomPrompt = promptList[randomIndex];
inputRef.current?.handleInputChange({
target: {
value: randomPrompt
}
});
}, []);
const setImageSize = useCallback(() => {
let size: Record<string, string | number> = {
span: 12
};
if (parameters.n === 1) {
size.span = 24;
}
if (parameters.n === 2) {
size.span = 12;
}
if (parameters.n === 3) {
size.span = 12;
}
if (parameters.n === 4) {
size.span = 12;
}
return size;
}, [parameters.n]);
const finalParameters = useMemo(() => {
if (parameters.size === 'custom') {
return {
..._.omit(parameters, ['width', 'height']),
size:
parameters.width && parameters.height
? `${parameters.width}x${parameters.height}`
: ''
};
}
return {
..._.omit(parameters, ['width', 'height', 'random_seed'])
};
}, [parameters]);
const viewCodeContent = useMemo(() => {
if (isOpenaiCompatible) {
return generateOpenaiImageCode({
api: '/v1-openai/images/generations',
parameters: {
...finalParameters,
prompt: currentPrompt
}
});
}
return generateImageCode({
api: '/v1-openai/images/generations',
parameters: {
...finalParameters,
prompt: currentPrompt
}
});
}, [finalParameters, currentPrompt, parameters.size]);
const setMessageId = () => {
messageId.current = messageId.current + 1;
return messageId.current;
};
const handleStopConversation = () => {
requestToken.current?.abort?.();
setLoading(false);
};
const submitMessage = async (current?: { content: string }) => {
try {
await form.current?.form?.validateFields();
if (!parameters.model) return;
const size: any = setImageSize();
setLoading(true);
setMessageId();
setTokenResult(null);
setCurrentPrompt(current?.content || '');
const imgSize = _.split(finalParameters.size, 'x');
let newImageList = Array(parameters.n)
.fill({})
.map((item, index: number) => {
return {
dataUrl: 'data:image/png;base64,',
...size,
progress: 0,
height: imgSize[1],
width: imgSize[0],
loading: true,
uid: setMessageId()
};
});
setImageList(newImageList);
requestToken.current?.abort?.();
requestToken.current = new AbortController();
const params = {
..._.omitBy(finalParameters, (value: string) => !value),
seed: parameters.random_seed ? generateRandomNumber() : parameters.seed,
stream: true,
stream_options: {
chunk_size: 16 * 1024,
chunk_results: true
},
prompt: current?.content || currentPrompt || ''
};
setParams({
...parameters,
seed: params.seed
});
form.current?.form?.setFieldValue('seed', params.seed);
const result: any = await fetchChunkedData({
data: params,
url: `${CREAT_IMAGE_API}?t=${Date.now()}`,
signal: requestToken.current.signal
});
if (result.error) {
setTokenResult({
error: true,
errorMessage:
result?.data?.error?.message || result?.data?.error || ''
});
setImageList([]);
return;
}
const { reader, decoder } = result;
await readStreamData(reader, decoder, (chunk: any) => {
if (chunk?.error) {
setTokenResult({
error: true,
errorMessage: chunk?.error?.message || chunk?.message || ''
});
return;
}
chunk?.data?.forEach((item: any) => {
const imgItem = newImageList[item.index];
if (item.b64_json) {
imgItem.dataUrl += item.b64_json;
}
const progress = _.round(item.progress, 0);
newImageList[item.index] = {
dataUrl: imgItem.dataUrl,
height: imgSize[1],
width: imgSize[0],
maxHeight: `${imgSize[1]}px`,
maxWidth: `${imgSize[0]}px`,
uid: imgItem.uid,
span: imgItem.span,
loading: progress < 100,
progress: progress
};
});
setImageList([...newImageList]);
});
} catch (error) {
console.log('error:', error);
requestToken.current?.abort?.();
setImageList([]);
} finally {
setLoading(false);
}
};
const handleClear = () => {
setMessageId();
setImageList([]);
setTokenResult(null);
};
const handleInputChange = (e: any) => {
setCurrentPrompt(e.target.value);
};
const handleSendMessage = (message: Omit<MessageItem, 'uid'>) => {
const currentMessage = message.content ? message : undefined;
submitMessage(currentMessage);
};
const handleCloseViewCode = () => {
setShow(false);
};
const handleToggleParamsStyle = () => {
if (isOpenaiCompatible) {
form.current?.form?.setFieldsValue({
...advancedFieldsDefaultValus
});
setParams((pre: object) => {
return {
..._.omit(pre, _.keys(openaiCompatibleFieldsDefaultValus)),
...advancedFieldsDefaultValus
};
});
} else {
form.current?.form?.setFieldsValue({
...openaiCompatibleFieldsDefaultValus
});
setParams((pre: object) => {
return {
...openaiCompatibleFieldsDefaultValus,
..._.omit(pre, _.keys(advancedFieldsDefaultValus))
};
});
}
setIsOpenaiCompatible(!isOpenaiCompatible);
};
const renderExtra = useMemo(() => {
if (!isOpenaiCompatible) {
return [];
}
return ImageconstExtraConfig.map((item: ParamsSchema) => {
return (
<Form.Item name={item.name} rules={item.rules} key={item.name}>
<SealSelect
{...item.attrs}
options={item.options}
label={
item.label.isLocalized
? intl.formatMessage({ id: item.label.text })
: item.label.text
}
></SealSelect>
</Form.Item>
);
});
}, [ImageconstExtraConfig, isOpenaiCompatible, intl]);
const handleFieldChange = (e: any) => {
if (e.target.id.indexOf('random_seed') > -1) {
form.current?.form?.setFieldValue('random_seed', e.target.checked);
setParams((pre: object) => {
return {
...pre,
random_seed: e.target.checked
};
});
}
};
const renderAdvanced = useMemo(() => {
if (isOpenaiCompatible) {
return [];
}
const formValues = form.current?.form?.getFieldsValue();
return ImageAdvancedParamsConfig.map((item: ParamsSchema) => {
return (
<Form.Item
name={item.name}
rules={item.rules}
key={item.name}
noStyle={item.name === 'random_seed'}
>
<FieldComponent
style={item.name === 'random_seed' ? { marginBottom: 20 } : {}}
disabled={
item.disabledConfig
? item.disabledConfig?.when?.(formValues)
: item.disabled
}
onChange={item.name === 'random_seed' ? handleFieldChange : null}
{..._.omit(item, ['name', 'rules', 'disabledConfig'])}
></FieldComponent>
</Form.Item>
);
});
}, [ImageAdvancedParamsConfig, isOpenaiCompatible, intl, form.current]);
const renderCustomSize = useMemo(() => {
if (size === 'custom') {
return ImageCustomSizeConfig.map((item: ParamsSchema) => {
return (
<Form.Item
name={item.name}
rules={[
{
message: intl.formatMessage(
{ id: 'common.form.rule.input' },
{ name: intl.formatMessage({ id: item.label.text }) }
),
required: true
}
]}
key={item.name}
>
<FieldComponent
label={
item.label.isLocalized
? intl.formatMessage({ id: item.label.text })
: item.label.text
}
description={
item.description?.isLocalized
? intl.formatMessage({ id: item.description.text })
: item.description?.text
}
{...item.attrs}
{..._.omit(item, [
'name',
'description',
'rules',
'disabledConfig'
])}
></FieldComponent>
</Form.Item>
);
});
}
return null;
}, [size, intl]);
useEffect(() => {
return () => {
requestToken.current?.abort?.();
};
}, []);
useEffect(() => {
if (size === 'custom') {
form.current?.form?.setFieldsValue({
width: 512,
height: 512
});
setParams((pre: object) => {
return {
...pre,
width: 512,
height: 512
};
});
}
}, [size]);
useEffect(() => {
if (scroller.current) {
initialize(scroller.current);
}
}, [scroller.current, initialize]);
useEffect(() => {
if (paramsRef.current) {
innitializeParams(paramsRef.current);
}
}, [paramsRef.current, innitializeParams]);
useEffect(() => {
if (loading) {
updateScrollerPosition();
}
}, [imageList, loading]);
useEffect(() => {
if (imageList.length > messageListLengthCache.current) {
updateScrollerPosition();
}
messageListLengthCache.current = imageList.length;
}, [imageList.length]);
return (
<div className="ground-left-wrapper">
<div className="ground-left">
<div
className="message-list-wrap"
ref={scroller}
style={{ paddingBottom: 16 }}
>
<>
<div className="content" style={{ height: '100%' }}>
<ThumbImg
style={{
padding: 0,
height: '100%',
justifyContent: 'center',
flexDirection: 'column',
flexWrap: 'unset',
alignItems: 'center'
}}
autoBgColor={false}
editable={false}
dataList={imageList}
loading={loading}
responseable={true}
gutter={[8, 16]}
autoSize={true}
></ThumbImg>
{!imageList.length && (
<div className="flex-column font-size-14 flex-center gap-20 justify-center hold-wrapper">
<span>
<FileImageOutlined className="font-size-32 text-secondary" />
</span>
<span>
{intl.formatMessage({ id: 'playground.params.empty.tips' })}
</span>
</div>
)}
</div>
</>
</div>
{tokenResult && (
<div style={{ height: 40 }}>
<AlertInfo
type="danger"
message={tokenResult?.errorMessage}
></AlertInfo>
</div>
)}
<div className="ground-left-footer">
<MessageInput
ref={inputRef}
placeholer={intl.formatMessage({
id: 'playground.input.prompt.holder'
})}
actions={['clear']}
defaultSize={{
minRows: 5,
maxRows: 5
}}
title={
<span className="font-600">
{intl.formatMessage({ id: 'playground.image.prompt' })}
</span>
}
loading={loading}
disabled={!parameters.model}
isEmpty={!imageList.length}
handleSubmit={handleSendMessage}
handleAbortFetch={handleStopConversation}
onInputChange={handleInputChange}
shouldResetMessage={false}
clearAll={handleClear}
/>
</div>
</div>
<div
className={classNames('params-wrapper', {
collapsed: collapse
})}
ref={paramsRef}
>
<div className="box">
<DynamicParams
ref={form}
parametersTitle={
<div className="flex-between flex-center">
<span>
{intl.formatMessage({ id: 'playground.parameters' })}
</span>
<Tooltip
title={intl.formatMessage({
id: 'playground.image.params.custom.tips'
})}
>
<Button
size="middle"
type="text"
icon={<SwapOutlined />}
onClick={handleToggleParamsStyle}
>
{isOpenaiCompatible
? intl.formatMessage({
id: 'playground.image.params.custom'
})
: intl.formatMessage({
id: 'playground.image.params.openai'
})}
</Button>
</Tooltip>
</div>
}
setParams={setParams}
paramsConfig={paramsConfig}
initialValues={initialValues}
params={parameters}
selectedModel={selectModel}
modelList={modelList}
extra={[renderCustomSize, ...renderExtra, ...renderAdvanced]}
/>
</div>
</div>
<ViewCommonCode
open={show}
viewCodeContent={viewCodeContent}
onCancel={handleCloseViewCode}
title={intl.formatMessage({ id: 'playground.viewcode' })}
></ViewCommonCode>
</div>
);
});
export default memo(GroundImages);

@ -128,6 +128,67 @@ export const ImageParamsConfig: ParamsSchema[] = [
}
];
export const ImageEidtParamsConfig: ParamsSchema[] = [
// {
// type: 'Slider',
// name: 'brush_size',
// label: {
// text: 'Brush Size',
// isLocalized: false
// },
// attrs: {
// min: 25,
// max: 80
// },
// rules: [
// {
// required: false
// }
// ]
// },
{
type: 'InputNumber',
name: 'n',
label: {
text: 'playground.params.counts',
isLocalized: true
},
attrs: {
min: 1,
max: 4
},
rules: [
{
required: false
}
]
},
{
type: 'Select',
name: 'size',
options: [
{ label: 'playground.params.custom', value: 'custom', locale: true },
{ label: '512x512', value: '512x512' },
{ label: '768x1024', value: '768x1024' },
{ label: '1024x1024', value: '1024x1024' }
],
description: {
text: 'playground.params.size.description',
html: true,
isLocalized: true
},
label: {
text: 'playground.params.size',
isLocalized: true
},
rules: [
{
required: false
}
]
}
];
export const ImageconstExtraConfig: ParamsSchema[] = [
{
type: 'Select',

@ -2,31 +2,76 @@ import IconFont from '@/components/icon-font';
import breakpoints from '@/config/breakpoints';
import HotKeys from '@/config/hotkeys';
import useWindowResize from '@/hooks/use-window-resize';
import { DiffOutlined, HighlightOutlined } from '@ant-design/icons';
import { PageContainer } from '@ant-design/pro-components';
import { useIntl } from '@umijs/max';
import { Button, Space } from 'antd';
import { Button, Segmented, Space, Tabs, TabsProps } from 'antd';
import classNames from 'classnames';
import _ from 'lodash';
import { useCallback, useEffect, useRef, useState } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { queryModelsList } from './apis';
import GroundImages from './components/ground-images';
import ImageEdit from './components/image-edit';
import './style/play-ground.less';
const TabsValueMap = {
Tab1: 'generate',
Tab2: 'edit'
};
const TextToImages: React.FC = () => {
const intl = useIntl();
const { size } = useWindowResize();
const [activeKey, setActiveKey] = useState(TabsValueMap.Tab1);
const groundTabRef1 = useRef<any>(null);
const groundTabRef2 = useRef<any>(null);
const [modelList, setModelList] = useState<Global.BaseOption<string>[]>([]);
const [loaded, setLoaded] = useState(false);
const optionsList = [
{
label: 'Generate',
value: TabsValueMap.Tab1,
icon: <DiffOutlined />
},
{
label: 'Edit',
value: TabsValueMap.Tab2,
icon: <HighlightOutlined />
}
];
const handleViewCode = useCallback(() => {
groundTabRef1.current?.viewCode?.();
}, []);
if (activeKey === TabsValueMap.Tab1) {
groundTabRef1.current?.viewCode?.();
} else if (activeKey === TabsValueMap.Tab2) {
groundTabRef2.current?.viewCode?.();
}
}, [activeKey]);
const handleToggleCollapse = useCallback(() => {
groundTabRef1.current?.setCollapse?.();
}, []);
if (activeKey === TabsValueMap.Tab1) {
groundTabRef1.current?.setCollapse?.();
return;
}
groundTabRef2.current?.setCollapse?.();
}, [activeKey]);
const items: TabsProps['items'] = [
{
key: TabsValueMap.Tab1,
label: 'Generate',
children: (
<GroundImages ref={groundTabRef1} modelList={modelList}></GroundImages>
)
},
{
key: TabsValueMap.Tab2,
label: 'Edit',
children: <ImageEdit modelList={modelList} ref={groundTabRef2} />
}
];
useEffect(() => {
if (size.width < breakpoints.lg) {
@ -105,7 +150,21 @@ const TextToImages: React.FC = () => {
<PageContainer
ghost
header={{
title: intl.formatMessage({ id: 'menu.playground.text2images' }),
title: (
<div className="flex items-center">
<span className="font-600">
{intl.formatMessage({ id: 'menu.playground.text2images' })}
</span>
{
<Segmented
options={optionsList}
size="middle"
className="m-l-40"
onChange={(key) => setActiveKey(key)}
></Segmented>
}
</div>
),
breadcrumb: {}
}}
extra={renderExtra()}
@ -113,10 +172,7 @@ const TextToImages: React.FC = () => {
>
<div className="play-ground">
<div className="chat">
<GroundImages
ref={groundTabRef1}
modelList={modelList}
></GroundImages>
<Tabs items={items} activeKey={activeKey}></Tabs>
</div>
</div>
</PageContainer>

Loading…
Cancel
Save