From 5a3b0497c1f4a2c3eced8a9eaeb4833c4911a6e8 Mon Sep 17 00:00:00 2001 From: Jong Eun Lee Date: Mon, 6 Jan 2025 17:09:44 +0800 Subject: [PATCH] improve image generation model endpoint generation and realted UX --- react/src/components/ChatContent.tsx | 10 +++++++- react/src/components/ModelCardChat.tsx | 1 + .../lablupTalkativotUI/ChatUIModal.tsx | 10 +++++++- .../lablupTalkativotUI/LLMChatCard.tsx | 24 +++++++++++++++---- 4 files changed, 39 insertions(+), 6 deletions(-) diff --git a/react/src/components/ChatContent.tsx b/react/src/components/ChatContent.tsx index ff6f8f159c..b66d97b367 100644 --- a/react/src/components/ChatContent.tsx +++ b/react/src/components/ChatContent.tsx @@ -13,12 +13,14 @@ import { useLazyLoadQuery } from 'react-relay/hooks'; interface ChatContentProps { endpointId: string; endpointUrl: string; + endpointName: string; basePath: string; } const ChatContent: React.FC = ({ endpointId, endpointUrl, + endpointName, basePath, }) => { const { t } = useTranslation(); @@ -56,6 +58,7 @@ const ChatContent: React.FC = ({ fetchPolicy: 'network-only', }, ); + const isTextToImageModel = _.includes(endpointName, 'stable-diffusion'); const newestValidToken = _.orderBy(endpoint_token_list?.items, ['valid_until'], ['desc'])[0] @@ -85,7 +88,12 @@ const ChatContent: React.FC = ({ return ( ({ id: m.id, name: m.id, diff --git a/react/src/components/ModelCardChat.tsx b/react/src/components/ModelCardChat.tsx index cd4728e443..5b1d6c6600 100644 --- a/react/src/components/ModelCardChat.tsx +++ b/react/src/components/ModelCardChat.tsx @@ -67,6 +67,7 @@ const ModelCardChat: React.FC = ({ ) : ( diff --git a/react/src/components/lablupTalkativotUI/ChatUIModal.tsx b/react/src/components/lablupTalkativotUI/ChatUIModal.tsx index 48d897b601..154aad9c6c 100644 --- a/react/src/components/lablupTalkativotUI/ChatUIModal.tsx +++ b/react/src/components/lablupTalkativotUI/ChatUIModal.tsx @@ -71,6 +71,7 @@ const EndpointChatContent: React.FC = ({ graphql` fragment ChatUIModalFragment on Endpoint { endpoint_id + name url status } @@ -90,6 +91,7 @@ const EndpointChatContent: React.FC = ({ `, endpointTokenFrgmt, ); + const isTextToImageModel = _.includes(endpoint?.name, 'stable-diffusion'); const newestToken = _.maxBy( endpointTokenList?.items, @@ -124,7 +126,12 @@ const EndpointChatContent: React.FC = ({ ) : ( ({ id: m.id, name: m.id, @@ -133,6 +140,7 @@ const EndpointChatContent: React.FC = ({ style={{ flex: 1 }} allowCustomModel={_.isEmpty(modelsResult?.data)} alert={ + !isTextToImageModel && _.isEmpty(modelsResult?.data) && ( = ({ setLoadingImageGeneration(true); try { const response = await fetch( - 'https://stable-diffusion-3m.asia03.app.backend.ai/generate-image', + customModelFormRef.current?.getFieldValue('baseURL'), { method: 'POST', headers: { @@ -244,7 +244,9 @@ const LLMChatCard: React.FC = ({ ); if (response.ok) { const responseData = await response.json(); - return 'data:image/png;base64,' + responseData.image_base64; + return _.startsWith(responseData.image_base64, 'data:image/png;base64,') + ? responseData.image_base64 + : 'data:image/png;base64,' + responseData.image_base64; } else { throw new Error('Error generating image'); } @@ -414,8 +416,8 @@ const LLMChatCard: React.FC = ({ }); if (isImageGeneration) { + const generationId = _.uniqueId(); try { - const imageBase64 = await generateImage(input, 'accessKey'); setMessages((prevMessages) => [ ...prevMessages, { @@ -424,7 +426,20 @@ const LLMChatCard: React.FC = ({ content: input, }, { - id: _.uniqueId(), + id: generationId, + role: 'assistant', + content: 'Processing...', + }, + ]); + setInput(''); + const imageBase64 = await generateImage(input, 'accessKey'); + setMessages((prevMessages) => [ + ..._.filter( + prevMessages, + (message) => message.id !== generationId, + ), + { + id: generationId, role: 'assistant', content: '', experimental_attachments: [ @@ -510,6 +525,7 @@ const LLMChatCard: React.FC = ({ required: true, }, ]} + hidden={isImageGeneration} >