Skip to content

Commit

Permalink
feat(ai): add aiContext prop to AIConversation (#6090)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbanksdesign authored Nov 14, 2024
1 parent 022b586 commit a25e5ed
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 8 deletions.
21 changes: 21 additions & 0 deletions .changeset/beige-pugs-drive.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
---
"@aws-amplify/ui-react-ai": minor
---

feat(ai): add aiContext prop to AIConversation

```tsx
<AIConversation
messages={messages}
isLoading={isLoading}
handleSendMessage={sendMessage}
// This will let the LLM know about the current state of this application
// so it can better respond to questions, you can put any information
// in this object that might be helpful
aiContext={() => {
return {
currentTime: new Date().toLocaleTimeString(),
};
}}
/>
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import * as React from 'react';
import { Amplify } from 'aws-amplify';
import { signOut } from 'aws-amplify/auth';
import { createAIHooks, AIConversation } from '@aws-amplify/ui-react-ai';
import { generateClient } from 'aws-amplify/api';
import '@aws-amplify/ui-react/styles.css';

import outputs from './amplify_outputs';
import type { Schema } from '@environments/ai/gen2/amplify/data/resource';
import { Authenticator, Button, Card, Flex } from '@aws-amplify/ui-react';

const client = generateClient<Schema>({ authMode: 'userPool' });
const { useAIConversation } = createAIHooks(client);

Amplify.configure(outputs);

function Chat() {
const { data } = React.useContext(AIContext);
const [
{
data: { messages },
isLoading,
},
sendMessage,
] = useAIConversation('pirateChat');

return (
<AIConversation
messages={messages}
isLoading={isLoading}
handleSendMessage={sendMessage}
// This will let the LLM know about the current state of this application
// so it can better respond to questions
aiContext={() => {
return {
...data,
currentTime: new Date().toLocaleTimeString(),
};
}}
/>
);
}

function Counter() {
const { data, setData } = React.useContext(AIContext);
const count = data.count ?? 0;
return (
<Button onClick={() => setData({ ...data, count: count + 1 })}>
{count}
</Button>
);
}

const AIContext = React.createContext<{
data: any;
setData: (value: React.SetStateAction<any>) => void;
}>({ data: {}, setData: () => {} });

const AIContextProvider = ({
children,
}: {
children?: React.ReactNode;
}): JSX.Element => {
const [data, setData] = React.useState({});
return (
<AIContext.Provider value={{ data, setData }}>
{children}
</AIContext.Provider>
);
};

export default function Example() {
return (
<Authenticator>
<AIContextProvider>
<Flex direction="column" alignItems="flex-start">
<Button
onClick={() => {
signOut();
}}
>
Sign out
</Button>
<Card
flex="1"
variation="outlined"
// height="400px"
width="100%"
margin="large"
>
<Chat />
</Card>
<Counter />
</Flex>
</AIContextProvider>
</Authenticator>
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
WelcomeMessageProvider,
FallbackComponentProvider,
MessageRendererProvider,
AIContextProvider,
} from './context';
import { AttachmentProvider } from './context/AttachmentContext';

Expand All @@ -29,6 +30,7 @@ export interface AIConversationProviderProps
}

export const AIConversationProvider = ({
aiContext,
actions,
allowAttachments,
avatars,
Expand Down Expand Up @@ -72,9 +74,16 @@ export const AIConversationProvider = ({
<ActionsProvider actions={actions}>
<MessageVariantProvider variant={variant}>
<MessagesProvider messages={messages}>
<LoadingContextProvider isLoading={isLoading}>
{children}
</LoadingContextProvider>
{/* aiContext should be as close as possible to the bottom */}
{/* because the intent is users should update the context */}
{/* without it affecting the already rendered messages */}
<AIContextProvider aiContext={aiContext}>
<LoadingContextProvider
isLoading={isLoading}
>
{children}
</LoadingContextProvider>
</AIContextProvider>
</MessagesProvider>
</MessageVariantProvider>
</ActionsProvider>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import React from 'react';

export const AIContextContext = React.createContext<(() => object) | undefined>(
undefined
);

export const AIContextProvider = ({
children,
aiContext,
}: {
children?: React.ReactNode;
aiContext?: () => object;
}): JSX.Element => {
return (
<AIContextContext.Provider value={aiContext}>
{children}
</AIContextContext.Provider>
);
};
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export { AIContextContext, AIContextProvider } from './AIContextContext';
export { ActionsContext, ActionsProvider } from './ActionsContext';
export { AvatarsContext, AvatarsProvider } from './AvatarsContext';
export {
Expand Down
1 change: 1 addition & 0 deletions packages/react-ai/src/components/AIConversation/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export interface AIConversationProps {
handleSendMessage: SendMessage;
avatars?: Avatars;
isLoading?: boolean;
aiContext?: () => object;
}

export interface AIConversation<
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import React from 'react';

import { withBaseElementProps } from '@aws-amplify/ui-react-core/elements';
import { ConversationInputContext } from '../../context';
import { AIContextContext, ConversationInputContext } from '../../context';
import { AIConversationElements } from '../../context/elements';
import { AttachFileControl } from './AttachFileControl';
import { MessagesContext } from '../../context';
Expand All @@ -16,6 +16,7 @@ import { ControlsContext } from '../../context/ControlsContext';
import { getImageTypeFromMimeType } from '../../utils';
import { LoadingContext } from '../../context/LoadingContext';
import { AttachmentContext } from '../../context/AttachmentContext';
import { isFunction } from '@aws-amplify/ui';

const {
Button,
Expand Down Expand Up @@ -150,8 +151,9 @@ export const FormControl: FormControl = () => {
const { input, setInput } = React.useContext(ConversationInputContext);
const handleSendMessage = React.useContext(SendMessageContext);
const allowAttachments = React.useContext(AttachmentContext);
const ref = React.useRef<HTMLFormElement | null>(null);
const responseComponents = React.useContext(ResponseComponentsContext);
const aiContext = React.useContext(AIContextContext);
const ref = React.useRef<HTMLFormElement | null>(null);
const controls = React.useContext(ControlsContext);
const [composing, setComposing] = React.useState(false);

Expand Down Expand Up @@ -181,6 +183,7 @@ export const FormControl: FormControl = () => {
if (handleSendMessage) {
handleSendMessage({
content: submittedContent,
aiContext: isFunction(aiContext) ? aiContext() : undefined,
toolConfiguration:
convertResponseComponentsToToolConfiguration(responseComponents),
});
Expand All @@ -198,7 +201,7 @@ export const FormControl: FormControl = () => {
) => {
const { key, shiftKey } = event;

if (key === 'Enter' && !shiftKey && !composing ) {
if (key === 'Enter' && !shiftKey && !composing) {
event.preventDefault();

const hasInput =
Expand Down Expand Up @@ -232,8 +235,8 @@ export const FormControl: FormControl = () => {
<VisuallyHidden>
<Label />
</VisuallyHidden>
<TextInput
onKeyDown={handleOnKeyDown}
<TextInput
onKeyDown={handleOnKeyDown}
onCompositionStart={() => setComposing(true)}
onCompositionEnd={() => setComposing(false)}
/>
Expand Down

0 comments on commit a25e5ed

Please sign in to comment.