Skip to content

Commit

Permalink
some cleanup and inline docs
Browse files Browse the repository at this point in the history
  • Loading branch information
atierian committed Aug 28, 2024
1 parent 9344b57 commit 69171cb
Show file tree
Hide file tree
Showing 3 changed files with 314 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ import {
generateGetArgumentsInput,
TransformerResolver,
} from '@aws-amplify/graphql-transformer-core';
import { TransformerContextProvider, TransformerSchemaVisitStepContextProvider } from '@aws-amplify/graphql-transformer-interfaces';
import {
MappingTemplateProvider,
TransformerContextProvider,
TransformerSchemaVisitStepContextProvider,
} from '@aws-amplify/graphql-transformer-interfaces';
import { DirectiveNode, FieldDefinitionNode, InterfaceTypeDefinitionNode, ObjectTypeDefinitionNode } from 'graphql';
import { HttpResourceIDs, ResolverResourceIDs } from 'graphql-transformer-common';
import { ToolConfig, createResponseTypeTool } from './utils/tools';
Expand Down Expand Up @@ -65,47 +69,151 @@ export class GenerationTransformer extends TransformerPluginBase {
};

generateResolvers = (ctx: TransformerContextProvider): void => {
// If there are no directives, bail out to prevent creating an empty stack
if (this.directives.length === 0) {
return;
}
if (this.directives.length === 0) return;

for (const directive of this.directives) {
this.directives.forEach((directive) => {
const { parent, field } = directive;
const fieldName = field.name.value;
const parentName = parent.name.value;
// We're doing this here (as opposed to in the field method) to access generated queries
// and input definitions for @model queries.
directive.toolConfig = createResponseTypeTool(field, ctx);

const capitalizedFieldName = fieldName.charAt(0).toUpperCase() + fieldName.slice(1);
const stackName = `Generation${capitalizedFieldName}BedrockDataSourceStack`;

const stack: cdk.Stack = ctx.stackManager.createStack(stackName);
stack.templateOptions.templateFormatVersion = '2010-09-09';
stack.templateOptions.description = 'An auto-generated nested stack for the @generation directive.';
directive.toolConfig = createResponseTypeTool(field, ctx);

const stackName = `Generation${this.capitalizeFirstLetter(fieldName)}BedrockDataSourceStack`;
const stack = this.createStack(ctx, stackName);

const resolverResourceId = ResolverResourceIDs.ResolverResourceID(parentName, fieldName);
const httpDataSourceId = HttpResourceIDs.HttpDataSourceID(`GenerationBedrockDataSource-${fieldName}`);
const dataSource = createBedrockDataSource(ctx, directive, stack.region, stackName, httpDataSourceId);
const dataSource = this.createBedrockDataSource(ctx, directive, stack.region, stackName, httpDataSourceId);
const invokeBedrockFunction = invokeBedrockResolver(directive);
// pipeline resolver
const conversationPipelineResolver = new TransformerResolver(
parentName,
fieldName,
resolverResourceId,
invokeBedrockFunction.req,
invokeBedrockFunction.res,
['auth'],
[],
dataSource as any,
{ name: 'APPSYNC_JS', runtimeVersion: '1.0.0' },
);

ctx.resolvers.addResolver(parentName, fieldName, conversationPipelineResolver);
}

this.createPipelineResolver(ctx, parentName, fieldName, resolverResourceId, invokeBedrockFunction, dataSource);
});
};

private capitalizeFirstLetter(str: string): string {
return str.charAt(0).toUpperCase() + str.slice(1);
}

/**
* Creates a new CDK stack for the Generation transformer.
* @param {TransformerContextProvider} ctx - The transformer context provider.
* @param {string} stackName - The name of the stack to create.
* @returns {cdk.Stack} The created CDK stack.
*/
private createStack(ctx: TransformerContextProvider, stackName: string): cdk.Stack {
const stack = ctx.stackManager.createStack(stackName);
stack.templateOptions.templateFormatVersion = '2010-09-09';
stack.templateOptions.description = 'An auto-generated nested stack for the @generation directive.';
return stack;
}

/**
* Creates a pipeline resolver for the Generation transformer.
* @param {TransformerContextProvider} ctx - The transformer context provider.
* @param {string} parentName - The name of the parent resolver.
* @param {string} fieldName - The name of the field.
* @param {string} resolverResourceId - The ID for the resolver resource.
* @param {MappingTemplateProvider} invokeBedrockFunction - The invoke Bedrock function.
*/
private createPipelineResolver(
ctx: TransformerContextProvider,
parentName: string,
fieldName: string,
resolverResourceId: string,
invokeBedrockFunction: { req: MappingTemplateProvider; res: MappingTemplateProvider },
dataSource: cdk.aws_appsync.HttpDataSource,
): void {
const conversationPipelineResolver = new TransformerResolver(
parentName,
fieldName,
resolverResourceId,
invokeBedrockFunction.req,
invokeBedrockFunction.res,
['auth'],
[],
dataSource as any,
{ name: 'APPSYNC_JS', runtimeVersion: '1.0.0' },
);

ctx.resolvers.addResolver(parentName, fieldName, conversationPipelineResolver);
}

/**
* Creates a Bedrock data source for the Generation transformer.
* @param {TransformerContextProvider} ctx - The transformer context provider.
* @param {GenerationDirectiveConfiguration} directive - The directive configuration.
* @param {string} region - The AWS region for the Bedrock service.
* @param {string} stackName - The name of the stack.
* @param {string} httpDataSourceId - The ID for the HTTP data source.
* @returns {MappingTemplateProvider} The created Bedrock data source.
*/
private createBedrockDataSource(
ctx: TransformerContextProvider,
directive: GenerationDirectiveConfiguration,
region: string,
stackName: string,
httpDataSourceId: string,
): cdk.aws_appsync.HttpDataSource {
const {
field: {
name: { value: fieldName },
},
aiModel,
} = directive;

const bedrockUrl = `https://bedrock-runtime.${region}.amazonaws.com`;

const dataSourceScope = ctx.stackManager.getScopeFor(httpDataSourceId, stackName);
const dataSource = ctx.api.host.addHttpDataSource(
httpDataSourceId,
bedrockUrl,
{
authorizationConfig: {
signingRegion: region,
signingServiceName: 'bedrock',
},
},
dataSourceScope,
);

const roleName = ctx.resourceHelper.generateIAMRoleName(`GenerationBedrockDataSourceRole${fieldName}`);
const role = this.createBedrockDataSourceRole(dataSourceScope, fieldName, roleName, region, aiModel);
dataSource.ds.serviceRoleArn = role.roleArn;
return dataSource;
}

/**
* Creates an IAM role for the Bedrock service.
* @param {Construct} dataSourceScope - The construct scope for the IAM role.
* @param {string} fieldName - The name of the field.
* @param {string} roleName - The name of the IAM role.
* @param {string} region - The AWS region for the Bedrock service.
* @param {string} bedrockModelId - The ID for the Bedrock model.
* @returns {iam.Role} The created IAM role.
*/
private createBedrockDataSourceRole(
dataSourceScope: Construct,
fieldName: string,
roleName: string,
region: string,
bedrockModelId: string,
): cdk.aws_iam.Role {
return new iam.Role(dataSourceScope, `GenerationBedrockDataSourceRole${fieldName}`, {
roleName,
assumedBy: new iam.ServicePrincipal('appsync.amazonaws.com'),
inlinePolicies: {
BedrockRuntimeAccess: new iam.PolicyDocument({
statements: [
new iam.PolicyStatement({
effect: iam.Effect.ALLOW,
actions: ['bedrock:InvokeModel'],
resources: [`arn:aws:bedrock:${region}::foundation-model/${bedrockModelId}`],
}),
],
}),
},
});
}
}

const validate = (config: GenerationDirectiveConfiguration, ctx: TransformerContextProvider): void => {
Expand Down Expand Up @@ -139,62 +247,3 @@ const validateInferenceConfig = (config: GenerationDirectiveConfiguration): void
throw new InvalidDirectiveError(`@generation directive topP valid range: Minimum value of 0. Maximum value of 1. Provided: ${topP}`);
}
};

const createBedrockDataSource = (
ctx: TransformerContextProvider,
directive: GenerationDirectiveConfiguration,
region: string,
stackName: string,
httpDataSourceId: string,
): cdk.aws_appsync.HttpDataSource => {
const {
field: {
name: { value: fieldName },
},
aiModel,
} = directive;

const bedrockUrl = `https://bedrock-runtime.${region}.amazonaws.com`;

const dataSourceScope = ctx.stackManager.getScopeFor(httpDataSourceId, stackName);
const dataSource = ctx.api.host.addHttpDataSource(
httpDataSourceId,
bedrockUrl,
{
authorizationConfig: {
signingRegion: region,
signingServiceName: 'bedrock',
},
},
dataSourceScope,
);

const roleName = ctx.resourceHelper.generateIAMRoleName(`GenerationBedrockDataSourceRole${fieldName}`);
const role = createBedrockDataSourceRole(dataSourceScope, fieldName, roleName, region, aiModel);
dataSource.ds.serviceRoleArn = role.roleArn;
return dataSource;
};

const createBedrockDataSourceRole = (
dataSourceScope: Construct,
fieldName: string,
roleName: string,
region: string,
bedrockModelId: string,
): cdk.aws_iam.Role => {
return new iam.Role(dataSourceScope, `GenerationBedrockDataSourceRole${fieldName}`, {
roleName,
assumedBy: new iam.ServicePrincipal('appsync.amazonaws.com'),
inlinePolicies: {
BedrockRuntimeAccess: new iam.PolicyDocument({
statements: [
new iam.PolicyStatement({
effect: iam.Effect.ALLOW,
actions: ['bedrock:InvokeModel'],
resources: [`arn:aws:bedrock:${region}::foundation-model/${bedrockModelId}`],
}),
],
}),
},
});
};
Loading

0 comments on commit 69171cb

Please sign in to comment.