Skip to content

Commit

Permalink
TDL-26718: Implement token chaining (#123)
Browse files Browse the repository at this point in the history
* Initial commit

* Enhance token chaining

* Fix unit tests

* Fix pylint issue

* Add unit tests

* Fix start_date tests

* Fix tap-tester issues

* Enhance tap testers

* Fix pylint

* Fix bookmark cursor tests for inventory

* Remove debug log

* Run tests for all streams

* Remove unnecessary changes

* Address review comments

* Remove returned_quantities field from all_fields test

* Resolve review comments

* Replace double quotes with single quotes
  • Loading branch information
prijendev authored Jan 10, 2025
1 parent d724866 commit e0ac6c9
Show file tree
Hide file tree
Showing 18 changed files with 450 additions and 236 deletions.
2 changes: 1 addition & 1 deletion tap_square/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def main():
if args.discover:
write_catalog(catalog)
else:
sync(args.config, args.state, catalog)
sync(args.config, args.config_path, args.state, catalog)

if __name__ == '__main__':
main()
96 changes: 74 additions & 22 deletions tap_square/client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from datetime import timedelta
import urllib.parse
import os
import json
import requests
from square.client import Client
from singer import utils
import singer
import requests
import backoff


LOGGER = singer.get_logger()

REFRESH_TOKEN_BEFORE = 22

def get_batch_token_from_headers(headers):
link = headers.get('link')
Expand Down Expand Up @@ -39,43 +39,91 @@ def log_backoff(details):
LOGGER.warning('Error receiving data from square. Sleeping %.1f seconds before trying again', details['wait'])


def write_config(config, config_path, data):
'''
Updates the provided filepath with json format of the `data` object
'''
config.update(data)
with open(config_path, "w") as tap_config:
json.dump(config, tap_config, indent=2)
return config


def require_new_access_token(access_token, client):
'''
Checks if the access token needs to be refreshed
'''
# If there is no access token, we need to generate a new one
if not access_token:
return True

authorization = f"Bearer {access_token}"

with singer.http_request_timer('Check access token expiry'):
response = client.o_auth.retrieve_token_status(authorization)

if response.is_error():
error_message = response.errors if response.errors else response.body
LOGGER.error(error_message)
return True

# Parse the token expiry date
token_expiry_date = singer.utils.strptime_with_tz(response.body['expires_at'])
now = utils.now()
return (token_expiry_date - now).days <= REFRESH_TOKEN_BEFORE


class RetryableError(Exception):
pass


class SquareClient():
def __init__(self, config):
def __init__(self, config, config_path):
self._refresh_token = config['refresh_token']
self._client_id = config['client_id']
self._client_secret = config['client_secret']

self._environment = 'sandbox' if config.get('sandbox') == 'true' else 'production'

self._access_token = self._get_access_token()
self._access_token = self._get_access_token(config, config_path)
self._client = Client(access_token=self._access_token, environment=self._environment)

def _get_access_token(self):
if "TAP_SQUARE_ACCESS_TOKEN" in os.environ.keys():
LOGGER.info("Using access token from environment, not creating the new one")
return os.environ["TAP_SQUARE_ACCESS_TOKEN"]
def _get_access_token(self, config, config_path):
'''
Retrieves the access token from the config file. If the access token is expired, it will refresh it.
Otherwise, it will return the cached access token.
'''
access_token = config.get("access_token")
client = Client(environment=self._environment)

body = {
'client_id': self._client_id,
'client_secret': self._client_secret,
'grant_type': 'refresh_token',
'refresh_token': self._refresh_token
}
# Check if the access token needs to be refreshed
if require_new_access_token(access_token, client):
LOGGER.info('Refreshing access token...')
body = {
'client_id': self._client_id,
'client_secret': self._client_secret,
'grant_type': 'refresh_token',
'refresh_token': self._refresh_token
}

client = Client(environment=self._environment)
with singer.http_request_timer('GET access token'):
result = client.o_auth.obtain_token(body)

with singer.http_request_timer('GET access token'):
result = client.o_auth.obtain_token(body)
if result.is_error():
error_message = result.errors if result.errors else result.body
raise RuntimeError(error_message)

if result.is_error():
error_message = result.errors if result.errors else result.body
raise RuntimeError(error_message)
access_token = result.body['access_token']
write_config(
config,
config_path,
{
'access_token': access_token,
'refresh_token': result.body['refresh_token'],
},
)

return result.body['access_token']
return access_token

@staticmethod
@backoff.on_exception(
Expand Down Expand Up @@ -158,6 +206,10 @@ def get_customers(self, start_time, end_time):
'end_at': end_time # Exclusive on end_at
}
},
'sort': {
'field': 'CREATED_AT',
'order': 'ASC'
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions tap_square/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

LOGGER = singer.get_logger()

def sync(config, state, catalog): # pylint: disable=too-many-statements
client = SquareClient(config)
def sync(config, config_path, state, catalog): # pylint: disable=too-many-statements
client = SquareClient(config, config_path)

with Transformer() as transformer:
for stream in catalog.get_selected_streams(state):
Expand Down
50 changes: 43 additions & 7 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,19 @@ class TestSquareBase(ABC, TestCase):
STATIC_START_DATE = "2020-07-13T00:00:00Z"
START_DATE = ""
PRODUCTION_ONLY_STREAMS = {'bank_accounts', 'payouts'}
TEST_NAME_SANDBOX = 'tap_tester_square_sandbox_tests'
TEST_NAME_PROD = 'tap_tester_square_prod_tests'
test_name = TEST_NAME_SANDBOX

DEFAULT_BATCH_LIMIT = 1000
API_LIMIT = {
'items': DEFAULT_BATCH_LIMIT,
'inventories': 100,
'inventories': DEFAULT_BATCH_LIMIT,
'categories': DEFAULT_BATCH_LIMIT,
'discounts': DEFAULT_BATCH_LIMIT,
'taxes': DEFAULT_BATCH_LIMIT,
'cash_drawer_shifts': DEFAULT_BATCH_LIMIT,
'locations': None, # Api does not accept a cursor and documents no limit, see https://developer.squareup.com/reference/square/locations/list-locations
'roles': 100,
'refunds': 100,
'payments': 100,
'payouts': 100,
Expand Down Expand Up @@ -83,15 +85,37 @@ def get_type():
def tap_name():
return "tap-square"

@staticmethod
def name():
return TestSquareBaseParent.TestSquareBase.test_name

def set_environment(self, env):
"""
Change the Square App Environmnet.
Requires re-instatiating TestClient and setting env var.
"""
os.environ['TAP_SQUARE_ENVIRONMENT'] = env
self.set_access_token_in_env()
self.client = TestClient(env=env)
self.SQUARE_ENVIRONMENT = env

def set_access_token_in_env(self):
'''
Fetch the access token from the existing connection and set it in the env.
This is used to avoid rate limiting issues when running tests.
'''
existing_connections = connections.fetch_existing_connections(self)
if not existing_connections:
os.environ.pop('TAP_SQUARE_ACCESS_TOKEN', None)
return

conn_with_creds = connections.fetch_existing_connection_with_creds(existing_connections[0]['id'])
access_token = conn_with_creds['credentials'].get('access_token')
if not access_token:
LOGGER.warning('No access token found in env')
else:
os.environ['TAP_SQUARE_ACCESS_TOKEN'] = access_token

@staticmethod
def get_environment():
return os.environ['TAP_SQUARE_ENVIRONMENT']
Expand All @@ -116,12 +140,23 @@ def get_credentials():
'refresh_token': os.getenv('TAP_SQUARE_REFRESH_TOKEN') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_REFRESH_TOKEN'),
'client_id': os.getenv('TAP_SQUARE_APPLICATION_ID') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_APPLICATION_ID'),
'client_secret': os.getenv('TAP_SQUARE_APPLICATION_SECRET') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_APPLICATION_SECRET'),
'access_token': os.environ['TAP_SQUARE_ACCESS_TOKEN']
}
else:
raise Exception("Square Environment: {} is not supported.".format(environment))

return creds

@staticmethod
def preserve_access_token(existing_conns, payload):
'''This method is used get the access token from an existing refresh token'''
if not existing_conns:
return payload

conn_with_creds = connections.fetch_existing_connection_with_creds(existing_conns[0]['id'])
payload['properties']['access_token'] = conn_with_creds['credentials'].get('access_token')
return payload

def expected_check_streams(self):
return set(self.expected_metadata().keys()).difference(set())

Expand Down Expand Up @@ -490,7 +525,8 @@ def create_test_data(self, testable_streams, start_date, start_date_2=None, min_
else:
raise NotImplementedError("created_records unknown type: {}".format(created_records))

print("Adjust expectations for stream: {}".format(stream))
stream_to_expected_records[stream] = self.client.get_all(stream, start_date)
LOGGER.info('Adjust expectations for stream: {}'.format(stream))
self.modify_expected_records(stream_to_expected_records[stream])

return stream_to_expected_records
Expand Down Expand Up @@ -546,7 +582,7 @@ def run_and_verify_check_mode(self, conn_id):
found_catalog_names = found_catalog_names - {'settlements'}
diff = self.expected_check_streams().symmetric_difference(found_catalog_names)
self.assertEqual(len(diff), 0, msg="discovered schemas do not match: {}".format(diff))
print("discovered schemas are OK")
LOGGER.info("discovered schemas are OK")

return found_catalogs

Expand All @@ -568,7 +604,7 @@ def perform_and_verify_table_and_field_selection(self, conn_id, found_catalogs,
catalog_entry = menagerie.get_annotated_schema(conn_id, cat['stream_id'])
# Verify all testable streams are selected
selected = catalog_entry.get('annotated-schema').get('selected')
print("Validating selection on {}: {}".format(cat['stream_name'], selected))
LOGGER.info('Validating selection on {}: {}'.format(cat['stream_name'], selected))
if cat['stream_name'] not in streams_to_select:
self.assertFalse(selected, msg="Stream selected, but not testable.")
continue # Skip remaining assertions if we aren't selecting this stream
Expand All @@ -578,7 +614,7 @@ def perform_and_verify_table_and_field_selection(self, conn_id, found_catalogs,
# Verify all fields within each selected stream are selected
for field, field_props in catalog_entry.get('annotated-schema').get('properties').items():
field_selected = field_props.get('selected')
print("\tValidating selection on {}.{}: {}".format(cat['stream_name'], field, field_selected))
LOGGER.info('\tValidating selection on {}.{}: {}'.format(cat['stream_name'], field, field_selected))
self.assertTrue(field_selected, msg="Field not selected.")
else:
# Verify only automatic fields are selected
Expand Down Expand Up @@ -662,7 +698,7 @@ def assertRecordsEqual(self, stream, expected_record, sync_record):
Certain Square streams cannot be compared directly with assertDictEqual().
So we handle that logic here.
"""
if stream not in ['refunds', 'orders', 'customers']:
if stream not in ['refunds', 'orders', 'customers', 'locations']:
expected_record.pop('created_at', None)
if stream == 'payments':
self.assertDictEqualWithOffKeys(expected_record, sync_record, {'updated_at'})
Expand Down
Loading

0 comments on commit e0ac6c9

Please sign in to comment.