diff --git a/CHANGELOG.md b/CHANGELOG.md index 00a63df..0537a5e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## Unreleased +**New features:** +* Support SSL encrypted connection to Tarantool EE (closes [#22](https://github.com/igorcoding/asynctnt/issues/22)) + ## v2.0.1 * Fixed an issue with encoding datetimes less than 01-01-1970 (fixes [#29](https://github.com/igorcoding/asynctnt/issues/29)) * Fixed "Edit on Github" links in docs (fixes [#26](https://github.com/igorcoding/asynctnt/issues/26)) diff --git a/asynctnt/__init__.py b/asynctnt/__init__.py index d0aad74..e366826 100644 --- a/asynctnt/__init__.py +++ b/asynctnt/__init__.py @@ -1,3 +1,4 @@ +from .const import Transport from .connection import Connection, connect from .iproto.protocol import ( Iterator, Response, TarantoolTuple, PushIterator, diff --git a/asynctnt/connection.py b/asynctnt/connection.py index 5ad289d..145d82e 100644 --- a/asynctnt/connection.py +++ b/asynctnt/connection.py @@ -1,16 +1,18 @@ import asyncio import enum import functools +import ssl import os from typing import Optional, Union from .api import Api +from .const import Transport from .exceptions import TarantoolDatabaseError, \ - ErrorCode, TarantoolError + ErrorCode, TarantoolError, SSLError from .iproto import protocol from .log import logger from .stream import Stream -from .utils import get_running_loop +from .utils import get_running_loop, PY_37 __all__ = ( 'Connection', 'connect', 'ConnectionState' @@ -27,11 +29,13 @@ class ConnectionState(enum.IntEnum): class Connection(Api): __slots__ = ( - '_host', '_port', '_username', '_password', - '_fetch_schema', '_auto_refetch_schema', '_initial_read_buffer_size', - '_encoding', '_connect_timeout', '_reconnect_timeout', - '_request_timeout', '_ping_timeout', '_loop', '_state', '_state_prev', - '_transport', '_protocol', + '_host', '_port', '_parameter_transport', '_ssl_key_file', + '_ssl_cert_file', '_ssl_ca_file', '_ssl_ciphers', + '_username', '_password', '_fetch_schema', + '_auto_refetch_schema', '_initial_read_buffer_size', + '_encoding', '_connect_timeout', '_ssl_handshake_timeout', + '_reconnect_timeout', '_request_timeout', '_ping_timeout', + '_loop', '_state', '_state_prev', '_transport', '_protocol', '_disconnect_waiter', '_reconnect_task', '_connect_lock', '_disconnect_lock', '_ping_task', '__create_task' @@ -40,11 +44,17 @@ class Connection(Api): def __init__(self, *, host: str = '127.0.0.1', port: Union[int, str] = 3301, + transport: Optional[Transport] = Transport.DEFAULT, + ssl_key_file: Optional[str] = None, + ssl_cert_file: Optional[str] = None, + ssl_ca_file: Optional[str] = None, + ssl_ciphers: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None, fetch_schema: bool = True, auto_refetch_schema: bool = True, connect_timeout: float = 3., + ssl_handshake_timeout: float = 3., request_timeout: float = -1., reconnect_timeout: float = 1. / 3., ping_timeout: float = 5., @@ -78,6 +88,22 @@ def __init__(self, *, :param port: Tarantool port (pass ``/path/to/sockfile`` to connect ot unix socket) + :param transport: + This parameter can be used to configure traffic encryption. + Pass ``asynctnt.Transport.SSL`` value to enable SSL + encryption (by default there is no encryption) + :param ssl_key_file: + A path to a private SSL key file. + Optional, mandatory if server uses CA file + :param ssl_cert_file: + A path to an SSL certificate file. + Optional, mandatory if server uses CA file + :param ssl_ca_file: + A path to a trusted certificate authorities (CA) file. + Optional + :param ssl_ciphers: + A colon-separated (:) list of SSL cipher suites + the connection can use. Optional :param username: Username to use for auth (if ``None`` you are connected as a guest) @@ -93,6 +119,10 @@ def __init__(self, *, be checked by Tarantool, so no errors will occur :param connect_timeout: Time in seconds how long to wait for connecting to socket + :param ssl_handshake_timeout: + Time in seconds to wait for the TLS handshake to complete + before aborting the connection (used only for a TLS + connection). Supported for Python 3.7 or newer :param request_timeout: Request timeout (in seconds) for all requests (by default there is no timeout) @@ -116,6 +146,13 @@ def __init__(self, *, super().__init__() self._host = host self._port = port + + self._parameter_transport = transport + self._ssl_key_file = ssl_key_file + self._ssl_cert_file = ssl_cert_file + self._ssl_ca_file = ssl_ca_file + self._ssl_ciphers = ssl_ciphers + self._username = username self._password = password self._fetch_schema = False if fetch_schema is None else fetch_schema @@ -131,6 +168,7 @@ def __init__(self, *, self._encoding = encoding or 'utf-8' self._connect_timeout = connect_timeout + self._ssl_handshake_timeout = ssl_handshake_timeout self._reconnect_timeout = reconnect_timeout or 0 self._request_timeout = request_timeout self._ping_timeout = ping_timeout or 0 @@ -220,6 +258,54 @@ def protocol_factory(self, on_connection_lost=self.connection_lost, loop=self._loop) + def _create_ssl_context(self): + try: + if hasattr(ssl, 'TLSVersion'): + # Since python 3.7 + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + # Reset to default OpenSSL values. + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + # Require TLSv1.2, because other protocol versions don't seem + # to support the GOST cipher. + context.minimum_version = ssl.TLSVersion.TLSv1_2 + context.maximum_version = ssl.TLSVersion.TLSv1_2 + else: + # Deprecated, but it works for python < 3.7 + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + + if self._ssl_cert_file: + # If the password argument is not specified and a password is + # required, OpenSSL’s built-in password prompting mechanism + # will be used to interactively prompt the user for a password. + # + # We should disable this behaviour, because a python + # application that uses the connector unlikely assumes + # interaction with a human + a Tarantool implementation does + # not support this at least for now. + def password_raise_error(): + raise SSLError("a password for decrypting the private " + + "key is unsupported") + context.load_cert_chain(certfile=self._ssl_cert_file, + keyfile=self._ssl_key_file, + password=password_raise_error) + + if self._ssl_ca_file: + context.load_verify_locations(cafile=self._ssl_ca_file) + context.verify_mode = ssl.CERT_REQUIRED + # A Tarantool implementation does not check hostname. We don't + # do that too. As a result we don't set here: + # context.check_hostname = True + + if self._ssl_ciphers: + context.set_ciphers(self._ssl_ciphers) + + return context + except SSLError as e: + raise + except Exception as e: + raise SSLError(e) + async def _connect(self, return_exceptions: bool = True): if self._loop is None: self._loop = get_running_loop() @@ -246,6 +332,12 @@ async def full_connect(): while True: connected_fut = _create_future(self._loop) + ssl_context = None + ssl_handshake_timeout = None + if self._parameter_transport == Transport.SSL: + ssl_context = self._create_ssl_context() + ssl_handshake_timeout = self._ssl_handshake_timeout + if self._host.startswith('unix/'): unix_path = self._port assert isinstance(unix_path, str), \ @@ -257,16 +349,34 @@ async def full_connect(): 'Unix socket `{}` not found'.format( unix_path) - conn = self._loop.create_unix_connection( - functools.partial(self.protocol_factory, - connected_fut), - unix_path - ) + if PY_37: + conn = self._loop.create_unix_connection( + functools.partial(self.protocol_factory, + connected_fut), + unix_path, + ssl=ssl_context, + ssl_handshake_timeout=ssl_handshake_timeout) + else: + conn = self._loop.create_unix_connection( + functools.partial(self.protocol_factory, + connected_fut), + unix_path, + ssl=ssl_context) + else: - conn = self._loop.create_connection( - functools.partial(self.protocol_factory, - connected_fut), - self._host, self._port) + if PY_37: + conn = self._loop.create_connection( + functools.partial(self.protocol_factory, + connected_fut), + self._host, self._port, + ssl=ssl_context, + ssl_handshake_timeout=ssl_handshake_timeout) + else: + conn = self._loop.create_connection( + functools.partial(self.protocol_factory, + connected_fut), + self._host, self._port, + ssl=ssl_context) tr, pr = await conn @@ -337,6 +447,8 @@ async def full_connect(): if return_exceptions: self._reconnect_task = None + if isinstance(e, ssl.SSLError): + e = SSLError(e) raise e logger.exception(e) diff --git a/asynctnt/const.py b/asynctnt/const.py new file mode 100644 index 0000000..ae597b6 --- /dev/null +++ b/asynctnt/const.py @@ -0,0 +1,5 @@ +import enum + +class Transport(enum.IntEnum): + DEFAULT = 1 + SSL = 2 diff --git a/asynctnt/exceptions.py b/asynctnt/exceptions.py index 00fa09a..ceead6a 100644 --- a/asynctnt/exceptions.py +++ b/asynctnt/exceptions.py @@ -42,6 +42,12 @@ class TarantoolNotConnectedError(TarantoolNetworkError): """ pass +class SSLError(TarantoolError): + """ + Raised when something is wrong with encrypted connection + """ + pass + class ErrorCode(enum.IntEnum): """ diff --git a/docs/examples.md b/docs/examples.md index 480861f..e1e51d7 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -65,3 +65,32 @@ async def main(): asyncio.run(main()) ``` + +## Connect with SSL encryption +```python +import asyncio +import asynctnt + + +async def main(): + conn = asynctnt.Connection(host='127.0.0.1', + port=3301, + transport=asynctnt.Transport.SSL, + ssl_key_file='./ssl/host.key', + ssl_cert_file='./ssl/host.crt', + ssl_ca_file='./ssl/ca.crt', + ssl_ciphers='ECDHE-RSA-AES256-GCM-SHA384') + await conn.connect() + + resp = await conn.ping() + print(resp) + + await conn.disconnect() + +asyncio.run(main()) +``` + +Stdout: +``` + +```