diff --git a/codebuild_specs/e2e_workflow.yml b/codebuild_specs/e2e_workflow.yml index 21e9002352..b00481d570 100644 --- a/codebuild_specs/e2e_workflow.yml +++ b/codebuild_specs/e2e_workflow.yml @@ -1005,7 +1005,7 @@ batch: variables: TEST_SUITE: >- src/__tests__/amplify-table-5.test.ts|src/__tests__/add-resources.test.ts|src/__tests__/deploy-velocity-temporarily-disabled/single-gsi-single-record.test.ts|src/__tests__/deploy-velocity-temporarily-disabled/single-gsi-empty-table.test.ts - CLI_REGION: ap-east-1 + CLI_REGION: ap-northeast-1 depend-on: - publish_to_local_registry - identifier: >- @@ -1016,7 +1016,7 @@ batch: variables: TEST_SUITE: >- src/__tests__/deploy-velocity-temporarily-disabled/single-gsi-1k-records.test.ts|src/__tests__/deploy-velocity-temporarily-disabled/single-gsi-10k-records.test.ts|src/__tests__/deploy-velocity-temporarily-disabled/replace-2-gsis-update-attr-single-record.test.ts|src/__tests__/deploy-velocity-temporarily-disabled/replace-2-gsis-update-attr-empty-table.test.ts - CLI_REGION: ap-northeast-1 + CLI_REGION: ap-northeast-2 depend-on: - publish_to_local_registry - identifier: >- @@ -1027,7 +1027,7 @@ batch: variables: TEST_SUITE: >- src/__tests__/deploy-velocity-temporarily-disabled/replace-2-gsis-update-attr-1k-records.test.ts|src/__tests__/deploy-velocity-temporarily-disabled/replace-2-gsis-update-attr-10k-records.test.ts|src/__tests__/deploy-velocity-temporarily-disabled/replace-2-gsis-single-record.test.ts|src/__tests__/deploy-velocity-temporarily-disabled/replace-2-gsis-empty-table.test.ts - CLI_REGION: ap-northeast-2 + CLI_REGION: ap-south-1 depend-on: - publish_to_local_registry - identifier: >- @@ -1038,7 +1038,7 @@ batch: variables: TEST_SUITE: >- src/__tests__/deploy-velocity-temporarily-disabled/replace-2-gsis-1k-records.test.ts|src/__tests__/deploy-velocity-temporarily-disabled/replace-2-gsis-10k-records.test.ts|src/__tests__/deploy-velocity-temporarily-disabled/3-gsis-single-record.test.ts|src/__tests__/deploy-velocity-temporarily-disabled/3-gsis-empty-table.test.ts - CLI_REGION: ap-south-1 + CLI_REGION: ap-southeast-1 depend-on: - publish_to_local_registry - identifier: 3_gsis_1k_records_3_gsis_10k_records @@ -1048,7 +1048,7 @@ batch: variables: TEST_SUITE: >- src/__tests__/deploy-velocity-temporarily-disabled/3-gsis-1k-records.test.ts|src/__tests__/deploy-velocity-temporarily-disabled/3-gsis-10k-records.test.ts - CLI_REGION: ap-southeast-1 + CLI_REGION: ap-southeast-2 depend-on: - publish_to_local_registry - identifier: sql_pg_models @@ -1755,13 +1755,23 @@ batch: CLI_REGION: ap-southeast-2 depend-on: - publish_to_local_registry + - identifier: generation + buildspec: codebuild_specs/run_cdk_tests.yml + env: + compute-type: BUILD_GENERAL1_SMALL + variables: + TEST_SUITE: src/__tests__/generations/generation.test.ts + CLI_REGION: us-west-2 + USE_PARENT_ACCOUNT: 1 + depend-on: + - publish_to_local_registry - identifier: single_gsi_100k_records buildspec: codebuild_specs/run_cdk_tests.yml env: compute-type: BUILD_GENERAL1_SMALL variables: TEST_SUITE: src/__tests__/deploy-velocity/single-gsi-100k-records.test.ts - CLI_REGION: ap-southeast-2 + CLI_REGION: ca-central-1 depend-on: - publish_to_local_registry - identifier: replace_2_gsis_update_attr_100k_records @@ -1771,7 +1781,7 @@ batch: variables: TEST_SUITE: >- src/__tests__/deploy-velocity/replace-2-gsis-update-attr-100k-records.test.ts - CLI_REGION: ca-central-1 + CLI_REGION: eu-central-1 depend-on: - publish_to_local_registry - identifier: replace_2_gsis_100k_records @@ -1780,7 +1790,7 @@ batch: compute-type: BUILD_GENERAL1_SMALL variables: TEST_SUITE: src/__tests__/deploy-velocity/replace-2-gsis-100k-records.test.ts - CLI_REGION: eu-central-1 + CLI_REGION: eu-north-1 depend-on: - publish_to_local_registry - identifier: 3_gsis_100k_records @@ -1789,7 +1799,7 @@ batch: compute-type: BUILD_GENERAL1_SMALL variables: TEST_SUITE: src/__tests__/deploy-velocity/3-gsis-100k-records.test.ts - CLI_REGION: eu-north-1 + CLI_REGION: eu-south-1 depend-on: - publish_to_local_registry - identifier: >- diff --git a/package.json b/package.json index e7728c96a0..843b84f51f 100755 --- a/package.json +++ b/package.json @@ -112,6 +112,7 @@ "@aws-amplify/graphql-api-construct/@aws-amplify/graphql-default-value-transformer", "@aws-amplify/graphql-api-construct/@aws-amplify/graphql-directives", "@aws-amplify/graphql-api-construct/@aws-amplify/graphql-function-transformer", + "@aws-amplify/graphql-api-construct/@aws-amplify/graphql-generation-transformer", "@aws-amplify/graphql-api-construct/@aws-amplify/graphql-http-transformer", "@aws-amplify/graphql-api-construct/@aws-amplify/graphql-index-transformer", "@aws-amplify/graphql-api-construct/@aws-amplify/graphql-maps-to-transformer", @@ -149,6 +150,7 @@ "@aws-amplify/data-construct/@aws-amplify/graphql-auth-transformer", "@aws-amplify/data-construct/@aws-amplify/graphql-default-value-transformer", "@aws-amplify/data-construct/@aws-amplify/graphql-directives", + "@aws-amplify/data-construct/@aws-amplify/graphql-generation-transformer", "@aws-amplify/data-construct/@aws-amplify/graphql-function-transformer", "@aws-amplify/data-construct/@aws-amplify/graphql-http-transformer", "@aws-amplify/data-construct/@aws-amplify/graphql-index-transformer", diff --git a/packages/amplify-data-construct/.jsii b/packages/amplify-data-construct/.jsii index e114c9fe2e..fce90ab130 100644 --- a/packages/amplify-data-construct/.jsii +++ b/packages/amplify-data-construct/.jsii @@ -12,6 +12,7 @@ "@aws-amplify/graphql-default-value-transformer": "3.0.0", "@aws-amplify/graphql-directives": "2.0.0", "@aws-amplify/graphql-function-transformer": "3.0.0", + "@aws-amplify/graphql-generation-transformer": "0.1.0", "@aws-amplify/graphql-http-transformer": "3.0.0", "@aws-amplify/graphql-index-transformer": "3.0.0", "@aws-amplify/graphql-maps-to-transformer": "4.0.0", @@ -3886,5 +3887,5 @@ }, "types": {}, "version": "1.9.6", - "fingerprint": "SdJZCjnS8qNDxzovyPRKBjmX8BKQnFxJLzNMF5SN3Zs=" + "fingerprint": "4GG4kJbY8BkoSdR5tyKQKBR+c7lqlaZWxFJJjeWILUA=" } \ No newline at end of file diff --git a/packages/amplify-data-construct/package.json b/packages/amplify-data-construct/package.json index 7012c4eab6..f1136db6f3 100644 --- a/packages/amplify-data-construct/package.json +++ b/packages/amplify-data-construct/package.json @@ -51,6 +51,7 @@ "@aws-amplify/graphql-relational-transformer", "@aws-amplify/graphql-searchable-transformer", "@aws-amplify/graphql-sql-transformer", + "@aws-amplify/graphql-generation-transformer", "@aws-amplify/platform-core", "@aws-amplify/plugin-types", "fs-extra", @@ -91,6 +92,7 @@ "@aws-amplify/graphql-transformer": "2.0.0", "@aws-amplify/graphql-transformer-core": "3.0.0", "@aws-amplify/graphql-transformer-interfaces": "4.0.0", + "@aws-amplify/graphql-generation-transformer": "0.1.0", "@aws-amplify/platform-core": "^0.2.0", "@aws-amplify/plugin-types": "^0.4.1", "charenc": "^0.0.2", diff --git a/packages/amplify-graphql-api-construct-tests/src/__tests__/generations/API.ts b/packages/amplify-graphql-api-construct-tests/src/__tests__/generations/API.ts new file mode 100644 index 0000000000..2795eea455 --- /dev/null +++ b/packages/amplify-graphql-api-construct-tests/src/__tests__/generations/API.ts @@ -0,0 +1,78 @@ +/* tslint:disable */ +/* eslint-disable */ +// This file was automatically generated and should not be edited. + +export type Recipe = { + __typename: 'Recipe'; + ingredients?: Array | null; + instructions?: string | null; + name?: string | null; +}; + +export type Todo = { + __typename: 'Todo'; + content?: string | null; + createdAt: string; + id: string; + isDone?: boolean | null; + updatedAt: string; +}; + +export type GenerateRecipeQueryVariables = { + description?: string | null; +}; + +export type GenerateRecipeQuery = { + generateRecipe?: { + __typename: 'Recipe'; + ingredients?: Array | null; + instructions?: string | null; + name?: string | null; + } | null; +}; + +export type GetTodoQueryVariables = { + id: string; +}; + +export type GetTodoQuery = { + getTodo?: { + __typename: 'Todo'; + content?: string | null; + createdAt: string; + id: string; + isDone?: boolean | null; + updatedAt: string; + } | null; +}; + +export type MakeTodoQueryVariables = { + description: string; +}; + +export type MakeTodoQuery = { + makeTodo?: { + __typename: 'Todo'; + content?: string | null; + createdAt: string; + id: string; + isDone?: boolean | null; + updatedAt: string; + } | null; +}; + +export type SummarizeQueryVariables = { + input?: string | null; +}; + +export type SummarizeQuery = { + summarize?: string | null; +}; + +export type SolveEquationQueryVariables = { + equation?: string | null; +}; + +export type SolveEquation = { + solveEquation?: number | null; +}; diff --git a/packages/amplify-graphql-api-construct-tests/src/__tests__/generations/generation.test.ts b/packages/amplify-graphql-api-construct-tests/src/__tests__/generations/generation.test.ts new file mode 100644 index 0000000000..fee2f093e6 --- /dev/null +++ b/packages/amplify-graphql-api-construct-tests/src/__tests__/generations/generation.test.ts @@ -0,0 +1,145 @@ +/* eslint-disable import/namespace */ +import * as path from 'path'; +import * as fs from 'fs-extra'; +import { DURATION_20_MINUTES } from '../../utils/duration-constants'; +import { createNewProjectDir, deleteProjectDir } from 'amplify-category-api-e2e-core'; +import { cdkDeploy, cdkDestroy, initCDKProject } from '../../commands'; +import { DDB_AMPLIFY_MANAGED_DATASOURCE_STRATEGY } from '@aws-amplify/graphql-transformer-core'; +import { doAppSyncGraphqlQuery, TestDefinition, writeStackConfig, writeTestDefinitions } from '../../utils'; +import { generateRecipe, makeTodo, solveEquation, summarize } from './graphql/queries'; + +jest.setTimeout(DURATION_20_MINUTES); + +describe('generation', () => { + const baseProjFolderName = path.basename(__filename, '.test.ts'); + + describe('Generation Model', () => { + const projFolderName = `${baseProjFolderName}-model`; + let apiEndpoint: string; + let apiKey: string; + let projRoot: string; + + beforeAll(async () => { + projRoot = await createNewProjectDir(projFolderName); + const templatePath = path.resolve(path.join(__dirname, '..', 'backends', 'configurable-stack')); + const name = await initCDKProject(projRoot, templatePath); + + const generationSchemaPath = path.resolve(path.join(__dirname, 'graphql', 'schema-generation.graphql')); + const generationSchema = fs.readFileSync(generationSchemaPath).toString(); + + const testDefinitions: Record = { + generation: { + schema: [generationSchema].join('\n'), + strategy: DDB_AMPLIFY_MANAGED_DATASOURCE_STRATEGY, + }, + }; + + writeStackConfig(projRoot, { prefix: 'Gen', useSandbox: true }); + writeTestDefinitions(testDefinitions, projRoot); + + const outputs = await cdkDeploy(projRoot, '--all'); + apiEndpoint = outputs[name].awsAppsyncApiEndpoint; + apiKey = outputs[name].awsAppsyncApiKey; + }); + + afterAll(async () => { + try { + await cdkDestroy(projRoot, '--all'); + } catch (err) { + console.log(`Error invoking 'cdk destroy': ${err}`); + } + deleteProjectDir(projRoot); + }); + + describe('Generation type', () => { + test('should generate a type', async () => { + const args = { + apiEndpoint, + auth: { apiKey }, + }; + + const variables = { + description: `I'd like to make a gluten-free chocolate cake.`, + }; + + const generateRecipeResult = await doAppSyncGraphqlQuery({ ...args, query: generateRecipe, variables }); + const recipe = generateRecipeResult.body.data.generateRecipe; + expect(recipe.name).toBeDefined(); + }); + }); + + describe('Generation scalar', () => { + test('should generate a string scalar type', async () => { + const args = { + apiEndpoint, + auth: { apiKey }, + }; + + const variables = { + input: ` + Two roads diverged in a yellow wood, + And sorry I could not travel both + And be one traveler, long I stood + And looked down one as far as I could + To where it bent in the undergrowth; + + Then took the other, as just as fair, + And having perhaps the better claim, + Because it was grassy and wanted wear; + Though as for that the passing there + Had worn them really about the same, + + And both that morning equally lay + In leaves no step had trodden black. + Oh, I kept the first for another day! + Yet knowing how way leads on to way, + I doubted if I should ever come back. + + I shall be telling this with a sigh + Somewhere ages and ages hence: + Two roads diverged in a wood, and I— + I took the one less traveled by, + And that has made all the difference.`, + }; + const summarizeResult = await doAppSyncGraphqlQuery({ ...args, query: summarize, variables }); + const summary = summarizeResult.body.data.summarize; + expect(summary).toBeDefined(); + }); + }); + + test('should generate an int scalar type', async () => { + const args = { + apiEndpoint, + auth: { apiKey }, + }; + + const variables = { + equation: ` + There is a three-digit number. The second digit is four times as big as the third digit, + while the first digit is three less than the second digit. What is the number? + `, + }; + const solveEquationResult = await doAppSyncGraphqlQuery({ ...args, query: solveEquation, variables }); + const solution = solveEquationResult.body.data.solveEquation; + expect(solution).toBeDefined(); + }); + + describe('Generation model', () => { + // TODO: This currently doesn't work because LLMs are not great at following regex pattern requirements, they'll sometimes return "" + // which fails GraphQL type validation for implicitly generated required model values like id, createdAt, updatedAt. + xtest('should generate a model', async () => { + const args = { + apiEndpoint, + auth: { apiKey }, + }; + const variables = { + description: + 'I have to pick up the kids from school. One goes to soccer practice at 3:30pm and the other to swim practice at 4:30pm.', + }; + const makeTodoResult = await doAppSyncGraphqlQuery({ ...args, query: makeTodo, variables }); + const todo = makeTodoResult.body.data.makeTodo; + expect(todo.content).toBeDefined(); + }); + }); + }); +}); diff --git a/packages/amplify-graphql-api-construct-tests/src/__tests__/generations/graphql/queries.ts b/packages/amplify-graphql-api-construct-tests/src/__tests__/generations/graphql/queries.ts new file mode 100644 index 0000000000..39ff593a2f --- /dev/null +++ b/packages/amplify-graphql-api-construct-tests/src/__tests__/generations/graphql/queries.ts @@ -0,0 +1,50 @@ +/* tslint:disable */ +/* eslint-disable */ +// this is an auto generated file. This will be overwritten + +import * as APITypes from '../API'; +type GeneratedQuery = string & { + __generatedQueryInput: InputType; + __generatedQueryOutput: OutputType; +}; + +export const generateRecipe = /* GraphQL */ `query GenerateRecipe($description: String) { + generateRecipe(description: $description) { + ingredients + instructions + name + __typename + } +} +` as GeneratedQuery; +export const getTodo = /* GraphQL */ `query GetTodo($id: ID!) { + getTodo(id: $id) { + content + createdAt + id + isDone + updatedAt + __typename + } +} +` as GeneratedQuery; +export const makeTodo = /* GraphQL */ `query MakeTodo($description: String!) { + makeTodo(description: $description) { + content + createdAt + id + isDone + updatedAt + __typename + } +} +` as GeneratedQuery; +export const summarize = /* GraphQL */ `query Summarize($input: String) { + summarize(input: $input) +} +` as GeneratedQuery; + +export const solveEquation = /* GraphQL */ `query SolveEquation($equation: String) { + solveEquation(equation: $equation) +} +` as GeneratedQuery; diff --git a/packages/amplify-graphql-api-construct-tests/src/__tests__/generations/graphql/schema-generation.graphql b/packages/amplify-graphql-api-construct-tests/src/__tests__/generations/graphql/schema-generation.graphql new file mode 100644 index 0000000000..695a22f1aa --- /dev/null +++ b/packages/amplify-graphql-api-construct-tests/src/__tests__/generations/graphql/schema-generation.graphql @@ -0,0 +1,23 @@ +type Todo @model(queries: { list: null }, mutations: null, subscriptions: null) { + content: String + isDone: Boolean +} + +type Recipe { + name: String + ingredients: [String] + instructions: String +} + +type Query { + makeTodo(description: String!): Todo + @generation(aiModel: "anthropic.claude-3-haiku-20240307-v1:0", systemPrompt: "Make a list of todo items based on the description.") + + summarize(input: String): String @generation(aiModel: "anthropic.claude-3-haiku-20240307-v1:0", systemPrompt: "summarize the input.") + + generateRecipe(description: String): Recipe + @generation(aiModel: "anthropic.claude-3-haiku-20240307-v1:0", systemPrompt: "You are a 3 star michelin chef that generates recipes.") + + solveEquation(equation: String): Int + @generation(aiModel: "anthropic.claude-3-haiku-20240307-v1:0", systemPrompt: "Solve the equation and return the result.") +} diff --git a/packages/amplify-graphql-api-construct/.jsii b/packages/amplify-graphql-api-construct/.jsii index 1db991c7e6..eb32f02e8a 100644 --- a/packages/amplify-graphql-api-construct/.jsii +++ b/packages/amplify-graphql-api-construct/.jsii @@ -12,6 +12,7 @@ "@aws-amplify/graphql-default-value-transformer": "3.0.0", "@aws-amplify/graphql-directives": "2.0.0", "@aws-amplify/graphql-function-transformer": "3.0.0", + "@aws-amplify/graphql-generation-transformer": "0.1.0", "@aws-amplify/graphql-http-transformer": "3.0.0", "@aws-amplify/graphql-index-transformer": "3.0.0", "@aws-amplify/graphql-maps-to-transformer": "4.0.0", @@ -8803,5 +8804,5 @@ } }, "version": "1.11.6", - "fingerprint": "jJBbP9Aj3KaF2GpI7kUkS7tzKs7wX3M5l9q2wtnSSjs=" + "fingerprint": "y1xQmViCyD5+ag1HJuKzy9fnaVPN3Wtfz+iDk01/mwQ=" } \ No newline at end of file diff --git a/packages/amplify-graphql-api-construct/package.json b/packages/amplify-graphql-api-construct/package.json index 39db8636ee..d016eb2c7e 100644 --- a/packages/amplify-graphql-api-construct/package.json +++ b/packages/amplify-graphql-api-construct/package.json @@ -40,6 +40,7 @@ "@aws-amplify/graphql-default-value-transformer", "@aws-amplify/graphql-directives", "@aws-amplify/graphql-function-transformer", + "@aws-amplify/graphql-generation-transformer", "@aws-amplify/graphql-http-transformer", "@aws-amplify/graphql-index-transformer", "@aws-amplify/graphql-maps-to-transformer", @@ -80,6 +81,7 @@ "@aws-amplify/graphql-default-value-transformer": "3.0.0", "@aws-amplify/graphql-directives": "2.0.0", "@aws-amplify/graphql-function-transformer": "3.0.0", + "@aws-amplify/graphql-generation-transformer": "0.1.0", "@aws-amplify/graphql-http-transformer": "3.0.0", "@aws-amplify/graphql-index-transformer": "3.0.0", "@aws-amplify/graphql-maps-to-transformer": "4.0.0", diff --git a/packages/amplify-graphql-generation-transformer/.npmignore b/packages/amplify-graphql-generation-transformer/.npmignore new file mode 100644 index 0000000000..3ee5d55b0b --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/.npmignore @@ -0,0 +1,5 @@ +**/__mocks__/** +**/__tests__/** +src +tsconfig.json +tsconfig.tsbuildinfo diff --git a/packages/amplify-graphql-generation-transformer/API.md b/packages/amplify-graphql-generation-transformer/API.md new file mode 100644 index 0000000000..30ba24ddc6 --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/API.md @@ -0,0 +1,31 @@ +## API Report File for "@aws-amplify/graphql-generation-transformer" + +> Do not edit this file. It is a report generated by [API Extractor](https://api-extractor.com/). + +```ts + +import { DirectiveNode } from 'graphql'; +import { FieldDefinitionNode } from 'graphql'; +import { InterfaceTypeDefinitionNode } from 'graphql'; +import { ObjectTypeDefinitionNode } from 'graphql'; +import { TransformerContextProvider } from '@aws-amplify/graphql-transformer-interfaces'; +import { TransformerPluginBase } from '@aws-amplify/graphql-transformer-core'; +import { TransformerSchemaVisitStepContextProvider } from '@aws-amplify/graphql-transformer-interfaces'; + +// @public (undocumented) +export class GenerationTransformer extends TransformerPluginBase { + constructor(); + // (undocumented) + field: ( + parent: ObjectTypeDefinitionNode | InterfaceTypeDefinitionNode, + definition: FieldDefinitionNode, + directive: DirectiveNode, + context: TransformerSchemaVisitStepContextProvider, + ) => void; + // (undocumented) + generateResolvers: (ctx: TransformerContextProvider) => void; +} + +// (No @packageDocumentation comment for this package) + +``` diff --git a/packages/amplify-graphql-generation-transformer/README.md b/packages/amplify-graphql-generation-transformer/README.md new file mode 100644 index 0000000000..0dd6aaa9a2 --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/README.md @@ -0,0 +1,85 @@ +# Amplify GraphQL Generation Transformer + +The Amplify GraphQL Generation Transformer is a tool that enables the quick and easy creation of AI-powered Generation routes within your AWS AppSync API. This transformer can be leveraged by using the `@generation` directive to configure AI models and system prompts for generating content. + +## Installation + +```bash +npm install @aws-amplify/graphql-generation-transformer +``` + +## Directive Definition + +The `@generation` directive is defined as follows: + +```graphql +directive @generation(aiModel: String!, systemPrompt: String!, inferenceConfiguration: GenerationInferenceConfiguration) on FIELD_DEFINITION +``` + +## Features + +1. AI Model Integration: Specify the AI model to be used for generation. +2. System Prompt Configuration: Define a system prompt to guide the AI's output. +3. Inference Configuration: Fine-tune generation parameters like max tokens, temperature, and top-p. +4. Integrates with `@auth` Directive: Supports existing auth modes like IAM, API key, and Amazon Cognito User Pools. +5. Resolver Creation: Generates resolvers with tool definitions based on the Query field's return type to interact with the specified AI model. +6. Bedrock HTTP Data Source Creation: Creates a AppSync HTTP Data Source for Bedrock to interact with the specified AI model. + +## Examples + +### Basic Usage + +#### Scalar Type Generation + +```graphql +type Query { + generateStory(topic: String!): String + @generation( + aiModel: "anthropic.claude-3-haiku-20240307-v1:0" + systemPrompt: "You are a creative storyteller. Generate a short story based on the given topic." + ) +} +``` + +#### Complex Type Generation + +```graphql +type Recipe { + name: String! + ingredients: [String!]! + instructions: [String!]! + prepTime: Int! + cookTime: Int! + servings: Int! + difficulty: String! +} + +type Query { + generateRecipe(cuisine: String!, dietaryRestrictions: [String]): Recipe + @generation( + aiModel: "anthropic.claude-3-haiku-20240307-v1:0" + systemPrompt: "You are a professional chef specializing in creating recipes. Generate a detailed recipe based on the given cuisine and dietary restrictions." + ) +} +``` + +### Advanced Configuration + +```graphql +type Query { + generateCode(description: String!): String + @generation( + aiModel: "anthropic.claude-3-haiku-20240307-v1:0" + systemPrompt: "You are an expert programmer. Generate code based on the given description." + inferenceConfiguration: { maxTokens: 500, temperature: 0.7, topP: 0.9 } + ) +} +``` + +## Limitations + +- The `@generation` directive can only be used on Query fields. +- The AI model specified must: + - be supported by Amazon Bedrock's /converse API + - support tool usage +- Some AppSync scalar types are not currently supported. diff --git a/packages/amplify-graphql-generation-transformer/package.json b/packages/amplify-graphql-generation-transformer/package.json new file mode 100644 index 0000000000..7933f52a07 --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/package.json @@ -0,0 +1,78 @@ +{ + "name": "@aws-amplify/graphql-generation-transformer", + "version": "0.1.0", + "description": "Amplify GraphQL @generation transformer", + "repository": { + "type": "git", + "url": "https://github.com/aws-amplify/amplify-category-api.git", + "directory": "packages/amplify-graphql-generation-transformer" + }, + "author": "Amazon Web Services", + "license": "Apache-2.0", + "main": "lib/index.js", + "types": "lib/index.d.ts", + "keywords": [], + "publishConfig": { + "access": "public" + }, + "scripts": { + "build": "tsc", + "watch": "tsc -w", + "clean": "rimraf ./lib", + "test": "jest", + "extract-api": "ts-node ../../scripts/extract-api.ts" + }, + "dependencies": { + "@aws-amplify/graphql-directives": "2.0.0", + "@aws-amplify/graphql-transformer-core": "3.0.0", + "@aws-amplify/graphql-transformer-interfaces": "4.0.0", + "graphql": "^15.5.0", + "graphql-mapping-template": "5.0.0", + "graphql-transformer-common": "5.0.0", + "immer": "^9.0.12" + }, + "devDependencies": { + "@aws-amplify/graphql-transformer-test-utils": "0.6.0" + }, + "peerDependencies": { + "aws-cdk-lib": "^2.152.0", + "constructs": "^10.3.0" + }, + "jest": { + "transform": { + "^.+\\.(ts|tsx)?$": "ts-jest" + }, + "testRegex": "(src/__tests__/.*.test.ts)$", + "moduleFileExtensions": [ + "ts", + "tsx", + "js", + "jsx", + "json", + "node" + ], + "collectCoverage": true, + "coverageProvider": "v8", + "coverageThreshold": { + "global": { + "branches": 90, + "functions": 100, + "lines": 99 + } + }, + "coverageReporters": [ + "clover", + "text" + ], + "modulePathIgnorePatterns": [ + "overrides" + ], + "testEnvironment": "../../FixJestEnvironment.js", + "collectCoverageFrom": [ + "src/**/*.ts" + ], + "coveragePathIgnorePatterns": [ + "/__tests__/" + ] + } +} diff --git a/packages/amplify-graphql-generation-transformer/src/__tests__/__snapshots__/amplify-graphql-generation-transformer.test.ts.snap b/packages/amplify-graphql-generation-transformer/src/__tests__/__snapshots__/amplify-graphql-generation-transformer.test.ts.snap new file mode 100644 index 0000000000..277e4297d9 --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/src/__tests__/__snapshots__/amplify-graphql-generation-transformer.test.ts.snap @@ -0,0 +1,401 @@ +// Jest Snapshot v1, https://goo.gl/fbAQLP + +exports[`generation route all scalar types 1`] = ` +Object { + "Fn::Join": Array [ + "", + Array [ + "export const request = (ctx) => { + ctx.stash.typeName = \\"Query\\"; + ctx.stash.fieldName = \\"makeBox\\"; + ctx.stash.conditions = []; + ctx.stash.metadata = {}; + ctx.stash.metadata.dataSourceType = \\"HTTP\\"; + ctx.stash.metadata.apiId = \\"", + Object { + "Fn::GetAtt": Array [ + "GraphQLAPI", + "ApiId", + ], + }, + "\\"; + ctx.stash.connectionAttributes = {}; + ctx.stash.endpoint = \\"https://bedrock-runtime.", + Object { + "Ref": "AWS::Region", + }, + ".amazonaws.com\\"; + ctx.stash.authRole = \\"arn:aws:sts::", + Object { + "Ref": "AWS::AccountId", + }, + ":assumed-role/", + Object { + "Ref": "authRoleName", + }, + "/CognitoIdentityCredentials\\"; + ctx.stash.unauthRole = \\"arn:aws:sts::", + Object { + "Ref": "AWS::AccountId", + }, + ":assumed-role/", + Object { + "Ref": "unauthRoleName", + }, + "/CognitoIdentityCredentials\\"; + ctx.stash.adminRoles = []; + return {}; +} + +export const response = (ctx) => { + return ctx.prev.result; +};", + ], + ], +} +`; + +exports[`generation route all scalar types 2`] = ` +"export function request(ctx) { + const toolConfig = {\\"tools\\":[{\\"toolSpec\\":{\\"name\\":\\"responseType\\",\\"description\\":\\"Generate a response type for the given field\\",\\"inputSchema\\":{\\"json\\":{\\"type\\":\\"object\\",\\"properties\\":{\\"value\\":{\\"type\\":\\"object\\",\\"properties\\":{\\"int\\":{\\"type\\":\\"number\\",\\"description\\":\\"A signed 32-bit integer value.\\"},\\"float\\":{\\"type\\":\\"number\\",\\"description\\":\\"An IEEE 754 floating point value.\\"},\\"string\\":{\\"type\\":\\"string\\",\\"description\\":\\"A UTF-8 character sequence.\\"},\\"id\\":{\\"type\\":\\"string\\",\\"description\\":\\"A unique identifier for an object. This scalar is serialized like a String but isn't meant to be human-readable.\\"},\\"boolean\\":{\\"type\\":\\"boolean\\",\\"description\\":\\"A boolean value.\\"},\\"awsjson\\":{\\"type\\":\\"string\\",\\"description\\":\\"A JSON string. Any valid JSON construct is automatically parsed and loaded in the resolver code as maps, lists, or scalar values rather than as the literal input strings. Unquoted strings or otherwise invalid JSON result in a GraphQL validation error.\\"},\\"awsemail\\":{\\"type\\":\\"string\\",\\"description\\":\\"An email address in the format local-part@domain-part as defined by RFC 822.\\",\\"pattern\\":\\"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\\\\\\\.[a-zA-Z]{2,}$\\"},\\"awsdate\\":{\\"type\\":\\"string\\",\\"description\\":\\"An extended ISO 8601 date string in the format YYYY-MM-DD.\\",\\"pattern\\":\\"^\\\\\\\\d{4}-d{2}-d{2}$\\"},\\"awstime\\":{\\"type\\":\\"string\\",\\"description\\":\\"An extended ISO 8601 time string in the format hh:mm:ss.sss.\\",\\"pattern\\":\\"^\\\\\\\\d{2}:\\\\\\\\d{2}:\\\\\\\\d{2}\\\\\\\\.\\\\\\\\d{3}$\\"},\\"awsdatetime\\":{\\"type\\":\\"string\\",\\"description\\":\\"An extended ISO 8601 date and time string in the format YYYY-MM-DDThh:mm:ss.sssZ.\\",\\"pattern\\":\\"^\\\\\\\\d{4}-\\\\\\\\d{2}-\\\\\\\\d{2}T\\\\\\\\d{2}:\\\\\\\\d{2}:\\\\\\\\d{2}\\\\\\\\.\\\\\\\\d{3}Z$\\"},\\"awstimestamp\\":{\\"type\\":\\"string\\",\\"description\\":\\"An integer value representing the number of seconds before or after 1970-01-01-T00:00Z.\\",\\"pattern\\":\\"^\\\\\\\\d+$\\"},\\"awsphone\\":{\\"type\\":\\"string\\",\\"description\\":\\"A phone number. This value is stored as a string. Phone numbers can contain either spaces or hyphens to separate digit groups. Phone numbers without a country code are assumed to be US/North American numbers adhering to the North American Numbering Plan (NANP).\\",\\"pattern\\":\\"^\\\\\\\\d{3}-d{3}-d{4}$\\"},\\"awsurl\\":{\\"type\\":\\"string\\",\\"description\\":\\"A URL as defined by RFC 1738. For example, https://www.amazon.com/dp/B000NZW3KC/ or mailto:example@example.com. URLs must contain a schema (http, mailto) and can't contain two forward slashes (//) in the path part.\\",\\"pattern\\":\\"^(https?|mailto)://[^s/$.?#].[^s]*$\\"},\\"awsipaddress\\":{\\"type\\":\\"string\\",\\"description\\":\\"A valid IPv4 or IPv6 address. IPv4 addresses are expected in quad-dotted notation (123.12.34.56). IPv6 addresses are expected in non-bracketed, colon-separated format (1a2b:3c4b::1234:4567). You can include an optional CIDR suffix (123.45.67.89/16) to indicate subnet mask.\\"}},\\"required\\":[]}},\\"required\\":[\\"value\\"]}}}}],\\"toolChoice\\":{\\"tool\\":{\\"name\\":\\"responseType\\"}}}; + const prompt = \\"\\"; + const args = JSON.stringify(ctx.args); + + return { + resourcePath: '/model/anthropic.claude-3-haiku-20240307-v1:0/converse', + method: 'POST', + params: { + headers: { 'Content-Type': 'application/json' }, + body: { + messages: [{ + role: 'user', + content: [{ text: args }], + }], + system: [{ text: prompt }], + toolConfig, + // default inference config + } + } + } +} + +export function response(ctx) { + if (ctx.error) { + util.error(ctx.error.message, ctx.error.type); + } + const body = JSON.parse(ctx.result.body); + const { content } = body.output.message; + + if (content.length < 1) { + util.error('No content block in assistant response.', 'error'); + } + + const toolUse = content[0].toolUse; + if (!toolUse) { + util.error('Missing tool use block in assistant response.', 'error'); + } + + const response = toolUse.input.value; + return response; +}" +`; + +exports[`generation route custom query 1`] = ` +Object { + "Fn::Join": Array [ + "", + Array [ + "export const request = (ctx) => { + ctx.stash.typeName = \\"Query\\"; + ctx.stash.fieldName = \\"generateRecipe\\"; + ctx.stash.conditions = []; + ctx.stash.metadata = {}; + ctx.stash.metadata.dataSourceType = \\"HTTP\\"; + ctx.stash.metadata.apiId = \\"", + Object { + "Fn::GetAtt": Array [ + "GraphQLAPI", + "ApiId", + ], + }, + "\\"; + ctx.stash.connectionAttributes = {}; + ctx.stash.endpoint = \\"https://bedrock-runtime.", + Object { + "Ref": "AWS::Region", + }, + ".amazonaws.com\\"; + ctx.stash.authRole = \\"arn:aws:sts::", + Object { + "Ref": "AWS::AccountId", + }, + ":assumed-role/", + Object { + "Ref": "authRoleName", + }, + "/CognitoIdentityCredentials\\"; + ctx.stash.unauthRole = \\"arn:aws:sts::", + Object { + "Ref": "AWS::AccountId", + }, + ":assumed-role/", + Object { + "Ref": "unauthRoleName", + }, + "/CognitoIdentityCredentials\\"; + ctx.stash.adminRoles = []; + return {}; +} + +export const response = (ctx) => { + return ctx.prev.result; +};", + ], + ], +} +`; + +exports[`generation route custom query 2`] = ` +"export function request(ctx) { + const toolConfig = {\\"tools\\":[{\\"toolSpec\\":{\\"name\\":\\"responseType\\",\\"description\\":\\"Generate a response type for the given field\\",\\"inputSchema\\":{\\"json\\":{\\"type\\":\\"object\\",\\"properties\\":{\\"value\\":{\\"type\\":\\"object\\",\\"properties\\":{\\"name\\":{\\"type\\":\\"string\\",\\"description\\":\\"A UTF-8 character sequence.\\"},\\"ingredients\\":{\\"type\\":\\"array\\",\\"items\\":{\\"type\\":\\"string\\",\\"description\\":\\"A UTF-8 character sequence.\\"}},\\"instructions\\":{\\"type\\":\\"string\\",\\"description\\":\\"A UTF-8 character sequence.\\"},\\"meal\\":{\\"type\\":\\"object\\",\\"properties\\":{\\"Meal\\":{\\"type\\":\\"string\\",\\"enum\\":[\\"BREAKFAST\\",\\"LUNCH\\",\\"DINNER\\"]}},\\"required\\":[]}},\\"required\\":[]}},\\"required\\":[\\"value\\"]}}}}],\\"toolChoice\\":{\\"tool\\":{\\"name\\":\\"responseType\\"}}}; + const prompt = \\"You are a helpful assistant that generates recipes.\\"; + const args = JSON.stringify(ctx.args); + + return { + resourcePath: '/model/anthropic.claude-3-haiku-20240307-v1:0/converse', + method: 'POST', + params: { + headers: { 'Content-Type': 'application/json' }, + body: { + messages: [{ + role: 'user', + content: [{ text: args }], + }], + system: [{ text: prompt }], + toolConfig, + // default inference config + } + } + } +} + +export function response(ctx) { + if (ctx.error) { + util.error(ctx.error.message, ctx.error.type); + } + const body = JSON.parse(ctx.result.body); + const { content } = body.output.message; + + if (content.length < 1) { + util.error('No content block in assistant response.', 'error'); + } + + const toolUse = content[0].toolUse; + if (!toolUse) { + util.error('Missing tool use block in assistant response.', 'error'); + } + + const response = toolUse.input.value; + return response; +}" +`; + +exports[`generation route model type with null timestamps 1`] = ` +Object { + "Fn::Join": Array [ + "", + Array [ + "export const request = (ctx) => { + ctx.stash.typeName = \\"Query\\"; + ctx.stash.fieldName = \\"makeTodo\\"; + ctx.stash.conditions = []; + ctx.stash.metadata = {}; + ctx.stash.metadata.dataSourceType = \\"HTTP\\"; + ctx.stash.metadata.apiId = \\"", + Object { + "Fn::GetAtt": Array [ + "GraphQLAPI", + "ApiId", + ], + }, + "\\"; + ctx.stash.connectionAttributes = {}; + ctx.stash.endpoint = \\"https://bedrock-runtime.", + Object { + "Ref": "AWS::Region", + }, + ".amazonaws.com\\"; + ctx.stash.authRole = \\"arn:aws:sts::", + Object { + "Ref": "AWS::AccountId", + }, + ":assumed-role/", + Object { + "Ref": "authRoleName", + }, + "/CognitoIdentityCredentials\\"; + ctx.stash.unauthRole = \\"arn:aws:sts::", + Object { + "Ref": "AWS::AccountId", + }, + ":assumed-role/", + Object { + "Ref": "unauthRoleName", + }, + "/CognitoIdentityCredentials\\"; + ctx.stash.adminRoles = []; + return {}; +} + +export const response = (ctx) => { + return ctx.prev.result; +};", + ], + ], +} +`; + +exports[`generation route model type with null timestamps 2`] = ` +"export function request(ctx) { + const toolConfig = {\\"tools\\":[{\\"toolSpec\\":{\\"name\\":\\"responseType\\",\\"description\\":\\"Generate a response type for the given field\\",\\"inputSchema\\":{\\"json\\":{\\"type\\":\\"object\\",\\"properties\\":{\\"value\\":{\\"type\\":\\"object\\",\\"properties\\":{\\"content\\":{\\"type\\":\\"string\\",\\"description\\":\\"A UTF-8 character sequence.\\"},\\"isDone\\":{\\"type\\":\\"boolean\\",\\"description\\":\\"A boolean value.\\"},\\"id\\":{\\"type\\":\\"string\\",\\"description\\":\\"A unique identifier for an object. This scalar is serialized like a String but isn't meant to be human-readable.\\"}},\\"required\\":[\\"id\\"]}},\\"required\\":[\\"value\\"]}}}}],\\"toolChoice\\":{\\"tool\\":{\\"name\\":\\"responseType\\"}}}; + const prompt = \\"Make a string based on the description.\\"; + const args = JSON.stringify(ctx.args); + + return { + resourcePath: '/model/anthropic.claude-3-haiku-20240307-v1:0/converse', + method: 'POST', + params: { + headers: { 'Content-Type': 'application/json' }, + body: { + messages: [{ + role: 'user', + content: [{ text: args }], + }], + system: [{ text: prompt }], + toolConfig, + // default inference config + } + } + } +} + +export function response(ctx) { + if (ctx.error) { + util.error(ctx.error.message, ctx.error.type); + } + const body = JSON.parse(ctx.result.body); + const { content } = body.output.message; + + if (content.length < 1) { + util.error('No content block in assistant response.', 'error'); + } + + const toolUse = content[0].toolUse; + if (!toolUse) { + util.error('Missing tool use block in assistant response.', 'error'); + } + + const response = toolUse.input.value; + return response; +}" +`; + +exports[`generation route scalar type 1`] = ` +Object { + "Fn::Join": Array [ + "", + Array [ + "export const request = (ctx) => { + ctx.stash.typeName = \\"Query\\"; + ctx.stash.fieldName = \\"makeTodo\\"; + ctx.stash.conditions = []; + ctx.stash.metadata = {}; + ctx.stash.metadata.dataSourceType = \\"HTTP\\"; + ctx.stash.metadata.apiId = \\"", + Object { + "Fn::GetAtt": Array [ + "GraphQLAPI", + "ApiId", + ], + }, + "\\"; + ctx.stash.connectionAttributes = {}; + ctx.stash.endpoint = \\"https://bedrock-runtime.", + Object { + "Ref": "AWS::Region", + }, + ".amazonaws.com\\"; + ctx.stash.authRole = \\"arn:aws:sts::", + Object { + "Ref": "AWS::AccountId", + }, + ":assumed-role/", + Object { + "Ref": "authRoleName", + }, + "/CognitoIdentityCredentials\\"; + ctx.stash.unauthRole = \\"arn:aws:sts::", + Object { + "Ref": "AWS::AccountId", + }, + ":assumed-role/", + Object { + "Ref": "unauthRoleName", + }, + "/CognitoIdentityCredentials\\"; + ctx.stash.adminRoles = []; + return {}; +} + +export const response = (ctx) => { + return ctx.prev.result; +};", + ], + ], +} +`; + +exports[`generation route scalar type 2`] = ` +"export function request(ctx) { + const toolConfig = {\\"tools\\":[{\\"toolSpec\\":{\\"name\\":\\"responseType\\",\\"description\\":\\"Generate a response type for the given field\\",\\"inputSchema\\":{\\"json\\":{\\"type\\":\\"object\\",\\"properties\\":{\\"value\\":{\\"type\\":\\"string\\",\\"description\\":\\"A UTF-8 character sequence.\\"}},\\"required\\":[\\"value\\"]}}}}],\\"toolChoice\\":{\\"tool\\":{\\"name\\":\\"responseType\\"}}}; + const prompt = \\"Make a string based on the description.\\"; + const args = JSON.stringify(ctx.args); + + return { + resourcePath: '/model/anthropic.claude-3-haiku-20240307-v1:0/converse', + method: 'POST', + params: { + headers: { 'Content-Type': 'application/json' }, + body: { + messages: [{ + role: 'user', + content: [{ text: args }], + }], + system: [{ text: prompt }], + toolConfig, + // default inference config + } + } + } +} + +export function response(ctx) { + if (ctx.error) { + util.error(ctx.error.message, ctx.error.type); + } + const body = JSON.parse(ctx.result.body); + const { content } = body.output.message; + + if (content.length < 1) { + util.error('No content block in assistant response.', 'error'); + } + + const toolUse = content[0].toolUse; + if (!toolUse) { + util.error('Missing tool use block in assistant response.', 'error'); + } + + const response = toolUse.input.value; + return response; +}" +`; diff --git a/packages/amplify-graphql-generation-transformer/src/__tests__/amplify-graphql-generation-transformer.test.ts b/packages/amplify-graphql-generation-transformer/src/__tests__/amplify-graphql-generation-transformer.test.ts new file mode 100644 index 0000000000..14b33f7940 --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/src/__tests__/amplify-graphql-generation-transformer.test.ts @@ -0,0 +1,351 @@ +import { AuthTransformer } from '@aws-amplify/graphql-auth-transformer'; +import { IndexTransformer, PrimaryKeyTransformer } from '@aws-amplify/graphql-index-transformer'; +import { ModelTransformer } from '@aws-amplify/graphql-model-transformer'; +import { DDB_AMPLIFY_MANAGED_DATASOURCE_STRATEGY, validateModelSchema } from '@aws-amplify/graphql-transformer-core'; +import { AppSyncAuthConfiguration, ModelDataSourceStrategy } from '@aws-amplify/graphql-transformer-interfaces'; +import { DeploymentResources, testTransform } from '@aws-amplify/graphql-transformer-test-utils'; +import { parse } from 'graphql'; +import { GenerationTransformer } from '..'; +import { BelongsToTransformer, HasManyTransformer, HasOneTransformer } from '@aws-amplify/graphql-relational-transformer'; + +const todoModel = ` +type Todo @model { + content: String + isDone: Boolean +}`; + +test('generation route model list response type', () => { + const queryName = 'makeTodos'; + const inputSchema = ` + ${todoModel} + + type Query { + ${queryName}(description: String!): [Todo] + @generation( + aiModel: "anthropic.claude-3-haiku-20240307-v1:0", + systemPrompt: "Make a list of todo items based on the description." + ) + @auth(rules: [{ allow: public, provider: iam }]) + } + `; + + // Models are not currently supported for generation routes. + // This test can fail on `createdAt` or `updatedAt` fields. Hence the generalized error message assertion. + expect(() => transform(inputSchema)).toThrow(/Disallowed required field type/); +}); + +test('generation route scalar type', () => { + const queryName = 'makeTodo'; + const inputSchema = ` + type Query { + ${queryName}(description: String!): String + @generation( + aiModel: "anthropic.claude-3-haiku-20240307-v1:0", + systemPrompt: "Make a string based on the description.", + ) + @auth(rules: [{ allow: public, provider: iam }]) + } + `; + const out = transform(inputSchema); + expect(out).toBeDefined(); + + const resolverCode = getResolverResource(queryName, out.rootStack.Resources)['Properties']['Code']; + expect(resolverCode).toBeDefined(); + expect(resolverCode).toMatchSnapshot(); + + const resolverFnCode = getResolverFnResource(queryName, out.rootStack.Resources)['Properties']['Code']; + expect(resolverFnCode).toBeDefined(); + expect(resolverFnCode).toMatchSnapshot(); + + const schema = parse(out.schema); + validateModelSchema(schema); +}); + +test('generation route custom query', () => { + const queryName = 'generateRecipe'; + const inputSchema = ` + type Recipe { + name: String + ingredients: [String] + instructions: String + meal: Meal + } + + enum Meal { + BREAKFAST + LUNCH + DINNER + } + + type Query { + ${queryName}(description: String!): Recipe + @generation( + aiModel: "anthropic.claude-3-haiku-20240307-v1:0", + systemPrompt: "You are a helpful assistant that generates recipes.", + ) + } + `; + const out = transform(inputSchema); + expect(out).toBeDefined(); + + const resolverCode = getResolverResource(queryName, out.rootStack.Resources)['Properties']['Code']; + expect(resolverCode).toBeDefined(); + expect(resolverCode).toMatchSnapshot(); + + const resolverFnCode = getResolverFnResource(queryName, out.rootStack.Resources)['Properties']['Code']; + expect(resolverFnCode).toBeDefined(); + expect(resolverFnCode).toMatchSnapshot(); + + const schema = parse(out.schema); + validateModelSchema(schema); +}); + +test('generation route model type with null timestamps', () => { + const queryName = 'makeTodo'; + const inputSchema = ` + type Todo @model(timestamps: {createdAt: null, updatedAt: null}) { + content: String + isDone: Boolean + } + + type Query { + ${queryName}(description: String!): Todo + @generation( + aiModel: "anthropic.claude-3-haiku-20240307-v1:0", + systemPrompt: "Make a string based on the description.", + ) + } + `; + const out = transform(inputSchema, { + Todo: DDB_AMPLIFY_MANAGED_DATASOURCE_STRATEGY, + }); + + const resolverCode = getResolverResource(queryName, out.rootStack.Resources)['Properties']['Code']; + expect(resolverCode).toBeDefined(); + expect(resolverCode).toMatchSnapshot(); + + const resolverFnCode = getResolverFnResource(queryName, out.rootStack.Resources)['Properties']['Code']; + expect(resolverFnCode).toBeDefined(); + expect(resolverFnCode).toMatchSnapshot(); + + const schema = parse(out.schema); + validateModelSchema(schema); +}); + +test('generation route required model type required field', () => { + const queryName = 'makeTodo'; + const inputSchema = ` + ${todoModel} + + type Query { + ${queryName}(description: String!): Todo! + @generation( + aiModel: "anthropic.claude-3-haiku-20240307-v1:0", + systemPrompt: "Make a string based on the description.", + ) + } + `; + + // Models are not currently supported for generation routes. + // This test can fail on `createdAt` or `updatedAt` fields. Hence the generalized error message assertion. + expect(() => transform(inputSchema)).toThrow(/Disallowed required field type/); +}); + +test('generation route invalid field type in response type', () => { + const inputSchema = ` + union Foo = Bar | Baz + type Bar { + value: String + } + + type Baz { + value: Int + } + + type Query { + makeFoo(description: String!): Foo + @generation( + aiModel: "anthropic.claude-3-haiku-20240307-v1:0", + systemPrompt: "", + ) + } + `; + + expect(() => transform(inputSchema)).toThrow('Unsupported type definition: UnionTypeDefinition'); +}); + +test('generation route invalid parent type', () => { + const inputSchema = ` + type Thing { + int: Int + } + + type Mutation { + makeThing(description: String!): Thing + @generation( + aiModel: "anthropic.claude-3-haiku-20240307-v1:0", + systemPrompt: "", + ) + } + `; + + expect(() => transform(inputSchema)).toThrow('@generation directive must be used on Query field.'); +}); + +test('generation route all scalar types', () => { + const queryName = 'makeBox'; + const inputSchema = ` + type Box { + int: Int + float: Float + string: String + id: ID + boolean: Boolean + awsjson: AWSJSON + awsemail: AWSEmail + awsdate: AWSDate + awstime: AWSTime + awsdatetime: AWSDateTime + awstimestamp: AWSTimestamp + awsphone: AWSPhone + awsurl: AWSURL + awsipaddress: AWSIPAddress + } + + type Query { + makeBox(description: String!): Box + @generation( + aiModel: "anthropic.claude-3-haiku-20240307-v1:0", + systemPrompt: "", + ) + } + `; + + const out = transform(inputSchema); + const resolverCode = getResolverResource(queryName, out.rootStack.Resources)['Properties']['Code']; + expect(resolverCode).toBeDefined(); + expect(resolverCode).toMatchSnapshot(); + + const resolverFnCode = getResolverFnResource(queryName, out.rootStack.Resources)['Properties']['Code']; + expect(resolverFnCode).toBeDefined(); + expect(resolverFnCode).toMatchSnapshot(); + + const schema = parse(out.schema); + validateModelSchema(schema); +}); + +describe('generation route invalid inference configuration', () => { + const maxTokens = 'inferenceConfiguration: { maxTokens: 0 }'; + const temperature = { + over: 'inferenceConfiguration: { temperature: 1.1 }', + under: 'inferenceConfiguration: { temperature: -0.1 }', + }; + const topP = { + over: 'inferenceConfiguration: { topP: 1.1 }', + under: 'inferenceConfiguration: { topP: -0.1 }', + }; + + const generationRoute = (invalidInferenceConfig: string): string => { + return ` + type Query { + generate(description: String!): String + @generation( + aiModel: "anthropic.claude-3-haiku-20240307-v1:0", + systemPrompt: "Make a string based on the description.", + ${invalidInferenceConfig} + ) + } + `; + }; + + test('maxTokens invalid', () => { + expect(() => transform(generationRoute(maxTokens))).toThrow( + '@generation directive maxTokens valid range: Minimum value of 1. Provided: 0', + ); + }); + + test('temperature over', () => { + expect(() => transform(generationRoute(temperature.over))).toThrow( + '@generation directive temperature valid range: Minimum value of 0. Maximum value of 1. Provided: 1.1', + ); + }); + + test('topP over', () => { + expect(() => transform(generationRoute(topP.over))).toThrow( + '@generation directive topP valid range: Minimum value of 0. Maximum value of 1. Provided: 1.1', + ); + }); + + test('temperature under', () => { + expect(() => transform(generationRoute(temperature.under))).toThrow( + '@generation directive temperature valid range: Minimum value of 0. Maximum value of 1. Provided: -0.1', + ); + }); + + test('topP under', () => { + expect(() => transform(generationRoute(topP.under))).toThrow( + '@generation directive topP valid range: Minimum value of 0. Maximum value of 1. Provided: -0.1', + ); + }); +}); + +const getResolverResource = (queryName: string, resources?: Record): Record => { + const resolverName = `Query${queryName}Resolver`; + return resources?.[resolverName]; +}; + +const getResolverFnResource = (queryName: string, resources?: Record): Record => { + const capitalizedQueryName = queryName.charAt(0).toUpperCase() + queryName.slice(1); + const resourcePrefix = `Query${capitalizedQueryName}DataResolverFn`; + if (!resources) { + fail('No resources found.'); + } + const resource = Object.entries(resources).find(([key, _]) => { + return key.startsWith(resourcePrefix); + })?.[1]; + + if (!resource) { + fail(`Resource named with prefix ${resourcePrefix} not found.`); + } + return resource; +}; + +const defaultAuthConfig: AppSyncAuthConfiguration = { + defaultAuthentication: { + authenticationType: 'AWS_IAM', + }, + additionalAuthenticationProviders: [], +}; + +function transform( + inputSchema: string, + dataSourceStrategies?: Record, + authConfig: AppSyncAuthConfiguration = defaultAuthConfig, +): DeploymentResources { + const modelTransformer = new ModelTransformer(); + const authTransformer = new AuthTransformer(); + const indexTransformer = new IndexTransformer(); + const hasOneTransformer = new HasOneTransformer(); + const belongsToTransformer = new BelongsToTransformer(); + const hasManyTransformer = new HasManyTransformer(); + + const transformers = [ + modelTransformer, + new PrimaryKeyTransformer(), + indexTransformer, + hasManyTransformer, + hasOneTransformer, + belongsToTransformer, + new GenerationTransformer(), + authTransformer, + ]; + + const out = testTransform({ + schema: inputSchema, + authConfig, + transformers, + dataSourceStrategies, + }); + + return out; +} diff --git a/packages/amplify-graphql-generation-transformer/src/grapqhl-generation-transformer.ts b/packages/amplify-graphql-generation-transformer/src/grapqhl-generation-transformer.ts new file mode 100644 index 0000000000..63a025349a --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/src/grapqhl-generation-transformer.ts @@ -0,0 +1,221 @@ +import { GenerationDirective } from '@aws-amplify/graphql-directives'; +import { + DirectiveWrapper, + TransformerPluginBase, + generateGetArgumentsInput, + TransformerResolver, +} from '@aws-amplify/graphql-transformer-core'; +import { + MappingTemplateProvider, + TransformerContextProvider, + TransformerSchemaVisitStepContextProvider, +} from '@aws-amplify/graphql-transformer-interfaces'; +import { DirectiveNode, FieldDefinitionNode, InterfaceTypeDefinitionNode, ObjectTypeDefinitionNode } from 'graphql'; +import { HttpResourceIDs, ResolverResourceIDs } from 'graphql-transformer-common'; +import { ToolConfig, createResponseTypeTool } from './utils/tools'; +import * as cdk from 'aws-cdk-lib'; +import { createInvokeBedrockResolverFunction } from './resolvers/invoke-bedrock'; +import * as iam from 'aws-cdk-lib/aws-iam'; +import { Construct } from 'constructs'; +import { validate } from './validation'; +import { toUpper } from 'graphql-transformer-common'; + +export type GenerationDirectiveConfiguration = { + parent: ObjectTypeDefinitionNode; + directive: DirectiveNode; + aiModel: string; + field: FieldDefinitionNode; + systemPrompt: string; + inferenceConfiguration: InferenceConfiguration; +}; + +export type InferenceConfiguration = { + maxTokens?: number; + temperature?: number; + topP?: number; +}; + +export type GenerationConfigurationWithToolConfig = GenerationDirectiveConfiguration & { + toolConfig: ToolConfig; +}; + +export class GenerationTransformer extends TransformerPluginBase { + private directives: GenerationDirectiveConfiguration[] = []; + + constructor() { + super('amplify-generation-transformer', GenerationDirective.definition); + } + + field = ( + parent: ObjectTypeDefinitionNode | InterfaceTypeDefinitionNode, + definition: FieldDefinitionNode, + directive: DirectiveNode, + context: TransformerSchemaVisitStepContextProvider, + ): void => { + const directiveWrapped = new DirectiveWrapper(directive); + const config = directiveWrapped.getArguments( + { + parent, + directive, + field: definition, + inferenceConfiguration: {}, + } as GenerationDirectiveConfiguration, + generateGetArgumentsInput(context.transformParameters), + ); + + validate(config, context as TransformerContextProvider); + this.directives.push(config); + }; + + generateResolvers = (ctx: TransformerContextProvider): void => { + if (this.directives.length === 0) return; + + this.directives.forEach((directive) => { + const { parent, field } = directive; + const fieldName = field.name.value; + const parentName = parent.name.value; + + const directiveWithToolConfig: GenerationConfigurationWithToolConfig = { + ...directive, + toolConfig: createResponseTypeTool(field, ctx), + }; + + const stackName = this.bedrockDataSourceName(fieldName) + 'Stack'; + const stack = this.createStack(ctx, stackName); + + const resolverResourceId = ResolverResourceIDs.ResolverResourceID(parentName, fieldName); + const httpDataSourceId = HttpResourceIDs.HttpDataSourceID(this.bedrockDataSourceName(fieldName)); + const dataSource = this.createBedrockDataSource(ctx, directive, stack.region, stackName, httpDataSourceId); + const invokeBedrockFunction = createInvokeBedrockResolverFunction(directiveWithToolConfig); + + this.createPipelineResolver(ctx, parentName, fieldName, resolverResourceId, invokeBedrockFunction, dataSource); + }); + }; + + /** + * Creates a new CDK stack for the Generation transformer. + * @param {TransformerContextProvider} ctx - The transformer context provider. + * @param {string} stackName - The name of the stack to create. + * @returns {cdk.Stack} The created CDK stack. + */ + private createStack(ctx: TransformerContextProvider, stackName: string): cdk.Stack { + const stack = ctx.stackManager.createStack(stackName); + stack.templateOptions.templateFormatVersion = '2010-09-09'; + stack.templateOptions.description = 'An auto-generated nested stack for the @generation directive.'; + return stack; + } + + /** + * Creates a pipeline resolver for the Generation transformer. + * @param {TransformerContextProvider} ctx - The transformer context provider. + * @param {string} parentName - The name of the parent type. + * @param {string} fieldName - The name of the field. + * @param {string} resolverResourceId - The ID for the resolver resource. + * @param {MappingTemplateProvider} invokeBedrockFunction - The invoke Bedrock function. + */ + private createPipelineResolver( + ctx: TransformerContextProvider, + parentName: string, + fieldName: string, + resolverResourceId: string, + invokeBedrockFunction: { req: MappingTemplateProvider; res: MappingTemplateProvider }, + dataSource: cdk.aws_appsync.HttpDataSource, + ): void { + const conversationPipelineResolver = new TransformerResolver( + parentName, + fieldName, + resolverResourceId, + invokeBedrockFunction.req, + invokeBedrockFunction.res, + ['auth'], + [], + dataSource as any, + { name: 'APPSYNC_JS', runtimeVersion: '1.0.0' }, + ); + + ctx.resolvers.addResolver(parentName, fieldName, conversationPipelineResolver); + } + + /** + * Creates a Bedrock data source for the Generation transformer. + * @param {TransformerContextProvider} ctx - The transformer context provider. + * @param {GenerationDirectiveConfiguration} directive - The directive configuration. + * @param {string} region - The AWS region for the Bedrock service. + * @param {string} stackName - The name of the stack. + * @param {string} httpDataSourceId - The ID for the HTTP data source. + * @returns {MappingTemplateProvider} The created Bedrock data source. + */ + private createBedrockDataSource( + ctx: TransformerContextProvider, + directive: GenerationDirectiveConfiguration, + region: string, + stackName: string, + httpDataSourceId: string, + ): cdk.aws_appsync.HttpDataSource { + const { + field: { + name: { value: fieldName }, + }, + aiModel, + } = directive; + + const bedrockUrl = `https://bedrock-runtime.${region}.amazonaws.com`; + + const dataSourceScope = ctx.stackManager.getScopeFor(httpDataSourceId, stackName); + const dataSource = ctx.api.host.addHttpDataSource( + httpDataSourceId, + bedrockUrl, + { + authorizationConfig: { + signingRegion: region, + signingServiceName: 'bedrock', + }, + }, + dataSourceScope, + ); + + // This follows the existing pattern of generating logical IDs and names for IAM roles. + const roleLogicalId = this.bedrockDataSourceName(fieldName) + 'IAMRole'; + const roleName = ctx.resourceHelper.generateIAMRoleName(roleLogicalId); + const role = this.createBedrockDataSourceRole(dataSourceScope, roleLogicalId, roleName, region, aiModel); + dataSource.ds.serviceRoleArn = role.roleArn; + return dataSource; + } + + /** + * Creates an IAM role for the Bedrock service. + * @param {Construct} dataSourceScope - The construct scope for the IAM role. + * @param {string} fieldName - The name of the field. + * @param {string} roleName - The name of the IAM role. + * @param {string} region - The AWS region for the Bedrock service. + * @param {string} bedrockModelId - The ID for the Bedrock model. + * @returns {iam.Role} The created IAM role. + */ + private createBedrockDataSourceRole( + dataSourceScope: Construct, + roleLogicalId: string, + roleName: string, + region: string, + bedrockModelId: string, + ): cdk.aws_iam.Role { + return new iam.Role(dataSourceScope, roleLogicalId, { + roleName, + assumedBy: new iam.ServicePrincipal('appsync.amazonaws.com'), + inlinePolicies: { + BedrockRuntimeAccess: new iam.PolicyDocument({ + statements: [ + new iam.PolicyStatement({ + effect: iam.Effect.ALLOW, + actions: ['bedrock:InvokeModel'], + resources: [`arn:aws:bedrock:${region}::foundation-model/${bedrockModelId}`], + }), + ], + }), + }, + }); + } + + private bedrockDataSourceName(fieldName: string): string { + return `GenerationBedrockDataSource${toUpper(fieldName)}`; + } +} diff --git a/packages/amplify-graphql-generation-transformer/src/index.ts b/packages/amplify-graphql-generation-transformer/src/index.ts new file mode 100644 index 0000000000..953f5bae1d --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/src/index.ts @@ -0,0 +1 @@ +export { GenerationTransformer } from './grapqhl-generation-transformer'; diff --git a/packages/amplify-graphql-generation-transformer/src/resolvers/invoke-bedrock.ts b/packages/amplify-graphql-generation-transformer/src/resolvers/invoke-bedrock.ts new file mode 100644 index 0000000000..2e6c4b28ca --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/src/resolvers/invoke-bedrock.ts @@ -0,0 +1,103 @@ +import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; +import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; +import { dedent } from 'ts-dedent'; +import { GenerationConfigurationWithToolConfig, InferenceConfiguration } from '../grapqhl-generation-transformer'; + +/** + * Creates the resolver functions for invoking Amazon Bedrock. + * + * @param {GenerationConfigurationWithToolConfig} config - The configuration object containing AI model details, tool config, and inference settings. + * @returns {Object} An object containing request and response resolver functions. + */ + +export const createInvokeBedrockResolverFunction = ( + config: GenerationConfigurationWithToolConfig, +): { req: MappingTemplateProvider; res: MappingTemplateProvider } => { + const req = createInvokeBedrockRequestFunction(config); + const res = createInvokeBedrockResponseFunction(); + return { req, res }; +}; + +/** + * Creates the request function for the Bedrock resolver. + * + * @param {GenerationConfigurationWithToolConfig} config - The configuration object for the resolver. + * @returns {MappingTemplateProvider} A MappingTemplateProvider for the request function. + */ +const createInvokeBedrockRequestFunction = (config: GenerationConfigurationWithToolConfig): MappingTemplateProvider => { + const { aiModel, toolConfig, inferenceConfiguration } = config; + const stringifiedToolConfig = JSON.stringify(toolConfig); + const stringifiedSystemPrompt = JSON.stringify(config.systemPrompt); + // TODO: add stopReason: max_tokens error handling + const inferenceConfig = getInferenceConfigResolverDefinition(inferenceConfiguration); + const requestFunctionString = ` + export function request(ctx) { + const toolConfig = ${stringifiedToolConfig}; + const prompt = ${stringifiedSystemPrompt}; + const args = JSON.stringify(ctx.args); + + return { + resourcePath: '/model/${aiModel}/converse', + method: 'POST', + params: { + headers: { 'Content-Type': 'application/json' }, + body: { + messages: [{ + role: 'user', + content: [{ text: args }], + }], + system: [{ text: prompt }], + toolConfig, + ${inferenceConfig} + } + } + } + } +`; + + return MappingTemplate.inlineTemplateFromString(dedent(requestFunctionString)); +}; + +/** + * Creates the response function for the Bedrock resolver. + * + * @returns {MappingTemplateProvider} A MappingTemplateProvider for the response function. + */ +const createInvokeBedrockResponseFunction = (): MappingTemplateProvider => { + // TODO: add stopReason: max_tokens error handling + const responseFunctionString = ` + export function response(ctx) { + if (ctx.error) { + util.error(ctx.error.message, ctx.error.type); + } + const body = JSON.parse(ctx.result.body); + const { content } = body.output.message; + + if (content.length < 1) { + util.error('No content block in assistant response.', 'error'); + } + + const toolUse = content[0].toolUse; + if (!toolUse) { + util.error('Missing tool use block in assistant response.', 'error'); + } + + const response = toolUse.input.value; + return response; + } +`; + + return MappingTemplate.inlineTemplateFromString(dedent(responseFunctionString)); +}; + +/** + * Generates the inference configuration string for the resolver definition. + * + * @param {InferenceConfiguration | undefined} inferenceConfiguration - The inference configuration object. + * @returns {string} A string representation of the inference configuration for use in the resolver definition. + */ +const getInferenceConfigResolverDefinition = (inferenceConfiguration?: InferenceConfiguration): string => { + return inferenceConfiguration && Object.keys(inferenceConfiguration).length > 0 + ? `inferenceConfig: ${JSON.stringify(inferenceConfiguration)},` + : '// default inference config'; +}; diff --git a/packages/amplify-graphql-generation-transformer/src/utils/graphql-json-schema-type.ts b/packages/amplify-graphql-generation-transformer/src/utils/graphql-json-schema-type.ts new file mode 100644 index 0000000000..b270c89d15 --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/src/utils/graphql-json-schema-type.ts @@ -0,0 +1,219 @@ +import { TransformerContextProvider } from '@aws-amplify/graphql-transformer-interfaces'; +import { + EnumTypeDefinitionNode, + Kind, + ListTypeNode, + NamedTypeNode, + NonNullTypeNode, + ObjectTypeDefinitionNode, + TypeNode, + TypeSystemDefinitionNode, +} from 'graphql'; +import { getBaseType, isScalar } from 'graphql-transformer-common'; +import { GraphQLScalarJSONSchemaDefinition, isDisallowedScalarType, supportedScalarTypes } from './graphql-scalar-json-schema-definitions'; + +export type JSONLike = string | number | boolean | null | { [key: string]: JSONLike } | JSONLike[]; + +export type JSONSchema = { + type: string; + properties?: Record; + required?: string[]; + items?: JSONSchema; + enum?: (string | number | boolean | null)[]; + minimum?: number; + maximum?: number; + minLength?: number; + maxLength?: number; + pattern?: string; + format?: string; + description?: string; + default?: JSONLike; + additionalProperties?: boolean | JSONSchema; +}; + +/** + * Generates a JSON Schema from a GraphQL TypeNode. + * @param {TypeNode} typeNode - The GraphQL TypeNode to convert. + * @param {TransformerContextProvider} ctx - The transformer context. + * @param {JSONSchema} [schema={ type: '' }] - The initial schema object. + * @returns {JSONSchema} The generated JSON Schema. + */ +export const generateJSONSchemaFromTypeNode = ( + typeNode: TypeNode, + ctx: TransformerContextProvider, + schema: JSONSchema = { type: '' }, +): JSONSchema => { + switch (typeNode.kind) { + case Kind.NAMED_TYPE: + return handleNamedType(typeNode, ctx, schema); + case Kind.NON_NULL_TYPE: + return handleNonNullType(typeNode, ctx, schema); + case Kind.LIST_TYPE: + return handleListType(typeNode, ctx); + } +}; + +/** + * Handles the conversion of a NamedTypeNode to JSON Schema. + * @param {NamedTypeNode} typeNode - The NamedTypeNode to process. + * @param {TransformerContextProvider} ctx - The transformer context. + * @param {JSONSchema} schema - The current schema object. + * @returns {JSONSchema} The updated JSON Schema. + */ +const handleNamedType = (typeNode: NamedTypeNode, ctx: TransformerContextProvider, schema: JSONSchema): JSONSchema => { + const namedTypeSchema = processNamedType(typeNode); + Object.assign(schema, namedTypeSchema); + + if (isScalar(typeNode)) { + return schema; + } + + const baseTypeName = getBaseType(typeNode); + const typeDef = ctx.output.getType(baseTypeName); + if (!typeDef) { + throw new Error(`Type ${baseTypeName} not found`); + } + + schema.properties = generateJSONSchemaForDef(typeDef, ctx, schema); + return schema; +}; + +/** + * Handles the conversion of a NonNullTypeNode to JSON Schema. + * @param {NonNullTypeNode} typeNode - The NonNullTypeNode to process. + * @param {TransformerContextProvider} ctx - The transformer context. + * @param {JSONSchema} schema - The current schema object. + * @returns {JSONSchema} The updated JSON Schema. + * @throws {Error} If the field type is disallowed for required fields without a default value. + */ +const handleNonNullType = (typeNode: NonNullTypeNode, ctx: TransformerContextProvider, schema: JSONSchema): JSONSchema => { + const baseType = getBaseType(typeNode); + if (isDisallowedScalarType(baseType)) { + throw new Error(` + Disallowed required field type ${baseType} without a default value. + Use one of the supported scalar types for generation routes: [${supportedScalarTypes.join(', ')}] + `); + } + return generateJSONSchemaFromTypeNode(typeNode.type, ctx, schema); +}; + +/** + * Handles the conversion of a ListTypeNode to JSON Schema. + * @param {ListTypeNode} typeNode - The ListTypeNode to process. + * @param {TransformerContextProvider} ctx - The transformer context. + * @returns {JSONSchema} The JSON Schema representing the list type. + */ +const handleListType = (typeNode: ListTypeNode, ctx: TransformerContextProvider): JSONSchema => { + return { + type: 'array', + items: generateJSONSchemaFromTypeNode(typeNode.type, ctx), + }; +}; + +/** + * Generates JSON Schema for a GraphQL type definition. + * @param {TypeSystemDefinitionNode} def - The GraphQL type definition node. + * @param {TransformerContextProvider} ctx - The transformer context. + * @param {JSONSchema} schema - The current schema object. + * @returns {Record} A record of field names to their JSON Schema representations. + * @throws {Error} If an unsupported type definition is encountered. + */ +const generateJSONSchemaForDef = ( + def: TypeSystemDefinitionNode, + ctx: TransformerContextProvider, + schema: JSONSchema, +): Record => { + switch (def.kind) { + case 'ObjectTypeDefinition': + return handleObjectTypeDefinition(def, ctx, schema); + case 'EnumTypeDefinition': + return handleEnumTypeDefinition(def); + default: + throw new Error(`Unsupported type definition: ${def.kind}`); + } +}; + +/** + * Handles the conversion of an ObjectTypeDefinition to JSON Schema. + * @param {ObjectTypeDefinitionNode} def - The ObjectTypeDefinition node to process. + * @param {TransformerContextProvider} ctx - The transformer context. + * @param {JSONSchema} schema - The current schema object. + * @returns {Record} A record of field names to their JSON Schema representations. + * @throws {Error} If the object type has no fields. + */ +const handleObjectTypeDefinition = ( + def: ObjectTypeDefinitionNode, + ctx: TransformerContextProvider, + schema: JSONSchema, +): Record => { + const properties = (def.fields ?? []).reduce((acc: Record, field) => { + acc[field.name.value] = generateJSONSchemaFromTypeNode(field.type, ctx, { type: '' }); + + // Add required fields to the schema + if (field.type.kind === Kind.NON_NULL_TYPE) { + schema.required = [...(schema.required || []), field.name.value]; + } + + return acc; + }, {}); + + return properties; +}; + +/** + * Handles the conversion of an EnumTypeDefinition to JSON Schema. + * @param {EnumTypeDefinitionNode} def - The EnumTypeDefinition node to process. + * @returns {Record} A record containing the enum name and its JSON Schema representation. + */ +const handleEnumTypeDefinition = (def: EnumTypeDefinitionNode): Record => { + return { + [def.name.value]: { + type: 'string', + enum: def.values?.map((value) => value.name.value), + }, + }; +}; + +/** + * Processes a NamedTypeNode and returns the corresponding JSON Schema. + * @param {NamedTypeNode} namedType - The NamedTypeNode to process. + * @returns {JSONSchema} The JSON Schema representation of the named type. + */ +function processNamedType(namedType: NamedTypeNode): JSONSchema { + switch (namedType.name.value) { + case 'Int': + return GraphQLScalarJSONSchemaDefinition.Int; + case 'Float': + return GraphQLScalarJSONSchemaDefinition.Float; + case 'String': + return GraphQLScalarJSONSchemaDefinition.String; + case 'ID': + return GraphQLScalarJSONSchemaDefinition.ID; + case 'Boolean': + return GraphQLScalarJSONSchemaDefinition.Boolean; + case 'AWSJSON': + return GraphQLScalarJSONSchemaDefinition.AWSJSON; + case 'AWSEmail': + return GraphQLScalarJSONSchemaDefinition.AWSEmail; + case 'AWSDate': + return GraphQLScalarJSONSchemaDefinition.AWSDate; + case 'AWSTime': + return GraphQLScalarJSONSchemaDefinition.AWSTime; + case 'AWSDateTime': + return GraphQLScalarJSONSchemaDefinition.AWSDateTime; + case 'AWSTimestamp': + return GraphQLScalarJSONSchemaDefinition.AWSTimestamp; + case 'AWSPhone': + return GraphQLScalarJSONSchemaDefinition.AWSPhone; + case 'AWSURL': + return GraphQLScalarJSONSchemaDefinition.AWSURL; + case 'AWSIPAddress': + return GraphQLScalarJSONSchemaDefinition.AWSIPAddress; + default: + return { + type: 'object', + properties: {}, + required: [], + }; + } +} diff --git a/packages/amplify-graphql-generation-transformer/src/utils/graphql-scalar-json-schema-definitions.ts b/packages/amplify-graphql-generation-transformer/src/utils/graphql-scalar-json-schema-definitions.ts new file mode 100644 index 0000000000..0cafb0b0ca --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/src/utils/graphql-scalar-json-schema-definitions.ts @@ -0,0 +1,154 @@ +import { JSONSchema } from './graphql-json-schema-type'; + +/** + * JSON Schema definitions for GraphQL scalar types and AWS AppSync custom scalar types. + * These definitions are used to create valid JSON schema for tool definitions in the + * context of AI model interactions for generation routes, ensuring that generated responses + * conform to the expected GraphQL types. + * + * Each constant represents a JSON Schema object that describes the structure and + * constraints of a specific GraphQL scalar type. These scalar types are utilized when + * generating tool configurations for AI models, allowing for accurate type validation + * and response generation. + */ + +/** JSON Schema definition for GraphQL Boolean type */ +const Boolean: JSONSchema = { + type: 'boolean', + description: 'A boolean value.', +}; + +/** JSON Schema definition for GraphQL Int type */ +const Int: JSONSchema = { + type: 'number', + description: 'A signed 32-bit integer value.', +}; + +/** JSON Schema definition for GraphQL Float type */ +const Float: JSONSchema = { + type: 'number', + description: 'An IEEE 754 floating point value.', +}; + +/** JSON Schema definition for GraphQL String type */ +const String: JSONSchema = { + type: 'string', + description: 'A UTF-8 character sequence.', +}; + +/** JSON Schema definition for GraphQL ID type */ +const ID: JSONSchema = { + type: 'string', + description: "A unique identifier for an object. This scalar is serialized like a String but isn't meant to be human-readable.", +}; + +/** JSON Schema definition for AWS AppSync AWSJSON type */ +const AWSJSON: JSONSchema = { + type: 'string', + description: + 'A JSON string. Any valid JSON construct is automatically parsed and loaded in the resolver code as maps, lists, or scalar values rather than as the literal input strings. Unquoted strings or otherwise invalid JSON result in a GraphQL validation error.', +}; + +/** JSON Schema definition for AWS AppSync AWSEmail type */ +const AWSEmail: JSONSchema = { + type: 'string', + description: 'An email address in the format local-part@domain-part as defined by RFC 822.', + pattern: '^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$', +}; + +/** JSON Schema definition for AWS AppSync AWSDate type */ +const AWSDate: JSONSchema = { + type: 'string', + description: 'An extended ISO 8601 date string in the format YYYY-MM-DD.', + pattern: '^\\d{4}-d{2}-d{2}$', +}; + +/** JSON Schema definition for AWS AppSync AWSTime type */ +const AWSTime: JSONSchema = { + type: 'string', + description: 'An extended ISO 8601 time string in the format hh:mm:ss.sss.', + pattern: '^\\d{2}:\\d{2}:\\d{2}\\.\\d{3}$', +}; + +/** JSON Schema definition for AWS AppSync AWSDateTime type */ +const AWSDateTime: JSONSchema = { + type: 'string', + description: 'An extended ISO 8601 date and time string in the format YYYY-MM-DDThh:mm:ss.sssZ.', + pattern: '^\\d{4}-\\d{2}-\\d{2}T\\d{2}:\\d{2}:\\d{2}\\.\\d{3}Z$', +}; + +/** JSON Schema definition for AWS AppSync AWSTimestamp type */ +const AWSTimestamp: JSONSchema = { + type: 'string', + description: 'An integer value representing the number of seconds before or after 1970-01-01-T00:00Z.', + pattern: '^\\d+$', +}; + +/** JSON Schema definition for AWS AppSync AWSPhone type */ +const AWSPhone: JSONSchema = { + type: 'string', + description: + 'A phone number. This value is stored as a string. Phone numbers can contain either spaces or hyphens to separate digit groups. Phone numbers without a country code are assumed to be US/North American numbers adhering to the North American Numbering Plan (NANP).', + pattern: '^\\d{3}-d{3}-d{4}$', +}; + +/** JSON Schema definition for AWS AppSync AWSURL type */ +const AWSURL: JSONSchema = { + type: 'string', + description: + "A URL as defined by RFC 1738. For example, https://www.amazon.com/dp/B000NZW3KC/ or mailto:example@example.com. URLs must contain a schema (http, mailto) and can't contain two forward slashes (//) in the path part.", + pattern: '^(https?|mailto)://[^s/$.?#].[^s]*$', +}; + +/** JSON Schema definition for AWS AppSync AWSIPAddress type */ +const AWSIPAddress: JSONSchema = { + type: 'string', + description: + 'A valid IPv4 or IPv6 address. IPv4 addresses are expected in quad-dotted notation (123.12.34.56). IPv6 addresses are expected in non-bracketed, colon-separated format (1a2b:3c4b::1234:4567). You can include an optional CIDR suffix (123.45.67.89/16) to indicate subnet mask.', +}; + +/** + * List of scalar types that are not allowed for required fields. + * + * @remarks + * LLMs are not great at following regex pattern requirements, they'll sometimes return "\" for required fields. + * This leads to AppSync type validation failures. This is particularly problematic for required `@model` fields like `createdAt` and `updatedAt`. + * + * @todo + * Explore ways to lift this constraint. Current thoughts: + * - Improve prompt engineering for better handling of these types + * - Refine regex patterns in JSON Schema tool definitions + * - Implement special case handling for models (e.g., omitting createdAt, updatedAt, and id in tool definition, + * and populating them in the resolver) + */ +const disallowedScalarTypes = ['AWSEmail', 'AWSDate', 'AWSTime', 'AWSDateTime', 'AWSTimestamp', 'AWSPhone', 'AWSURL', 'AWSIPAddress']; + +/** List of supported scalar types */ +export const supportedScalarTypes = ['Boolean', 'Int', 'Float', 'String', 'ID', 'AWSJSON']; + +/** + * Checks if a given type is a disallowed scalar type + * @param type - The type to check + * @returns True if the type is disallowed, false otherwise + */ +export const isDisallowedScalarType = (type: string): boolean => { + return disallowedScalarTypes.includes(type); +}; + +/** Object containing JSON Schema definitions for GraphQL scalar types */ +export const GraphQLScalarJSONSchemaDefinition = { + Boolean, + Int, + Float, + String, + AWSDateTime, + ID, + AWSJSON, + AWSEmail, + AWSDate, + AWSTime, + AWSTimestamp, + AWSPhone, + AWSURL, + AWSIPAddress, +}; diff --git a/packages/amplify-graphql-generation-transformer/src/utils/tools.ts b/packages/amplify-graphql-generation-transformer/src/utils/tools.ts new file mode 100644 index 0000000000..ca0591982c --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/src/utils/tools.ts @@ -0,0 +1,73 @@ +import { TransformerContextProvider } from '@aws-amplify/graphql-transformer-interfaces'; +import { FieldDefinitionNode } from 'graphql'; +import { JSONSchema, generateJSONSchemaFromTypeNode } from './graphql-json-schema-type'; + +export type Tool = { + toolSpec: ToolSpec; +}; + +export type Tools = { + tools: Tool[]; +}; + +export type ToolConfig = { + tools: Tool[]; + toolChoice: { + tool: { + name: string; + }; + }; +}; + +type ToolSpec = { + name: string; + description: string; + inputSchema: { + json: JSONSchema; + }; +}; + +/** + * Creates a tool configuration for generating a response type based on a GraphQL field definition. + * + * This function generates a JSON schema from the field's type and wraps it in a tool specification. + * The tool can be used by AI models to generate responses that conform to the field's type structure. + * + * @param {FieldDefinitionNode} field - The GraphQL field definition node for which to create the response type tool. + * @param {TransformerContextProvider} ctx - The transformer context provider, which supplies necessary context for schema generation. + * @returns {ToolConfig} A tool configuration object containing: + * - tools: An array with a single tool specification for the response type. + * - toolChoice: An object specifying the name of the tool to be used. + * + * The returned tool configuration can be used with AI models that support tool-based interactions, + * ensuring that generated responses match the expected structure of the GraphQL field. + */ +export const createResponseTypeTool = (field: FieldDefinitionNode, ctx: TransformerContextProvider): ToolConfig => { + const { type } = field; + const schema = generateJSONSchemaFromTypeNode(type, ctx); + + // We box the schema to support scalar return types. + // Bedrock only supports object types in tool definitions. + const boxedSchema = { + type: 'object', + properties: { + value: schema, + }, + required: ['value'], + }; + const tools: Tool[] = [ + { + toolSpec: { + name: 'responseType', + description: 'Generate a response type for the given field', + inputSchema: { + json: boxedSchema, + }, + }, + }, + ]; + const toolChoice = { tool: { name: 'responseType' } }; + const toolConfig = { tools, toolChoice }; + + return toolConfig; +}; diff --git a/packages/amplify-graphql-generation-transformer/src/validation.ts b/packages/amplify-graphql-generation-transformer/src/validation.ts new file mode 100644 index 0000000000..bb8fa9a1eb --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/src/validation.ts @@ -0,0 +1,60 @@ +import { InvalidDirectiveError } from '@aws-amplify/graphql-transformer-core'; +import { TransformerContextProvider } from '@aws-amplify/graphql-transformer-interfaces'; +import { GenerationDirectiveConfiguration } from './grapqhl-generation-transformer'; +import { isQueryType } from '@aws-amplify/graphql-transformer-core/src/utils'; + +/** + * Validates the configuration for the `@generation` directive. + * + * This function performs validation checks on the provided configuration + * to ensure it meets the requirements for the `@generation` directive. + * + * @param {GenerationDirectiveConfiguration} config - The configuration object for the `@generation` directive. + * @param {TransformerContextProvider} ctx - The transformer context provider. + * @throws {InvalidDirectiveError} If the configuration is invalid. + */ + +export const validate = (config: GenerationDirectiveConfiguration, ctx: TransformerContextProvider): void => { + validateFieldType(config); + validateInferenceConfig(config); +}; + +/** + * Validates the field type for the `@generation` directive. + * + * This function checks if the parent of the field with the `@generation` directive + * is of type 'Query'. If not, it throws an InvalidDirectiveError. + * + * @param {GenerationDirectiveConfiguration} config - The configuration object for the `@generation` directive. + * @throws {InvalidDirectiveError} If the parent is not of type 'Query'. + */ +const validateFieldType = (config: GenerationDirectiveConfiguration): void => { + const { parent } = config; + if (!isQueryType(parent.name.value)) { + throw new InvalidDirectiveError('@generation directive must be used on Query field.'); + } +}; + +/** + * Validates the inference configuration for the `@generation` directive according to the Bedrock API docs. + * {@link https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html} + * @param config The generation directive configuration to validate. + */ +const validateInferenceConfig = (config: GenerationDirectiveConfiguration): void => { + const { maxTokens, temperature, topP } = config.inferenceConfiguration; + + // dealing with possible 0 values, so we check for undefined. + if (maxTokens !== undefined && maxTokens < 1) { + throw new InvalidDirectiveError(`@generation directive maxTokens valid range: Minimum value of 1. Provided: ${maxTokens}`); + } + + if (temperature !== undefined && (temperature < 0 || temperature > 1)) { + throw new InvalidDirectiveError( + `@generation directive temperature valid range: Minimum value of 0. Maximum value of 1. Provided: ${temperature}`, + ); + } + + if (topP !== undefined && (topP < 0 || topP > 1)) { + throw new InvalidDirectiveError(`@generation directive topP valid range: Minimum value of 0. Maximum value of 1. Provided: ${topP}`); + } +}; diff --git a/packages/amplify-graphql-generation-transformer/tsconfig.json b/packages/amplify-graphql-generation-transformer/tsconfig.json new file mode 100644 index 0000000000..bf970b21a9 --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/tsconfig.json @@ -0,0 +1,14 @@ +{ + "extends": "../../tsconfig.base.json", + "compilerOptions": { + "rootDir": "src", + "outDir": "lib" + }, + "references": [ + { "path": "../amplify-graphql-directives" }, + { "path": "../amplify-graphql-transformer-core" }, + { "path": "../amplify-graphql-transformer-interfaces" }, + { "path": "../graphql-mapping-template" }, + { "path": "../graphql-transformer-common" } + ] +} diff --git a/packages/amplify-graphql-transformer-interfaces/API.md b/packages/amplify-graphql-transformer-interfaces/API.md index 17a9448ce6..fdf19717f0 100644 --- a/packages/amplify-graphql-transformer-interfaces/API.md +++ b/packages/amplify-graphql-transformer-interfaces/API.md @@ -25,6 +25,7 @@ import { FieldNode } from 'graphql'; import { Grant } from 'aws-cdk-lib/aws-iam'; import { GraphqlApiBase } from 'aws-cdk-lib/aws-appsync'; import { HttpDataSource } from 'aws-cdk-lib/aws-appsync'; +import { HttpDataSourceOptions } from 'aws-cdk-lib/aws-appsync'; import { IamResource } from 'aws-cdk-lib/aws-appsync'; import { IAsset } from 'aws-cdk-lib'; import { IConstruct } from 'constructs'; @@ -872,7 +873,7 @@ export interface TransformHostProvider { // (undocumented) addDynamoDbDataSource(name: string, table: ITable, options?: DynamoDbDataSourceOptions, scope?: Construct): DynamoDbDataSource; // (undocumented) - addHttpDataSource(name: string, endpoint: string, options?: DataSourceOptions, scope?: Construct): HttpDataSource; + addHttpDataSource(name: string, endpoint: string, options?: HttpDataSourceOptions, scope?: Construct): HttpDataSource; // (undocumented) addLambdaDataSource(name: string, lambdaFunction: IFunction, options?: DataSourceOptions, scope?: Construct): LambdaDataSource; // (undocumented) diff --git a/packages/amplify-graphql-transformer-interfaces/src/transform-host-provider.ts b/packages/amplify-graphql-transformer-interfaces/src/transform-host-provider.ts index f12a41b45c..7baec664a0 100644 --- a/packages/amplify-graphql-transformer-interfaces/src/transform-host-provider.ts +++ b/packages/amplify-graphql-transformer-interfaces/src/transform-host-provider.ts @@ -8,6 +8,7 @@ import { NoneDataSource, CfnResolver, CfnFunctionConfiguration, + HttpDataSourceOptions, } from 'aws-cdk-lib/aws-appsync'; import { ITable } from 'aws-cdk-lib/aws-dynamodb'; import { IFunction, ILayerVersion, Runtime } from 'aws-cdk-lib/aws-lambda'; @@ -31,7 +32,7 @@ export interface DynamoDbDataSourceOptions extends DataSourceOptions { export interface TransformHostProvider { setAPI(api: GraphqlApiBase): void; - addHttpDataSource(name: string, endpoint: string, options?: DataSourceOptions, scope?: Construct): HttpDataSource; + addHttpDataSource(name: string, endpoint: string, options?: HttpDataSourceOptions, scope?: Construct): HttpDataSource; addDynamoDbDataSource(name: string, table: ITable, options?: DynamoDbDataSourceOptions, scope?: Construct): DynamoDbDataSource; addNoneDataSource(name: string, options?: DataSourceOptions, scope?: Construct): NoneDataSource; addLambdaDataSource(name: string, lambdaFunction: IFunction, options?: DataSourceOptions, scope?: Construct): LambdaDataSource; diff --git a/packages/amplify-graphql-transformer/package.json b/packages/amplify-graphql-transformer/package.json index 4f696eb306..b2fd81c080 100644 --- a/packages/amplify-graphql-transformer/package.json +++ b/packages/amplify-graphql-transformer/package.json @@ -32,6 +32,7 @@ "@aws-amplify/graphql-auth-transformer": "4.0.0", "@aws-amplify/graphql-default-value-transformer": "3.0.0", "@aws-amplify/graphql-function-transformer": "3.0.0", + "@aws-amplify/graphql-generation-transformer": "0.1.0", "@aws-amplify/graphql-http-transformer": "3.0.0", "@aws-amplify/graphql-index-transformer": "3.0.0", "@aws-amplify/graphql-maps-to-transformer": "4.0.0", diff --git a/packages/amplify-graphql-transformer/src/__tests__/graphql-transformer.test.ts b/packages/amplify-graphql-transformer/src/__tests__/graphql-transformer.test.ts index 8c3a056921..e124981d4e 100644 --- a/packages/amplify-graphql-transformer/src/__tests__/graphql-transformer.test.ts +++ b/packages/amplify-graphql-transformer/src/__tests__/graphql-transformer.test.ts @@ -20,7 +20,7 @@ import { defaultPrintTransformerLog, } from '../graphql-transformer'; -const numOfTransformers = 16; +const numOfTransformers = 17; describe('constructTransformerChain', () => { it(`returns ${numOfTransformers} transformers when no custom transformers are provided`, () => { expect(constructTransformerChain().length).toEqual(numOfTransformers); diff --git a/packages/amplify-graphql-transformer/src/graphql-transformer.ts b/packages/amplify-graphql-transformer/src/graphql-transformer.ts index 563b91efdd..424f12b10f 100644 --- a/packages/amplify-graphql-transformer/src/graphql-transformer.ts +++ b/packages/amplify-graphql-transformer/src/graphql-transformer.ts @@ -33,6 +33,7 @@ import type { import { GraphQLTransform, ResolverConfig, UserDefinedSlot } from '@aws-amplify/graphql-transformer-core'; import { Construct } from 'constructs'; import { IFunction } from 'aws-cdk-lib/aws-lambda'; +import { GenerationTransformer } from '@aws-amplify/graphql-generation-transformer'; /** * Arguments passed into a TransformerFactory @@ -74,6 +75,7 @@ export const constructTransformerChain = (options?: TransformerFactoryArgs): Tra hasOneTransformer, new ManyToManyTransformer(modelTransformer, indexTransformer, hasOneTransformer, authTransformer), new BelongsToTransformer(), + new GenerationTransformer(), new DefaultValueTransformer(), authTransformer, new MapsToTransformer(), diff --git a/scripts/split-e2e-tests.ts b/scripts/split-e2e-tests.ts index 3de2f1348b..9ec76f9e9e 100644 --- a/scripts/split-e2e-tests.ts +++ b/scripts/split-e2e-tests.ts @@ -75,6 +75,7 @@ type CandidateJob = { const FORCE_REGION_MAP = { interactions: 'us-west-2', containers: 'us-east-1', + generation: 'us-west-2', }; // some tests require additional time, the parent account can handle longer tests (up to 90 minutes) @@ -82,6 +83,7 @@ const USE_PARENT_ACCOUNT = [ 'src/__tests__/graphql-v2/searchable-datastore', 'src/__tests__/schema-searchable', 'src/__tests__/FunctionTransformerTestsV2.e2e.test.ts', + 'src/__tests__/generations/generation.test.ts', ]; const TEST_TIMINGS_PATH = join(REPO_ROOT, 'scripts', 'test-timings.data.json'); const CODEBUILD_CONFIG_BASE_PATH = join(REPO_ROOT, 'codebuild_specs', 'e2e_workflow_base.yml'); @@ -145,6 +147,8 @@ const RUN_SOLO: (string | RegExp)[] = [ /src\/__tests__\/owner-auth\/.*\.test\.ts/, /src\/__tests__\/relationships\/.*\.test\.ts/, /src\/__tests__\/restricted-field-auth\/.*\.test\.ts/, + // Generation tests + 'src/__tests__/generations/generation.test.ts', ]; const RUN_IN_ALL_REGIONS = [