Skip to content

Commit

Permalink
chore(ai): add graphql errors to useAIGeneration (#5900)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbanksdesign authored Oct 15, 2024
1 parent 5e20643 commit 1421dde
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 47 deletions.
5 changes: 5 additions & 0 deletions .changeset/fresh-parents-invent.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@aws-amplify/ui-react-ai": patch
---

chore(ai): add graphql errors to useAIGeneration
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,44 @@ import '@aws-amplify/ui-react-ai/ai-conversation-styles.css';

import outputs from './amplify_outputs.js';
import type { Schema } from '@environments/ai/gen2/amplify/data/resource';
import { Authenticator } from '@aws-amplify/ui-react';
import {
Button,
Flex,
Loader,
TextField,
withAuthenticator,
} from '@aws-amplify/ui-react';
import React from 'react';

const client = generateClient<Schema>();
const { useAIGeneration } = createAIHooks(client);

Amplify.configure(outputs);

export default function Example() {
const [{ data }, handler] = useAIGeneration('generateRecipe');
function Example() {
const [{ data, isLoading, hasError, messages }, handler] =
useAIGeneration('generateRecipe');
const handleSubmit = (e: React.FormEvent<HTMLFormElement>) => {
e.preventDefault();
const formData = new FormData(e.currentTarget);
const description = formData.get('description') as string;
handler({ description });
};
return (
<Authenticator>
{({ user }) => {
return (
<>
<h1>Hello {user.username}</h1>
<div>{JSON.stringify(data)}</div>
<button
onClick={() => {
handler({
description:
'I want a recipe for a gluten-free chocolate cake.',
});
}}
>
generate
</button>
</>
);
}}
</Authenticator>
<Flex direction="column" gap="medium">
<Flex direction="row" as="form" onSubmit={handleSubmit}>
<TextField label="description" name="description" />
<Button type="submit">generate</Button>
</Flex>
{isLoading ? <Loader /> : null}
{hasError ? (
<div>
{messages?.map(({ message }, i) => <div key={i}>{message}</div>)}
</div>
) : null}
{data ? <div>{JSON.stringify(data)}</div> : null}
</Flex>
);
}

export default withAuthenticator(Example);
14 changes: 12 additions & 2 deletions packages/react-ai/src/hooks/__tests__/createAIHooks.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,14 @@ describe('createAIHooks', () => {
};
const generateReturn = {
data: expectedResult,
errors: ['this is just one error'],
errors: [
{
errorType: '',
locations: [],
path: ['generateRecipe'],
message: 'this is just one error',
},
],
};
generateRecipeMock.mockResolvedValueOnce(generateReturn);
const { useAIGeneration } = createAIHooks(client);
Expand All @@ -202,7 +209,10 @@ describe('createAIHooks', () => {

const [awaitedState] = hookResult.current;
expect(awaitedState.data).toStrictEqual(expectedResult);
expect(awaitedState.graphqlErrors).toHaveLength(1);
expect(awaitedState.isLoading).toBeFalsy();
expect(awaitedState.hasError).toBeTruthy();
expect(awaitedState.messages).toHaveLength(1);
expect(awaitedState.messages?.[0].message).toContain('error');
});
});
});
67 changes: 45 additions & 22 deletions packages/react-ai/src/hooks/useAIGeneration.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { DataState, useDataState } from '@aws-amplify/ui-react-core';
import * as React from 'react';
import { DataState } from '@aws-amplify/ui-react-core';
import { V6Client } from '@aws-amplify/api-graphql';
import { getSchema } from '../types';

Expand All @@ -9,7 +10,7 @@ export interface UseAIGenerationHookWrapper<
useAIGeneration: <U extends Key>(
routeName: U
) => [
Awaited<GenerateState<Schema[U]['returnType']>>,
Awaited<GenerationState<Schema[U]['returnType']>>,
(input: Schema[U]['args']) => void,
];
}
Expand All @@ -20,7 +21,7 @@ export type UseAIGenerationHook<
> = (
routeName: Key
) => [
Awaited<GenerateState<Schema[Key]['returnType']>>,
Awaited<GenerationState<Schema[Key]['returnType']>>,
(input: Schema[Key]['args']) => void,
];

Expand All @@ -42,10 +43,19 @@ type SingularReturnValue<T> = {
errors?: GraphQLFormattedError[];
};

type GenerateState<T> = DataState<T> & {
graphqlErrors?: GraphQLFormattedError[];
type GenerationState<T> = Omit<DataState<T>, 'message'> & {
messages?: GraphQLFormattedError[];
};

// default state
const INITIAL_STATE = {
hasError: false,
isLoading: false,
messages: undefined,
};
const LOADING_STATE = { hasError: false, isLoading: true, messages: undefined };
const ERROR_STATE = { hasError: true, isLoading: false };

export function createUseAIGeneration<
Client extends Record<'generations' | 'conversations', Record<string, any>>,
Schema extends getSchema<Client>,
Expand All @@ -55,29 +65,42 @@ export function createUseAIGeneration<
>(
routeName: Key
): [
state: GenerateState<Schema[Key]['returnType']>,
state: GenerationState<Schema[Key]['returnType']>,
handleAction: (input: Schema[Key]['args']) => void,
] => {
const handleGenerate = (
client.generations as AIGenerationClient<Schema>['generations']
)[routeName];
const [dataState, setDataState] = React.useState<
GenerationState<Schema[Key]['returnType']>
>(() => ({
...INITIAL_STATE,
data: undefined,
}));

const updateAIGenerationStateAction = async (
_prev: Schema[Key]['returnType'],
input: Schema[Key]['args']
): Promise<Schema[Key]['returnType']> => {
return await handleGenerate(input);
};
const handleGeneration = React.useCallback(
async (input: Schema[Key]['args']) => {
setDataState(({ data }) => ({ ...LOADING_STATE, data }));

const [result, handler] = useDataState(
updateAIGenerationStateAction,
undefined
);
const result = await (
client.generations as AIGenerationClient<Schema>['generations']
)[routeName](input);

const { data, errors } =
(result?.data as SingularReturnValue<Schema[Key]['returnType']>) ?? {};
const { data, errors } = result as SingularReturnValue<
Schema[Key]['returnType']
>;

if (errors) {
setDataState({
...ERROR_STATE,
data,
messages: errors,
});
} else {
setDataState({ ...INITIAL_STATE, data });
}
},
[routeName]
);

return [{ ...result, data, graphqlErrors: errors }, handler];
return [dataState, handleGeneration];
};

return useAIGeneration;
Expand Down

0 comments on commit 1421dde

Please sign in to comment.