From a064a496001192bccc47bafdef819ed90ea2529f Mon Sep 17 00:00:00 2001 From: Jing Wang Date: Sun, 16 Jul 2017 21:08:37 -0700 Subject: [PATCH] Add options to customize Thrift transport and requests kwargs (#140) Closes #122, closes #132, closes #135 --- pyhive/hive.py | 90 +++++++++++++++++++++++-------------- pyhive/presto.py | 40 ++++++++++++----- pyhive/tests/test_hive.py | 35 ++++++++++++++- pyhive/tests/test_presto.py | 38 +++++++++++++++- 4 files changed, 156 insertions(+), 47 deletions(-) diff --git a/pyhive/hive.py b/pyhive/hive.py index aa191b5a..811c9063 100644 --- a/pyhive/hive.py +++ b/pyhive/hive.py @@ -68,20 +68,25 @@ def connect(*args, **kwargs): class Connection(object): """Wraps a Thrift session""" - def __init__(self, host, port=10000, username=None, database='default', auth='NONE', - configuration=None, kerberos_service_name=None, password=None): + def __init__(self, host=None, port=None, username=None, database='default', auth=None, + configuration=None, kerberos_service_name=None, password=None, + thrift_transport=None): """Connect to HiveServer2 - :param auth: The value of hive.server2.authentication used by HiveServer2 + :param host: What host HiveServer2 runs on + :param port: What port HiveServer2 runs on. Defaults to 10000. + :param auth: The value of hive.server2.authentication used by HiveServer2. + Defaults to ``NONE``. :param configuration: A dictionary of Hive settings (functionally same as the `set` command) :param kerberos_service_name: Use with auth='KERBEROS' only :param password: Use with auth='LDAP' only + :param thrift_transport: A ``TTransportBase`` for custom advanced usage. + Incompatible with host, port, auth, kerberos_service_name, and password. The way to support LDAP and GSSAPI is originated from cloudera/Impyla: https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62 /impala/_thrift_api.py#L152-L160 """ - socket = thrift.transport.TSocket.TSocket(host, port) username = username or getpass.getuser() configuration = configuration or {} @@ -90,37 +95,56 @@ def __init__(self, host, port=10000, username=None, database='default', auth='NO "Remove password or add auth='LDAP'") if (kerberos_service_name is not None) != (auth == 'KERBEROS'): raise ValueError("kerberos_service_name should be set if and only if in KERBEROS mode") + if thrift_transport is not None: + has_incompatible_arg = ( + host is not None + or port is not None + or auth is not None + or kerberos_service_name is not None + or password is not None + ) + if has_incompatible_arg: + raise ValueError("thrift_transport cannot be used with " + "host/port/auth/kerberos_service_name/password") - if auth == 'NOSASL': - # NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml - self._transport = thrift.transport.TTransport.TBufferedTransport(socket) - elif auth in ('LDAP', 'KERBEROS', 'NONE'): - if auth == 'KERBEROS': - # KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library - sasl_auth = 'GSSAPI' - else: - sasl_auth = 'PLAIN' - if password is None: - # Password doesn't matter in NONE mode, just needs to be nonempty. - password = 'x' - - def sasl_factory(): - sasl_client = sasl.Client() - sasl_client.setAttr('host', host) - if sasl_auth == 'GSSAPI': - sasl_client.setAttr('service', kerberos_service_name) - elif sasl_auth == 'PLAIN': - sasl_client.setAttr('username', username) - sasl_client.setAttr('password', password) - else: - raise AssertionError - sasl_client.init() - return sasl_client - self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) + if thrift_transport is not None: + self._transport = thrift_transport else: - raise NotImplementedError( - "Only NONE, NOSASL, LDAP, KERBEROS " - "authentication are supported, got {}".format(auth)) + if port is None: + port = 10000 + if auth is None: + auth = 'NONE' + socket = thrift.transport.TSocket.TSocket(host, port) + if auth == 'NOSASL': + # NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml + self._transport = thrift.transport.TTransport.TBufferedTransport(socket) + elif auth in ('LDAP', 'KERBEROS', 'NONE'): + if auth == 'KERBEROS': + # KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library + sasl_auth = 'GSSAPI' + else: + sasl_auth = 'PLAIN' + if password is None: + # Password doesn't matter in NONE mode, just needs to be nonempty. + password = 'x' + + def sasl_factory(): + sasl_client = sasl.Client() + sasl_client.setAttr('host', host) + if sasl_auth == 'GSSAPI': + sasl_client.setAttr('service', kerberos_service_name) + elif sasl_auth == 'PLAIN': + sasl_client.setAttr('username', username) + sasl_client.setAttr('password', password) + else: + raise AssertionError + sasl_client.init() + return sasl_client + self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) + else: + raise NotImplementedError( + "Only NONE, NOSASL, LDAP, KERBEROS " + "authentication are supported, got {}".format(auth)) protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport) self._client = TCLIService.Client(protocol) diff --git a/pyhive/presto.py b/pyhive/presto.py index 293fd066..2be9d313 100644 --- a/pyhive/presto.py +++ b/pyhive/presto.py @@ -7,6 +7,7 @@ from __future__ import absolute_import from __future__ import unicode_literals + from builtins import object from pyhive import common from pyhive.common import DBAPITypeObject @@ -79,7 +80,7 @@ class Cursor(common.DBAPICursor): def __init__(self, host, port='8080', username=None, catalog='hive', schema='default', poll_interval=1, source='pyhive', session_props=None, - protocol='http', password=None): + protocol='http', password=None, requests_session=None, requests_kwargs=None): """ :param host: hostname to connect to, e.g. ``presto.example.com`` :param port: int -- port, defaults to 8080 @@ -91,29 +92,45 @@ def __init__(self, host, port='8080', username=None, catalog='hive', :param source: string -- arbitrary identifier (shows up in the Presto monitoring page) :param protocol: string -- network protocol, valid options are ``http`` and ``https``. defaults to ``http`` - :param password: string -- defaults to ``None``, using BasicAuth, requires ``https`` + :param password: string -- Deprecated. Defaults to ``None``. + Using BasicAuth, requires ``https``. + Prefer ``requests_kwargs={'auth': HTTPBasicAuth(username, password)}``. + May not be specified with ``requests_kwargs``. + :param requests_session: a ``requests.Session`` object for advanced usage. If absent, this + class will use the default requests behavior of making a new session per HTTP request. + Caller is responsible for closing session. + :param requests_kwargs: Additional ``**kwargs`` to pass to requests """ super(Cursor, self).__init__(poll_interval) # Config self._host = host self._port = port self._username = username or getpass.getuser() - self._password = password self._catalog = catalog self._schema = schema self._arraysize = 1 self._poll_interval = poll_interval self._source = source self._session_props = session_props if session_props is not None else {} + if protocol not in ('http', 'https'): raise ValueError("Protocol must be http/https, was {!r}".format(protocol)) self._protocol = protocol - if password is None: - self._auth = None - else: - self._auth = HTTPBasicAuth(username, self._password) + + self._requests_session = requests_session or requests + + if password is not None and requests_kwargs is not None: + raise ValueError("Cannot use both password and requests_kwargs") + requests_kwargs = dict(requests_kwargs) if requests_kwargs is not None else {} + for k in ('method', 'url', 'data', 'headers'): + if k in requests_kwargs: + raise ValueError("Cannot override requests argument {}".format(k)) + if password is not None: + requests_kwargs['auth'] = HTTPBasicAuth(username, password) if protocol != 'https': raise ValueError("Protocol must be https when passing a password") + self._requests_kwargs = requests_kwargs + self._reset_state() def _reset_state(self): @@ -184,7 +201,8 @@ def execute(self, operation, parameters=None): '{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None)) _logger.info('%s', sql) _logger.debug("Headers: %s", headers) - response = requests.post(url, data=sql.encode('utf-8'), headers=headers, auth=self._auth) + response = self._requests_session.post( + url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs) self._process_response(response) def cancel(self): @@ -194,7 +212,7 @@ def cancel(self): assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None" return - response = requests.delete(self._nextUri, auth=self._auth) + response = self._requests_session.delete(self._nextUri, **self._requests_kwargs) if response.status_code != requests.codes.no_content: fmt = "Unexpected status code after cancel {}\n{}" raise OperationalError(fmt.format(response.status_code, response.content)) @@ -216,13 +234,13 @@ def poll(self): if self._nextUri is None: assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None" return None - response = requests.get(self._nextUri, auth=self._auth) + response = self._requests_session.get(self._nextUri, **self._requests_kwargs) self._process_response(response) return response.json() def _fetch_more(self): """Fetch the next URI and update state""" - self._process_response(requests.get(self._nextUri, auth=self._auth)) + self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs)) def _decode_binary(self, rows): # As of Presto 0.69, binary data is returned as the varbinary type in base64 format diff --git a/pyhive/tests/test_hive.py b/pyhive/tests/test_hive.py index 4cf09dd8..c766f4bf 100644 --- a/pyhive/tests/test_hive.py +++ b/pyhive/tests/test_hive.py @@ -14,10 +14,14 @@ import mock import os -from TCLIService import ttypes -from pyhive import hive +import sasl +import thrift.transport.TSocket +import thrift.transport.TTransport +import thrift_sasl from thrift.transport.TTransport import TTransportException +from TCLIService import ttypes +from pyhive import hive from pyhive.tests.dbapi_test_case import DBAPITestCase from pyhive.tests.dbapi_test_case import with_cursor @@ -185,3 +189,30 @@ def test_invalid_kerberos_config(self): lambda: hive.connect(_HOST, kerberos_service_name='')) self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS', lambda: hive.connect(_HOST, auth='KERBEROS')) + + def test_invalid_transport(self): + """transport and auth are incompatible""" + socket = thrift.transport.TSocket.TSocket('localhost', 10000) + transport = thrift.transport.TTransport.TBufferedTransport(socket) + self.assertRaisesRegexp( + ValueError, 'thrift_transport cannot be used with', + lambda: hive.connect(_HOST, thrift_transport=transport) + ) + + def test_custom_transport(self): + socket = thrift.transport.TSocket.TSocket('localhost', 10000) + sasl_auth = 'PLAIN' + + def sasl_factory(): + sasl_client = sasl.Client() + sasl_client.setAttr('host', 'localhost') + sasl_client.setAttr('username', 'test_username') + sasl_client.setAttr('password', 'x') + sasl_client.init() + return sasl_client + transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) + conn = hive.connect(thrift_transport=transport) + with contextlib.closing(conn): + with contextlib.closing(conn.cursor()) as cursor: + cursor.execute('SELECT * FROM one_row') + self.assertEqual(cursor.fetchall(), [(1,)]) diff --git a/pyhive/tests/test_presto.py b/pyhive/tests/test_presto.py index 78d3e30c..2f8594d7 100644 --- a/pyhive/tests/test_presto.py +++ b/pyhive/tests/test_presto.py @@ -9,6 +9,7 @@ import contextlib import os +import requests from pyhive import exc from pyhive import presto @@ -157,7 +158,7 @@ def test_set_session(self, cursor): session_prop = rows[0] assert session_prop[1] != '1234m' - def test_set_session_in_consructor(self): + def test_set_session_in_constructor(self): conn = presto.connect( host=_HOST, source=self.id(), session_props={'query_max_run_time': '1234m'} ) @@ -184,3 +185,38 @@ def test_invalid_protocol_config(self): ValueError, 'Protocol.*https.*password', lambda: presto.connect( host=_HOST, username='user', password='secret', protocol='http').cursor() ) + + def test_invalid_password_and_kwargs(self): + """password and requests_kwargs are incompatible""" + self.assertRaisesRegexp( + ValueError, 'Cannot use both', lambda: presto.connect( + host=_HOST, username='user', password='secret', protocol='https', + requests_kwargs={} + ).cursor() + ) + + def test_invalid_kwargs(self): + """some kwargs are reserved""" + self.assertRaisesRegexp( + ValueError, 'Cannot override', lambda: presto.connect( + host=_HOST, username='user', requests_kwargs={'url': 'test'} + ).cursor() + ) + + def test_requests_kwargs(self): + connection = presto.connect( + host=_HOST, port=_PORT, source=self.id(), + requests_kwargs={'proxies': {'http': 'localhost:99999'}}, + ) + cursor = connection.cursor() + self.assertRaises(requests.exceptions.ProxyError, + lambda: cursor.execute('SELECT * FROM one_row')) + + def test_requests_session(self): + with requests.Session() as session: + connection = presto.connect( + host=_HOST, port=_PORT, source=self.id(), requests_session=session + ) + cursor = connection.cursor() + cursor.execute('SELECT * FROM one_row') + self.assertEqual(cursor.fetchall(), [(1,)])