Skip to content

Commit

Permalink
feat: call generate image API if isImageGeneration is true.
Browse files Browse the repository at this point in the history
  • Loading branch information
agatha197 committed Jan 6, 2025
1 parent 6dabc24 commit 420fab9
Showing 1 changed file with 75 additions and 11 deletions.
86 changes: 75 additions & 11 deletions react/src/components/lablupTalkativotUI/LLMChatCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
DeleteOutlined,
LinkOutlined,
MoreOutlined,
PictureOutlined,
RocketOutlined,
} from '@ant-design/icons';
import { Attachments, AttachmentsProps, Sender } from '@ant-design/x';
Expand All @@ -32,6 +33,7 @@ import {
MenuProps,
Tag,
theme,
Tooltip,
Typography,
} from 'antd';
import _ from 'lodash';
Expand Down Expand Up @@ -70,6 +72,7 @@ export interface LLMChatCardProps extends CardProps {
onSubmitChange?: () => void;
showCompareMenuItem?: boolean;
modelToken?: string;
isImageGeneration?: boolean;
}

const LLMChatCard: React.FC<LLMChatCardProps> = ({
Expand All @@ -89,10 +92,12 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
onSubmitChange,
showCompareMenuItem,
modelToken,
isImageGeneration,
...cardProps
}) => {
const webuiNavigate = useWebUINavigate();
const [isOpenAttachments, setIsOpenAttachments] = useState(false);
const [loadingImageGeneration, setLoadingImageGeneration] = useState(false);

const [modelId, setModelId] = useControllableValue(cardProps, {
valuePropName: 'modelId',
Expand Down Expand Up @@ -221,6 +226,33 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
},
]);

const generateImage = async (prompt: string, accessKey: string) => {
setLoadingImageGeneration(true);
try {
const response = await fetch(
'https://stable-diffusion-3m.asia03.app.backend.ai/generate-image',
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
prompt: prompt,
access_key: accessKey,
}),
},
);
if (response.ok) {
const responseData = await response.json();
return 'data:image/png;base64,' + responseData.image_base64;
} else {
throw new Error('Error generating image');
}
} finally {
setLoadingImageGeneration(false);
}
};

return (
<Card
ref={cardRef}
Expand Down Expand Up @@ -326,6 +358,11 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
/>
</Sender.Header>
}
styles={{
prefix: {
alignSelf: 'center',
},
}}
prefix={
<Attachments
beforeUpload={() => false}
Expand Down Expand Up @@ -359,11 +396,11 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
onInputChange(v);
}
}}
loading={isLoading}
loading={isLoading || loadingImageGeneration}
onStop={() => {
stop();
}}
onSend={() => {
onSend={async () => {
if (input || !_.isEmpty(files)) {
const fileList = _.map(
files,
Expand All @@ -376,15 +413,42 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
dataTransfer.items.add(file);
});

append(
{
role: 'user',
content: input,
},
{
experimental_attachments: dataTransfer.files,
},
);
if (isImageGeneration) {
try {
const imageBase64 = await generateImage(input, 'accessKey');
setMessages((prevMessages) => [
...prevMessages,
{
id: _.uniqueId(),
role: 'user',
content: input,
},
{
id: _.uniqueId(),
role: 'assistant',
content: '',
experimental_attachments: [
{
contentType: 'image/png',
url: imageBase64,
},
],
},
]);
} catch (error) {
console.error(error);
}
} else {
append(
{
role: 'user',
content: input,
},
{
experimental_attachments: dataTransfer.files,
},
);
}

setTimeout(() => {
setInput('');
Expand Down

0 comments on commit 420fab9

Please sign in to comment.