Skip to content

Commit

Permalink
remove viz action and parse viz contents
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Fontanier committed Jul 31, 2024
1 parent feabc1a commit 5baa734
Show file tree
Hide file tree
Showing 37 changed files with 506 additions and 1,062 deletions.
5 changes: 0 additions & 5 deletions front/components/actions/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import { DustAppRunActionDetails } from "@app/components/actions/dust_app_run/Du
import { ProcessActionDetails } from "@app/components/actions/process/ProcessActionDetails";
import { RetrievalActionDetails } from "@app/components/actions/retrieval/RetrievalActionDetails";
import { TablesQueryActionDetails } from "@app/components/actions/tables_query/TablesQueryActionDetails";
import { VisualizationActionDetails } from "@app/components/actions/visualization/VisualizationActionDetails";
import { WebsearchActionDetails } from "@app/components/actions/websearch/WebsearchActionDetails";

export interface ActionDetailsComponentBaseProps<
Expand Down Expand Up @@ -51,10 +50,6 @@ const actionsSpecification: ActionSpecifications = {
detailsComponent: BrowseActionDetails,
runningLabel: "Browsing page",
},
visualization_action: {
detailsComponent: VisualizationActionDetails,
runningLabel: "Analyzing request",
},
};

export function getActionSpecification<T extends ActionType>(
Expand Down

This file was deleted.

15 changes: 0 additions & 15 deletions front/components/assistant/AssistantDetails.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import {
Page,
PlanetIcon,
ServerIcon,
ShapesIcon,
Spinner,
Tree,
} from "@dust-tt/sparkle";
Expand All @@ -34,7 +33,6 @@ import {
isProcessConfiguration,
isRetrievalConfiguration,
isTablesQueryConfiguration,
isVisualizationConfiguration,
isWebsearchConfiguration,
} from "@dust-tt/types";
import { useCallback, useContext, useEffect, useMemo, useState } from "react";
Expand Down Expand Up @@ -315,19 +313,6 @@ export function AssistantDetails({
</div>
) : isBrowseConfiguration(action) ? (
false
) : isVisualizationConfiguration(action) ? (
<div className="flex flex-col gap-2" key={`action-${index}`}>
<div className="text-lg font-bold text-element-800">
Visualization
</div>
<div className="flex items-center gap-2">
<Icon visual={ShapesIcon} size="xs" />
<div>
Assistant can generate graphs to visually represent your
data.
</div>
</div>
</div>
) : (
!isRetrievalConfiguration(action) && assertNever(action)
)
Expand Down
105 changes: 57 additions & 48 deletions front/components/assistant/conversation/AgentMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@ import type {
import {
assertNever,
isRetrievalActionType,
isVisualizationActionType,
isWebsearchActionType,
removeNulls,
} from "@dust-tt/types";
import assert from "assert";
import Link from "next/link";
import { useRouter } from "next/router";
import { useCallback, useContext, useEffect, useRef, useState } from "react";
Expand Down Expand Up @@ -88,8 +86,8 @@ export function AgentMessage({
const [streamedAgentMessage, setStreamedAgentMessage] =
useState<AgentMessageType>(message);

const [streamedVisualizations, setStreamedVisualizations] = useState<
{ actionId: number; visualization: string }[]
const [visualizations, setVisualizations] = useState<
{ visualization: string; complete: boolean }[]
>([]);

const [isRetryHandlerProcessing, setIsRetryHandlerProcessing] =
Expand All @@ -103,6 +101,17 @@ export function AgentMessage({
{ index: number; document: RetrievalDocumentType | WebsearchResultType }[]
>([]);

useEffect(() => {
if (message.status === "succeeded") {
setVisualizations(
message.visualizations.map((v) => ({
visualization: v,
complete: true,
}))
);
}
}, [message.status, message.visualizations]);

const shouldStream = (() => {
if (message.status !== "created") {
return false;
Expand Down Expand Up @@ -180,7 +189,6 @@ export function AgentMessage({
case "process_params":
case "websearch_params":
case "browse_params":
case "visualization_params":
setStreamedAgentMessage((m) => {
return updateMessageWithAction(m, event.action);
});
Expand All @@ -203,13 +211,27 @@ export function AgentMessage({
...event.message,
};
});
setStreamedVisualizations([]);
break;
}

case "generation_tokens": {
switch (event.classification) {
case "closing_delimiter":
if (event.delimiterClassification === "visualization") {
// If we receive a closing delimiter for a visualization, we can
// consider the last viz to be complete.
setVisualizations((v) => {
const lastViz = v[v.length - 1];
if (lastViz) {
return [
...v.slice(0, v.length - 1),
{ ...lastViz, complete: true },
];
}
return v;
});
}
break;
case "opening_delimiter":
break;
case "tokens":
Expand All @@ -230,32 +252,28 @@ export function AgentMessage({
});
break;
case "visualization":
// TODO(@fontanierh): support viz tokens
// Append new content to the last viz, or create a new one if there
// is none or the last one is complete.
setVisualizations((v) => {
const lastViz = v[v.length - 1];
if (lastViz && !lastViz.complete) {
return [
...v.slice(0, v.length - 1),
{
visualization: lastViz.visualization + event.text,
complete: false,
},
];
}
return [...v, { visualization: event.text, complete: false }];
});
break;
default:
assertNever(event.classification);
assertNever(event);
}
break;
}

case "visualization_generation_tokens":
setStreamedVisualizations((m) => {
const actionId = event.actionId;
const tokens = event.text;
const index = m.findIndex((v) => v.actionId === actionId);
if (index === -1) {
return [...m, { actionId, visualization: tokens }];
} else {
return m.map((v) => {
if (v.actionId === actionId) {
return { ...v, visualization: v.visualization + tokens };
}
return v;
});
}
});
break;

default:
assertNever(event);
}
Expand Down Expand Up @@ -472,7 +490,7 @@ export function AgentMessage({
references: references,
streaming: shouldStream,
lastTokenClassification: lastTokenClassification,
streamedVisualizations,
visualizations,
})}
</div>
{/* Invisible div to act as a scroll anchor for detecting when the user has scrolled to the bottom */}
Expand All @@ -485,13 +503,13 @@ export function AgentMessage({
references,
streaming,
lastTokenClassification,
streamedVisualizations,
visualizations,
}: {
agentMessage: AgentMessageType;
references: { [key: string]: RetrievalDocumentType | WebsearchResultType };
streaming: boolean;
lastTokenClassification: null | "tokens" | "chain_of_thought";
streamedVisualizations: { actionId: number; visualization: string }[];
visualizations: { visualization: string; complete: boolean }[];
}) {
if (agentMessage.status === "failed") {
return (
Expand All @@ -513,25 +531,16 @@ export function AgentMessage({
<div className="flex flex-col gap-y-4">
<AgentMessageActions agentMessage={agentMessage} size={size} />
<>
{agentMessage.actions
.filter((a) => isVisualizationActionType(a))
.map((a, i) => {
const streamingViz = streamedVisualizations.find(
(sv) => sv.actionId === a.id
);
assert(isVisualizationActionType(a));
return (
<VisualizationActionIframe
action={a}
conversationId={conversationId}
isStreaming={!!streamingViz}
key={i}
onRetry={() => retryHandler(agentMessage)}
owner={owner}
streamedCode={streamingViz?.visualization || null}
/>
);
})}
{visualizations.map((v, i) => {
return (
<VisualizationActionIframe
visualization={{ ...v, index: i }}
key={i}
onRetry={() => retryHandler(agentMessage)}
owner={owner}
/>
);
})}
</>

{agentMessage.chainOfThought?.length ? (
Expand Down
Loading

0 comments on commit 5baa734

Please sign in to comment.