From ab939b27b4fcfad4f7c5dbedf15023985e9ec2f3 Mon Sep 17 00:00:00 2001 From: axel7083 <42176370+axel7083@users.noreply.github.com> Date: Tue, 6 Aug 2024 15:59:53 +0200 Subject: [PATCH] fix: remove incompatible models from playground selection (#1488) Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> --- .../src/pages/PlaygroundCreate.spec.ts | 71 ++++++++++++++----- .../src/pages/PlaygroundCreate.svelte | 3 +- 2 files changed, 57 insertions(+), 17 deletions(-) diff --git a/packages/frontend/src/pages/PlaygroundCreate.spec.ts b/packages/frontend/src/pages/PlaygroundCreate.spec.ts index 61c8894a0..bf1fb5d85 100644 --- a/packages/frontend/src/pages/PlaygroundCreate.spec.ts +++ b/packages/frontend/src/pages/PlaygroundCreate.spec.ts @@ -17,7 +17,7 @@ ***********************************************************************/ import '@testing-library/jest-dom/vitest'; -import { render, screen } from '@testing-library/svelte'; +import { render, within } from '@testing-library/svelte'; import { expect, test, vi, beforeEach } from 'vitest'; import { studioClient } from '../utils/client'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; @@ -27,6 +27,31 @@ import * as tasksStore from '/@/stores/tasks'; import * as modelsInfoStore from '/@/stores/modelsInfo'; import type { Task } from '@shared/src/models/ITask'; import PlaygroundCreate from './PlaygroundCreate.svelte'; +import { InferenceType } from '@shared/src/models/IInference'; + +const dummyLlamaCppModel: ModelInfo = { + id: 'llama-cpp-model-id', + name: 'Dummy LlamaCpp model', + file: { + file: 'file', + path: '/tmp/path', + }, + properties: {}, + description: '', + backend: InferenceType.LLAMA_CPP, +}; + +const dummyWhisperCppModel: ModelInfo = { + id: 'whisper-cpp-model-id', + name: 'Dummy Whisper model', + file: { + file: 'file', + path: '/tmp/path', + }, + properties: {}, + description: '', + backend: InferenceType.WHISPER_CPP, +}; vi.mock('../utils/client', async () => { return { @@ -57,34 +82,48 @@ vi.mock('/@/stores/modelsInfo', async () => { beforeEach(() => { window.HTMLElement.prototype.scrollIntoView = vi.fn(); -}); -test('should display error message if createPlayground fails', async () => { const tasksList = writable([]); vi.mocked(tasksStore).tasks = tasksList; +}); - const modelsInfoList = writable([ - { - id: 'id', - file: { - file: 'file', - path: '/tmp/path', - }, - } as unknown as ModelInfo, - ]); +test('model should be selected by default', () => { + const modelsInfoList = writable([dummyLlamaCppModel]); + vi.mocked(modelsInfoStore).modelsInfo = modelsInfoList; + + vi.mocked(studioClient.requestCreatePlayground).mockRejectedValue('error creating playground'); + + const { container } = render(PlaygroundCreate); + + const model = within(container).getByText(dummyLlamaCppModel.name); + expect(model).toBeInTheDocument(); +}); + +test('models with incompatible backend should not be listed', async () => { + const modelsInfoList = writable([dummyWhisperCppModel]); + vi.mocked(modelsInfoStore).modelsInfo = modelsInfoList; + + const { container } = render(PlaygroundCreate); + + const model = within(container).queryByText(dummyWhisperCppModel.name); + expect(model).toBeNull(); +}); + +test('should display error message if createPlayground fails', async () => { + const modelsInfoList = writable([dummyLlamaCppModel]); vi.mocked(modelsInfoStore).modelsInfo = modelsInfoList; vi.mocked(studioClient.requestCreatePlayground).mockRejectedValue('error creating playground'); - render(PlaygroundCreate); + const { container } = render(PlaygroundCreate); - const errorMessage = screen.queryByLabelText('Error Message Content'); + const errorMessage = within(container).queryByLabelText('Error Message Content'); expect(errorMessage).not.toBeInTheDocument(); - const createButton = screen.getByTitle('Create playground'); + const createButton = within(container).getByTitle('Create playground'); await userEvent.click(createButton); - const errorMessageAfterSubmit = screen.queryByLabelText('Error Message Content'); + const errorMessageAfterSubmit = within(container).queryByLabelText('Error Message Content'); expect(errorMessageAfterSubmit).toBeInTheDocument(); expect(errorMessageAfterSubmit?.textContent).equal('error creating playground'); }); diff --git a/packages/frontend/src/pages/PlaygroundCreate.svelte b/packages/frontend/src/pages/PlaygroundCreate.svelte index b02251348..edb329cae 100644 --- a/packages/frontend/src/pages/PlaygroundCreate.svelte +++ b/packages/frontend/src/pages/PlaygroundCreate.svelte @@ -13,9 +13,10 @@ import { filterByLabel } from '../utils/taskUtils'; import type { Unsubscriber } from 'svelte/store'; import { Button, ErrorMessage, FormPage, Input } from '@podman-desktop/ui-svelte'; import ModelSelect from '/@/lib/ModelSelect.svelte'; +import { InferenceType } from '@shared/src/models/IInference'; let localModels: ModelInfo[]; -$: localModels = $modelsInfo.filter(model => model.file); +$: localModels = $modelsInfo.filter(model => model.file && model.backend === InferenceType.LLAMA_CPP); $: availModels = $modelsInfo.filter(model => !model.file); let model: ModelInfo | undefined = undefined; let submitted: boolean = false;