Skip to content

Commit

Permalink
improve image generation model endpoint generation and realted UX
Browse files Browse the repository at this point in the history
  • Loading branch information
yomybaby authored and lizable committed Jan 7, 2025
1 parent 0feb16a commit 5a3b049
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 6 deletions.
10 changes: 9 additions & 1 deletion react/src/components/ChatContent.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import { useLazyLoadQuery } from 'react-relay/hooks';
interface ChatContentProps {
endpointId: string;
endpointUrl: string;
endpointName: string;
basePath: string;
}

const ChatContent: React.FC<ChatContentProps> = ({
endpointId,
endpointUrl,
endpointName,
basePath,
}) => {
const { t } = useTranslation();
Expand Down Expand Up @@ -56,6 +58,7 @@ const ChatContent: React.FC<ChatContentProps> = ({
fetchPolicy: 'network-only',
},
);
const isTextToImageModel = _.includes(endpointName, 'stable-diffusion');

const newestValidToken =
_.orderBy(endpoint_token_list?.items, ['valid_until'], ['desc'])[0]
Expand Down Expand Up @@ -85,7 +88,12 @@ const ChatContent: React.FC<ChatContentProps> = ({
return (
<LLMChatCard
endpointId={endpointId || ''}
baseURL={new URL(basePath, endpointUrl).toString()}
baseURL={
isTextToImageModel
? new URL('/generate-image', endpointUrl || '').toString()
: new URL(basePath, endpointUrl || '').toString()
}
isImageGeneration={isTextToImageModel}
models={_.map(modelsResult?.data, (m) => ({
id: m.id,
name: m.id,
Expand Down
1 change: 1 addition & 0 deletions react/src/components/ModelCardChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const ModelCardChat: React.FC<ModelCardChatProps> = ({
<ChatContent
endpointId={healthyEndpoint[0]?.endpoint_id as string}
endpointUrl={healthyEndpoint[0]?.url as string}
endpointName={healthyEndpoint[0]?.name as string}
basePath={basePath}
/>
) : (
Expand Down
10 changes: 9 additions & 1 deletion react/src/components/lablupTalkativotUI/ChatUIModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
graphql`
fragment ChatUIModalFragment on Endpoint {
endpoint_id
name
url
status
}
Expand All @@ -90,6 +91,7 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
`,
endpointTokenFrgmt,
);
const isTextToImageModel = _.includes(endpoint?.name, 'stable-diffusion');

const newestToken = _.maxBy(
endpointTokenList?.items,
Expand Down Expand Up @@ -124,7 +126,12 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
) : (
<LLMChatCard
endpointId={endpoint?.endpoint_id || ''}
baseURL={new URL(basePath, endpoint?.url || '').toString()}
baseURL={
isTextToImageModel
? new URL('/generate-image', endpoint?.url || '').toString()
: new URL(basePath, endpoint?.url || '').toString()
}
isImageGeneration={isTextToImageModel}
models={_.map(modelsResult?.data, (m) => ({
id: m.id,
name: m.id,
Expand All @@ -133,6 +140,7 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
style={{ flex: 1 }}
allowCustomModel={_.isEmpty(modelsResult?.data)}
alert={
!isTextToImageModel &&
_.isEmpty(modelsResult?.data) && (
<Alert
type="warning"
Expand Down
24 changes: 20 additions & 4 deletions react/src/components/lablupTalkativotUI/LLMChatCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
setLoadingImageGeneration(true);
try {
const response = await fetch(
'https://stable-diffusion-3m.asia03.app.backend.ai/generate-image',
customModelFormRef.current?.getFieldValue('baseURL'),
{
method: 'POST',
headers: {
Expand All @@ -244,7 +244,9 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
);
if (response.ok) {
const responseData = await response.json();
return 'data:image/png;base64,' + responseData.image_base64;
return _.startsWith(responseData.image_base64, 'data:image/png;base64,')
? responseData.image_base64
: 'data:image/png;base64,' + responseData.image_base64;
} else {
throw new Error('Error generating image');
}
Expand Down Expand Up @@ -414,8 +416,8 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
});

if (isImageGeneration) {
const generationId = _.uniqueId();
try {
const imageBase64 = await generateImage(input, 'accessKey');
setMessages((prevMessages) => [
...prevMessages,
{
Expand All @@ -424,7 +426,20 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
content: input,
},
{
id: _.uniqueId(),
id: generationId,
role: 'assistant',
content: 'Processing...',
},
]);
setInput('');
const imageBase64 = await generateImage(input, 'accessKey');
setMessages((prevMessages) => [
..._.filter(
prevMessages,
(message) => message.id !== generationId,
),
{
id: generationId,
role: 'assistant',
content: '',
experimental_attachments: [
Expand Down Expand Up @@ -510,6 +525,7 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
required: true,
},
]}
hidden={isImageGeneration}
>
<Input placeholder="llm-model" />
</Form.Item>
Expand Down

0 comments on commit 5a3b049

Please sign in to comment.