Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: call generate image API if isImageGeneration is true. #3001

Draft
wants to merge 4 commits into
base: feature/model-player-UI-updated
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion react/src/components/ChatContent.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import { useLazyLoadQuery } from 'react-relay/hooks';
interface ChatContentProps {
endpointId: string;
endpointUrl: string;
endpointName: string;
basePath: string;
}

const ChatContent: React.FC<ChatContentProps> = ({
endpointId,
endpointUrl,
endpointName,
basePath,
}) => {
const { t } = useTranslation();
Expand Down Expand Up @@ -56,6 +58,7 @@ const ChatContent: React.FC<ChatContentProps> = ({
fetchPolicy: 'network-only',
},
);
const isTextToImageModel = _.includes(endpointName, 'stable-diffusion');

const newestValidToken =
_.orderBy(endpoint_token_list?.items, ['valid_until'], ['desc'])[0]
Expand Down Expand Up @@ -85,7 +88,14 @@ const ChatContent: React.FC<ChatContentProps> = ({
return (
<LLMChatCard
endpointId={endpointId || ''}
baseURL={new URL(basePath, endpointUrl).toString()}
baseURL={
endpointUrl
? isTextToImageModel
? new URL('/generate-image', endpointUrl || '').toString()
: new URL(basePath, endpointUrl || '').toString()
: ''
}
isImageGeneration={isTextToImageModel}
models={_.map(modelsResult?.data, (m) => ({
id: m.id,
name: m.id,
Expand Down
1 change: 1 addition & 0 deletions react/src/components/ModelCardChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const ModelCardChat: React.FC<ModelCardChatProps> = ({
<ChatContent
endpointId={healthyEndpoint[0]?.endpoint_id as string}
endpointUrl={healthyEndpoint[0]?.url as string}
endpointName={healthyEndpoint[0]?.name as string}
basePath={basePath}
/>
) : (
Expand Down
3 changes: 2 additions & 1 deletion react/src/components/ModelCardModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ const ModelCardModal: React.FC<ModelCardModalProps> = ({
wrap="wrap"
align="stretch"
gap={'sm'}
style={{ width: '100%' }}
style={{ width: '100%', minHeight: '50vh' }}
>
<Flex
direction="row"
Expand Down Expand Up @@ -200,6 +200,7 @@ const ModelCardModal: React.FC<ModelCardModalProps> = ({
>
<ModelTryContent
modelStorageHost={model_card?.vfolder?.host as string}
modelStoreName={model_card?.vfolder?.name as string}
modelName={model_card?.name as string}
minAIAcclResource={(() => {
const minResource = _.toNumber(
Expand Down
20 changes: 13 additions & 7 deletions react/src/components/ModelTryContent.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import { useTranslation } from 'react-i18next';

interface ModelTryContentProps {
modelStorageHost?: string;
modelStoreName?: string;
modelName?: string;
minAIAcclResource: number;
title?: string;
Expand All @@ -36,6 +37,7 @@ interface ModelTryContentProps {
const ModelTryContent: React.FC<ModelTryContentProps> = ({
modelName,
modelStorageHost,
modelStoreName,
minAIAcclResource,
title,
...props
Expand Down Expand Up @@ -220,7 +222,7 @@ const ModelTryContent: React.FC<ModelTryContentProps> = ({
image: {
registry: 'cr.backend.ai',
name: (() => {
if (modelName?.includes('stable-diffusion')) {
if (['stable-diffusion', 'phi-4'].includes(modelName as string)) {
return 'testing/ngc-pytorch';
}
switch (runtimeVariant) {
Expand All @@ -233,7 +235,7 @@ const ModelTryContent: React.FC<ModelTryContentProps> = ({
}
})(),
tag: (() => {
if (modelName?.includes('stable-diffusion')) {
if (['stable-diffusion', 'phi-4'].includes(modelName as string)) {
return '24.07-pytorch2.4-py310-cuda12.5';
}
switch (runtimeVariant) {
Expand Down Expand Up @@ -264,7 +266,9 @@ const ModelTryContent: React.FC<ModelTryContentProps> = ({
version: '',
},
// FIXME: temporally hard-coded runtime variant
runtimeVariant: modelName?.includes('stable-diffusion')
runtimeVariant: ['stable-diffusion', 'phi-4'].includes(
modelName as string,
)
? 'custom'
: runtimeVariant,
cluster_size: 1,
Expand Down Expand Up @@ -308,10 +312,10 @@ const ModelTryContent: React.FC<ModelTryContentProps> = ({
cloneable: true,
permission: 'wd', // write-delete permission
target_host: modelStorageHost, // lowestUsageHost, // clone to accessible and lowest usage storage host
target_name: `${modelName === 'Talkativot UI' ? 'talkativot-standalone-1' : modelName}`,
target_name: `${modelName === 'Talkativot UI' ? 'talkativot-standalone-1' : modelName + '-1'}`,
usage_mode: 'model',
},
name: `${modelName === 'Talkativot UI' ? 'talkativot-standalone' : modelName}`,
name: `${modelName === 'Talkativot UI' ? 'talkativot-standalone' : modelStoreName}`,
},
{
onSuccess: (data) => {
Expand Down Expand Up @@ -536,7 +540,8 @@ const ModelTryContent: React.FC<ModelTryContentProps> = ({
type="primary"
disabled={
modelName?.includes('stable-diffusion') ||
modelName?.includes('Talkativot UI')
modelName?.includes('Talkativot UI') ||
modelName?.includes('phi-4')
}
onClick={() => {
cloneOrCreateModelService('vllm');
Expand All @@ -553,7 +558,8 @@ const ModelTryContent: React.FC<ModelTryContentProps> = ({
modelName?.includes('stable-diffusion') ||
modelName?.includes('gemma-2-27b-it') ||
modelName?.includes('Llama-3.2-11B-Vision-Instruct') ||
modelName?.includes('Talkativot UI')
modelName?.includes('Talkativot UI') ||
modelName?.includes('phi-4')
}
type="primary"
onClick={() => {
Expand Down
4 changes: 2 additions & 2 deletions react/src/components/lablupTalkativotUI/ChatMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ const ChatMessage: React.FC<{
src={attachment?.url}
alt={attachment?.name}
style={{
maxWidth: '50vw',
maxHeight: '12vh',
maxWidth: placement === 'left' ? 200 : 300,
maxHeight: placement === 'left' ? 200 : 300,
borderRadius: token.borderRadius,
}}
/>
Expand Down
12 changes: 11 additions & 1 deletion react/src/components/lablupTalkativotUI/ChatUIModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
graphql`
fragment ChatUIModalFragment on Endpoint {
endpoint_id
name
url
status
}
Expand All @@ -90,6 +91,7 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
`,
endpointTokenFrgmt,
);
const isTextToImageModel = _.includes(endpoint?.name, 'stable-diffusion');

const newestToken = _.maxBy(
endpointTokenList?.items,
Expand Down Expand Up @@ -124,7 +126,14 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
) : (
<LLMChatCard
endpointId={endpoint?.endpoint_id || ''}
baseURL={new URL(basePath, endpoint?.url || '').toString()}
baseURL={
endpoint?.url
? isTextToImageModel
? new URL('/generate-image', endpoint?.url || '').toString()
: new URL(basePath, endpoint?.url || '').toString()
: ''
}
isImageGeneration={isTextToImageModel}
models={_.map(modelsResult?.data, (m) => ({
id: m.id,
name: m.id,
Expand All @@ -133,6 +142,7 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
style={{ flex: 1 }}
allowCustomModel={_.isEmpty(modelsResult?.data)}
alert={
!isTextToImageModel &&
_.isEmpty(modelsResult?.data) && (
<Alert
type="warning"
Expand Down
13 changes: 10 additions & 3 deletions react/src/components/lablupTalkativotUI/EndpointLLMChatCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const EndpointLLMChatCard: React.FC<EndpointLLMChatCardProps> = ({
fragment EndpointLLMChatCard_endpoint on Endpoint {
endpoint_id
url
name
}
`,
endpointFrgmt,
Expand All @@ -63,6 +64,8 @@ const EndpointLLMChatCard: React.FC<EndpointLLMChatCardProps> = ({
chatSubmitKeyInfoState,
);

const isTextToImageModel = _.includes(endpoint?.name, 'stable-diffusion');

const { data: modelsResult } = useSuspenseTanQuery<{
data: Array<Model>;
}>({
Expand Down Expand Up @@ -93,9 +96,12 @@ const EndpointLLMChatCard: React.FC<EndpointLLMChatCardProps> = ({
{...cardProps}
baseURL={
endpoint?.url
? new URL(basePath, endpoint?.url ?? undefined).toString()
: undefined
? isTextToImageModel
? new URL('/generate-image', endpoint?.url).toString()
: new URL(basePath, endpoint?.url).toString()
: ''
}
isImageGeneration={isTextToImageModel}
models={models}
fetchOnClient
leftExtra={
Expand Down Expand Up @@ -147,7 +153,8 @@ const EndpointLLMChatCard: React.FC<EndpointLLMChatCardProps> = ({
}
allowCustomModel={_.isEmpty(models)}
alert={
_.isEmpty(models) && (
!isTextToImageModel &&
_.isEmpty(modelsResult?.data) && (
<Alert
type="warning"
showIcon
Expand Down
Loading
Loading