From 1421ddef49215f232a580d464d13920b9213b698 Mon Sep 17 00:00:00 2001 From: Danny Banks Date: Tue, 15 Oct 2024 11:15:34 -0700 Subject: [PATCH] chore(ai): add graphql errors to useAIGeneration (#5900) --- .changeset/fresh-parents-invent.md | 5 ++ .../ai/use-ai-generation/index.page.tsx | 54 ++++++++------- .../hooks/__tests__/createAIHooks.test.tsx | 14 +++- .../react-ai/src/hooks/useAIGeneration.tsx | 67 +++++++++++++------ 4 files changed, 93 insertions(+), 47 deletions(-) create mode 100644 .changeset/fresh-parents-invent.md diff --git a/.changeset/fresh-parents-invent.md b/.changeset/fresh-parents-invent.md new file mode 100644 index 00000000000..76b58b69349 --- /dev/null +++ b/.changeset/fresh-parents-invent.md @@ -0,0 +1,5 @@ +--- +"@aws-amplify/ui-react-ai": patch +--- + +chore(ai): add graphql errors to useAIGeneration diff --git a/examples/next/pages/ui/components/ai/use-ai-generation/index.page.tsx b/examples/next/pages/ui/components/ai/use-ai-generation/index.page.tsx index 7421d29d0f0..7d5a051a429 100644 --- a/examples/next/pages/ui/components/ai/use-ai-generation/index.page.tsx +++ b/examples/next/pages/ui/components/ai/use-ai-generation/index.page.tsx @@ -6,7 +6,13 @@ import '@aws-amplify/ui-react-ai/ai-conversation-styles.css'; import outputs from './amplify_outputs.js'; import type { Schema } from '@environments/ai/gen2/amplify/data/resource'; -import { Authenticator } from '@aws-amplify/ui-react'; +import { + Button, + Flex, + Loader, + TextField, + withAuthenticator, +} from '@aws-amplify/ui-react'; import React from 'react'; const client = generateClient(); @@ -14,28 +20,30 @@ const { useAIGeneration } = createAIHooks(client); Amplify.configure(outputs); -export default function Example() { - const [{ data }, handler] = useAIGeneration('generateRecipe'); +function Example() { + const [{ data, isLoading, hasError, messages }, handler] = + useAIGeneration('generateRecipe'); + const handleSubmit = (e: React.FormEvent) => { + e.preventDefault(); + const formData = new FormData(e.currentTarget); + const description = formData.get('description') as string; + handler({ description }); + }; return ( - - {({ user }) => { - return ( - <> -

Hello {user.username}

-
{JSON.stringify(data)}
- - - ); - }} -
+ + + + + + {isLoading ? : null} + {hasError ? ( +
+ {messages?.map(({ message }, i) =>
{message}
)} +
+ ) : null} + {data ?
{JSON.stringify(data)}
: null} +
); } + +export default withAuthenticator(Example); diff --git a/packages/react-ai/src/hooks/__tests__/createAIHooks.test.tsx b/packages/react-ai/src/hooks/__tests__/createAIHooks.test.tsx index a2aa63fd3b0..9ff19fae396 100644 --- a/packages/react-ai/src/hooks/__tests__/createAIHooks.test.tsx +++ b/packages/react-ai/src/hooks/__tests__/createAIHooks.test.tsx @@ -179,7 +179,14 @@ describe('createAIHooks', () => { }; const generateReturn = { data: expectedResult, - errors: ['this is just one error'], + errors: [ + { + errorType: '', + locations: [], + path: ['generateRecipe'], + message: 'this is just one error', + }, + ], }; generateRecipeMock.mockResolvedValueOnce(generateReturn); const { useAIGeneration } = createAIHooks(client); @@ -202,7 +209,10 @@ describe('createAIHooks', () => { const [awaitedState] = hookResult.current; expect(awaitedState.data).toStrictEqual(expectedResult); - expect(awaitedState.graphqlErrors).toHaveLength(1); + expect(awaitedState.isLoading).toBeFalsy(); + expect(awaitedState.hasError).toBeTruthy(); + expect(awaitedState.messages).toHaveLength(1); + expect(awaitedState.messages?.[0].message).toContain('error'); }); }); }); diff --git a/packages/react-ai/src/hooks/useAIGeneration.tsx b/packages/react-ai/src/hooks/useAIGeneration.tsx index c9aa4c139ef..e5c5981ce88 100644 --- a/packages/react-ai/src/hooks/useAIGeneration.tsx +++ b/packages/react-ai/src/hooks/useAIGeneration.tsx @@ -1,4 +1,5 @@ -import { DataState, useDataState } from '@aws-amplify/ui-react-core'; +import * as React from 'react'; +import { DataState } from '@aws-amplify/ui-react-core'; import { V6Client } from '@aws-amplify/api-graphql'; import { getSchema } from '../types'; @@ -9,7 +10,7 @@ export interface UseAIGenerationHookWrapper< useAIGeneration: ( routeName: U ) => [ - Awaited>, + Awaited>, (input: Schema[U]['args']) => void, ]; } @@ -20,7 +21,7 @@ export type UseAIGenerationHook< > = ( routeName: Key ) => [ - Awaited>, + Awaited>, (input: Schema[Key]['args']) => void, ]; @@ -42,10 +43,19 @@ type SingularReturnValue = { errors?: GraphQLFormattedError[]; }; -type GenerateState = DataState & { - graphqlErrors?: GraphQLFormattedError[]; +type GenerationState = Omit, 'message'> & { + messages?: GraphQLFormattedError[]; }; +// default state +const INITIAL_STATE = { + hasError: false, + isLoading: false, + messages: undefined, +}; +const LOADING_STATE = { hasError: false, isLoading: true, messages: undefined }; +const ERROR_STATE = { hasError: true, isLoading: false }; + export function createUseAIGeneration< Client extends Record<'generations' | 'conversations', Record>, Schema extends getSchema, @@ -55,29 +65,42 @@ export function createUseAIGeneration< >( routeName: Key ): [ - state: GenerateState, + state: GenerationState, handleAction: (input: Schema[Key]['args']) => void, ] => { - const handleGenerate = ( - client.generations as AIGenerationClient['generations'] - )[routeName]; + const [dataState, setDataState] = React.useState< + GenerationState + >(() => ({ + ...INITIAL_STATE, + data: undefined, + })); - const updateAIGenerationStateAction = async ( - _prev: Schema[Key]['returnType'], - input: Schema[Key]['args'] - ): Promise => { - return await handleGenerate(input); - }; + const handleGeneration = React.useCallback( + async (input: Schema[Key]['args']) => { + setDataState(({ data }) => ({ ...LOADING_STATE, data })); - const [result, handler] = useDataState( - updateAIGenerationStateAction, - undefined - ); + const result = await ( + client.generations as AIGenerationClient['generations'] + )[routeName](input); - const { data, errors } = - (result?.data as SingularReturnValue) ?? {}; + const { data, errors } = result as SingularReturnValue< + Schema[Key]['returnType'] + >; + + if (errors) { + setDataState({ + ...ERROR_STATE, + data, + messages: errors, + }); + } else { + setDataState({ ...INITIAL_STATE, data }); + } + }, + [routeName] + ); - return [{ ...result, data, graphqlErrors: errors }, handler]; + return [dataState, handleGeneration]; }; return useAIGeneration;