Skip to content

Commit

Permalink
5757 adding a goal or criteria for model comparisons (#5761)
Browse files Browse the repository at this point in the history
Co-authored-by: Cole Blanchard <[email protected]>
Co-authored-by: Cole Blanchard <[email protected]>
Co-authored-by: Yohann Paris <[email protected]>
  • Loading branch information
4 people authored Dec 5, 2024
1 parent bb368bd commit b0e734a
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@focus="onFocus"
@blur="onBlur"
@input="updateValue"
@keydown.space.stop
/>
</main>
<aside v-if="getErrorMessage"><i class="pi pi-exclamation-circle" /> {{ getErrorMessage }}</aside>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ export interface ModelComparisonOperationState extends BaseState {
hasCodeRun: boolean;
comparisonImageIds: string[];
comparisonPairs: string[][];
goal: string;
hasRun: boolean;
}

export const ModelComparisonOperation: Operation = {
Expand All @@ -28,7 +30,9 @@ export const ModelComparisonOperation: Operation = {
notebookHistory: [],
hasCodeRun: false,
comparisonImageIds: [],
comparisonPairs: []
comparisonPairs: [],
goal: '',
hasRun: false
};
return init;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,23 @@
<section class="comparison-overview">
<Accordion :activeIndex="currentActiveIndicies">
<AccordionTab header="Overview">
<p v-if="isEmpty(overview)" class="subdued">
<template #header>
<tera-input-text
class="ml-auto w-4"
placeholder="What is your goal? (Optional)"
:model-value="goalQuery"
@blur="onUpdateGoalQuery"
/>
<Button
class="ml-4"
label="Compare"
@click.stop="processCompareModels"
size="small"
icon="pi pi-sparkles"
:loading="isProcessingComparison"
/>
</template>
<p v-if="isProcessingComparison" class="subdued">
<i class="pi pi-spin pi-spinner mr-1" />
Analyzing models metadata to generate a detailed comparison analysis...
</p>
Expand Down Expand Up @@ -152,7 +168,7 @@ import { ClientEvent, ClientEventType, type Model, TaskResponse, TaskStatus } fr
import { OperatorStatus, WorkflowNode, WorkflowPortStatus } from '@/types/workflow';
import { logger } from '@/utils/logger';
import Button from 'primevue/button';
import { onMounted, onUnmounted, ref } from 'vue';
import { computed, onMounted, onUnmounted, ref } from 'vue';
import { VAceEditor } from 'vue3-ace-editor';
import { VAceEditorInstance } from 'vue3-ace-editor/types';
Expand All @@ -165,6 +181,7 @@ import { addImage, deleteImages, getImages } from '@/services/image';
import TeraColumnarPanel from '@/components/widgets/tera-columnar-panel.vue';
import { b64DecodeUnicode } from '@/utils/binary';
import { useClientEvent } from '@/composables/useClientEvent';
import TeraInputText from '@/components/widgets/tera-input-text.vue';
import { ModelComparisonOperationState } from './model-comparison-operation';
const props = defineProps<{
Expand All @@ -173,6 +190,13 @@ const props = defineProps<{
const emit = defineEmits(['update-state', 'update-status', 'close']);
const goalQuery = ref(props.node.state.goal);
const isProcessingComparison = ref(false);
const modelIds = computed(() =>
props.node.inputs.filter((input) => input.status === WorkflowPortStatus.CONNECTED).map((input) => input.value?.[0])
);
enum Tabs {
Wizard = 'Wizard',
Notebook = 'Notebook'
Expand Down Expand Up @@ -249,6 +273,12 @@ function resetNotebook() {
emptyImages();
updateCodeState();
}
function onUpdateGoalQuery(goal: string) {
const state = cloneDeep(props.node.state);
goalQuery.value = goal;
state.goal = goal;
emit('update-state', state);
}
function runCode() {
const messageContent = {
Expand Down Expand Up @@ -461,25 +491,31 @@ function generateOverview(output: string) {
}
// Create a task to compare the models
async function processCompareModels(modelIds: string[]) {
const taskRes = await compareModels(modelIds, props.node.workflowId, props.node.id);
const processCompareModels = async () => {
isProcessingComparison.value = true;
const taskRes = await compareModels(modelIds.value, goalQuery.value, props.node.workflowId, props.node.id);
compareModelsTaskId = taskRes.id;
if (taskRes.status === TaskStatus.Success) {
generateOverview(taskRes.output);
}
}
const state = cloneDeep(props.node.state);
state.hasRun = true;
emit('update-state', state);
isProcessingComparison.value = false;
};
// Listen for the task completion event
useClientEvent(ClientEventType.TaskGollmCompareModel, (event: ClientEvent<TaskResponse>) => {
if (
!event.data ||
event.data.id !== compareModelsTaskId ||
!isEmpty(overview.value) ||
event.data.status !== TaskStatus.Success
) {
return;
if (!event.data || event.data.id !== compareModelsTaskId) return;
if ([TaskStatus.Queued, TaskStatus.Running, TaskStatus.Cancelling].includes(event.data.status)) {
isProcessingComparison.value = true;
} else if (event.data.status === TaskStatus.Success) {
generateOverview(event.data.output);
isProcessingComparison.value = false;
} else if ([TaskStatus.Failed, TaskStatus.Cancelled].includes(event.data.status)) {
isProcessingComparison.value = false;
}
generateOverview(event.data.output);
});
onMounted(async () => {
Expand All @@ -489,18 +525,16 @@ onMounted(async () => {
isLoadingStructuralComparisons.value = false;
}
const modelIds: string[] = props.node.inputs
.filter((input) => input.status === WorkflowPortStatus.CONNECTED)
.map((input) => input.value?.[0]);
modelsToCompare.value = (await Promise.all(modelIds.map(async (modelId) => getModel(modelId)))).filter(
modelsToCompare.value = (await Promise.all(modelIds.value.map(async (modelId) => getModel(modelId)))).filter(
Boolean
) as Model[];
modelCardsToCompare.value = modelsToCompare.value.map(({ metadata }) => metadata?.gollmCard);
fields.value = [...new Set(modelCardsToCompare.value.flatMap((card) => (card ? Object.keys(card) : [])))];
await buildJupyterContext();
await processCompareModels(modelIds);
if (props.node.state.hasRun) {
processCompareModels();
}
});
onUnmounted(() => {
Expand Down
8 changes: 7 additions & 1 deletion packages/client/hmi-client/src/services/goLLM.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,16 @@ export async function configureModelFromDataset(
return data;
}

export async function compareModels(modelIds: string[], workflowId?: string, nodeId?: string): Promise<TaskResponse> {
export async function compareModels(
modelIds: string[],
goal?: string,
workflowId?: string,
nodeId?: string
): Promise<TaskResponse> {
const { data } = await API.get<TaskResponse>('/gollm/compare-models', {
params: {
'model-ids': modelIds.join(','),
goal,
'workflow-id': workflowId,
'node-id': nodeId
}
Expand Down
2 changes: 1 addition & 1 deletion packages/gollm/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ModelCardModel(BaseModel):

class ModelCompareModel(BaseModel):
amrs: List[str] # expects AMRs to be a stringified JSON object

goal: str = None

class EquationsCleanup(BaseModel):
equations: List[str]
Expand Down
11 changes: 10 additions & 1 deletion packages/gollm/gollm_openai/prompts/model_meta_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@
Answer with accuracy, precision, and without repetition.
Ensure that you compare EVERY supplied model. Do not leave any model out of the comparison.
---MODELS START---
{amrs}
---MODELS END---
Comparison:
"""

MODEL_METADATA_COMPARE_GOAL_PROMPT = """
When creating your final conclusion, consider the following goal of the user when comparing the models:
{goal}
"""
Loading

0 comments on commit b0e734a

Please sign in to comment.