From 2d9a9bee127ebc54b0b13a9d7c9a2ddea1b1458d Mon Sep 17 00:00:00 2001 From: andrey-snowflake <42752788+andrey-snowflake@users.noreply.github.com> Date: Wed, 5 Dec 2018 07:03:23 -0800 Subject: [PATCH] Adds dbconnect override for displaying URL in command line (#77) --- src/runners/helpers/db.py | 17 +++++- src/runners/helpers/dbconnect.py | 93 ++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 src/runners/helpers/dbconnect.py diff --git a/src/runners/helpers/db.py b/src/runners/helpers/db.py index 8c912dbcd..c750bbced 100644 --- a/src/runners/helpers/db.py +++ b/src/runners/helpers/db.py @@ -7,6 +7,9 @@ from . import log from .auth import load_pkb from .dbconfig import ACCOUNT, USER, WAREHOUSE, PRIVATE_KEY, PRIVATE_KEY_PASSWORD, TIMEOUT +from .dbconnect import snowflake_connect + +CACHED_CONNECTION = None def retry(f, E=Exception, n=3): @@ -29,11 +32,19 @@ def preflight_checks(ctx): def connect(run_preflight_checks=True): + global CACHED_CONNECTION + if CACHED_CONNECTION: + return CACHED_CONNECTION + + connect_db, authenticator, pk = (snowflake_connect, 'EXTERNALBROWSER', None) if PRIVATE_KEY is None else \ + (snowflake.connector.connect, None, load_pkb(PRIVATE_KEY, PRIVATE_KEY_PASSWORD)) + try: - connection = retry(lambda: snowflake.connector.connect( + connection = retry(lambda: connect_db( user=USER, account=ACCOUNT, - private_key=load_pkb(PRIVATE_KEY, PRIVATE_KEY_PASSWORD), + private_key=pk, + authenticator=authenticator, ocsp_response_cache_filename='/tmp/.cache/snowflake/ocsp_response_cache', network_timeout=TIMEOUT )) @@ -45,6 +56,8 @@ def connect(run_preflight_checks=True): log.fatal(e, "Failed to connect.") execute(connection, f'USE WAREHOUSE {WAREHOUSE};') + + CACHED_CONNECTION = connection return connection diff --git a/src/runners/helpers/dbconnect.py b/src/runners/helpers/dbconnect.py new file mode 100644 index 000000000..d20ed3fee --- /dev/null +++ b/src/runners/helpers/dbconnect.py @@ -0,0 +1,93 @@ +import os +import socket + +from snowflake.connector.auth import Auth +from snowflake.connector.auth_webbrowser import AuthByWebBrowser +from snowflake.connector.connection import SnowflakeConnection as BadSnowflakeConnection +from snowflake.connector.network import SnowflakeRestful + + +def snowflake_connect(**kwargs): + """The bad snowflake connection presumes too much! This one lets you override: + 1. how URL is displayed + 2. which port is listened on + """ + + class WebbrowserPkg(object): + @staticmethod + def open_new(url): + print(f'auth here', url, flush=True) + return True + + class SocketPkg(socket.socket): + def __init__(self, *args, **kwargs): + return super(SocketPkg, self).__init__(*args, **kwargs) + + def close(self, *args, **kwargs): + super(SocketPkg, self).close(*args, **kwargs) + + def bind(self, address): + return super(SocketPkg, self).bind(('0.0.0.0', 1901)) + + class SnowflakeConnection(BadSnowflakeConnection): + def __open_connection(self): + u""" + Opens a new network connection + """ + print("in right one") + self.converter = self._converter_class( + use_sfbinaryformat=False, + use_numpy=self._numpy) + + self._rest = SnowflakeRestful( + host=self.host, + port=self.port, + proxy_host=self.proxy_host, + proxy_port=self.proxy_port, + proxy_user=self.proxy_user, + proxy_password=self.proxy_password, + protocol=self._protocol, + inject_client_pause=self._inject_client_pause, + connection=self) + + if self.host.endswith(u".privatelink.snowflakecomputing.com"): + ocsp_cache_server = \ + u'http://ocsp{}/ocsp_response_cache.json'.format( + self.host[self.host.index('.'):]) + os.environ['SF_OCSP_RESPONSE_CACHE_SERVER_URL'] = ocsp_cache_server + else: + if 'SF_OCSP_RESPONSE_CACHE_SERVER_URL' in os.environ: + del os.environ['SF_OCSP_RESPONSE_CACHE_SERVER_URL'] + + auth_instance = AuthByWebBrowser( + self.rest, self.application, protocol=self._protocol, + host=self.host, port=self.port, webbrowser_pkg=WebbrowserPkg, socket_pkg=SocketPkg) + + if self._session_parameters is None: + self._session_parameters = {} + if self._autocommit is not None: + self._session_parameters['AUTOCOMMIT'] = self._autocommit + + if self._timezone is not None: + self._session_parameters['TIMEZONE'] = self._timezone + + if self.client_session_keep_alive: + self._session_parameters['CLIENT_SESSION_KEEP_ALIVE'] = True + + # enable storing temporary credential in a file + self._session_parameters['CLIENT_STORE_TEMPORARY_CREDENTIAL'] = True + + auth = Auth(self.rest) + if not auth.read_temporary_credential(self.account, self.user, self._session_parameters): + self.__authenticate(auth_instance) + else: + # set the current objects as the session is derived from the id + # token, and the current objects may be different. + self._set_current_objects() + + self._password = None # ensure password won't persist + + if self.client_session_keep_alive: + self._add_heartbeat() + + return SnowflakeConnection(**kwargs)