Skip to content

Commit

Permalink
Add options to customize Thrift transport and requests kwargs (#140)
Browse files Browse the repository at this point in the history
Closes #122, closes #132, closes #135
  • Loading branch information
jingw committed Jul 17, 2017
1 parent 2b0fbd6 commit a064a49
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 47 deletions.
90 changes: 57 additions & 33 deletions pyhive/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,25 @@ def connect(*args, **kwargs):
class Connection(object):
"""Wraps a Thrift session"""

def __init__(self, host, port=10000, username=None, database='default', auth='NONE',
configuration=None, kerberos_service_name=None, password=None):
def __init__(self, host=None, port=None, username=None, database='default', auth=None,
configuration=None, kerberos_service_name=None, password=None,
thrift_transport=None):
"""Connect to HiveServer2
:param auth: The value of hive.server2.authentication used by HiveServer2
:param host: What host HiveServer2 runs on
:param port: What port HiveServer2 runs on. Defaults to 10000.
:param auth: The value of hive.server2.authentication used by HiveServer2.
Defaults to ``NONE``.
:param configuration: A dictionary of Hive settings (functionally same as the `set` command)
:param kerberos_service_name: Use with auth='KERBEROS' only
:param password: Use with auth='LDAP' only
:param thrift_transport: A ``TTransportBase`` for custom advanced usage.
Incompatible with host, port, auth, kerberos_service_name, and password.
The way to support LDAP and GSSAPI is originated from cloudera/Impyla:
https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62
/impala/_thrift_api.py#L152-L160
"""
socket = thrift.transport.TSocket.TSocket(host, port)
username = username or getpass.getuser()
configuration = configuration or {}

Expand All @@ -90,37 +95,56 @@ def __init__(self, host, port=10000, username=None, database='default', auth='NO
"Remove password or add auth='LDAP'")
if (kerberos_service_name is not None) != (auth == 'KERBEROS'):
raise ValueError("kerberos_service_name should be set if and only if in KERBEROS mode")
if thrift_transport is not None:
has_incompatible_arg = (
host is not None
or port is not None
or auth is not None
or kerberos_service_name is not None
or password is not None
)
if has_incompatible_arg:
raise ValueError("thrift_transport cannot be used with "
"host/port/auth/kerberos_service_name/password")

if auth == 'NOSASL':
# NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml
self._transport = thrift.transport.TTransport.TBufferedTransport(socket)
elif auth in ('LDAP', 'KERBEROS', 'NONE'):
if auth == 'KERBEROS':
# KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library
sasl_auth = 'GSSAPI'
else:
sasl_auth = 'PLAIN'
if password is None:
# Password doesn't matter in NONE mode, just needs to be nonempty.
password = 'x'

def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)
if sasl_auth == 'GSSAPI':
sasl_client.setAttr('service', kerberos_service_name)
elif sasl_auth == 'PLAIN':
sasl_client.setAttr('username', username)
sasl_client.setAttr('password', password)
else:
raise AssertionError
sasl_client.init()
return sasl_client
self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
if thrift_transport is not None:
self._transport = thrift_transport
else:
raise NotImplementedError(
"Only NONE, NOSASL, LDAP, KERBEROS "
"authentication are supported, got {}".format(auth))
if port is None:
port = 10000
if auth is None:
auth = 'NONE'
socket = thrift.transport.TSocket.TSocket(host, port)
if auth == 'NOSASL':
# NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml
self._transport = thrift.transport.TTransport.TBufferedTransport(socket)
elif auth in ('LDAP', 'KERBEROS', 'NONE'):
if auth == 'KERBEROS':
# KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library
sasl_auth = 'GSSAPI'
else:
sasl_auth = 'PLAIN'
if password is None:
# Password doesn't matter in NONE mode, just needs to be nonempty.
password = 'x'

def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)
if sasl_auth == 'GSSAPI':
sasl_client.setAttr('service', kerberos_service_name)
elif sasl_auth == 'PLAIN':
sasl_client.setAttr('username', username)
sasl_client.setAttr('password', password)
else:
raise AssertionError
sasl_client.init()
return sasl_client
self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
else:
raise NotImplementedError(
"Only NONE, NOSASL, LDAP, KERBEROS "
"authentication are supported, got {}".format(auth))

protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport)
self._client = TCLIService.Client(protocol)
Expand Down
40 changes: 29 additions & 11 deletions pyhive/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from __future__ import absolute_import
from __future__ import unicode_literals

from builtins import object
from pyhive import common
from pyhive.common import DBAPITypeObject
Expand Down Expand Up @@ -79,7 +80,7 @@ class Cursor(common.DBAPICursor):

def __init__(self, host, port='8080', username=None, catalog='hive',
schema='default', poll_interval=1, source='pyhive', session_props=None,
protocol='http', password=None):
protocol='http', password=None, requests_session=None, requests_kwargs=None):
"""
:param host: hostname to connect to, e.g. ``presto.example.com``
:param port: int -- port, defaults to 8080
Expand All @@ -91,29 +92,45 @@ def __init__(self, host, port='8080', username=None, catalog='hive',
:param source: string -- arbitrary identifier (shows up in the Presto monitoring page)
:param protocol: string -- network protocol, valid options are ``http`` and ``https``.
defaults to ``http``
:param password: string -- defaults to ``None``, using BasicAuth, requires ``https``
:param password: string -- Deprecated. Defaults to ``None``.
Using BasicAuth, requires ``https``.
Prefer ``requests_kwargs={'auth': HTTPBasicAuth(username, password)}``.
May not be specified with ``requests_kwargs``.
:param requests_session: a ``requests.Session`` object for advanced usage. If absent, this
class will use the default requests behavior of making a new session per HTTP request.
Caller is responsible for closing session.
:param requests_kwargs: Additional ``**kwargs`` to pass to requests
"""
super(Cursor, self).__init__(poll_interval)
# Config
self._host = host
self._port = port
self._username = username or getpass.getuser()
self._password = password
self._catalog = catalog
self._schema = schema
self._arraysize = 1
self._poll_interval = poll_interval
self._source = source
self._session_props = session_props if session_props is not None else {}

if protocol not in ('http', 'https'):
raise ValueError("Protocol must be http/https, was {!r}".format(protocol))
self._protocol = protocol
if password is None:
self._auth = None
else:
self._auth = HTTPBasicAuth(username, self._password)

self._requests_session = requests_session or requests

if password is not None and requests_kwargs is not None:
raise ValueError("Cannot use both password and requests_kwargs")
requests_kwargs = dict(requests_kwargs) if requests_kwargs is not None else {}
for k in ('method', 'url', 'data', 'headers'):
if k in requests_kwargs:
raise ValueError("Cannot override requests argument {}".format(k))
if password is not None:
requests_kwargs['auth'] = HTTPBasicAuth(username, password)
if protocol != 'https':
raise ValueError("Protocol must be https when passing a password")
self._requests_kwargs = requests_kwargs

self._reset_state()

def _reset_state(self):
Expand Down Expand Up @@ -184,7 +201,8 @@ def execute(self, operation, parameters=None):
'{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None))
_logger.info('%s', sql)
_logger.debug("Headers: %s", headers)
response = requests.post(url, data=sql.encode('utf-8'), headers=headers, auth=self._auth)
response = self._requests_session.post(
url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs)
self._process_response(response)

def cancel(self):
Expand All @@ -194,7 +212,7 @@ def cancel(self):
assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None"
return

response = requests.delete(self._nextUri, auth=self._auth)
response = self._requests_session.delete(self._nextUri, **self._requests_kwargs)
if response.status_code != requests.codes.no_content:
fmt = "Unexpected status code after cancel {}\n{}"
raise OperationalError(fmt.format(response.status_code, response.content))
Expand All @@ -216,13 +234,13 @@ def poll(self):
if self._nextUri is None:
assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None"
return None
response = requests.get(self._nextUri, auth=self._auth)
response = self._requests_session.get(self._nextUri, **self._requests_kwargs)
self._process_response(response)
return response.json()

def _fetch_more(self):
"""Fetch the next URI and update state"""
self._process_response(requests.get(self._nextUri, auth=self._auth))
self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs))

def _decode_binary(self, rows):
# As of Presto 0.69, binary data is returned as the varbinary type in base64 format
Expand Down
35 changes: 33 additions & 2 deletions pyhive/tests/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@

import mock
import os
from TCLIService import ttypes
from pyhive import hive
import sasl
import thrift.transport.TSocket
import thrift.transport.TTransport
import thrift_sasl
from thrift.transport.TTransport import TTransportException

from TCLIService import ttypes
from pyhive import hive
from pyhive.tests.dbapi_test_case import DBAPITestCase
from pyhive.tests.dbapi_test_case import with_cursor

Expand Down Expand Up @@ -185,3 +189,30 @@ def test_invalid_kerberos_config(self):
lambda: hive.connect(_HOST, kerberos_service_name=''))
self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS',
lambda: hive.connect(_HOST, auth='KERBEROS'))

def test_invalid_transport(self):
"""transport and auth are incompatible"""
socket = thrift.transport.TSocket.TSocket('localhost', 10000)
transport = thrift.transport.TTransport.TBufferedTransport(socket)
self.assertRaisesRegexp(
ValueError, 'thrift_transport cannot be used with',
lambda: hive.connect(_HOST, thrift_transport=transport)
)

def test_custom_transport(self):
socket = thrift.transport.TSocket.TSocket('localhost', 10000)
sasl_auth = 'PLAIN'

def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr('host', 'localhost')
sasl_client.setAttr('username', 'test_username')
sasl_client.setAttr('password', 'x')
sasl_client.init()
return sasl_client
transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
conn = hive.connect(thrift_transport=transport)
with contextlib.closing(conn):
with contextlib.closing(conn.cursor()) as cursor:
cursor.execute('SELECT * FROM one_row')
self.assertEqual(cursor.fetchall(), [(1,)])
38 changes: 37 additions & 1 deletion pyhive/tests/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import contextlib
import os
import requests

from pyhive import exc
from pyhive import presto
Expand Down Expand Up @@ -157,7 +158,7 @@ def test_set_session(self, cursor):
session_prop = rows[0]
assert session_prop[1] != '1234m'

def test_set_session_in_consructor(self):
def test_set_session_in_constructor(self):
conn = presto.connect(
host=_HOST, source=self.id(), session_props={'query_max_run_time': '1234m'}
)
Expand All @@ -184,3 +185,38 @@ def test_invalid_protocol_config(self):
ValueError, 'Protocol.*https.*password', lambda: presto.connect(
host=_HOST, username='user', password='secret', protocol='http').cursor()
)

def test_invalid_password_and_kwargs(self):
"""password and requests_kwargs are incompatible"""
self.assertRaisesRegexp(
ValueError, 'Cannot use both', lambda: presto.connect(
host=_HOST, username='user', password='secret', protocol='https',
requests_kwargs={}
).cursor()
)

def test_invalid_kwargs(self):
"""some kwargs are reserved"""
self.assertRaisesRegexp(
ValueError, 'Cannot override', lambda: presto.connect(
host=_HOST, username='user', requests_kwargs={'url': 'test'}
).cursor()
)

def test_requests_kwargs(self):
connection = presto.connect(
host=_HOST, port=_PORT, source=self.id(),
requests_kwargs={'proxies': {'http': 'localhost:99999'}},
)
cursor = connection.cursor()
self.assertRaises(requests.exceptions.ProxyError,
lambda: cursor.execute('SELECT * FROM one_row'))

def test_requests_session(self):
with requests.Session() as session:
connection = presto.connect(
host=_HOST, port=_PORT, source=self.id(), requests_session=session
)
cursor = connection.cursor()
cursor.execute('SELECT * FROM one_row')
self.assertEqual(cursor.fetchall(), [(1,)])

0 comments on commit a064a49

Please sign in to comment.