diff --git a/front/components/assistant/RenderMessageMarkdown.tsx b/front/components/assistant/RenderMessageMarkdown.tsx index da7e0041e8db..3a1fe8941591 100644 --- a/front/components/assistant/RenderMessageMarkdown.tsx +++ b/front/components/assistant/RenderMessageMarkdown.tsx @@ -6,7 +6,7 @@ import { Tooltip, } from "@dust-tt/sparkle"; import dynamic from "next/dynamic"; -import React, { useState } from "react"; +import React, { useEffect, useState } from "react"; import ReactMarkdown from "react-markdown"; import { ReactMarkdownProps } from "react-markdown/lib/complex-types"; import remarkDirective from "remark-directive"; @@ -121,24 +121,42 @@ function addClosingBackticks(str: string): string { return str; } -export const ReferencesContext = React.createContext<{ - [key: string]: RetrievalDocumentType; -}>({}); +type CitationsContextType = { + references: { + [key: string]: RetrievalDocumentType; + }; + updateActiveReferences: (doc: RetrievalDocumentType, index: number) => void; + setHoveredReference: (index: number | null) => void; +}; + +export const CitationsContext = React.createContext({ + references: {}, + updateActiveReferences: () => null, + setHoveredReference: () => null, +}); export function RenderMessageMarkdown({ content, blinkingCursor, - references, agentConfigurations, + citationsContext, }: { content: string; blinkingCursor: boolean; - references?: { [key: string]: RetrievalDocumentType }; agentConfigurations?: AgentConfigurationType[]; + citationsContext?: CitationsContextType; }) { return (
- + null, + setHoveredReference: () => null, + } + } + > {addClosingBackticks(content)} - +
); } @@ -213,21 +231,32 @@ function isCiteProps(props: ReactMarkdownProps): props is ReactMarkdownProps & { } function CiteBlock(props: ReactMarkdownProps) { - const references = React.useContext(ReferencesContext); - - if (isCiteProps(props) && props.references) { - const refs = ( - JSON.parse(props.references) as { - counter: number; - ref: string; - }[] - ).filter((r) => r.ref in references); + const { references, updateActiveReferences, setHoveredReference } = + React.useContext(CitationsContext); + const refs = + isCiteProps(props) && props.references + ? ( + JSON.parse(props.references) as { + counter: number; + ref: string; + }[] + ).filter((r) => r.ref in references) + : undefined; + + useEffect(() => { + if (refs) { + refs.forEach((r) => { + const document = references[r.ref]; + updateActiveReferences(document, r.counter); + }); + } + }, [refs, references, updateActiveReferences]); + if (refs) { return ( <> {refs.map((r, i) => { const document = references[r.ref]; - const provider = providerFromDocument(document); const title = titleFromDocument(document); const link = linkFromDocument(document); @@ -263,6 +292,7 @@ function CiteBlock(props: ReactMarkdownProps) { target="_blank" rel="noopener noreferrer" className={citeClassNames} + onMouseEnter={() => setHoveredReference(r.counter)} > {r.counter} diff --git a/front/components/assistant/conversation/AgentMessage.tsx b/front/components/assistant/conversation/AgentMessage.tsx index 3a63dc22cf3f..b58571926254 100644 --- a/front/components/assistant/conversation/AgentMessage.tsx +++ b/front/components/assistant/conversation/AgentMessage.tsx @@ -4,15 +4,25 @@ import { Chip, ClipboardIcon, DocumentDuplicateIcon, + DocumentTextIcon, DropdownMenu, + ExternalLinkIcon, EyeIcon, + IconButton, Spinner, } from "@dust-tt/sparkle"; -import { useCallback, useContext, useEffect, useState } from "react"; +import Link from "next/link"; +import { useCallback, useContext, useEffect, useRef, useState } from "react"; import { AgentAction } from "@app/components/assistant/conversation/AgentAction"; import { ConversationMessage } from "@app/components/assistant/conversation/ConversationMessage"; import { GenerationContext } from "@app/components/assistant/conversation/GenerationContextProvider"; +import { + linkFromDocument, + PROVIDER_LOGO_PATH, + providerFromDocument, + titleFromDocument, +} from "@app/components/assistant/conversation/RetrievalAction"; import { RenderMessageMarkdown } from "@app/components/assistant/RenderMessageMarkdown"; import { useEventSource } from "@app/hooks/useEventSource"; import { @@ -24,6 +34,7 @@ import { AgentMessageSuccessEvent, } from "@app/lib/api/assistant/agent"; import { GenerationTokensEvent } from "@app/lib/api/assistant/generation"; +import { classNames } from "@app/lib/utils"; import { isRetrievalActionType, RetrievalDocumentType, @@ -224,7 +235,21 @@ export function AgentMessage({ const [references, setReferences] = useState<{ [key: string]: RetrievalDocumentType; }>({}); - + const [activeReferences, setActiveReferences] = useState< + { index: number; document: RetrievalDocumentType }[] + >([]); + function updateActiveReferences( + document: RetrievalDocumentType, + index: number + ) { + const existingIndex = activeReferences.find((r) => r.index === index); + if (!existingIndex) { + setActiveReferences([...activeReferences, { index, document }]); + } + } + const [lastHoveredReference, setLastHoveredReference] = useState< + number | null + >(null); useEffect(() => { if ( agentMessageToRender.action && @@ -311,7 +336,15 @@ export function AgentMessage({ + )} @@ -339,6 +372,95 @@ export function AgentMessage({ } } +function Citations({ + activeReferences, + lastHoveredReference, +}: { + activeReferences: { index: number; document: RetrievalDocumentType }[]; + lastHoveredReference: number | null; +}) { + const citationContainer = useRef(null); + + useEffect(() => { + if (citationContainer.current) { + if (lastHoveredReference !== null) { + citationContainer.current.scrollTo({ + left: citationsScrollOffset(lastHoveredReference), + behavior: "smooth", + }); + } + } + }, [lastHoveredReference]); + + function citationsScrollOffset(reference: number | null) { + if (!citationContainer.current || reference === null) { + return 0; + } + const offset = ( + citationContainer.current.firstElementChild + ?.firstElementChild as HTMLElement + ).offsetLeft; + const scrolling = + (citationContainer.current.firstElementChild?.firstElementChild + ?.scrollWidth || 0) * + (reference - 2); + return scrolling - offset; + } + + activeReferences.sort((a, b) => a.index - b.index); + return ( +
+
+ {activeReferences.map(({ document, index }) => { + const provider = providerFromDocument(document); + return ( +
+
+
+ {index} +
+
+ {provider === "none" ? ( + + ) : ( + + )} +
+
+ + + +
+
+ {titleFromDocument(document)} +
+
+ ); + })} +
+
+
+ ); +} + function ErrorMessage({ error, retryHandler, diff --git a/front/components/assistant/conversation/ConversationMessage.tsx b/front/components/assistant/conversation/ConversationMessage.tsx index 3b4a643c20a4..018f1630d946 100644 --- a/front/components/assistant/conversation/ConversationMessage.tsx +++ b/front/components/assistant/conversation/ConversationMessage.tsx @@ -115,7 +115,7 @@ export function ConversationMessage({ <> {/* SMALL SIZE SCREEN*/}
-
+
- {/* COLUMN 2: CONTENT */} + {/* COLUMN 2: CONTENT + * min-w-0 prevents the content from overflowing the container + */}
{name}
diff --git a/front/components/sparkle/AppLayout.tsx b/front/components/sparkle/AppLayout.tsx index 24dfb267a9fb..ef4d465c1686 100644 --- a/front/components/sparkle/AppLayout.tsx +++ b/front/components/sparkle/AppLayout.tsx @@ -361,7 +361,7 @@ export default function AppLayout({ >
diff --git a/front/styles/global.css b/front/styles/global.css index 35e172bc1a3a..343c92ea0fa6 100644 --- a/front/styles/global.css +++ b/front/styles/global.css @@ -67,3 +67,13 @@ main { transform: translate3d(8px, 3px, 0); } } + +@keyframes bgblink { + 0%, + 100% { + @apply bg-structure-0; + } + 50% { + @apply bg-structure-200; + } +}