diff --git a/config/routes.ts b/config/routes.ts index 5885a09a..35cd2bd2 100644 --- a/config/routes.ts +++ b/config/routes.ts @@ -33,6 +33,14 @@ export default [ icon: 'Comment', component: './playground/speech' }, + { + name: 'text2images', + title: 'Text2Images', + path: '/playground/text-to-images', + key: 'text2images', + icon: 'Comment', + component: './playground/images' + }, { name: 'embedding', title: 'embedding', diff --git a/src/components/auto-image/index.less b/src/components/auto-image/index.less new file mode 100644 index 00000000..9298a78a --- /dev/null +++ b/src/components/auto-image/index.less @@ -0,0 +1,21 @@ +.toolbar-wrapper { + padding: 0 24px; + color: rgba(255, 255, 255, 65%); + font-size: 16px; + background-color: rgba(0, 0, 0, 10%); + border-radius: 100px; +} + +.toolbar-wrapper .anticon { + padding: 12px; + cursor: pointer; +} + +.toolbar-wrapper .anticon[disabled] { + cursor: not-allowed; + opacity: 0.3; +} + +.toolbar-wrapper .anticon:hover { + opacity: 0.3; +} diff --git a/src/components/auto-image/index.tsx b/src/components/auto-image/index.tsx index db8c21a2..059d7848 100644 --- a/src/components/auto-image/index.tsx +++ b/src/components/auto-image/index.tsx @@ -1,6 +1,17 @@ -import { Image as AntImage, ImageProps } from 'antd'; +import { + DownloadOutlined, + EyeOutlined, + RotateLeftOutlined, + RotateRightOutlined, + SwapOutlined, + UndoOutlined, + ZoomInOutlined, + ZoomOutOutlined +} from '@ant-design/icons'; +import { Image as AntImage, ImageProps, Space } from 'antd'; import { round } from 'lodash'; import React, { useCallback, useEffect, useState } from 'react'; +import './index.less'; const AutoImage: React.FC = (props) => { const { height = 100, ...rest } = props; @@ -24,11 +35,58 @@ const AutoImage: React.FC = (props) => { setWidth(height * ratio); }, [getImgRatio, height, props.src]); + const onDownload = () => { + const url = props.src || ''; + const filename = Date.now() + ''; + + const link = document.createElement('a'); + link.href = url; + link.download = filename; + document.body.appendChild(link); + link.click(); + link.remove(); + }; + useEffect(() => { handleOnLoad(); }, [handleOnLoad]); - return ; + return ( + , + toolbarRender: ( + _, + { + transform: { scale }, + actions: { + onFlipY, + onFlipX, + onRotateLeft, + onRotateRight, + onZoomOut, + onZoomIn, + onReset + } + } + ) => ( + + + + + + + + + + + ) + }} + /> + ); }; export default AutoImage; diff --git a/src/locales/en-US/menu.ts b/src/locales/en-US/menu.ts index 4e48291f..dad668bb 100644 --- a/src/locales/en-US/menu.ts +++ b/src/locales/en-US/menu.ts @@ -5,6 +5,7 @@ export default { 'menu.playground.embedding': 'Embedding', 'menu.playground.chat': 'Chat', 'menu.playground.speech': 'Speech', + 'menu.playground.text2images': 'Text to Images', 'menu.compare': 'Compare', 'menu.models': 'Models', 'menu.resources': 'Resources', diff --git a/src/locales/zh-CN/menu.ts b/src/locales/zh-CN/menu.ts index a9141b61..735195d7 100644 --- a/src/locales/zh-CN/menu.ts +++ b/src/locales/zh-CN/menu.ts @@ -5,6 +5,7 @@ export default { 'menu.playground.embedding': '文本嵌入', 'menu.playground.chat': '对话', 'menu.playground.speech': '语音', + 'menu.playground.text2images': '文生图', 'menu.compare': '多模型对比', 'menu.models': '模型', 'menu.resources': '资源', diff --git a/src/pages/playground/apis/index.ts b/src/pages/playground/apis/index.ts index 8b4703c5..435da7d0 100644 --- a/src/pages/playground/apis/index.ts +++ b/src/pages/playground/apis/index.ts @@ -2,6 +2,8 @@ import { request } from '@umijs/max'; export const CHAT_API = '/v1-openai/chat/completions'; +export const CREAT_IMAGE_API = '/v1-openai/images/generations'; + export const EMBEDDING_API = '/v1-openai/embeddings'; export const OPENAI_MODELS = '/v1-openai/models'; @@ -51,3 +53,19 @@ export const handleEmbedding = async ( cancelToken: options?.cancelToken }); }; + +export const createImages = async ( + params: { + model: string; + prompt: string; + n: number; + size: string; + }, + options?: any +) => { + return request(`${CREAT_IMAGE_API}`, { + method: 'POST', + data: params, + cancelToken: options?.cancelToken + }); +}; diff --git a/src/pages/playground/components/ground-images.tsx b/src/pages/playground/components/ground-images.tsx new file mode 100644 index 00000000..5596d93a --- /dev/null +++ b/src/pages/playground/components/ground-images.tsx @@ -0,0 +1,310 @@ +import useOverlayScroller from '@/hooks/use-overlay-scroller'; +import useRequestToken from '@/hooks/use-request-token'; +import { useIntl, useSearchParams } from '@umijs/max'; +import { Spin } from 'antd'; +import classNames from 'classnames'; +import _ from 'lodash'; +import 'overlayscrollbars/overlayscrollbars.css'; +import { + forwardRef, + memo, + useEffect, + useImperativeHandle, + useMemo, + useRef, + useState +} from 'react'; +import { createImages } from '../apis'; +import { Roles, generateMessages } from '../config'; +import { ImageParamsConfig as paramsConfig } from '../config/params-config'; +import { MessageItem } from '../config/types'; +import '../style/ground-left.less'; +import '../style/system-message-wrap.less'; +import MessageInput from './message-input'; +import MessageContent from './multiple-chat/message-content'; +import ReferenceParams from './reference-params'; +import RerankerParams from './reranker-params'; +import ViewCodeModal from './view-code-modal'; + +interface MessageProps { + modelList: Global.BaseOption[]; + loaded?: boolean; + ref?: any; +} + +const initialValues = { + n: 1, + size: '512x512', + quality: 'standard', + style: 'vivid' +}; + +const GroundImages: React.FC = forwardRef((props, ref) => { + const { modelList } = props; + const messageId = useRef(0); + const [messageList, setMessageList] = useState([]); + + const intl = useIntl(); + const requestSource = useRequestToken(); + const [searchParams] = useSearchParams(); + const selectModel = searchParams.get('model') || ''; + const [parameters, setParams] = useState({}); + const [systemMessage, setSystemMessage] = useState(''); + const [show, setShow] = useState(false); + const [loading, setLoading] = useState(false); + const [tokenResult, setTokenResult] = useState(null); + const [collapse, setCollapse] = useState(false); + const contentRef = useRef(''); + const scroller = useRef(null); + const currentMessageRef = useRef(null); + const paramsRef = useRef(null); + const messageListLengthCache = useRef(0); + const requestToken = useRef(null); + + const { initialize, updateScrollerPosition } = useOverlayScroller(); + const { initialize: innitializeParams } = useOverlayScroller(); + + useImperativeHandle(ref, () => { + return { + viewCode() { + setShow(true); + }, + setCollapse() { + setCollapse(!collapse); + }, + collapse: collapse + }; + }); + + const viewCodeMessage = useMemo(() => { + return generateMessages([ + { role: Roles.System, content: systemMessage }, + ...messageList + ]); + }, [messageList, systemMessage]); + + const setMessageId = () => { + messageId.current = messageId.current + 1; + return messageId.current; + }; + + const handleNewMessage = (message?: { role: string; content: string }) => { + const newMessage = message || { + role: + _.last(messageList)?.role === Roles.User ? Roles.Assistant : Roles.User, + content: '' + }; + messageList.push({ + ...newMessage, + uid: messageId.current + 1 + }); + setMessageId(); + setMessageList([...messageList]); + }; + + const handleStopConversation = () => { + requestToken.current?.cancel?.(); + setLoading(false); + }; + + const submitMessage = async (current?: { role: string; content: string }) => { + if (!parameters.model) return; + try { + setLoading(true); + setMessageId(); + + requestToken.current?.cancel?.(); + requestToken.current = requestSource(); + + currentMessageRef.current = current + ? [ + { + ...current, + uid: messageId.current + } + ] + : []; + + contentRef.current = ''; + setMessageList((pre) => { + return [...pre, ...currentMessageRef.current]; + }); + + const params = { + prompt: current?.content || '', + ...parameters + }; + + const result = await createImages(params, { + cancelToken: requestToken.current.token + }); + + const imgList = _.map(result.data, (item: any, index: number) => { + return { + dataUrl: `data:image/png;base64,${item.b64_json}`, + created: result.created, + uid: index + }; + }); + setMessageList((pre) => { + return [ + ...pre, + { + content: '', + role: Roles.Assistant, + imgs: imgList, + uid: messageId.current + } + ]; + }); + console.log('result:', imgList); + + setMessageId(); + } catch (error) { + // console.log('error:', error); + } finally { + setLoading(false); + } + }; + const handleClear = () => { + if (!messageList.length) { + return; + } + setMessageId(); + setMessageList([]); + setTokenResult(null); + }; + + const handleSendMessage = (message: Omit) => { + console.log('message:', message); + const currentMessage = + message.content || message.imgs?.length ? message : undefined; + submitMessage(currentMessage); + }; + + const handleCloseViewCode = () => { + setShow(false); + }; + + const handleSelectModel = () => {}; + + const handlePresetPrompt = (list: { role: string; content: string }[]) => { + const sysMsg = list.filter((item) => item.role === 'system'); + const userMsg = list + .filter((item) => item.role === 'user') + .map((item) => { + setMessageId(); + return { + ...item, + uid: messageId.current + }; + }); + setSystemMessage(sysMsg[0]?.content || ''); + setMessageList(userMsg); + }; + + useEffect(() => { + if (scroller.current) { + initialize(scroller.current); + } + }, [scroller.current, initialize]); + + useEffect(() => { + if (paramsRef.current) { + innitializeParams(paramsRef.current); + } + }, [paramsRef.current, innitializeParams]); + + useEffect(() => { + if (loading) { + updateScrollerPosition(); + } + }, [messageList, loading]); + + useEffect(() => { + if (messageList.length > messageListLengthCache.current) { + updateScrollerPosition(); + } + messageListLengthCache.current = messageList.length; + }, [messageList.length]); + + return ( +
+
+
+ <> +
+ + {loading && ( + +
+
+ )} +
+ +
+ {tokenResult && ( +
+ +
+ )} +
+ +
+
+
+
+ +
+
+ + +
+ ); +}); + +export default memo(GroundImages); diff --git a/src/pages/playground/components/ground-reranker.tsx b/src/pages/playground/components/ground-reranker.tsx index f7e24304..fff6745e 100644 --- a/src/pages/playground/components/ground-reranker.tsx +++ b/src/pages/playground/components/ground-reranker.tsx @@ -281,6 +281,7 @@ const GroundReranker: React.FC = forwardRef((props, ref) => {
} loading={loading} disabled={!parameters.model} diff --git a/src/pages/playground/components/ground-tts.tsx b/src/pages/playground/components/ground-tts.tsx index bce84aa1..edf99753 100644 --- a/src/pages/playground/components/ground-tts.tsx +++ b/src/pages/playground/components/ground-tts.tsx @@ -33,6 +33,12 @@ interface MessageProps { ref?: any; } +const initialValues = { + voice: 'Alloy', + response_format: 'mp3', + speed: 1 +}; + const GroundLeft: React.FC = forwardRef((props, ref) => { const { modelList } = props; const messageId = useRef(0); @@ -57,12 +63,6 @@ const GroundLeft: React.FC = forwardRef((props, ref) => { const { initialize, updateScrollerPosition } = useOverlayScroller(); const { initialize: innitializeParams } = useOverlayScroller(); - const initialValues = { - voice: 'Alloy', - response_format: 'mp3', - speed: 1 - }; - useImperativeHandle(ref, () => { return { viewCode() { diff --git a/src/pages/playground/components/message-input.tsx b/src/pages/playground/components/message-input.tsx index 6426fea4..64e964e9 100644 --- a/src/pages/playground/components/message-input.tsx +++ b/src/pages/playground/components/message-input.tsx @@ -15,6 +15,8 @@ import UploadImg from './upload-img'; type CurrentMessage = Omit; +type ActionType = 'clear' | 'layout' | 'role' | 'upload' | 'add' | 'paste'; + const layoutOptions = [ { label: '2 columns', @@ -77,6 +79,7 @@ interface MessageInputProps { placeholer?: string; shouldResetMessage?: boolean; style?: React.CSSProperties; + actions?: ActionType[]; } const MessageInput: React.FC = ({ @@ -97,7 +100,8 @@ const MessageInput: React.FC = ({ placeholer, tools, style, - shouldResetMessage = true + shouldResetMessage = true, + actions = ['clear', 'layout', 'role', 'upload', 'add', 'paste'] }) => { const { TextArea } = Input; const intl = useIntl(); @@ -256,25 +260,8 @@ const MessageInput: React.FC = ({ }; const handleOnPaste = (e: any) => { - // e.preventDefault(); - const text = e.clipboardData.getData('text'); - if (text) { - // const startPos = e.target.selectionStart; - // const endPos = e.target.selectionEnd; - // setMessage?.({ - // ...message, - // content: - // message.content.slice(0, startPos) + - // text + - // message.content.slice(endPos) - // }); - // if (endPos !== startPos) { - // setTimeout(() => { - // e.target.setSelectionRange(endPos, endPos); - // }, 0); - // } - } else { + if (!text) { e.preventDefault(); getPasteContent(e); } @@ -352,26 +339,30 @@ const MessageInput: React.FC = ({
{tools} - {scope !== 'reranker' && ( + { <> - - - {message.role === Roles.User && ( + {actions.includes('role') && ( + <> + + + + )} + {actions.includes('upload') && message.role === Roles.User && ( )} - )} - {scope !== 'reranker' && ( + } + {actions.includes('clear') && ( @@ -383,7 +374,7 @@ const MessageInput: React.FC = ({ > )} - {updateLayout && ( + {actions.includes('layout') && updateLayout && ( <> {layoutOptions.map((option) => ( @@ -418,7 +409,7 @@ const MessageInput: React.FC = ({ > )} - {scope !== 'reranker' && ( + {actions.includes('add') && ( @@ -471,7 +462,7 @@ const MessageInput: React.FC = ({ onDelete={handleDeleteImg} >
- {scope !== 'reranker' ? ( + {actions.includes('paste') ? (