Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TDL-26718: Implement token chaining #123

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tap_square/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def main():
if args.discover:
write_catalog(catalog)
else:
sync(args.config, args.state, catalog)
sync(args.config, args.config_path, args.state, catalog)

if __name__ == '__main__':
main()
96 changes: 74 additions & 22 deletions tap_square/client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from datetime import timedelta
import urllib.parse
import os
import json
import requests
from square.client import Client
from singer import utils
import singer
import requests
import backoff


LOGGER = singer.get_logger()

REFRESH_TOKEN_BEFORE = 22

def get_batch_token_from_headers(headers):
link = headers.get('link')
Expand Down Expand Up @@ -39,43 +39,91 @@ def log_backoff(details):
LOGGER.warning('Error receiving data from square. Sleeping %.1f seconds before trying again', details['wait'])


def write_config(config, config_path, data):
"""
Updates the provided filepath with json format of the `data` object
"""
prijendev marked this conversation as resolved.
Show resolved Hide resolved
config.update(data)
with open(config_path, "w") as tap_config:
json.dump(config, tap_config, indent=2)
return config


def require_new_access_token(access_token, client):
"""
Checks if the access token needs to be refreshed
"""
# If there is no access token, we need to generate a new one
if not access_token:
return True

authorization = f"Bearer {access_token}"

with singer.http_request_timer("Check access token expiry"):
response = client.o_auth.retrieve_token_status(authorization)

if response.is_error():
error_message = response.errors if response.errors else response.body
LOGGER.error(error_message)
return True

# Parse the token expiry date
token_expiry_date = singer.utils.strptime_with_tz(response.body["expires_at"])
now = utils.now()
return (token_expiry_date - now).days <= REFRESH_TOKEN_BEFORE


class RetryableError(Exception):
pass


class SquareClient():
def __init__(self, config):
def __init__(self, config, config_path):
self._refresh_token = config['refresh_token']
self._client_id = config['client_id']
self._client_secret = config['client_secret']

self._environment = 'sandbox' if config.get('sandbox') == 'true' else 'production'

self._access_token = self._get_access_token()
self._access_token = self._get_access_token(config, config_path)
self._client = Client(access_token=self._access_token, environment=self._environment)

def _get_access_token(self):
if "TAP_SQUARE_ACCESS_TOKEN" in os.environ.keys():
LOGGER.info("Using access token from environment, not creating the new one")
return os.environ["TAP_SQUARE_ACCESS_TOKEN"]
def _get_access_token(self, config, config_path):
"""
Retrieves the access token from the config file. If the access token is expired, it will refresh it.
Otherwise, it will return the cached access token.
"""
access_token = config.get("access_token")
client = Client(environment=self._environment)

body = {
'client_id': self._client_id,
'client_secret': self._client_secret,
'grant_type': 'refresh_token',
'refresh_token': self._refresh_token
}
# Check if the access token needs to be refreshed
if require_new_access_token(access_token, client):
LOGGER.info("Refreshing access token...")
body = {
'client_id': self._client_id,
'client_secret': self._client_secret,
'grant_type': 'refresh_token',
'refresh_token': self._refresh_token
}

client = Client(environment=self._environment)
with singer.http_request_timer('GET access token'):
result = client.o_auth.obtain_token(body)

with singer.http_request_timer('GET access token'):
result = client.o_auth.obtain_token(body)
if result.is_error():
error_message = result.errors if result.errors else result.body
raise RuntimeError(error_message)

if result.is_error():
error_message = result.errors if result.errors else result.body
raise RuntimeError(error_message)
access_token = result.body['access_token']
write_config(
config,
config_path,
{
"access_token": access_token,
"refresh_token": result.body["refresh_token"],
},
)

return result.body['access_token']
return access_token

@staticmethod
@backoff.on_exception(
Expand Down Expand Up @@ -158,6 +206,10 @@ def get_customers(self, start_time, end_time):
'end_at': end_time # Exclusive on end_at
}
},
"sort": {
"field": "CREATED_AT",
"order": "ASC"
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions tap_square/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

LOGGER = singer.get_logger()

def sync(config, state, catalog): # pylint: disable=too-many-statements
client = SquareClient(config)
def sync(config, config_path, state, catalog): # pylint: disable=too-many-statements
client = SquareClient(config, config_path)

with Transformer() as transformer:
for stream in catalog.get_selected_streams(state):
Expand Down
50 changes: 43 additions & 7 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,19 @@ class TestSquareBase(ABC, TestCase):
STATIC_START_DATE = "2020-07-13T00:00:00Z"
START_DATE = ""
PRODUCTION_ONLY_STREAMS = {'bank_accounts', 'payouts'}
TEST_NAME_SANDBOX = "tap_tester_square_sandbox_tests"
TEST_NAME_PROD = "tap_tester_square_prod_tests"
test_name = TEST_NAME_SANDBOX

DEFAULT_BATCH_LIMIT = 1000
API_LIMIT = {
'items': DEFAULT_BATCH_LIMIT,
'inventories': 100,
'inventories': DEFAULT_BATCH_LIMIT,
'categories': DEFAULT_BATCH_LIMIT,
'discounts': DEFAULT_BATCH_LIMIT,
'taxes': DEFAULT_BATCH_LIMIT,
'cash_drawer_shifts': DEFAULT_BATCH_LIMIT,
'locations': None, # Api does not accept a cursor and documents no limit, see https://developer.squareup.com/reference/square/locations/list-locations
'roles': 100,
'refunds': 100,
'payments': 100,
'payouts': 100,
Expand Down Expand Up @@ -83,15 +85,37 @@ def get_type():
def tap_name():
return "tap-square"

@staticmethod
def name():
return TestSquareBaseParent.TestSquareBase.test_name

def set_environment(self, env):
"""
Change the Square App Environmnet.
Requires re-instatiating TestClient and setting env var.
"""
os.environ['TAP_SQUARE_ENVIRONMENT'] = env
self.set_access_token_in_env()
self.client = TestClient(env=env)
self.SQUARE_ENVIRONMENT = env

def set_access_token_in_env(self):
"""
Fetch the access token from the existing connection and set it in the env.
This is used to avoid rate limiting issues when running tests.
"""
existing_connections = connections.fetch_existing_connections(self)
if not existing_connections:
os.environ.pop('TAP_SQUARE_ACCESS_TOKEN', None)
return

conn_with_creds = connections.fetch_existing_connection_with_creds(existing_connections[0]['id'])
access_token = conn_with_creds['credentials'].get('access_token')
if not access_token:
LOGGER.warning("No access token found in env")
else:
os.environ['TAP_SQUARE_ACCESS_TOKEN'] = access_token

@staticmethod
def get_environment():
return os.environ['TAP_SQUARE_ENVIRONMENT']
Expand All @@ -116,12 +140,23 @@ def get_credentials():
'refresh_token': os.getenv('TAP_SQUARE_REFRESH_TOKEN') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_REFRESH_TOKEN'),
'client_id': os.getenv('TAP_SQUARE_APPLICATION_ID') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_APPLICATION_ID'),
'client_secret': os.getenv('TAP_SQUARE_APPLICATION_SECRET') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_APPLICATION_SECRET'),
'access_token': os.environ["TAP_SQUARE_ACCESS_TOKEN"]
}
else:
raise Exception("Square Environment: {} is not supported.".format(environment))

return creds

@staticmethod
def preserve_access_token(existing_conns, payload):
"""This method is used get the access token from an existing refresh token"""
if not existing_conns:
return payload

conn_with_creds = connections.fetch_existing_connection_with_creds(existing_conns[0]['id'])
payload['properties']['access_token'] = conn_with_creds['credentials'].get('access_token')
return payload

def expected_check_streams(self):
return set(self.expected_metadata().keys()).difference(set())

Expand Down Expand Up @@ -490,7 +525,8 @@ def create_test_data(self, testable_streams, start_date, start_date_2=None, min_
else:
raise NotImplementedError("created_records unknown type: {}".format(created_records))

print("Adjust expectations for stream: {}".format(stream))
stream_to_expected_records[stream] = self.client.get_all(stream, start_date)
LOGGER.info("Adjust expectations for stream: {}".format(stream))
self.modify_expected_records(stream_to_expected_records[stream])

return stream_to_expected_records
Expand Down Expand Up @@ -546,7 +582,7 @@ def run_and_verify_check_mode(self, conn_id):
found_catalog_names = found_catalog_names - {'settlements'}
diff = self.expected_check_streams().symmetric_difference(found_catalog_names)
self.assertEqual(len(diff), 0, msg="discovered schemas do not match: {}".format(diff))
print("discovered schemas are OK")
LOGGER.info("discovered schemas are OK")

return found_catalogs

Expand All @@ -568,7 +604,7 @@ def perform_and_verify_table_and_field_selection(self, conn_id, found_catalogs,
catalog_entry = menagerie.get_annotated_schema(conn_id, cat['stream_id'])
# Verify all testable streams are selected
selected = catalog_entry.get('annotated-schema').get('selected')
print("Validating selection on {}: {}".format(cat['stream_name'], selected))
LOGGER.info("Validating selection on {}: {}".format(cat['stream_name'], selected))
if cat['stream_name'] not in streams_to_select:
self.assertFalse(selected, msg="Stream selected, but not testable.")
continue # Skip remaining assertions if we aren't selecting this stream
Expand All @@ -578,7 +614,7 @@ def perform_and_verify_table_and_field_selection(self, conn_id, found_catalogs,
# Verify all fields within each selected stream are selected
for field, field_props in catalog_entry.get('annotated-schema').get('properties').items():
field_selected = field_props.get('selected')
print("\tValidating selection on {}.{}: {}".format(cat['stream_name'], field, field_selected))
LOGGER.info("\tValidating selection on {}.{}: {}".format(cat['stream_name'], field, field_selected))
self.assertTrue(field_selected, msg="Field not selected.")
else:
# Verify only automatic fields are selected
Expand Down Expand Up @@ -662,7 +698,7 @@ def assertRecordsEqual(self, stream, expected_record, sync_record):
Certain Square streams cannot be compared directly with assertDictEqual().
So we handle that logic here.
"""
if stream not in ['refunds', 'orders', 'customers']:
if stream not in ['refunds', 'orders', 'customers', 'locations']:
expected_record.pop('created_at', None)
if stream == 'payments':
self.assertDictEqualWithOffKeys(expected_record, sync_record, {'updated_at'})
Expand Down
Loading