diff --git a/tap_square/__init__.py b/tap_square/__init__.py index 4a8311c4..e7213730 100644 --- a/tap_square/__init__.py +++ b/tap_square/__init__.py @@ -20,7 +20,7 @@ def main(): if args.discover: write_catalog(catalog) else: - sync(args.config, args.state, catalog) + sync(args.config, args.config_path, args.state, catalog) if __name__ == '__main__': main() diff --git a/tap_square/client.py b/tap_square/client.py index 5fd54991..2120309c 100644 --- a/tap_square/client.py +++ b/tap_square/client.py @@ -1,15 +1,15 @@ from datetime import timedelta import urllib.parse -import os +import json +import requests from square.client import Client from singer import utils import singer -import requests import backoff LOGGER = singer.get_logger() - +REFRESH_TOKEN_BEFORE = 22 def get_batch_token_from_headers(headers): link = headers.get('link') @@ -39,43 +39,91 @@ def log_backoff(details): LOGGER.warning('Error receiving data from square. Sleeping %.1f seconds before trying again', details['wait']) +def write_config(config, config_path, data): + ''' + Updates the provided filepath with json format of the `data` object + ''' + config.update(data) + with open(config_path, "w") as tap_config: + json.dump(config, tap_config, indent=2) + return config + + +def require_new_access_token(access_token, client): + ''' + Checks if the access token needs to be refreshed + ''' + # If there is no access token, we need to generate a new one + if not access_token: + return True + + authorization = f"Bearer {access_token}" + + with singer.http_request_timer('Check access token expiry'): + response = client.o_auth.retrieve_token_status(authorization) + + if response.is_error(): + error_message = response.errors if response.errors else response.body + LOGGER.error(error_message) + return True + + # Parse the token expiry date + token_expiry_date = singer.utils.strptime_with_tz(response.body['expires_at']) + now = utils.now() + return (token_expiry_date - now).days <= REFRESH_TOKEN_BEFORE + + class RetryableError(Exception): pass class SquareClient(): - def __init__(self, config): + def __init__(self, config, config_path): self._refresh_token = config['refresh_token'] self._client_id = config['client_id'] self._client_secret = config['client_secret'] self._environment = 'sandbox' if config.get('sandbox') == 'true' else 'production' - self._access_token = self._get_access_token() + self._access_token = self._get_access_token(config, config_path) self._client = Client(access_token=self._access_token, environment=self._environment) - def _get_access_token(self): - if "TAP_SQUARE_ACCESS_TOKEN" in os.environ.keys(): - LOGGER.info("Using access token from environment, not creating the new one") - return os.environ["TAP_SQUARE_ACCESS_TOKEN"] + def _get_access_token(self, config, config_path): + ''' + Retrieves the access token from the config file. If the access token is expired, it will refresh it. + Otherwise, it will return the cached access token. + ''' + access_token = config.get("access_token") + client = Client(environment=self._environment) - body = { - 'client_id': self._client_id, - 'client_secret': self._client_secret, - 'grant_type': 'refresh_token', - 'refresh_token': self._refresh_token - } + # Check if the access token needs to be refreshed + if require_new_access_token(access_token, client): + LOGGER.info('Refreshing access token...') + body = { + 'client_id': self._client_id, + 'client_secret': self._client_secret, + 'grant_type': 'refresh_token', + 'refresh_token': self._refresh_token + } - client = Client(environment=self._environment) + with singer.http_request_timer('GET access token'): + result = client.o_auth.obtain_token(body) - with singer.http_request_timer('GET access token'): - result = client.o_auth.obtain_token(body) + if result.is_error(): + error_message = result.errors if result.errors else result.body + raise RuntimeError(error_message) - if result.is_error(): - error_message = result.errors if result.errors else result.body - raise RuntimeError(error_message) + access_token = result.body['access_token'] + write_config( + config, + config_path, + { + 'access_token': access_token, + 'refresh_token': result.body['refresh_token'], + }, + ) - return result.body['access_token'] + return access_token @staticmethod @backoff.on_exception( @@ -158,6 +206,10 @@ def get_customers(self, start_time, end_time): 'end_at': end_time # Exclusive on end_at } }, + 'sort': { + 'field': 'CREATED_AT', + 'order': 'ASC' + } } } diff --git a/tap_square/sync.py b/tap_square/sync.py index b69a1a3c..b2480e7f 100644 --- a/tap_square/sync.py +++ b/tap_square/sync.py @@ -6,8 +6,8 @@ LOGGER = singer.get_logger() -def sync(config, state, catalog): # pylint: disable=too-many-statements - client = SquareClient(config) +def sync(config, config_path, state, catalog): # pylint: disable=too-many-statements + client = SquareClient(config, config_path) with Transformer() as transformer: for stream in catalog.get_selected_streams(state): diff --git a/tests/base.py b/tests/base.py index 3ff765e6..b377543e 100644 --- a/tests/base.py +++ b/tests/base.py @@ -42,17 +42,19 @@ class TestSquareBase(ABC, TestCase): STATIC_START_DATE = "2020-07-13T00:00:00Z" START_DATE = "" PRODUCTION_ONLY_STREAMS = {'bank_accounts', 'payouts'} + TEST_NAME_SANDBOX = 'tap_tester_square_sandbox_tests' + TEST_NAME_PROD = 'tap_tester_square_prod_tests' + test_name = TEST_NAME_SANDBOX DEFAULT_BATCH_LIMIT = 1000 API_LIMIT = { 'items': DEFAULT_BATCH_LIMIT, - 'inventories': 100, + 'inventories': DEFAULT_BATCH_LIMIT, 'categories': DEFAULT_BATCH_LIMIT, 'discounts': DEFAULT_BATCH_LIMIT, 'taxes': DEFAULT_BATCH_LIMIT, 'cash_drawer_shifts': DEFAULT_BATCH_LIMIT, 'locations': None, # Api does not accept a cursor and documents no limit, see https://developer.squareup.com/reference/square/locations/list-locations - 'roles': 100, 'refunds': 100, 'payments': 100, 'payouts': 100, @@ -83,15 +85,37 @@ def get_type(): def tap_name(): return "tap-square" + @staticmethod + def name(): + return TestSquareBaseParent.TestSquareBase.test_name + def set_environment(self, env): """ Change the Square App Environmnet. Requires re-instatiating TestClient and setting env var. """ os.environ['TAP_SQUARE_ENVIRONMENT'] = env + self.set_access_token_in_env() self.client = TestClient(env=env) self.SQUARE_ENVIRONMENT = env + def set_access_token_in_env(self): + ''' + Fetch the access token from the existing connection and set it in the env. + This is used to avoid rate limiting issues when running tests. + ''' + existing_connections = connections.fetch_existing_connections(self) + if not existing_connections: + os.environ.pop('TAP_SQUARE_ACCESS_TOKEN', None) + return + + conn_with_creds = connections.fetch_existing_connection_with_creds(existing_connections[0]['id']) + access_token = conn_with_creds['credentials'].get('access_token') + if not access_token: + LOGGER.warning('No access token found in env') + else: + os.environ['TAP_SQUARE_ACCESS_TOKEN'] = access_token + @staticmethod def get_environment(): return os.environ['TAP_SQUARE_ENVIRONMENT'] @@ -116,12 +140,23 @@ def get_credentials(): 'refresh_token': os.getenv('TAP_SQUARE_REFRESH_TOKEN') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_REFRESH_TOKEN'), 'client_id': os.getenv('TAP_SQUARE_APPLICATION_ID') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_APPLICATION_ID'), 'client_secret': os.getenv('TAP_SQUARE_APPLICATION_SECRET') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_APPLICATION_SECRET'), + 'access_token': os.environ['TAP_SQUARE_ACCESS_TOKEN'] } else: raise Exception("Square Environment: {} is not supported.".format(environment)) return creds + @staticmethod + def preserve_access_token(existing_conns, payload): + '''This method is used get the access token from an existing refresh token''' + if not existing_conns: + return payload + + conn_with_creds = connections.fetch_existing_connection_with_creds(existing_conns[0]['id']) + payload['properties']['access_token'] = conn_with_creds['credentials'].get('access_token') + return payload + def expected_check_streams(self): return set(self.expected_metadata().keys()).difference(set()) @@ -490,7 +525,8 @@ def create_test_data(self, testable_streams, start_date, start_date_2=None, min_ else: raise NotImplementedError("created_records unknown type: {}".format(created_records)) - print("Adjust expectations for stream: {}".format(stream)) + stream_to_expected_records[stream] = self.client.get_all(stream, start_date) + LOGGER.info('Adjust expectations for stream: {}'.format(stream)) self.modify_expected_records(stream_to_expected_records[stream]) return stream_to_expected_records @@ -546,7 +582,7 @@ def run_and_verify_check_mode(self, conn_id): found_catalog_names = found_catalog_names - {'settlements'} diff = self.expected_check_streams().symmetric_difference(found_catalog_names) self.assertEqual(len(diff), 0, msg="discovered schemas do not match: {}".format(diff)) - print("discovered schemas are OK") + LOGGER.info("discovered schemas are OK") return found_catalogs @@ -568,7 +604,7 @@ def perform_and_verify_table_and_field_selection(self, conn_id, found_catalogs, catalog_entry = menagerie.get_annotated_schema(conn_id, cat['stream_id']) # Verify all testable streams are selected selected = catalog_entry.get('annotated-schema').get('selected') - print("Validating selection on {}: {}".format(cat['stream_name'], selected)) + LOGGER.info('Validating selection on {}: {}'.format(cat['stream_name'], selected)) if cat['stream_name'] not in streams_to_select: self.assertFalse(selected, msg="Stream selected, but not testable.") continue # Skip remaining assertions if we aren't selecting this stream @@ -578,7 +614,7 @@ def perform_and_verify_table_and_field_selection(self, conn_id, found_catalogs, # Verify all fields within each selected stream are selected for field, field_props in catalog_entry.get('annotated-schema').get('properties').items(): field_selected = field_props.get('selected') - print("\tValidating selection on {}.{}: {}".format(cat['stream_name'], field, field_selected)) + LOGGER.info('\tValidating selection on {}.{}: {}'.format(cat['stream_name'], field, field_selected)) self.assertTrue(field_selected, msg="Field not selected.") else: # Verify only automatic fields are selected @@ -662,7 +698,7 @@ def assertRecordsEqual(self, stream, expected_record, sync_record): Certain Square streams cannot be compared directly with assertDictEqual(). So we handle that logic here. """ - if stream not in ['refunds', 'orders', 'customers']: + if stream not in ['refunds', 'orders', 'customers', 'locations']: expected_record.pop('created_at', None) if stream == 'payments': self.assertDictEqualWithOffKeys(expected_record, sync_record, {'updated_at'}) diff --git a/tests/test_all_fields.py b/tests/test_all_fields.py index 2cdecf04..614c165f 100644 --- a/tests/test_all_fields.py +++ b/tests/test_all_fields.py @@ -1,11 +1,11 @@ import os from collections import namedtuple - +import singer import tap_tester.runner as runner import tap_tester.connections as connections from base import TestSquareBaseParent, DataType - +LOGGER = singer.get_logger() PaymentRecordDetails = namedtuple('PaymentRecordDetails', 'source_key, autocomplete, record') @@ -14,10 +14,6 @@ class TestSquareAllFields(TestSquareBaseParent.TestSquareBase): """Test that with all fields selected for a stream we replicate data as expected""" TESTABLE_STREAMS = set() - @staticmethod - def name(): - return "tap_tester_square_all_fields" - def testable_streams_dynamic(self): return self.dynamic_data_streams().difference(self.untestable_streams()) @@ -43,7 +39,7 @@ def ensure_dict_object(self, resp_object): def create_specific_payments(self): """Create a record using each source type, and a record that will autocomplete.""" - print("Creating a record using each source type, and the autocomplete flag.") + LOGGER.info('Creating a record using each source type, and the autocomplete flag.') payment_records = [] descriptions = { ("card", False), @@ -61,7 +57,7 @@ def create_specific_payments(self): def update_specific_payments(self, payments_to_update): """Perform specifc updates on specific payment records.""" updated_records = [] - print("Updating payment records by completing, canceling and refunding them.") + LOGGER.info('Updating payment records by completing, canceling and refunding them.') # Update a completed payment by making a refund (payments must have a status of 'COMPLETED' to process a refund) source_key, autocomplete = ("card", True) description = "refund" @@ -90,26 +86,28 @@ def update_specific_payments(self, payments_to_update): @classmethod def tearDownClass(cls): - print("\n\nTEST TEARDOWN\n\n") + LOGGER.info('\n\nTEST TEARDOWN\n\n') def test_run(self): """Instantiate start date according to the desired data set and run the test""" - print("\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.START_DATE = self.get_properties().get('start_date') self.TESTABLE_STREAMS = self.testable_streams_dynamic().difference(self.production_streams()) self.all_fields_test(self.SANDBOX, DataType.DYNAMIC) - print("\n\nTESTING WITH STATIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\nTESTING WITH STATIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.START_DATE = self.STATIC_START_DATE self.TESTABLE_STREAMS = self.testable_streams_static().difference(self.production_streams()) self.all_fields_test(self.SANDBOX, DataType.STATIC) + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_PROD self.set_environment(self.PRODUCTION) - print("\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.START_DATE = self.get_properties().get('start_date') self.TESTABLE_STREAMS = self.testable_streams_dynamic().difference(self.sandbox_streams()) self.all_fields_test(self.PRODUCTION, DataType.DYNAMIC) + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_SANDBOX def all_fields_test(self, environment, data_type): """ @@ -117,8 +115,8 @@ def all_fields_test(self, environment, data_type): and only the automatic fields are replicated. """ - print("\n\nRUNNING {}".format(self.name())) - print("WITH STREAMS: {}\n\n".format(self.TESTABLE_STREAMS)) + LOGGER.info('\n\nRUNNING {}_all_fields'.format(self.name())) + LOGGER.info('WITH STREAMS: {}\n\n'.format(self.TESTABLE_STREAMS)) # execute specific creates and updates for the payments stream in addition to the standard create if 'payments' in self.TESTABLE_STREAMS: @@ -129,7 +127,7 @@ def all_fields_test(self, environment, data_type): expected_records = self.create_test_data(self.TESTABLE_STREAMS, self.START_DATE, force_create_records=True) # instantiate connection - conn_id = connections.ensure_connection(self) + conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token) # run check mode found_catalogs = self.run_and_verify_check_mode(conn_id) @@ -150,7 +148,7 @@ def all_fields_test(self, environment, data_type): for stream, count in first_record_count_by_stream.items(): assert stream in self.expected_streams() self.assertGreater(count, 0, msg="failed to replicate any data for: {}".format(stream)) - print("total replicated row count: {}".format(replicated_row_count)) + LOGGER.info('total replicated row count: {}'.format(replicated_row_count)) MISSING_FROM_EXPECTATIONS = { # this is acceptable, we can't generate test data for EVERYTHING 'modifier_lists': {'absent_at_location_ids'}, @@ -165,7 +163,7 @@ def all_fields_test(self, environment, data_type): }, 'discounts': {'absent_at_location_ids'}, 'taxes': {'absent_at_location_ids'}, - 'customers': {'birthday', 'tax_ids', 'group_ids', 'reference_id', 'version', 'segment_ids'}, + 'customers': {'birthday', 'tax_ids', 'group_ids', 'reference_id', 'version', 'segment_ids', 'phone_number'}, 'payments': { 'customer_id', 'reference_id', 'cash_details', 'tip_money', 'external_details', 'device_details', @@ -191,7 +189,7 @@ def all_fields_test(self, environment, data_type): 'category_data', 'amount_money', 'processing_fee', 'refund_ids', 'delayed_until', 'delay_duration', 'delay_action', 'note', 'status', 'order_id', 'type', 'source_type', 'payment_id', 'tax_data', 'receipt_number', 'receipt_url', - 'discount_data', 'refunded_money', 'present_at_all_locations', 'card_details','returned_quantities' + 'discount_data', 'refunded_money', 'present_at_all_locations', 'card_details', 'is_deleted', 'reason'}, 'discounts': {'created_at'}, 'items': {'created_at'}, diff --git a/tests/test_automatic_fields.py b/tests/test_automatic_fields.py index 0ce83d57..80af8460 100644 --- a/tests/test_automatic_fields.py +++ b/tests/test_automatic_fields.py @@ -1,18 +1,14 @@ import os - +import singer import tap_tester.connections as connections import tap_tester.runner as runner from base import TestSquareBaseParent, DataType - +LOGGER = singer.get_logger() class TestAutomaticFields(TestSquareBaseParent.TestSquareBase): """Test that with no fields selected for a stream automatic fields are still replicated""" - @staticmethod - def name(): - return "tap_tester_square_automatic_fields" - def testable_streams_dynamic(self): return self.dynamic_data_streams().difference(self.untestable_streams()).difference({ 'inventories', # No PK or rep key so no automatic fields to check @@ -23,22 +19,24 @@ def testable_streams_static(self): def test_run(self): """Instantiate start date according to the desired data set and run the test""" - print("\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.START_DATE = self.get_properties().get('start_date') self.TESTABLE_STREAMS = self.testable_streams_dynamic().difference(self.production_streams()) - {'customers', 'team_members'} self.auto_fields_test(self.SANDBOX, DataType.DYNAMIC) - print("\n\nTESTING WITH STATIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\nTESTING WITH STATIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.START_DATE = self.STATIC_START_DATE self.TESTABLE_STREAMS = self.testable_streams_static().difference(self.production_streams()) - {'customers', 'team_members'} self.auto_fields_test(self.SANDBOX, DataType.STATIC) + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_PROD self.set_environment(self.PRODUCTION) - print("\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.START_DATE = self.get_properties().get('start_date') self.TESTABLE_STREAMS = self.testable_streams_dynamic().difference(self.sandbox_streams()) self.auto_fields_test(self.PRODUCTION, DataType.DYNAMIC) + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_SANDBOX def auto_fields_test(self, environment, data_type): """ @@ -46,8 +44,8 @@ def auto_fields_test(self, environment, data_type): and only the automatic fields are replicated. """ - print("\n\nRUNNING {}".format(self.name())) - print("WITH STREAMS: {}\n\n".format(self.TESTABLE_STREAMS)) + LOGGER.info('\n\nRUNNING {}_automatic_fields'.format(self.name())) + LOGGER.info('WITH STREAMS: {}\n\n'.format(self.TESTABLE_STREAMS)) # ensure data exists and set expectatinos with automatic fields only expected_records_all_fields = self.create_test_data(self.TESTABLE_STREAMS, self.START_DATE) @@ -61,7 +59,7 @@ def auto_fields_test(self, environment, data_type): ) # instantiate connection - conn_id = connections.ensure_connection(self) + conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token) # run check mode found_catalogs = self.run_and_verify_check_mode(conn_id) @@ -82,7 +80,7 @@ def auto_fields_test(self, environment, data_type): for stream, count in first_record_count_by_stream.items(): assert stream in self.expected_streams() self.assertGreater(count, 0, msg="failed to replicate any data for: {}".format(stream)) - print("total replicated row count: {}".format(replicated_row_count)) + LOGGER.info('total replicated row count: {}'.format(replicated_row_count)) for stream in self.TESTABLE_STREAMS: with self.subTest(stream=stream): diff --git a/tests/test_bookmarks.py b/tests/test_bookmarks.py index ad9ace1b..775fd3a9 100644 --- a/tests/test_bookmarks.py +++ b/tests/test_bookmarks.py @@ -13,9 +13,6 @@ class TestSquareIncrementalReplication(TestSquareBaseParent.TestSquareBase): - @staticmethod - def name(): - return "tap_tester_square_incremental_replication" def testable_streams_dynamic(self): return self.dynamic_data_streams().difference(self.untestable_streams()) @@ -34,7 +31,7 @@ def cannot_update_streams(): @classmethod def tearDownClass(cls): - print("\n\nTEST TEARDOWN\n\n") + LOGGER.info('\n\nTEST TEARDOWN\n\n') def run_sync(self, conn_id): """ @@ -61,15 +58,17 @@ def test_run(self): self.START_DATE = self.get_properties().get('start_date') - print("\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.bookmarks_test(self.testable_streams_dynamic().intersection(self.sandbox_streams())) + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_PROD self.set_environment(self.PRODUCTION) production_testable_streams = self.testable_streams_dynamic().intersection(self.production_streams()) if production_testable_streams: - print("\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.bookmarks_test(production_testable_streams) + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_SANDBOX def bookmarks_test(self, testable_streams): """ @@ -84,13 +83,13 @@ def bookmarks_test(self, testable_streams): For EACH stream that is incrementally replicated there are multiple rows of data with different values for the replication key """ - print("\n\nRUNNING {}\n\n".format(self.name())) + LOGGER.info('\n\nRUNNING {}_bookmark\n\n'.format(self.name())) # Ensure tested streams have existing records expected_records_first_sync = self.create_test_data(testable_streams, self.START_DATE, force_create_records=True) # Instantiate connection with default start - conn_id = connections.ensure_connection(self) + conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token) # run in check mode check_job_name = runner.run_check_mode(self, conn_id) @@ -249,7 +248,6 @@ def bookmarks_test(self, testable_streams): assert updated_record, "Failed to update a {} record".format('payments') assert len(updated_record) == 1, "Updated too many {} records".format('payments') - expected_records_second_sync['payments'] += updated_record[0] updated_records['payments'] += updated_record[0] # adjust expectations for full table streams to include the expected records from sync 1 @@ -268,9 +266,9 @@ def bookmarks_test(self, testable_streams): # Adjust expectations for datetime format for record_desc, records in [("created", created_records), ("updated", updated_records), ("2nd sync expected records", expected_records_second_sync)]: - print("Adjusting epxectations for {} records".format(record_desc)) + LOGGER.info('Adjusting epxectations for {} records'.format(record_desc)) for stream, expected_records in records.items(): - print("\tadjusting for stream: {}".format(stream)) + LOGGER.info('\tadjusting for stream: {}'.format(stream)) self.modify_expected_records(expected_records) # ensure validity of expected_records_second_sync @@ -319,7 +317,7 @@ def bookmarks_test(self, testable_streams): PARENT_FIELD_MISSING_SUBFIELDS = {'payments': {'card_details'}} # BUG_2 | https://stitchdata.atlassian.net/browse/SRCE-5143 - MISSING_FROM_SCHEMA = {'payments': {'capabilities', 'version_token', 'approved_money'}} + MISSING_FROM_SCHEMA = {'payments': {'capabilities', 'version_token', 'approved_money', 'refund_ids', 'refunded_money', 'processing_fee'}} # Loop first_sync_records and compare against second_sync_records diff --git a/tests/test_bookmarks_cursor.py b/tests/test_bookmarks_cursor.py index d19a8e8b..8ef3bfdf 100644 --- a/tests/test_bookmarks_cursor.py +++ b/tests/test_bookmarks_cursor.py @@ -11,10 +11,6 @@ class TestSquareIncrementalReplicationCursor(TestSquareBaseParent.TestSquareBase): - @staticmethod - def name(): - return "tap_tester_square_incremental_replication_cursor" - def testable_streams_dynamic(self): all_testable_streams = self.dynamic_data_streams().intersection( self.expected_full_table_streams()).difference( @@ -34,7 +30,7 @@ def testable_streams_static(): @classmethod def tearDownClass(cls): - print("\n\nTEST TEARDOWN\n\n") + LOGGER.info('\n\nTEST TEARDOWN\n\n') def test_run(self): """Instantiate start date according to the desired data set and run the test""" @@ -43,11 +39,14 @@ def test_run(self): self.bookmarks_test(self.testable_streams_dynamic().intersection(self.sandbox_streams())) + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_PROD self.set_environment(self.PRODUCTION) production_testable_streams = self.testable_streams_dynamic().intersection(self.production_streams()) if production_testable_streams: self.bookmarks_test(production_testable_streams) + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_SANDBOX + def bookmarks_test(self, testable_streams): """ Verify for each stream that you can do a sync which records bookmark cursor @@ -56,7 +55,7 @@ def bookmarks_test(self, testable_streams): PREREQUISITE For EACH stream that is interruptable with a bookmark cursor and not another one is replicated there are more than 1 page of data """ - print("\n\nRUNNING {}\n\n".format(self.name())) + LOGGER.info('\n\nRUNNING {}_bookmark_cursor\n\n'.format(self.name())) # Ensure tested streams have existing records stream_to_expected_records_before_removing_first_page = self.create_test_data(testable_streams, self.START_DATE, min_required_num_records_per_stream=self.API_LIMIT) @@ -65,7 +64,7 @@ def bookmarks_test(self, testable_streams): # verify the expected test data exceeds API LIMIT for all testable streams for stream in testable_streams: record_count = len(stream_to_expected_records_before_removing_first_page[stream]) - print("Verifying data is sufficient for stream {}. ".format(stream) + + LOGGER.info('Verifying data is sufficient for stream {}. '.format(stream) + "\tRecord Count: {}\tAPI Limit: {} ".format(record_count, self.API_LIMIT.get(stream))) self.assertGreater(record_count, self.API_LIMIT.get(stream), msg="Pagination not ensured.\n" + @@ -84,7 +83,7 @@ def bookmarks_test(self, testable_streams): ] # Create connection but do not use default start date - conn_id = connections.ensure_connection(self, original_properties=False) + conn_id = connections.ensure_connection(self, original_properties=False, payload_hook=self.preserve_access_token) # run check mode found_catalogs = self.run_and_verify_check_mode(conn_id) diff --git a/tests/test_bookmarks_static.py b/tests/test_bookmarks_static.py index 3603994b..1e4d5799 100644 --- a/tests/test_bookmarks_static.py +++ b/tests/test_bookmarks_static.py @@ -1,20 +1,16 @@ import os import unittest - +import singer import tap_tester.connections as connections import tap_tester.menagerie as menagerie import tap_tester.runner as runner from base import TestSquareBaseParent +LOGGER = singer.get_logger() - -class TestSquareIncrementalReplication(TestSquareBaseParent.TestSquareBase): +class TestSquareIncrementalReplicationStatic(TestSquareBaseParent.TestSquareBase): STATIC_START_DATE = "2020-07-13T00:00:00Z" - @staticmethod - def name(): - return "tap_tester_square_incremental_replication" - def testable_streams_static(self): return self.static_data_streams().difference(self.untestable_streams()) @@ -24,7 +20,7 @@ def testable_streams_dynamic(): @classmethod def tearDownClass(cls): - print("\n\nTEST TEARDOWN\n\n") + LOGGER.info('\n\nTEST TEARDOWN\n\n') def run_sync(self, conn_id): """ @@ -57,9 +53,9 @@ def test_run(self): For EACH stream that is incrementally replicated there are multiple rows of data with different values for the replication key """ - print("\n\nTESTING IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\nTESTING IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) - print("\n\nRUNNING {}\n\n".format(self.name())) + LOGGER.info('\n\nRUNNING {}_bookmark_static\n\n'.format(self.name())) # Instatiate static start date self.START_DATE = self.STATIC_START_DATE @@ -68,7 +64,7 @@ def test_run(self): expected_records_first_sync = self.create_test_data(self.testable_streams_static(), self.START_DATE) # Instantiate connection with default start - conn_id = connections.ensure_connection(self) + conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token) # run in check mode check_job_name = runner.run_check_mode(self, conn_id) diff --git a/tests/test_client.py b/tests/test_client.py index 9aed368e..d5e37e30 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,9 +8,10 @@ import backoff import singer from square.client import Client - +LOGGER = singer.get_logger() LOGGER = singer.get_logger() +REFRESH_TOKEN_BEFORE = 22 typesToKeyMap = { 'ITEM': 'item_data', @@ -38,6 +39,27 @@ def get_batch_token_from_headers(headers): return None +def require_new_access_token(access_token, client): + ''' + Checks if the access token needs to be refreshed + ''' + # If there is no access token, we need to generate a new one + if not access_token: + return True + + authorization = f'Bearer {access_token}' + + with singer.http_request_timer('Check access token expiry'): + response = client.o_auth.retrieve_token_status(authorization) + + if response.is_error(): + error_message = response.errors if response.errors else response.body + LOGGER.error(error_message) + + token_expiry_date = singer.utils.strptime_with_tz(response.body['expires_at']) + now = singer.utils.now() + return (token_expiry_date - now).days <= REFRESH_TOKEN_BEFORE + def log_backoff(details): ''' Logs a backoff retry message @@ -104,30 +126,34 @@ def __init__(self, env): self._client = Client(access_token=self._access_token, environment=self._environment) def _get_access_token(self): - if "TAP_SQUARE_ACCESS_TOKEN" in os.environ.keys(): - LOGGER.info("Using access token from environment, not creating the new") - return os.environ["TAP_SQUARE_ACCESS_TOKEN"] + ''' + Retrieves the access token from the env. If the access token is expired, it will refresh it. + Otherwise, it will return the cached access token. + ''' + access_token = os.getenv('TAP_SQUARE_ACCESS_TOKEN') + client = Client(environment=self._environment) - body = { - 'client_id': self._client_id, - 'client_secret': self._client_secret, - 'grant_type': 'refresh_token', - 'refresh_token': self._refresh_token - } + # Check if the access token needs to be refreshed + if require_new_access_token(access_token, client): + body = { + 'client_id': self._client_id, + 'client_secret': self._client_secret, + 'grant_type': 'refresh_token', + 'refresh_token': self._refresh_token + } - client = Client(environment=self._environment) + with singer.http_request_timer('GET access token'): + result = client.o_auth.obtain_token(body) - with singer.http_request_timer('GET access token'): - result = client.o_auth.obtain_token(body) + if result.is_error(): + error_message = result.errors if result.errors else result.body + LOGGER.info('error_message :-----------: %s',error_message) + raise RuntimeError(error_message) - if result.is_error(): - error_message = result.errors if result.errors else result.body - LOGGER.info("error_message :-----------: %s",error_message) - raise RuntimeError(error_message) + LOGGER.info('Generating new the access token to set in environment....') + os.environ["TAP_SQUARE_ACCESS_TOKEN"] = access_token = result.body['access_token'] - LOGGER.info("Setting the access token in environment....") - os.environ["TAP_SQUARE_ACCESS_TOKEN"] = result.body['access_token'] - return result.body['access_token'] + return access_token ########################################################################## ### V1 INFO @@ -339,11 +365,19 @@ def get_customers(self, start_time, bookmarked_cursor): if bookmarked_cursor: body = { "cursor": bookmarked_cursor, + 'sort': { + 'field': 'CREATED_AT', + 'order': 'ASC' + }, 'limit': 100, } else: body = { "query": { + 'sort': { + 'field': 'CREATED_AT', + 'order': 'ASC' + }, "filter": { "updated_at": { "start_at": start_time @@ -679,6 +713,13 @@ def get_catalog_object(self, obj_id): return response.body.get('object') + def get_customer(self, customer_id): + response = self._client.customers.retrieve_customer(customer_id) + if response.is_error(): + raise RuntimeError(response.errors) + + return response.body.get('customer') + def _create_batch_inventory_adjustment(self, start_date, num_records=1): # Create item object(s) # This is needed to get an item ID in order to perform an item_variation @@ -706,7 +747,7 @@ def _create_batch_inventory_adjustment(self, start_date, num_records=1): 'ignore_unchanged_counts': False, 'idempotency_key': str(uuid.uuid4()) } - LOGGER.info("About to create %s inventory adjustments", len(change_chunk)) + LOGGER.info('About to create %s inventory adjustments', len(change_chunk)) response = self._client.inventory.batch_change_inventory(body) if response.is_error(): LOGGER.error("response had error, body was %s", body) @@ -724,7 +765,7 @@ def create_specific_inventory_adjustment(self, catalog_obj_id, location_id, from 'ignore_unchanged_counts': False, 'idempotency_key': str(uuid.uuid4()) } - LOGGER.info("About to create %s inventory adjustments", len(change)) + LOGGER.info('About to create %s inventory adjustments', len(change)) response = self._client.inventory.batch_change_inventory(body) if response.is_error(): LOGGER.error("response had error, body was %s", body) @@ -764,6 +805,7 @@ def _inventory_adjustment_change(catalog_obj_id, location_id, from_state=None, t def create_refunds(self, start_date, num_records, payment_response=None): refunds = [] + for _ in range(num_records): (created_refund, _) = self.create_refund(start_date, payment_response) refunds += created_refund @@ -792,6 +834,7 @@ def create_refund(self, start_date, payment_response=None): payment_obj = self.get_object_matching_conditions('payments', payment_response.get('id'), start_date=start_date, status=payment_response.get('status'))[0] + payment_id = payment_obj.get('id') payment_amount = payment_obj.get('amount_money').get('amount') upper_limit = 10 if payment_amount > 10 else payment_amount @@ -850,8 +893,8 @@ def create_payment(self, autocomplete=False, source_key='card'): } new_payment = self._client.payments.create_payment(body) if new_payment.is_error(): - print("body: {}".format(body)) - print("response: {}".format(new_payment)) + LOGGER.info('body: {}'.format(body)) + LOGGER.info('response: {}'.format(new_payment)) raise RuntimeError(new_payment.errors) response = new_payment.body.get('payment') @@ -887,11 +930,11 @@ def create_customer(self): } response = self._client.customers.create_customer(body) if response.is_error(): - print("body: {}".format(body)) - print("response: {}".format(response)) + LOGGER.info('body: {}'.format(body)) + LOGGER.info('response: {}'.format(response)) raise RuntimeError(response.errors) - return response.body.get('customer') + return self.get_customer(response.body.get('customer')['id']) def create_modifier_list(self, num_records): objects = [] @@ -1116,8 +1159,8 @@ def _create_location(self): } response = self.post_location(body) if response.is_error(): - print("body: {}".format(body)) - print("response: {}".format(response)) + LOGGER.info('body: {}'.format(body)) + LOGGER.info('response: {}'.format(response)) raise RuntimeError(response.errors) location_id = response.body.get('location')['id'] @@ -1157,7 +1200,7 @@ def _create_orders(self, num_records, start_date): return created_orders def create_shift(self, start_date, end_date, num_records): - employee_id = self.get_or_create_first_found('employees', None)['id'] + team_member_id = self.get_or_create_first_found('team_members', None)['id'] all_location_ids = [location['id'] for location in self.get_all('locations', start_date)] all_shifts = self.get_all('shifts', start_date) @@ -1175,7 +1218,7 @@ def create_shift(self, start_date, end_date, num_records): body = { 'idempotency_key': str(uuid.uuid4()), 'shift': { - 'employee_id': employee_id, + 'team_member_id': team_member_id, 'location_id': location_id, 'start_at': start_at, # This can probably be derived from the test's start date 'end_at': end_at, @@ -1292,7 +1335,7 @@ def update_payment(self, obj_id: str, action=None): if not action: action = random.choice(list(self.PAYMENT_ACTION_TO_STATUS.keys())) - print("PAYMENT UPDATE: status for payment {} change to {} ".format(obj_id, action)) + LOGGER.info('PAYMENT UPDATE: status for payment {} change to {} '.format(obj_id, action)) if action == 'cancel': response = self._client.payments.cancel_payment(obj_id) if response.is_error(): @@ -1387,7 +1430,7 @@ def _update_inventory_adjustment(self, catalog_obj): } response = self._client.inventory.batch_change_inventory(body) if response.is_error(): - print(response.body.get('errors')) + LOGGER.error(response.body.get('errors')) raise RuntimeError(response.errors) all_counts = response.body.get('counts') diff --git a/tests/test_config.py b/tests/test_config.py deleted file mode 100644 index ca7e4bb1..00000000 --- a/tests/test_config.py +++ /dev/null @@ -1,51 +0,0 @@ -import os - -potential_paths = [ - 'tests/', - '../tests/' - 'tap-square/tests/', - '../tap-square/tests/', -] - - -def go_to_tests_directory(): - for path in potential_paths: - if os.path.exists(path): - os.chdir(path) - return os.getcwd() - raise NotImplementedError("This check cannot run from {}".format(os.getcwd())) - -########################################################################## -### TEST -########################################################################## - - -print("Acquiring path to tests directory.") -cwd = go_to_tests_directory() - -print("Reading in filenames from tests directory.") -files_in_dir = os.listdir(cwd) - -print("Dropping files that are not of the form 'test_.py'.") -test_files_in_dir = [fn for fn in files_in_dir if fn.startswith('test_') and fn.endswith('.py')] - -print("Dropping test_client.py from test files.") -test_files_in_dir.remove('test_client.py') - -print("Files found: {}".format(test_files_in_dir)) - -print("Reading contents of circle config.") -with open(cwd + "/../.circleci/config.yml", "r") as config: - contents = config.read() - -print("Parsing circle config for run blocks.") -runs = contents.replace(' ', '').replace('\n', '').split('-run:') - -print("Verify all test files are executed in circle...") -tests_not_found = set(test_files_in_dir) -for filename in test_files_in_dir: - print("\tVerifying {} is running in circle.".format(filename)) - if any([filename in run for run in runs]): - tests_not_found.remove(filename) -assert tests_not_found == set(), "The following tests are not running in circle:\t{}".format(tests_not_found) -print("\t SUCCESS: All tests are running in circle.") diff --git a/tests/test_default_start_date.py b/tests/test_default_start_date.py index 929be3ff..d5d365ba 100644 --- a/tests/test_default_start_date.py +++ b/tests/test_default_start_date.py @@ -1,19 +1,16 @@ from datetime import datetime as dt from datetime import timedelta - +import singer import tap_tester.connections as connections from base import TestSquareBaseParent, DataType - +LOGGER = singer.get_logger() class TestSquareStartDateDefault(TestSquareBaseParent.TestSquareBase): """ Test we can perform a successful sync for all streams with a start date of 1 year ago and older. This makes up for the time partions missed by `test_start_date.py`. """ - @staticmethod - def name(): - return "tap_tester_start_date_default" def testable_streams_dynamic(self): return self.dynamic_data_streams() @@ -29,15 +26,15 @@ def run_standard_sync(self, environment, data_type, select_all_fields=True): Select all fields or no fields based on the select_all_fields param. Run a sync. """ - conn_id = connections.ensure_connection(self, original_properties=False) + conn_id = connections.ensure_connection(self, original_properties=False, payload_hook=self.preserve_access_token) found_catalogs = self.run_and_verify_check_mode(conn_id) streams_to_select = self.testable_streams(environment, data_type) - print("\n\nRUNNING {}".format(self.name())) - print("WITH STREAMS: {}".format(streams_to_select)) - print("WITH START DATE: {}\n\n".format(self.START_DATE)) + LOGGER.info('\n\nRUNNING {}_default_start_date'.format(self.name())) + LOGGER.info('WITH STREAMS: {}'.format(streams_to_select)) + LOGGER.info('WITH START DATE: {}\n\n'.format(self.START_DATE)) self.perform_and_verify_table_and_field_selection( conn_id, found_catalogs, streams_to_select, select_all_fields=select_all_fields @@ -56,9 +53,11 @@ def test_run(self): self.set_environment(self.SANDBOX) self.default_start_date_test(DataType.DYNAMIC, self.testable_streams_dynamic().intersection(self.sandbox_streams())) self.default_start_date_test(DataType.STATIC, self.testable_streams_static().intersection(self.sandbox_streams())) + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_PROD self.set_environment(self.PRODUCTION) self.default_start_date_test(DataType.DYNAMIC, self.testable_streams_dynamic().intersection(self.production_streams())) self.default_start_date_test(DataType.STATIC, self.testable_streams_static().intersection(self.production_streams())) + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_SANDBOX def default_start_date_test(self, data_type, testable_streams): streams_without_data = self.untestable_streams() diff --git a/tests/test_discovery.py b/tests/test_discovery.py index abee04b2..7284c73c 100644 --- a/tests/test_discovery.py +++ b/tests/test_discovery.py @@ -2,19 +2,15 @@ Test tap discovery """ import re - +import singer from tap_tester import menagerie, connections, runner from base import TestSquareBaseParent - +LOGGER = singer.get_logger() class DiscoveryTest(TestSquareBaseParent.TestSquareBase): """ Test the tap discovery """ - @staticmethod - def name(): - return "tap_tester_square_discovery_test" - @staticmethod def testable_streams_dynamic(): # Unused for discovery testing @@ -56,7 +52,8 @@ def discovery_test(self): are given the inclusion of automatic (metadata and annotated schema). • verify that all other fields have inclusion of available (metadata and schema) """ - conn_id = connections.ensure_connection(self) + LOGGER.info('\n\nRUNNING {}_discovery'.format(self.name())) + conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token) check_job_name = runner.run_check_mode(self, conn_id) #verify check exit codes diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 578ee33b..75d8c448 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -15,10 +15,6 @@ class TestSquarePagination(TestSquareBaseParent.TestSquareBase): """Test that we are paginating for streams when exceeding the API record limit of a single query""" - @staticmethod - def name(): - return "tap_tester_square_pagination_test" - def testable_streams_dynamic(self): return self.dynamic_data_streams().difference(self.untestable_streams()) @@ -29,43 +25,46 @@ def testable_streams_static(self): @classmethod def tearDownClass(cls): - cls.set_environment(cls, cls.SANDBOX) + cls.set_environment(cls(), cls.SANDBOX) cleanup = {'categories': 10000} client = TestClient(env=os.environ['TAP_SQUARE_ENVIRONMENT']) for stream, limit in cleanup.items(): - print("Checking if cleanup is required.") + LOGGER.info('Checking if cleanup is required.') all_records = client.get_all(stream, start_date=cls.STATIC_START_DATE) all_ids = [rec.get('id') for rec in all_records if not rec.get('is_deleted')] if len(all_ids) > limit / 2: chunk = int(len(all_ids) - (limit / 2)) - print("Cleaning up {} excess records".format(chunk)) + LOGGER.info('Cleaning up {} excess records'.format(chunk)) client.delete_catalog(all_ids[:chunk]) def test_run(self): """Instantiate start date according to the desired data set and run the test""" - print("\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.START_DATE = self.get_properties().get('start_date') self.TESTABLE_STREAMS = self.testable_streams_dynamic().difference(self.production_streams()) self.pagination_test() - print("\n\n-- SKIPPING -- TESTING WITH STATIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\n-- SKIPPING -- TESTING WITH STATIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.TESTABLE_STREAMS = self.testable_streams_static().difference(self.production_streams()) self.assertEqual(set(), self.TESTABLE_STREAMS, msg="Testable streams exist for this category.") - print("\tThere are no testable streams.") + LOGGER.info('\tThere are no testable streams.') + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_PROD self.set_environment(self.PRODUCTION) - print("\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.START_DATE = self.get_properties().get('start_date') self.TESTABLE_STREAMS = self.testable_streams_dynamic().difference(self.sandbox_streams()) self.pagination_test() + TestSquarePagination.test_name = "tap_tester_sandbox_square_pagination_test" - print("\n\n-- SKIPPING -- TESTING WITH STATIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\n-- SKIPPING -- TESTING WITH STATIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.TESTABLE_STREAMS = self.testable_streams_static().difference(self.sandbox_streams()) self.assertEqual(set(), self.TESTABLE_STREAMS, msg="Testable streams exist for this category.") - print("\tThere are no testable streams.") + LOGGER.info('\tThere are no testable streams.') + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_SANDBOX def pagination_test(self): """ @@ -77,22 +76,22 @@ def pagination_test(self): fetch of data. For instance if you have a limit of 250 records ensure that 251 (or more) records have been posted for that stream. """ - print("\n\nRUNNING {}".format(self.name())) - print("WITH STREAMS: {}\n\n".format(self.TESTABLE_STREAMS)) + LOGGER.info('\n\nRUNNING {}_pagination'.format(self.name())) + LOGGER.info('WITH STREAMS: {}\n\n'.format(self.TESTABLE_STREAMS)) expected_records = self.create_test_data(self.TESTABLE_STREAMS, self.START_DATE, min_required_num_records_per_stream=self.API_LIMIT) # verify the expected test data exceeds API LIMIT for all testable streams for stream in self.TESTABLE_STREAMS: record_count = len(expected_records[stream]) - print("Verifying data is sufficient for stream {}. ".format(stream) + + LOGGER.info('Verifying data is sufficient for stream {}. '.format(stream) + "\tRecord Count: {}\tAPI Limit: {} ".format(record_count, self.API_LIMIT.get(stream))) self.assertGreater(record_count, self.API_LIMIT.get(stream), msg="Pagination not ensured.\n" + "{} does not have sufficient data in expecatations.\n ".format(stream)) # Create connection but do not use default start date - conn_id = connections.ensure_connection(self, original_properties=False) + conn_id = connections.ensure_connection(self, original_properties=False, payload_hook=self.preserve_access_token) # run check mode found_catalogs = self.run_and_verify_check_mode(conn_id) diff --git a/tests/test_start_date.py b/tests/test_start_date.py index 20f0c5b3..0f614762 100644 --- a/tests/test_start_date.py +++ b/tests/test_start_date.py @@ -17,10 +17,6 @@ class TestSquareStartDate(TestSquareBaseParent.TestSquareBase): START_DATE_1 = "" START_DATE_2 = "" - @staticmethod - def name(): - return "tap_tester_square_start_date_test" - def testable_streams_dynamic(self): return self.dynamic_data_streams().difference(self.untestable_streams()) @@ -41,7 +37,7 @@ def timedelta_formatted(self, dtime, days=0): def test_run(self): """Instantiate start date according to the desired data set and run the test""" - print("\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.START_DATE = self.get_properties().get('start_date') # Initialize start_date state to make assertions self.START_DATE_1 = self.START_DATE self.START_DATE_2 = dt.strftime(dt.utcnow(), self.START_DATE_FORMAT) @@ -49,24 +45,25 @@ def test_run(self): self.start_date_test(self.get_environment(), DataType.DYNAMIC) # Locations does not respect start date and it's the only static data type (see above) - print("\n\n-- SKIPPING -- TESTING WITH STATIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\n-- SKIPPING -- TESTING WITH STATIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.TESTABLE_STREAMS = self.testable_streams_static().difference(self.production_streams()) self.assertEqual(set(), self.TESTABLE_STREAMS, msg="Testable streams exist for this category.") - print("\tThere are no testable streams.") + LOGGER.info('\tThere are no testable streams.') + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_PROD self.set_environment(self.PRODUCTION) - print("\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}".format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) + LOGGER.info('\n\nTESTING WITH DYNAMIC DATA IN SQUARE_ENVIRONMENT: {}'.format(os.getenv('TAP_SQUARE_ENVIRONMENT'))) self.START_DATE = self.get_properties().get('start_date') self.START_DATE_1 = self.START_DATE self.START_DATE_2 = dt.strftime(dt.utcnow(), self.START_DATE_FORMAT) self.TESTABLE_STREAMS = self.testable_streams_dynamic().difference(self.sandbox_streams()) - self.start_date_test(self.get_environment(), DataType.DYNAMIC) + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_SANDBOX def start_date_test(self, environment, data_type): - print("\n\nRUNNING {}".format(self.name())) - print("WITH STREAMS: {}\n\n".format(self.TESTABLE_STREAMS)) + LOGGER.info('\n\nRUNNING {}_start_date'.format(self.name())) + LOGGER.info('WITH STREAMS: {}\n\n'.format(self.TESTABLE_STREAMS)) self.create_test_data(self.TESTABLE_STREAMS, self.START_DATE_1, self.START_DATE_2) @@ -75,7 +72,7 @@ def start_date_test(self, environment, data_type): ########################################################################## # instantiate connection - conn_id = connections.ensure_connection(self) + conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token) # run check mode found_catalogs = self.run_and_verify_check_mode(conn_id) @@ -91,7 +88,7 @@ def start_date_test(self, environment, data_type): replicated_row_count_1 = sum(first_record_count_by_stream.values()) self.assertGreater(replicated_row_count_1, 0, msg="failed to replicate any data: {}".format(first_record_count_by_stream)) - print("total replicated row count: {}".format(replicated_row_count_1)) + LOGGER.info('total replicated row count: {}'.format(replicated_row_count_1)) synced_records_1 = runner.get_records_from_target_output() ########################################################################## @@ -99,14 +96,14 @@ def start_date_test(self, environment, data_type): ########################################################################## self.START_DATE = self.START_DATE_2 - print("REPLICATION START DATE CHANGE: {} ===>>> {} ".format(self.START_DATE_1, self.START_DATE_2)) + LOGGER.info('REPLICATION START DATE CHANGE: {} ===>>> {} '.format(self.START_DATE_1, self.START_DATE_2)) ########################################################################## ### Second Sync ########################################################################## # create a new connection with the new start_date - conn_id = connections.ensure_connection(self, original_properties=False) + conn_id = connections.ensure_connection(self, original_properties=False, payload_hook=self.preserve_access_token) # run check mode found_catalogs = self.run_and_verify_check_mode(conn_id) @@ -122,7 +119,7 @@ def start_date_test(self, environment, data_type): replicated_row_count_2 = sum(record_count_by_stream_2.values()) self.assertGreater(replicated_row_count_2, 0, msg="failed to replicate any data") - print("total replicated row count: {}".format(replicated_row_count_2)) + LOGGER.info('total replicated row count: {}'.format(replicated_row_count_2)) synced_records_2 = runner.get_records_from_target_output() diff --git a/tests/test_sync_canary.py b/tests/test_sync_canary.py index 36e63ace..b3dc55e7 100644 --- a/tests/test_sync_canary.py +++ b/tests/test_sync_canary.py @@ -1,13 +1,10 @@ import tap_tester.connections as connections - +import singer from base import TestSquareBaseParent, DataType - +LOGGER = singer.get_logger() class TestSyncCanary(TestSquareBaseParent.TestSquareBase): """Test that sync code gets exercised for all streams regardless if we can't create data. Validates scopes, authorizations, sync code that can't yet be tested end-to-end.""" - @staticmethod - def name(): - return "tap_tester_sync_canary" def testable_streams_dynamic(self): return self.dynamic_data_streams() @@ -26,14 +23,14 @@ def run_standard_sync(self, environment, data_type, select_all_fields=True): Select all fields or no fields based on the select_all_fields param. Run a sync. """ - conn_id = connections.ensure_connection(self) + conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token) found_catalogs = self.run_and_verify_check_mode(conn_id) streams_to_select = self.testable_streams(environment, data_type) - print("\n\nRUNNING {}".format(self.name())) - print("WITH STREAMS: {}\n\n".format(streams_to_select)) + LOGGER.info('\n\nRUNNING {}_sync_canary'.format(self.name())) + LOGGER.info('WITH STREAMS: {}\n\n'.format(streams_to_select)) self.perform_and_verify_table_and_field_selection( conn_id, found_catalogs, streams_to_select, select_all_fields=select_all_fields @@ -49,6 +46,8 @@ def test_run(self): self.run_standard_sync(self.get_environment(), DataType.DYNAMIC) self.run_standard_sync(self.get_environment(), DataType.STATIC) + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_PROD self.set_environment(self.PRODUCTION) self.run_standard_sync(self.get_environment(), DataType.DYNAMIC) self.run_standard_sync(self.get_environment(), DataType.STATIC) + TestSquareBaseParent.TestSquareBase.test_name = self.TEST_NAME_SANDBOX \ No newline at end of file diff --git a/tests/unittests/test_refresh_access_token.py b/tests/unittests/test_refresh_access_token.py new file mode 100644 index 00000000..2cc6f01d --- /dev/null +++ b/tests/unittests/test_refresh_access_token.py @@ -0,0 +1,154 @@ +import unittest +from unittest.mock import MagicMock, patch +from datetime import datetime, timedelta +from singer import utils +from tap_square.client import require_new_access_token, SquareClient + +REFRESH_TOKEN_BEFORE = 22 + + +class TestRequireNewAccessToken(unittest.TestCase): + def setUp(self): + self.client = MagicMock() + self.access_token = 'test_access_token' + + def test_no_access_token(self): + '''Return true when no access token is provided.''' + result = require_new_access_token(None, self.client) + self.assertTrue(result) + + @patch('singer.http_request_timer') + def test_valid_access_token(self, mock_timer): + '''Return false when the token generated in less than 7 days.''' + token_expiry = utils.strftime(utils.now() + timedelta(days=26)) + self.client.o_auth.retrieve_token_status.return_value = MagicMock( + is_error=MagicMock(return_value=False), + body={'expires_at': token_expiry}, + ) + result = require_new_access_token(self.access_token, self.client) + self.assertFalse(result) + + @patch('singer.http_request_timer') + def test_access_token_older_than_7_days(self, mock_timer): + '''Return true when the token is older than 7 days.''' + token_expiry = utils.strftime(utils.now() + timedelta(days=20)) + self.client.o_auth.retrieve_token_status.return_value = MagicMock( + is_error=MagicMock(return_value=False), + body={'expires_at': token_expiry}, + ) + result = require_new_access_token(self.access_token, self.client) + self.assertTrue(result) + + @patch('singer.http_request_timer') + def test_almost_expired_access_token(self, mock_timer): + '''Return false when the token is created exactly 7 days ago''' + token_expiry = utils.strftime(utils.now() + timedelta(days=24)) + self.client.o_auth.retrieve_token_status.return_value = MagicMock( + is_error=MagicMock(return_value=False), + body={'expires_at': token_expiry}, + ) + result = require_new_access_token(self.access_token, self.client) + self.assertFalse(result) + + @patch('singer.http_request_timer') + def test_api_error(self, mock_timer): + '''Return false when the API returns an error.''' + self.client.o_auth.retrieve_token_status.return_value = MagicMock( + is_error=MagicMock(return_value=True), errors='API error' + ) + result = require_new_access_token(self.access_token, self.client) + self.assertTrue(result) + + +class TestGetAccessToken(unittest.TestCase): + def setUp(self): + self.config_path = '/path/to/config.json' + + self.config = { + 'client_id': 'test_client_id', + 'client_secret': 'test_client_secret', + 'refresh_token': 'test_refresh_token', + 'environment': 'sandbox', + 'access_token': 'cached_token', + } + + @patch('tap_square.client.require_new_access_token') + @patch('tap_square.client.write_config') + @patch('singer.http_request_timer') + def test_get_access_token_no_refresh_needed( + self, mock_http_timer, mock_write_config, mock_require_new_access_token + ): + ''' + Test the case where the access token does not need to be refreshed + ''' + mock_require_new_access_token.return_value = False + + _instance = SquareClient(self.config, self.config_path) + + # Assertions + self.assertEqual(_instance._access_token, 'cached_token') + mock_write_config.assert_not_called() + + @patch('tap_square.client.Client') + @patch('tap_square.client.require_new_access_token') + @patch('tap_square.client.write_config') + @patch('singer.http_request_timer') + def test_get_access_token_refresh_needed_success( + self, + mock_http_timer, + mock_write_config, + mock_require_new_access_token, + mock_client, + ): + ''' + Test the case where the access token needs to be refreshed and the API returns a new token + ''' + mock_require_new_access_token.return_value = True + + # Mock the Client's o_auth.obtain_token method + mock_client_instance = mock_client.return_value + mock_client_instance.o_auth.obtain_token.return_value = MagicMock( + is_error=MagicMock(return_value=False), + body={'access_token': 'new_token', 'refresh_token': 'new_refresh_token'}, + ) + + _instance = SquareClient(self.config, self.config_path) + + # Assertions + self.assertEqual(_instance._access_token, 'new_token') + mock_write_config.assert_called_once_with( + { + 'client_id': 'test_client_id', + 'client_secret': 'test_client_secret', + 'refresh_token': 'test_refresh_token', + 'environment': 'sandbox', + 'access_token': 'cached_token', + }, + '/path/to/config.json', + {'access_token': 'new_token', 'refresh_token': 'new_refresh_token'}, + ) + mock_client_instance.o_auth.obtain_token.assert_called_once() + + @patch('tap_square.client.Client') + @patch('tap_square.client.require_new_access_token') + @patch('singer.http_request_timer') + def test_get_access_token_refresh_needed_error( + self, mock_http_timer, mock_require_new_access_token, mock_client + ): + ''' + Test the case where the API returns an error while refreshing the access token + ''' + mock_require_new_access_token.return_value = True + + # Mock the Client's o_auth.obtain_token method to return an error + mock_client_instance = mock_client.return_value + mock_client_instance.o_auth.obtain_token.return_value = MagicMock( + is_error=MagicMock(return_value=True), errors=['Invalid credentials'] + ) + + # Call the method and check for exception + with self.assertRaises(RuntimeError) as context: + SquareClient(self.config, self.config_path) + + self.assertIn('Invalid credentials', str(context.exception)) + mock_client_instance.o_auth.obtain_token.assert_called_once() diff --git a/tests/unittests/test_sync.py b/tests/unittests/test_sync.py index b74c92d0..ceb55f6a 100644 --- a/tests/unittests/test_sync.py +++ b/tests/unittests/test_sync.py @@ -115,7 +115,7 @@ def test_search_team_members_sync(self, mocked_access_token, mocked_get_v2_objec """ expected_return_value = expected_return_state - team_members_obj = TeamMembers(SquareClient(mock_config)) + team_members_obj = TeamMembers(SquareClient(mock_config, 'config_path')) return_value = team_members_obj.sync( {"currently_syncing": "team_members"}, stream_schema,