diff --git a/package.json b/package.json index 790927a9ca8..8ff6be823ea 100644 --- a/package.json +++ b/package.json @@ -102,6 +102,7 @@ "eslint-plugin-prettier": "^5.1.3", "eslint-plugin-promise": "^6.1.1", "eslint-plugin-unused-imports": "^3.0.0", + "expect": "^29.7.0", "glob": "^10.3.10", "jest": "^29.7.0", "jest-environment-jsdom": "^29.7.0", diff --git a/packages/storage/__tests__/providers/s3/apis/copy.test.ts b/packages/storage/__tests__/providers/s3/apis/copy.test.ts index 52eaf7c902f..4ddbe0ded81 100644 --- a/packages/storage/__tests__/providers/s3/apis/copy.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/copy.test.ts @@ -13,6 +13,7 @@ import { CopyOutput, CopyWithPathOutput, } from '../../../../src/providers/s3/types'; +import './testUtils'; jest.mock('../../../../src/providers/s3/utils/client'); jest.mock('@aws-amplify/core', () => ({ @@ -185,7 +186,7 @@ describe('copy API', () => { }); expect(key).toEqual(destinationKey); expect(copyObject).toHaveBeenCalledTimes(1); - expect(copyObject).toHaveBeenCalledWith(copyObjectClientConfig, { + await expect(copyObject).toBeLastCalledWithConfigAndInput(copyObjectClientConfig, { ...copyObjectClientBaseParams, CopySource: expectedSourceKey, Key: expectedDestinationKey, @@ -238,7 +239,7 @@ describe('copy API', () => { }); expect(path).toEqual(expectedDestinationPath); expect(copyObject).toHaveBeenCalledTimes(1); - expect(copyObject).toHaveBeenCalledWith(copyObjectClientConfig, { + await expect(copyObject).toBeLastCalledWithConfigAndInput(copyObjectClientConfig, { ...copyObjectClientBaseParams, CopySource: `${bucket}/${expectedSourcePath}`, Key: expectedDestinationPath, @@ -269,7 +270,7 @@ describe('copy API', () => { }); } catch (error: any) { expect(copyObject).toHaveBeenCalledTimes(1); - expect(copyObject).toHaveBeenCalledWith(copyObjectClientConfig, { + await expect(copyObject).toBeLastCalledWithConfigAndInput(copyObjectClientConfig, { ...copyObjectClientBaseParams, CopySource: `${bucket}/public/${sourceKey}`, Key: `public/${destinationKey}`, diff --git a/packages/storage/__tests__/providers/s3/apis/downloadData.test.ts b/packages/storage/__tests__/providers/s3/apis/downloadData.test.ts index 0c9c4a3d007..104de9173d1 100644 --- a/packages/storage/__tests__/providers/s3/apis/downloadData.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/downloadData.test.ts @@ -22,6 +22,7 @@ import { ItemWithKey, ItemWithPath, } from '../../../../src/providers/s3/types/outputs'; +import './testUtils'; jest.mock('../../../../src/providers/s3/utils/client'); jest.mock('../../../../src/providers/s3/utils'); @@ -141,7 +142,7 @@ describe('downloadData with key', () => { body: 'body', }); expect(getObject).toHaveBeenCalledTimes(1); - expect(getObject).toHaveBeenCalledWith( + await expect(getObject).toBeLastCalledWithConfigAndInput( { credentials, region, @@ -287,7 +288,7 @@ describe('downloadData with path', () => { body: 'body', }); expect(getObject).toHaveBeenCalledTimes(1); - expect(getObject).toHaveBeenCalledWith( + await expect(getObject).toBeLastCalledWithConfigAndInput( { credentials, region, diff --git a/packages/storage/__tests__/providers/s3/apis/getProperties.test.ts b/packages/storage/__tests__/providers/s3/apis/getProperties.test.ts index 191802f04f9..8ec8743fd96 100644 --- a/packages/storage/__tests__/providers/s3/apis/getProperties.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/getProperties.test.ts @@ -11,6 +11,7 @@ import { GetPropertiesOutput, GetPropertiesWithPathOutput, } from '../../../../src/providers/s3/types'; +import './testUtils'; jest.mock('../../../../src/providers/s3/utils/client'); jest.mock('@aws-amplify/core', () => ({ @@ -144,7 +145,7 @@ describe('getProperties with key', () => { ...expectedResult, }); expect(headObject).toHaveBeenCalledTimes(1); - expect(headObject).toHaveBeenCalledWith(config, headObjectOptions); + await expect(headObject).toBeLastCalledWithConfigAndInput(config, headObjectOptions); }, ); }); @@ -165,7 +166,7 @@ describe('getProperties with key', () => { await getPropertiesWrapper({ key: inputKey }); } catch (error: any) { expect(headObject).toHaveBeenCalledTimes(1); - expect(headObject).toHaveBeenCalledWith( + await expect(headObject).toBeLastCalledWithConfigAndInput( { credentials, region: 'region', @@ -264,7 +265,7 @@ describe('Happy cases: With path', () => { ...expectedResult, }); expect(headObject).toHaveBeenCalledTimes(1); - expect(headObject).toHaveBeenCalledWith(config, headObjectOptions); + await expect(headObject).toBeLastCalledWithConfigAndInput(config, headObjectOptions); }, ); }); @@ -285,7 +286,7 @@ describe('Happy cases: With path', () => { await getPropertiesWrapper({ path: inputPath }); } catch (error: any) { expect(headObject).toHaveBeenCalledTimes(1); - expect(headObject).toHaveBeenCalledWith( + await expect(headObject).toBeLastCalledWithConfigAndInput( { credentials, region: 'region', diff --git a/packages/storage/__tests__/providers/s3/apis/getUrl.test.ts b/packages/storage/__tests__/providers/s3/apis/getUrl.test.ts index 8f56299d943..df975ded020 100644 --- a/packages/storage/__tests__/providers/s3/apis/getUrl.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/getUrl.test.ts @@ -14,6 +14,7 @@ import { GetUrlOutput, GetUrlWithPathOutput, } from '../../../../src/providers/s3/types'; +import './testUtils'; jest.mock('../../../../src/providers/s3/utils/client'); jest.mock('@aws-amplify/core', () => ({ @@ -30,8 +31,8 @@ jest.mock('@aws-amplify/core', () => ({ const bucket = 'bucket'; const region = 'region'; -const mockFetchAuthSession = Amplify.Auth.fetchAuthSession as jest.Mock; -const mockGetConfig = Amplify.getConfig as jest.Mock; +const mockFetchAuthSession = jest.mocked(Amplify.Auth.fetchAuthSession); +const mockGetConfig = jest.mocked(Amplify.getConfig); const credentials: AWSCredentials = { accessKeyId: 'accessKeyId', sessionToken: 'sessionToken', @@ -67,7 +68,7 @@ describe('getUrl test with key', () => { }; const key = 'key'; beforeEach(() => { - (headObject as jest.MockedFunction).mockResolvedValue({ + jest.mocked(headObject).mockResolvedValue({ ContentLength: 100, ContentType: 'text/plain', ETag: 'etag', @@ -75,11 +76,7 @@ describe('getUrl test with key', () => { Metadata: { meta: 'value' }, $metadata: {} as any, }); - ( - getPresignedGetObjectUrl as jest.MockedFunction< - typeof getPresignedGetObjectUrl - > - ).mockResolvedValue(mockURL); + jest.mocked(getPresignedGetObjectUrl).mockResolvedValue(mockURL); }); afterEach(() => { jest.clearAllMocks(); @@ -130,7 +127,7 @@ describe('getUrl test with key', () => { }; expect(getPresignedGetObjectUrl).toHaveBeenCalledTimes(1); expect(headObject).toHaveBeenCalledTimes(1); - expect(headObject).toHaveBeenCalledWith(config, headObjectOptions); + expect(headObject).toBeLastCalledWithConfigAndInput(config, headObjectOptions); expect({ url, expiresAt }).toEqual(expectedResult); }, ); @@ -186,7 +183,7 @@ describe('getUrl test with path', () => { userAgentValue: expect.any(String), }; beforeEach(() => { - (headObject as jest.MockedFunction).mockResolvedValue({ + jest.mocked(headObject).mockResolvedValue({ ContentLength: 100, ContentType: 'text/plain', ETag: 'etag', @@ -194,11 +191,7 @@ describe('getUrl test with path', () => { Metadata: { meta: 'value' }, $metadata: {} as any, }); - ( - getPresignedGetObjectUrl as jest.MockedFunction< - typeof getPresignedGetObjectUrl - > - ).mockResolvedValue(mockURL); + jest.mocked(getPresignedGetObjectUrl).mockResolvedValue(mockURL); }); afterEach(() => { jest.clearAllMocks(); @@ -228,7 +221,7 @@ describe('getUrl test with path', () => { }); expect(getPresignedGetObjectUrl).toHaveBeenCalledTimes(1); expect(headObject).toHaveBeenCalledTimes(1); - expect(headObject).toHaveBeenCalledWith(config, headObjectOptions); + expect(headObject).toBeLastCalledWithConfigAndInput(config, headObjectOptions); expect({ url, expiresAt }).toEqual({ url: mockURL, expiresAt: expect.any(Date), diff --git a/packages/storage/__tests__/providers/s3/apis/list.test.ts b/packages/storage/__tests__/providers/s3/apis/list.test.ts index 21ad76cdc33..a16f6900d0f 100644 --- a/packages/storage/__tests__/providers/s3/apis/list.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/list.test.ts @@ -15,6 +15,7 @@ import { ListPaginateOutput, ListPaginateWithPathOutput, } from '../../../../src/providers/s3/types'; +import './testUtils'; jest.mock('../../../../src/providers/s3/utils/client'); jest.mock('@aws-amplify/core', () => ({ @@ -170,7 +171,9 @@ describe('list API', () => { }); expect(response.nextToken).toEqual(nextToken); expect(listObjectsV2).toHaveBeenCalledTimes(1); - expect(listObjectsV2).toHaveBeenCalledWith(listObjectClientConfig, { + await expect( + listObjectsV2, + ).toBeLastCalledWithConfigAndInput(listObjectClientConfig, { Bucket: bucket, MaxKeys: 1000, Prefix: expectedKey, @@ -208,7 +211,9 @@ describe('list API', () => { }); expect(response.nextToken).toEqual(nextToken); expect(listObjectsV2).toHaveBeenCalledTimes(1); - expect(listObjectsV2).toHaveBeenCalledWith(listObjectClientConfig, { + await expect( + listObjectsV2, + ).toBeLastCalledWithConfigAndInput(listObjectClientConfig, { Bucket: bucket, Prefix: expectedKey, ContinuationToken: nextToken, @@ -234,7 +239,9 @@ describe('list API', () => { expect(response.items).toEqual([]); expect(response.nextToken).toEqual(undefined); - expect(listObjectsV2).toHaveBeenCalledWith(listObjectClientConfig, { + await expect( + listObjectsV2, + ).toBeLastCalledWithConfigAndInput(listObjectClientConfig, { Bucket: bucket, MaxKeys: 1000, Prefix: expectedKey, @@ -266,7 +273,7 @@ describe('list API', () => { expect(listObjectsV2).toHaveBeenCalledTimes(3); // first input recieves undefined as the Continuation Token - expect(listObjectsV2).toHaveBeenNthCalledWith( + await expect(listObjectsV2).toHaveBeenNthCalledWithConfigAndInput( 1, listObjectClientConfig, { @@ -277,7 +284,7 @@ describe('list API', () => { }, ); // last input recieves TEST_TOKEN as the Continuation Token - expect(listObjectsV2).toHaveBeenNthCalledWith( + await expect(listObjectsV2).toHaveBeenNthCalledWithConfigAndInput( 3, listObjectClientConfig, { @@ -340,11 +347,14 @@ describe('list API', () => { }); expect(response.nextToken).toEqual(nextToken); expect(listObjectsV2).toHaveBeenCalledTimes(1); - expect(listObjectsV2).toHaveBeenCalledWith(listObjectClientConfig, { - Bucket: bucket, - MaxKeys: 1000, - Prefix: resolvePath(inputPath), - }); + expect(listObjectsV2).toBeLastCalledWithConfigAndInput( + listObjectClientConfig, + { + Bucket: bucket, + MaxKeys: 1000, + Prefix: resolvePath(inputPath), + }, + ); }, ); @@ -379,12 +389,15 @@ describe('list API', () => { }); expect(response.nextToken).toEqual(nextToken); expect(listObjectsV2).toHaveBeenCalledTimes(1); - expect(listObjectsV2).toHaveBeenCalledWith(listObjectClientConfig, { - Bucket: bucket, - Prefix: resolvePath(inputPath), - ContinuationToken: nextToken, - MaxKeys: customPageSize, - }); + expect(listObjectsV2).toBeLastCalledWithConfigAndInput( + listObjectClientConfig, + { + Bucket: bucket, + Prefix: resolvePath(inputPath), + ContinuationToken: nextToken, + MaxKeys: customPageSize, + }, + ); }, ); @@ -400,11 +413,14 @@ describe('list API', () => { expect(response.items).toEqual([]); expect(response.nextToken).toEqual(undefined); - expect(listObjectsV2).toHaveBeenCalledWith(listObjectClientConfig, { - Bucket: bucket, - MaxKeys: 1000, - Prefix: resolvePath(path), - }); + expect(listObjectsV2).toBeLastCalledWithConfigAndInput( + listObjectClientConfig, + { + Bucket: bucket, + MaxKeys: 1000, + Prefix: resolvePath(path), + }, + ); }, ); @@ -432,7 +448,7 @@ describe('list API', () => { expect(listObjectsV2).toHaveBeenCalledTimes(3); // first input recieves undefined as the Continuation Token - expect(listObjectsV2).toHaveBeenNthCalledWith( + await expect(listObjectsV2).toHaveBeenNthCalledWithConfigAndInput( 1, listObjectClientConfig, { @@ -443,7 +459,7 @@ describe('list API', () => { }, ); // last input recieves TEST_TOKEN as the Continuation Token - expect(listObjectsV2).toHaveBeenNthCalledWith( + await expect(listObjectsV2).toHaveBeenNthCalledWithConfigAndInput( 3, listObjectClientConfig, { @@ -473,11 +489,14 @@ describe('list API', () => { } catch (error: any) { expect.assertions(3); expect(listObjectsV2).toHaveBeenCalledTimes(1); - expect(listObjectsV2).toHaveBeenCalledWith(listObjectClientConfig, { - Bucket: bucket, - MaxKeys: 1000, - Prefix: 'public/', - }); + expect(listObjectsV2).toBeLastCalledWithConfigAndInput( + listObjectClientConfig, + { + Bucket: bucket, + MaxKeys: 1000, + Prefix: 'public/', + }, + ); expect(error.$metadata.httpStatusCode).toBe(404); } }); diff --git a/packages/storage/__tests__/providers/s3/apis/remove.test.ts b/packages/storage/__tests__/providers/s3/apis/remove.test.ts index 0c8662492ac..1779c5af259 100644 --- a/packages/storage/__tests__/providers/s3/apis/remove.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/remove.test.ts @@ -12,6 +12,7 @@ import { RemoveOutput, RemoveWithPathOutput, } from '../../../../src/providers/s3/types'; +import './testUtils'; jest.mock('../../../../src/providers/s3/utils/client'); jest.mock('@aws-amplify/core', () => ({ @@ -104,7 +105,7 @@ describe('remove API', () => { }); expect(key).toEqual(inputKey); expect(deleteObject).toHaveBeenCalledTimes(1); - expect(deleteObject).toHaveBeenCalledWith(deleteObjectClientConfig, { + expect(deleteObject).toBeLastCalledWithConfigAndInput(deleteObjectClientConfig, { Bucket: bucket, Key: expectedKey, }); @@ -143,7 +144,7 @@ describe('remove API', () => { const { path } = await removeWrapper({ path: inputPath }); expect(path).toEqual(resolvedPath); expect(deleteObject).toHaveBeenCalledTimes(1); - expect(deleteObject).toHaveBeenCalledWith(deleteObjectClientConfig, { + await expect(deleteObject).toBeLastCalledWithConfigAndInput(deleteObjectClientConfig, { Bucket: bucket, Key: resolvedPath, }); @@ -169,7 +170,7 @@ describe('remove API', () => { await remove({ key }); } catch (error: any) { expect(deleteObject).toHaveBeenCalledTimes(1); - expect(deleteObject).toHaveBeenCalledWith(deleteObjectClientConfig, { + await expect(deleteObject).toBeLastCalledWithConfigAndInput(deleteObjectClientConfig, { Bucket: bucket, Key: `public/${key}`, }); diff --git a/packages/storage/__tests__/providers/s3/apis/testUtils.ts b/packages/storage/__tests__/providers/s3/apis/testUtils.ts new file mode 100644 index 00000000000..c1094836b91 --- /dev/null +++ b/packages/storage/__tests__/providers/s3/apis/testUtils.ts @@ -0,0 +1,115 @@ +import { expect } from '@jest/globals'; +import { type MatcherFunction } from 'expect'; + +const toBeLastCalledWithConfigAndInput: MatcherFunction< + [config: { credentials: unknown }, input: unknown] +> = async function (actualHandler, expectedConfig, expectedInput) { + if (!jest.isMockFunction(actualHandler)) { + return { + message: () => + `expected custom client handler to be a mock function, got ${actualHandler}`, + pass: false, + }; + } + const actualConfig = actualHandler.mock.lastCall?.[0]; + const actualInput = actualHandler.mock.lastCall?.[1]; + const actualConfigWithResolvedCredentials = + typeof actualConfig?.credentials === 'function' + ? { + ...actualConfig, + credentials: await actualConfig.credentials(), + } + : actualConfig; + if ( + this.equals(actualConfigWithResolvedCredentials, expectedConfig) && + this.equals(actualInput, expectedInput) + ) { + return { + message: () => '', + pass: true, + }; + } + return { + message: () => + `expected ${actualConfig} to equal ${expectedConfig} and ${actualInput} to equal ${expectedInput}`, + pass: false, + }; +}; + +const toHaveBeenNthCalledWithConfigAndInput: MatcherFunction< + [nthCall: number, config: unknown, input: unknown] +> = async function (actualHandler, nthCall, expectedConfig, expectedInput) { + if (!jest.isMockFunction(actualHandler)) { + return { + message: () => + `expected custom client handler to be a mock function, got ${actualHandler}`, + pass: false, + }; + } + const actualConfig = actualHandler.mock.calls[nthCall - 1]?.[0]; + const actualInput = actualHandler.mock.calls[nthCall - 1]?.[1]; + const actualConfigWithResolvedCredentials = + typeof actualConfig?.credentials === 'function' + ? { + ...actualConfig, + credentials: await actualConfig.credentials(), + } + : actualConfig; + if ( + this.equals(actualConfigWithResolvedCredentials, expectedConfig) && + this.equals(actualInput, expectedInput) + ) { + return { + message: () => '', + pass: true, + }; + } + return { + message: () => + `expected ${JSON.stringify(actualConfig)} to equal ${JSON.stringify(JSON.stringify(expectedConfig))} and ${JSON.stringify(JSON.stringify(actualInput))} to equal ${JSON.stringify(expectedInput)}`, + pass: false, + }; +}; + +expect.extend({ + toBeLastCalledWithConfigAndInput, + toHaveBeenNthCalledWithConfigAndInput, +}); + +declare global { + namespace jest { + interface AsymmetricMatchers { + toBeLastCalledWithConfigAndInput( + expectedConfig: unknown, + expectedInput: any, + ): void; + toHaveBeenNthCalledWithConfigAndInput( + nthCall: number, + expectedConfig: unknown, + expectedInput: any, + ): void; + } + interface Matchers { + /** + * Asynchronously asserts mocked custom client handler to be last called with expected config and input. + * If the actual client config has a credential that is a provider function, it will be resolved to static + * credential object and matched against the supplied config credentials. + */ + toBeLastCalledWithConfigAndInput( + expectedConfig: unknown, + expectedInput: any, + ): R; + + /** + * Asynchronously asserts mocked custom client handler to be Nth called with expected config and input. + * If the actual client config has a credential that is a provider function, it will be resolved to static + * credential object and matched against the supplied config credentials. + */ + toHaveBeenNthCalledWithConfigAndInput( + nthCall: number, + expectedConfig: unknown, + expectedInput: any, + ): R; + } + } +} diff --git a/packages/storage/__tests__/providers/s3/apis/uploadData/multipartHandlers.test.ts b/packages/storage/__tests__/providers/s3/apis/uploadData/multipartHandlers.test.ts index 302d76beaa8..bb0076d48ee 100644 --- a/packages/storage/__tests__/providers/s3/apis/uploadData/multipartHandlers.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/uploadData/multipartHandlers.test.ts @@ -20,6 +20,7 @@ import { UPLOADS_STORAGE_KEY } from '../../../../../src/providers/s3/utils/const import { byteLength } from '../../../../../src/providers/s3/apis/uploadData/byteLength'; import { CanceledError } from '../../../../../src/errors/CanceledError'; import { StorageOptions } from '../../../../../src/types'; +import '../testUtils'; jest.mock('@aws-amplify/core'); jest.mock('../../../../../src/providers/s3/utils/client'); @@ -39,12 +40,12 @@ const defaultCacheKey = '8388608_application/octet-stream_bucket_public_key'; const testPath = 'testPath/object'; const testPathCacheKey = `8388608_${defaultContentType}_${bucket}_custom_${testPath}`; -const mockCreateMultipartUpload = createMultipartUpload as jest.Mock; -const mockUploadPart = uploadPart as jest.Mock; -const mockCompleteMultipartUpload = completeMultipartUpload as jest.Mock; -const mockAbortMultipartUpload = abortMultipartUpload as jest.Mock; -const mockListParts = listParts as jest.Mock; -const mockHeadObject = headObject as jest.Mock; +const mockCreateMultipartUpload = jest.mocked(createMultipartUpload); +const mockUploadPart = jest.mocked(uploadPart); +const mockCompleteMultipartUpload = jest.mocked(completeMultipartUpload); +const mockAbortMultipartUpload = jest.mocked(abortMultipartUpload); +const mockListParts = jest.mocked(listParts); +const mockHeadObject = jest.mocked(headObject); const disableAssertion = true; @@ -54,20 +55,23 @@ const mockMultipartUploadSuccess = (disableAssertion?: boolean) => { let totalSize = 0; mockCreateMultipartUpload.mockResolvedValueOnce({ UploadId: 'uploadId', + $metadata: {} }); + // @ts-expect-error Special mock to make uploadPart return input part number mockUploadPart.mockImplementation(async (s3Config, input) => { if (!disableAssertion) { expect(input.UploadId).toEqual('uploadId'); } // mock 2 invocation of onProgress callback to simulate progress - s3Config?.onUploadProgress({ - transferredBytes: input.Body.byteLength / 2, - totalBytes: input.Body.byteLength, + const body = input.Body as ArrayBuffer; + s3Config?.onUploadProgress?.({ + transferredBytes: body.byteLength / 2, + totalBytes: body.byteLength, }); - s3Config?.onUploadProgress({ - transferredBytes: input.Body.byteLength, - totalBytes: input.Body.byteLength, + s3Config?.onUploadProgress?.({ + transferredBytes: body.byteLength, + totalBytes: body.byteLength, }); totalSize += byteLength(input.Body)!; @@ -79,9 +83,11 @@ const mockMultipartUploadSuccess = (disableAssertion?: boolean) => { }); mockCompleteMultipartUpload.mockResolvedValueOnce({ ETag: 'etag', + $metadata: {} }); mockHeadObject.mockResolvedValueOnce({ ContentLength: totalSize, + $metadata: {} }); }; @@ -90,8 +96,10 @@ const mockMultipartUploadCancellation = ( ) => { mockCreateMultipartUpload.mockImplementation(async ({ abortSignal }) => ({ UploadId: 'uploadId', + $metadata: {} })); + // @ts-expect-error Only need partial mock mockUploadPart.mockImplementation(async ({ abortSignal }, { PartNumber }) => { beforeUploadPartResponseCallback?.(); if (abortSignal?.aborted) { @@ -103,10 +111,13 @@ const mockMultipartUploadCancellation = ( }; }); - mockAbortMultipartUpload.mockResolvedValueOnce({}); + mockAbortMultipartUpload.mockResolvedValueOnce({ + $metadata: {} + }); // Mock resumed upload and completed upload successfully mockCompleteMultipartUpload.mockResolvedValueOnce({ ETag: 'etag', + $metadata: {} }); }; @@ -192,7 +203,7 @@ describe('getMultipartUploadHandlers with key', () => { options: options as StorageOptions, }); const result = await multipartUploadJob(); - expect(mockCreateMultipartUpload).toHaveBeenCalledWith( + await expect(mockCreateMultipartUpload).toBeLastCalledWithConfigAndInput( expect.objectContaining({ credentials, region, @@ -255,7 +266,7 @@ describe('getMultipartUploadHandlers with key', () => { expect(mockCreateMultipartUpload).toHaveBeenCalledTimes(1); expect(mockUploadPart).toHaveBeenCalledTimes(10_000); expect(mockCompleteMultipartUpload).toHaveBeenCalledTimes(1); - expect(mockUploadPart.mock.calls[0][1].Body.byteLength).toEqual(10 * MB); // The part size should be adjusted from default 5MB to 10MB. + expect((mockUploadPart.mock.calls[0][1].Body as ArrayBuffer).byteLength).toEqual(10 * MB); // The part size should be adjusted from default 5MB to 10MB. }); it('should throw error when remote and local file sizes do not match upon completed upload', async () => { @@ -264,6 +275,7 @@ describe('getMultipartUploadHandlers with key', () => { mockHeadObject.mockReset(); mockHeadObject.mockResolvedValue({ ContentLength: 1, + $metadata: {} }); const { multipartUploadJob } = getMultipartUploadHandlers( @@ -315,6 +327,7 @@ describe('getMultipartUploadHandlers with key', () => { mockUploadPart.mockReset(); mockUploadPart.mockResolvedValueOnce({ ETag: `etag-1`, + // @ts-expect-error Special mock to make uploadPart return input part number. PartNumber: 1, }); mockUploadPart.mockRejectedValueOnce(new Error('error')); @@ -367,7 +380,7 @@ describe('getMultipartUploadHandlers with key', () => { }), ); mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -385,7 +398,7 @@ describe('getMultipartUploadHandlers with key', () => { it('should cache the upload with file including file lastModified property', async () => { mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -420,7 +433,7 @@ describe('getMultipartUploadHandlers with key', () => { }), ); mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -438,7 +451,7 @@ describe('getMultipartUploadHandlers with key', () => { it('should cache upload task if new upload task is created', async () => { mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -465,7 +478,7 @@ describe('getMultipartUploadHandlers with key', () => { it('should remove from cache if upload task is completed', async () => { mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -487,7 +500,7 @@ describe('getMultipartUploadHandlers with key', () => { it('should remove from cache if upload task is canceled', async () => { expect.assertions(2); mockMultipartUploadSuccess(disableAssertion); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob, onCancel } = getMultipartUploadHandlers( { @@ -609,6 +622,7 @@ describe('getMultipartUploadHandlers with key', () => { ); mockListParts.mockResolvedValue({ Parts: [{ PartNumber: 1 }], + $metadata: {}, }); const onProgress = jest.fn(); @@ -698,7 +712,7 @@ describe('getMultipartUploadHandlers with path', () => { data: twoPartsPayload, }); const result = await multipartUploadJob(); - expect(mockCreateMultipartUpload).toHaveBeenCalledWith( + await expect(mockCreateMultipartUpload).toBeLastCalledWithConfigAndInput( expect.objectContaining({ credentials, region, @@ -761,7 +775,7 @@ describe('getMultipartUploadHandlers with path', () => { expect(mockCreateMultipartUpload).toHaveBeenCalledTimes(1); expect(mockUploadPart).toHaveBeenCalledTimes(10_000); expect(mockCompleteMultipartUpload).toHaveBeenCalledTimes(1); - expect(mockUploadPart.mock.calls[0][1].Body.byteLength).toEqual(10 * MB); // The part size should be adjusted from default 5MB to 10MB. + expect((mockUploadPart.mock.calls[0][1].Body as ArrayBuffer).byteLength).toEqual(10 * MB); // The part size should be adjusted from default 5MB to 10MB. }); it('should throw error when remote and local file sizes do not match upon completed upload', async () => { @@ -770,6 +784,7 @@ describe('getMultipartUploadHandlers with path', () => { mockHeadObject.mockReset(); mockHeadObject.mockResolvedValue({ ContentLength: 1, + $metadata: {} }); const { multipartUploadJob } = getMultipartUploadHandlers( @@ -821,6 +836,7 @@ describe('getMultipartUploadHandlers with path', () => { mockUploadPart.mockReset(); mockUploadPart.mockResolvedValueOnce({ ETag: `etag-1`, + // @ts-expect-error Special mock to make uploadPart return input part number. PartNumber: 1, }); mockUploadPart.mockRejectedValueOnce(new Error('error')); @@ -873,7 +889,7 @@ describe('getMultipartUploadHandlers with path', () => { }), ); mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -891,7 +907,7 @@ describe('getMultipartUploadHandlers with path', () => { it('should cache the upload with file including file lastModified property', async () => { mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -929,7 +945,7 @@ describe('getMultipartUploadHandlers with path', () => { }), ); mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -947,7 +963,7 @@ describe('getMultipartUploadHandlers with path', () => { it('should cache upload task if new upload task is created', async () => { mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -972,7 +988,7 @@ describe('getMultipartUploadHandlers with path', () => { it('should remove from cache if upload task is completed', async () => { mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -994,7 +1010,7 @@ describe('getMultipartUploadHandlers with path', () => { it('should remove from cache if upload task is canceled', async () => { expect.assertions(2); mockMultipartUploadSuccess(disableAssertion); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob, onCancel } = getMultipartUploadHandlers( { @@ -1116,6 +1132,7 @@ describe('getMultipartUploadHandlers with path', () => { ); mockListParts.mockResolvedValue({ Parts: [{ PartNumber: 1 }], + $metadata: {} }); const onProgress = jest.fn(); diff --git a/packages/storage/__tests__/providers/s3/apis/uploadData/putObjectJob.test.ts b/packages/storage/__tests__/providers/s3/apis/uploadData/putObjectJob.test.ts index b03822946da..a9e0bc76452 100644 --- a/packages/storage/__tests__/providers/s3/apis/uploadData/putObjectJob.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/uploadData/putObjectJob.test.ts @@ -6,6 +6,7 @@ import { Amplify } from '@aws-amplify/core'; import { putObject } from '../../../../../src/providers/s3/utils/client'; import { calculateContentMd5 } from '../../../../../src/providers/s3/utils'; import { putObjectJob } from '../../../../../src/providers/s3/apis/uploadData/putObjectJob'; +import '../testUtils'; jest.mock('../../../../../src/providers/s3/utils/client'); jest.mock('../../../../../src/providers/s3/utils', () => { @@ -33,14 +34,14 @@ const credentials: AWSCredentials = { secretAccessKey: 'secretAccessKey', }; const identityId = 'identityId'; -const mockFetchAuthSession = Amplify.Auth.fetchAuthSession as jest.Mock; -const mockPutObject = putObject as jest.Mock; +const mockFetchAuthSession = jest.mocked(Amplify.Auth.fetchAuthSession); +const mockPutObject = jest.mocked(putObject); mockFetchAuthSession.mockResolvedValue({ credentials, identityId, }); -(Amplify.getConfig as jest.Mock).mockReturnValue({ +jest.mocked(Amplify.getConfig).mockReturnValue({ Storage: { S3: { bucket: 'bucket', @@ -51,6 +52,7 @@ mockFetchAuthSession.mockResolvedValue({ mockPutObject.mockResolvedValue({ ETag: 'eTag', VersionId: 'versionId', + $metadata: {} }); /* TODO Remove suite when `key` parameter is removed */ @@ -90,7 +92,7 @@ describe('putObjectJob with key', () => { metadata: { key: 'value' }, size: undefined, }); - expect(mockPutObject).toHaveBeenCalledWith( + await expect(mockPutObject).toBeLastCalledWithConfigAndInput( { credentials, region: 'region', @@ -178,7 +180,7 @@ describe('putObjectJob with path', () => { metadata: { key: 'value' }, size: undefined, }); - expect(mockPutObject).toHaveBeenCalledWith( + await expect(mockPutObject).toBeLastCalledWithConfigAndInput( { credentials, region: 'region', diff --git a/packages/storage/__tests__/providers/s3/apis/utils/resolveS3ConfigAndInput.test.ts b/packages/storage/__tests__/providers/s3/apis/utils/resolveS3ConfigAndInput.test.ts index 5db56bed1ed..4ec756710b8 100644 --- a/packages/storage/__tests__/providers/s3/apis/utils/resolveS3ConfigAndInput.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/utils/resolveS3ConfigAndInput.test.ts @@ -54,17 +54,19 @@ describe('resolveS3ConfigAndInput', () => { }); it('should call fetchAuthSession with forceRefresh false for credentials and identityId', async () => { - await resolveS3ConfigAndInput(Amplify, {}); - expect(mockFetchAuthSession).toHaveBeenCalledWith({ - forceRefresh: false, - }); + const { s3Config: { credentials } } = await resolveS3ConfigAndInput(Amplify, {}); + expect(credentials).toBeInstanceOf(Function); + expect(mockFetchAuthSession).toHaveBeenCalled(); }); it('should throw if credentials are not available', async () => { - mockFetchAuthSession.mockResolvedValueOnce({ + mockFetchAuthSession.mockResolvedValue({ identityId: targetIdentityId, }); - await expect(resolveS3ConfigAndInput(Amplify, {})).rejects.toMatchObject( + const { s3Config: { credentials } } = await resolveS3ConfigAndInput(Amplify, {}); + expect(credentials).toBeInstanceOf(Function); + // @ts-expect-error Already validated credentials being function. + await expect(credentials()).rejects.toMatchObject( validationErrorMap[StorageValidationErrorCode.NoCredentials], ); }); diff --git a/packages/storage/__tests__/providers/s3/utils/client/S3/cases/shared.ts b/packages/storage/__tests__/providers/s3/utils/client/S3/cases/shared.ts index 6a1d8431825..83acecf2bd8 100644 --- a/packages/storage/__tests__/providers/s3/utils/client/S3/cases/shared.ts +++ b/packages/storage/__tests__/providers/s3/utils/client/S3/cases/shared.ts @@ -19,8 +19,16 @@ export const expectedMetadata = { export const defaultConfig = { region: 'us-east-1', + credentials: async () => ({ + accessKeyId: 'key', + secretAccessKey: 'secret', + }), +}; + +export const defaultConfigWithStaticCredentials = { + ...defaultConfig, credentials: { accessKeyId: 'key', secretAccessKey: 'secret', }, -}; +} diff --git a/packages/storage/__tests__/providers/s3/utils/client/S3/getPresignedGetObjectUrl.test.ts b/packages/storage/__tests__/providers/s3/utils/client/S3/getPresignedGetObjectUrl.test.ts index 0108bd37084..b67bb871b3f 100644 --- a/packages/storage/__tests__/providers/s3/utils/client/S3/getPresignedGetObjectUrl.test.ts +++ b/packages/storage/__tests__/providers/s3/utils/client/S3/getPresignedGetObjectUrl.test.ts @@ -3,7 +3,7 @@ import { presignUrl } from '@aws-amplify/core/internals/aws-client-utils'; import { getPresignedGetObjectUrl } from '../../../../../../src/providers/s3/utils/client'; -import { defaultConfig } from './cases/shared'; +import { defaultConfigWithStaticCredentials } from './cases/shared'; jest.mock('@aws-amplify/core/internals/aws-client-utils', () => { const original = jest.requireActual( @@ -23,8 +23,8 @@ describe('serializeGetObjectRequest', () => { it('should return get object API request', async () => { const actual = await getPresignedGetObjectUrl( { - ...defaultConfig, - signingRegion: defaultConfig.region, + ...defaultConfigWithStaticCredentials, + signingRegion: defaultConfigWithStaticCredentials.region, signingService: 's3', expiration: 900, userAgentValue: 'UA', @@ -36,7 +36,7 @@ describe('serializeGetObjectRequest', () => { ); const actualUrl = actual; expect(actualUrl.hostname).toEqual( - `bucket.s3.${defaultConfig.region}.amazonaws.com`, + `bucket.s3.${defaultConfigWithStaticCredentials.region}.amazonaws.com`, ); expect(actualUrl.pathname).toEqual('/key'); expect(actualUrl.searchParams.get('X-Amz-Expires')).toEqual('900'); @@ -49,8 +49,8 @@ describe('serializeGetObjectRequest', () => { it('should call presignUrl with uriEscapePath param set to false', async () => { await getPresignedGetObjectUrl( { - ...defaultConfig, - signingRegion: defaultConfig.region, + ...defaultConfigWithStaticCredentials, + signingRegion: defaultConfigWithStaticCredentials.region, signingService: 's3', expiration: 900, userAgentValue: 'UA', diff --git a/packages/storage/src/providers/s3/apis/internal/getUrl.ts b/packages/storage/src/providers/s3/apis/internal/getUrl.ts index a2de5d3f770..4f866ef80b3 100644 --- a/packages/storage/src/providers/s3/apis/internal/getUrl.ts +++ b/packages/storage/src/providers/s3/apis/internal/getUrl.ts @@ -46,7 +46,11 @@ export const getUrl = async ( let urlExpirationInSec = getUrlOptions?.expiresIn ?? DEFAULT_PRESIGN_EXPIRATION; - const awsCredExpiration = s3Config.credentials?.expiration; + const resolvedCredential = + typeof s3Config.credentials === 'function' + ? await s3Config.credentials() + : s3Config.credentials; + const awsCredExpiration = resolvedCredential.expiration; if (awsCredExpiration) { const awsCredExpirationInSec = Math.floor( (awsCredExpiration.getTime() - Date.now()) / 1000, @@ -64,6 +68,7 @@ export const getUrl = async ( url: await getPresignedGetObjectUrl( { ...s3Config, + credentials: resolvedCredential, expiration: urlExpirationInSec, }, { diff --git a/packages/storage/src/providers/s3/types/options.ts b/packages/storage/src/providers/s3/types/options.ts index 4d0af341f52..b2b7dfd0ddc 100644 --- a/packages/storage/src/providers/s3/types/options.ts +++ b/packages/storage/src/providers/s3/types/options.ts @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 import { StorageAccessLevel } from '@aws-amplify/core'; -import { AWSCredentials } from '@aws-amplify/core/internals/utils'; +import { SigningOptions } from '@aws-amplify/core/internals/aws-client-utils'; import { TransferProgressEvent } from '../../../types'; import { @@ -176,9 +176,8 @@ export type CopyDestinationOptionsWithKey = WriteOptions & { * * @internal */ -export interface ResolvedS3Config { - region: string; - credentials: AWSCredentials; +export interface ResolvedS3Config + extends Pick { customEndpoint?: string; forcePathStyle?: boolean; useAccelerateEndpoint?: boolean; diff --git a/packages/storage/src/providers/s3/utils/resolveS3ConfigAndInput.ts b/packages/storage/src/providers/s3/utils/resolveS3ConfigAndInput.ts index 701c046d52f..a9c5c06b3c5 100644 --- a/packages/storage/src/providers/s3/utils/resolveS3ConfigAndInput.ts +++ b/packages/storage/src/providers/s3/utils/resolveS3ConfigAndInput.ts @@ -39,16 +39,30 @@ export const resolveS3ConfigAndInput = async ( amplify: AmplifyClassV6, apiOptions?: S3ApiOptions, ): Promise => { - // identityId is always cached in memory if forceRefresh is not set. So we can safely make calls here. - const { credentials, identityId } = await amplify.Auth.fetchAuthSession({ - forceRefresh: false, - }); - assertValidationError( - !!credentials, - StorageValidationErrorCode.NoCredentials, - ); + /** + * identityId is always cached in memory if forceRefresh is not set. So we can + * safely make calls here. The identityId should be stable even for + * unauthenticated users, regardless of credentials. + */ + const { identityId } = await amplify.Auth.fetchAuthSession(); assertValidationError(!!identityId, StorageValidationErrorCode.NoIdentityId); + /** + * A credentials provider function instead of a static credentials object is + * used because the long-running tasks like multipart upload may span over the + * credentials expiry. Auth.fetchAuthSession() automatically refreshes the + * credentials if they are expired. + */ + const credentialsProvider = async () => { + const { credentials } = await amplify.Auth.fetchAuthSession(); + assertValidationError( + !!credentials, + StorageValidationErrorCode.NoCredentials, + ); + + return credentials; + }; + const { bucket, region, dangerouslyConnectToHttpEndpointForTesting } = amplify.getConfig()?.Storage?.S3 ?? {}; assertValidationError(!!bucket, StorageValidationErrorCode.NoBucket); @@ -72,7 +86,7 @@ export const resolveS3ConfigAndInput = async ( return { s3Config: { - credentials, + credentials: credentialsProvider, region, useAccelerateEndpoint: apiOptions?.useAccelerateEndpoint, ...(dangerouslyConnectToHttpEndpointForTesting