Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(api-graphql): server side iam auth mode is ineffective #12992

Merged
merged 2 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
317 changes: 194 additions & 123 deletions packages/api-graphql/__tests__/GraphQLAPI.test.ts

Large diffs are not rendered by default.

Large diffs are not rendered by default.

12 changes: 10 additions & 2 deletions packages/api-graphql/__tests__/internals/generateClient.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,18 @@ function makeAppSyncStreams() {
*/
function normalizePostGraphqlCalls(spy: jest.SpyInstance<any, any>) {
return spy.mock.calls.map((call: any) => {
const [postOptions] = call;
// The 1st param in `call` is an instance of `AmplifyClassV6`
// The 2nd param in `call` is the actual `postOptions`
const [_, postOptions] = call;
const userAgent = postOptions?.options?.headers?.['x-amz-user-agent'];
if (userAgent) {
const staticUserAgent = userAgent.replace(/\/[\d.]+/g, '/latest');
postOptions.options.headers['x-amz-user-agent'] = staticUserAgent;
}
return call;
// Calling of `post` API with an instance of `AmplifyClassV6` has been
// unit tested in other test suites. To reduce the noise in the generated
// snapshot, we hide the details of the instance here.
return ['AmplifyClassV6', postOptions];
});
}

Expand Down Expand Up @@ -2956,6 +2961,7 @@ describe('generateClient', () => {

// Request headers should overwrite client headers:
expect(spy).toHaveBeenCalledWith(
expect.any(AmplifyClassV6),
expect.objectContaining({
options: expect.objectContaining({
headers: expect.not.objectContaining({
Expand Down Expand Up @@ -3211,6 +3217,7 @@ describe('generateClient', () => {
expect(normalizePostGraphqlCalls(spy)).toMatchSnapshot();

expect(spy).toHaveBeenCalledWith(
expect.any(AmplifyClassV6),
expect.objectContaining({
options: expect.objectContaining({
body: expect.objectContaining({
Expand Down Expand Up @@ -3890,6 +3897,7 @@ describe('generateClient', () => {
expect(normalizePostGraphqlCalls(spy)).toMatchSnapshot();

expect(spy).toHaveBeenCalledWith(
expect.any(AmplifyClassV6),
expect.objectContaining({
options: expect.objectContaining({
body: expect.objectContaining({
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as raw from '../../../src';
import { Amplify, ResourcesConfig } from '@aws-amplify/core';
import { Amplify, AmplifyClassV6, ResourcesConfig } from '@aws-amplify/core';
import { generateClientWithAmplifyInstance } from '../../../src/internals/server';
import configFixture from '../../fixtures/modeled/amplifyconfiguration';
import { Schema } from '../../fixtures/modeled/schema';
Expand Down Expand Up @@ -97,14 +97,15 @@ describe('server generateClient', () => {
});

expect(spy).toHaveBeenCalledWith(
expect.any(AmplifyClassV6),
expect.objectContaining({
options: expect.objectContaining({
headers: expect.objectContaining({
'X-Api-Key': 'FAKE-KEY',
}),
body: {
query: expect.stringContaining(
'listTodos(filter: $filter, limit: $limit, nextToken: $nextToken)'
'listTodos(filter: $filter, limit: $limit, nextToken: $nextToken)',
),
variables: {
filter: {
Expand All @@ -115,18 +116,19 @@ describe('server generateClient', () => {
},
},
}),
})
}),
);

expect(spy).toHaveBeenCalledWith(
expect.any(AmplifyClassV6),
expect.objectContaining({
options: expect.objectContaining({
body: expect.objectContaining({
// match nextToken in selection set
query: expect.stringMatching(/^\s*nextToken\s*$/m),
}),
}),
})
}),
);

expect(data.length).toBe(1);
Expand All @@ -137,7 +139,7 @@ describe('server generateClient', () => {
owner: 'wirejobviously',
name: 'some name',
description: 'something something',
})
}),
);
});

Expand Down Expand Up @@ -176,14 +178,15 @@ describe('server generateClient', () => {
});

expect(spy).toHaveBeenCalledWith(
expect.any(AmplifyClassV6),
expect.objectContaining({
options: expect.objectContaining({
headers: expect.objectContaining({
'X-Api-Key': 'FAKE-KEY',
}),
body: {
query: expect.stringContaining(
'listTodos(filter: $filter, limit: $limit, nextToken: $nextToken)'
'listTodos(filter: $filter, limit: $limit, nextToken: $nextToken)',
),
variables: {
filter: {
Expand All @@ -195,18 +198,19 @@ describe('server generateClient', () => {
},
},
}),
})
}),
);

expect(spy).toHaveBeenCalledWith(
expect.any(AmplifyClassV6),
expect.objectContaining({
options: expect.objectContaining({
body: expect.objectContaining({
// match nextToken in selection set
query: expect.stringMatching(/^\s*nextToken\s*$/m),
}),
}),
})
}),
);
});

Expand Down Expand Up @@ -245,14 +249,15 @@ describe('server generateClient', () => {
});

expect(spy).toHaveBeenCalledWith(
expect.any(AmplifyClassV6),
expect.objectContaining({
options: expect.objectContaining({
headers: expect.objectContaining({
'X-Api-Key': 'FAKE-KEY',
}),
body: {
query: expect.stringContaining(
'listTodos(filter: $filter, limit: $limit, nextToken: $nextToken)'
'listTodos(filter: $filter, limit: $limit, nextToken: $nextToken)',
),
variables: {
filter: {
Expand All @@ -264,18 +269,19 @@ describe('server generateClient', () => {
},
},
}),
})
}),
);

expect(spy).toHaveBeenCalledWith(
expect.any(AmplifyClassV6),
expect.objectContaining({
options: expect.objectContaining({
body: expect.objectContaining({
// match nextToken in selection set
query: expect.stringMatching(/^\s*nextToken\s*$/m),
}),
}),
})
}),
);
});
});
Expand Down Expand Up @@ -321,7 +327,7 @@ describe('server generateClient', () => {
expect.objectContaining({
query: expect.stringContaining('listNotes'),
}),
{}
{},
);
});
});
Expand Down
26 changes: 16 additions & 10 deletions packages/api-graphql/__tests__/utils/expects.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,23 @@ export function expectGet(
opName: string,
item: Record<string, any>
) {
expect(spy).toHaveBeenCalledWith({
abortController: expect.any(AbortController),
url: new URL('https://localhost/graphql'),
options: expect.objectContaining({
headers: expect.objectContaining({ 'X-Api-Key': 'FAKE-KEY' }),
body: expect.objectContaining({
query: expect.stringContaining(`${opName}(id: $id)`),
variables: expect.objectContaining(item),
expect(spy).toHaveBeenCalledWith(
expect.objectContaining({
Auth: expect.any(Object),
configure: expect.any(Function),
getConfig: expect.any(Function),
}), {
abortController: expect.any(AbortController),
url: new URL('https://localhost/graphql'),
options: expect.objectContaining({
headers: expect.objectContaining({ 'X-Api-Key': 'FAKE-KEY' }),
body: expect.objectContaining({
query: expect.stringContaining(`${opName}(id: $id)`),
variables: expect.objectContaining(item),
}),
}),
}),
});
}
);
}

/**
Expand Down
43 changes: 24 additions & 19 deletions packages/api-graphql/src/internals/InternalGraphQLAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ const logger = new ConsoleLogger('GraphQLAPI');
const isAmplifyInstance = (
amplify:
| AmplifyClassV6
| ((fn: (amplify: any) => Promise<any>) => Promise<AmplifyClassV6>)
| ((fn: (amplify: any) => Promise<any>) => Promise<AmplifyClassV6>),
): amplify is AmplifyClassV6 => {
return typeof amplify !== 'function';
};
Expand Down Expand Up @@ -68,7 +68,7 @@ export class InternalGraphQLAPIClass {
private async _headerBasedAuth(
amplify: AmplifyClassV6,
authMode: GraphQLAuthMode,
additionalHeaders: Record<string, string> = {}
additionalHeaders: Record<string, string> = {},
) {
const {
region: region,
Expand Down Expand Up @@ -159,15 +159,15 @@ export class InternalGraphQLAPIClass {
| ((fn: (amplify: any) => Promise<any>) => Promise<AmplifyClassV6>),
{ query: paramQuery, variables = {}, authMode, authToken }: GraphQLOptions,
additionalHeaders?: CustomHeaders,
customUserAgentDetails?: CustomUserAgentDetails
customUserAgentDetails?: CustomUserAgentDetails,
): Observable<GraphQLResult<T>> | Promise<GraphQLResult<T>> {
const query =
typeof paramQuery === 'string'
? parse(paramQuery)
: parse(print(paramQuery));

const [operationDef = {}] = query.definitions.filter(
def => def.kind === 'OperationDefinition'
def => def.kind === 'OperationDefinition',
);
const { operation: operationType } =
operationDef as OperationDefinitionNode;
Expand All @@ -188,27 +188,32 @@ export class InternalGraphQLAPIClass {
headers,
abortController,
customUserAgentDetails,
authToken
authToken,
);
} else {
const wrapper = (amplifyInstance: AmplifyClassV6) =>
this._graphql<T>(
// NOTE: this wrapper function must be await-able so the Amplify server context manager can
// destroy the context only after it completes
const wrapper = async (amplifyInstance: AmplifyClassV6) => {
const result = await this._graphql<T>(
amplifyInstance,
{ query, variables, authMode },
headers,
abortController,
customUserAgentDetails,
authToken
authToken,
);

return result;
};

responsePromise = amplify(wrapper) as unknown as Promise<
GraphQLResult<T>
>;
}

this._api.updateRequestToBeCancellable(
responsePromise,
abortController
abortController,
);
return responsePromise;
case 'subscription':
Expand All @@ -217,7 +222,7 @@ export class InternalGraphQLAPIClass {
{ query, variables, authMode },
headers,
customUserAgentDetails,
authToken
authToken,
);
default:
throw new Error(`invalid operation type: ${operationType}`);
Expand All @@ -230,7 +235,7 @@ export class InternalGraphQLAPIClass {
additionalHeaders: CustomHeaders = {},
abortController: AbortController,
customUserAgentDetails?: CustomUserAgentDetails,
authToken?: string
authToken?: string,
): Promise<GraphQLResult<T>> {
const {
region: region,
Expand Down Expand Up @@ -263,7 +268,7 @@ export class InternalGraphQLAPIClass {
const requestOptions: RequestOptions = {
method: 'POST',
url: new AmplifyUrl(
customEndpoint || appSyncGraphqlEndpoint || ''
customEndpoint || appSyncGraphqlEndpoint || '',
).toString(),
queryString: print(query as DocumentNode),
};
Expand All @@ -287,7 +292,7 @@ export class InternalGraphQLAPIClass {
(await this._headerBasedAuth(
amplify,
authMode!,
additionalCustomHeaders
additionalCustomHeaders,
))),
/**
* Custom endpoint headers.
Expand All @@ -300,7 +305,7 @@ export class InternalGraphQLAPIClass {
? await this._headerBasedAuth(
amplify,
authMode!,
additionalCustomHeaders
additionalCustomHeaders,
)
: {})) ||
{}),
Expand Down Expand Up @@ -361,7 +366,7 @@ export class InternalGraphQLAPIClass {
let response: any;

try {
const { body: responseBody } = await this._api.post({
const { body: responseBody } = await this._api.post(amplify, {
url: new AmplifyUrl(endpoint),
options: {
headers,
Expand Down Expand Up @@ -392,7 +397,7 @@ export class InternalGraphQLAPIClass {
null,
null,
null,
err as any
err as any,
),
],
};
Expand Down Expand Up @@ -430,7 +435,7 @@ export class InternalGraphQLAPIClass {
{ query, variables, authMode }: GraphQLOptions,
additionalHeaders: CustomHeaders = {},
customUserAgentDetails?: CustomUserAgentDetails,
authToken?: string
authToken?: string,
): Observable<any> {
const config = resolveConfig(amplify);

Expand All @@ -457,15 +462,15 @@ export class InternalGraphQLAPIClass {
authToken,
libraryConfigHeaders,
},
customUserAgentDetails
customUserAgentDetails,
)
.pipe(
catchError(e => {
if (e.errors) {
throw repackageUnauthError(e);
}
throw e;
})
}),
);
}
}
Expand Down
Loading
Loading