From f11fff6f90f2c36b28222e8d9a8fd15a47139c4f Mon Sep 17 00:00:00 2001 From: Sergei Sokolov Date: Mon, 27 May 2024 16:31:41 +0700 Subject: [PATCH] Added basic functionality for read and write to HUAWEI Object Storage Service (OBS) --- README.rst | 49 ++++ setup.py | 4 +- smart_open/obs.py | 382 ++++++++++++++++++++++++++++ smart_open/tests/test_obs.py | 153 +++++++++++ smart_open/tests/test_smart_open.py | 91 +++++-- smart_open/tests/test_utils.py | 9 + smart_open/transport.py | 1 + smart_open/utils.py | 16 ++ 8 files changed, 689 insertions(+), 16 deletions(-) create mode 100644 smart_open/obs.py create mode 100644 smart_open/tests/test_obs.py diff --git a/README.rst b/README.rst index c7060131..3efdd47d 100644 --- a/README.rst +++ b/README.rst @@ -93,6 +93,7 @@ Other examples of URLs that ``smart_open`` accepts:: s3://my_key:my_secret@my_server:my_port@my_bucket/my_key gs://my_bucket/my_blob azure://my_bucket/my_blob + obs://bucket_id.server:port/object_key hdfs:///path/file hdfs://path/file webhdfs://host:port/path/file @@ -290,6 +291,7 @@ Transport-specific Options - WebHDFS - GCS - Azure Blob Storage +- OBS (Huawei Object Storage) Each option involves setting up its own set of parameters. For example, for accessing S3, you often need to set up authentication, like API keys or a profile name. @@ -455,6 +457,53 @@ Additional keyword arguments can be propagated to the ``commit_block_list`` meth kwargs = {'metadata': {'version': 2}} fout = open('azure://container/key', 'wb', transport_params={'blob_kwargs': kwargs}) +OBS Credentials +--------------- +``smart_open`` uses the ``esdk-obs-python`` library to talk to OBS. +Please see `esdk-obs-python docs `__. + +There are several ways to provide Access key, Secret Key and Security Token +- Using env variables +- Using custom client params + +AK, SK, ST can be encrypted in this case You need install and configure `security provider `__. + + +OBS Advanced Usage +-------------------- +- Supported env variables: +OBS_ACCESS_KEY_ID, +OBS_SECRET_ACCESS_KEY, +OBS_SECURITY_TOKEN, +SMART_OPEN_OBS_USE_CLIENT_WRITE_MODE, +SMART_OPEN_OBS_DECRYPT_AK_SK, +SMART_OPEN_OBS_SCC_LIB_PATH, +SMART_OPEN_OBS_SCC_CONF_PATH + +- Configuration via code +.. code-block:: python + + client = {'access_key_id': 'ak', 'secret_access_key': 'sk', 'security_token': 'st', 'server': 'server_url'} + headers = [] + transport_params = { + >>> # client can be dict with parameters supported by the obs.ObsClient or instance of the obs.ObsClient + >>> 'client': client, + >>> # additional header for request, please see esdk-obs-python docs + >>> 'headers': headers, + >>> # if True obs.ObsClient will be take write method argument as readable object to get bytes. For writing mode only. + >>> # Please see docs for ObsClient.putContent api. + >>> 'use_obs_client_write_mode': True, + >>> # True if need decrypt Ak, Sk, St + >>> # It required to install CryptoAPI libs. + >>> # https://support.huawei.com/enterprise/en/software/260510077-ESW2000847337 + >>> 'decrypt_ak_sk' : True, + >>> # path to python libs of the Crypto provider + >>> 'scc_lib_path': '/usr/lib/scc', + >>> # path to config file of the Crypto provider + >>> 'scc_conf_path': '/home/user/scc.conf'} + + fout = open('obs://bucket_id.server:port/object_key', 'wb', transport_params=transport_params) + Drop-in replacement of ``pathlib.Path.open`` -------------------------------------------- diff --git a/setup.py b/setup.py index a9a4fc53..7e57b40a 100644 --- a/setup.py +++ b/setup.py @@ -42,8 +42,9 @@ def read(fname): http_deps = ['requests'] ssh_deps = ['paramiko'] zst_deps = ['zstandard'] +obs_deps = ['esdk-obs-python'] -all_deps = aws_deps + gcs_deps + azure_deps + http_deps + ssh_deps + zst_deps +all_deps = aws_deps + gcs_deps + azure_deps + http_deps + ssh_deps + zst_deps + obs_deps tests_require = all_deps + [ 'moto[server]', 'responses', @@ -83,6 +84,7 @@ def read(fname): 'webhdfs': http_deps, 'ssh': ssh_deps, 'zst': zst_deps, + 'obs': obs_deps, }, python_requires=">=3.7,<4.0", diff --git a/smart_open/obs.py b/smart_open/obs.py new file mode 100644 index 00000000..30288a9a --- /dev/null +++ b/smart_open/obs.py @@ -0,0 +1,382 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 Sergei Sokolov +# +# This code is distributed under the terms and conditions +# from the MIT License (MIT). +# +"""Implements file-like objects for reading and writing from/to HUAWEI Object Storage Service (OBS).""" +from __future__ import annotations + +import io +import logging +import os +import struct +import sys +from typing import Optional, Tuple, List + +from smart_open.utils import set_defaults + +try: + import obs.client + from obs.searchmethod import get_token + from obs import loadtoken +except ImportError: + MISSING_DEPS = True + +import smart_open.bytebuffer +import smart_open.utils + +from smart_open import constants + +logger = logging.getLogger(__name__) + +SCHEMES = ('obs',) + +URI_EXAMPLES = ( + 'obs://bucket_id.server:port/object_key', +) + +DEFAULT_CHUNK_SIZE = 65536 +DEFAULT_HTTP_PROTOCOL = 'https' +DEFAULT_SECURITY_PROVIDER_POLICY = 'ENV' + +ENV_VAR_USE_CLIENT_WRITE_MODE = 'SMART_OPEN_OBS_USE_CLIENT_WRITE_MODE' +ENV_VAR_DECRYPT_AK_SK = 'SMART_OPEN_OBS_DECRYPT_AK_SK' +ENV_VAR_SCC_LIB_PATH = 'SMART_OPEN_OBS_SCC_LIB_PATH' +ENV_VAR_SCC_CONF_PATH = 'SMART_OPEN_OBS_SCC_CONF_PATH' + +default_client_kwargs = { + 'security_provider_policy': DEFAULT_SECURITY_PROVIDER_POLICY, +} + + +def parse_uri(uri_as_string): + split_uri = smart_open.utils.safe_urlsplit(uri_as_string) + assert split_uri.scheme in SCHEMES + + bucket_id, server = split_uri.netloc.split('.', 1) + object_key = split_uri.path[1:] + + return dict( + scheme=split_uri.scheme, + bucket_id=bucket_id, + object_key=object_key, + server=server, + ) + + +def open_uri(uri, mode, transport_params): + parsed_uri = parse_uri(uri) + kwargs = _prepare_open_kwargs(parsed_uri=parsed_uri, + transport_params=transport_params) + return open(parsed_uri['bucket_id'], parsed_uri['object_key'], mode, **kwargs) + + +def _prepare_open_kwargs(parsed_uri: dict, transport_params: dict) -> dict: + kwargs = smart_open.utils.check_kwargs(open, transport_params) + + http_protocol = transport_params.get('http_protocol', DEFAULT_HTTP_PROTOCOL) + client_kwargs = { + 'server': f'{http_protocol}://{parsed_uri["server"]}', + } + client_kwargs.update(default_client_kwargs) + + kwargs['client'] = transport_params.get('client', client_kwargs) + + default_kwarg = { + 'use_obs_client_write_mode': + os.environ.get(ENV_VAR_USE_CLIENT_WRITE_MODE, 'false').lower() in ('true'), + 'decrypt_ak_sk': + os.environ.get(ENV_VAR_DECRYPT_AK_SK, 'false').lower() in ('true'), + 'scc_lib_path': + os.environ.get(ENV_VAR_SCC_LIB_PATH, None), + 'scc_conf_path': + os.environ.get(ENV_VAR_SCC_CONF_PATH, None), + } + + set_defaults(kwargs, default_kwarg) + + return kwargs + + +def open( + bucket_id, + object_key, + mode, + buffer_size=DEFAULT_CHUNK_SIZE, + client: Optional[obs.ObsClient | dict] = None, + headers: Optional[List[Tuple[obs.PutObjectHeader | obs.GetObjectHeader], str]] = None, + use_obs_client_write_mode: bool = False, + decrypt_ak_sk: bool = False, + scc_lib_path: Optional[str] = None, + scc_conf_path: Optional[str] = None): + """Open an OBS object for reading or writing. + + Parameters + ---------- + bucket_id: str + The name of the bucket this object resides in. + object_key: str + The name of the key within the bucket. + mode: str + The mode for opening the object. Must be either "rb" or "wb". + buffer_size: int + The buffer size to use when performing I/O. + client: Optional[obs.ObsClient | dict] + The initialized OBS client or dict with args that will be supplied to obs.ObsClient constructor. + Please see docs for esdk-obs-python. + headers: Optional[List[Tuple]] + The optional additional headers of the request. + Please see docs for esdk-obs-python. + use_obs_client_write_mode: bool + True if we will use readable object to get bytes. For writing mode only. + Please see docs for ObsClient.putContent api + decrypt_ak_sk: bool + True if we need decrypt Access key, Secret key and Security token. + It required to install CryptoAPI libs. + https://support.huawei.com/enterprise/en/software/260510077-ESW2000847337 + scc_lib_path: Optional[str] + The path to CryptoAPI libs. + scc_conf_path: Optional[str] + The path to scc.conf. + """ + + logger.debug('%r', locals()) + if mode not in constants.BINARY_MODES: + raise NotImplementedError('bad mode: %r expected one of %r' % (mode, constants.BINARY_MODES)) + + _client = client if isinstance(client, obs.ObsClient) else create_obs_client( + client_config=client, + decrypt_ak_sk=decrypt_ak_sk, + scc_lib_path=scc_lib_path, + scc_conf_path=scc_conf_path) + + if mode == constants.READ_BINARY: + fileobj = ObsReader(bucket_id=bucket_id, + object_key=object_key, + client=_client, + headers=headers) + elif mode == constants.WRITE_BINARY: + fileobj = ObsWriter(bucket_id=bucket_id, + object_key=object_key, + client=_client, + headers=headers, + use_obs_client_write_mode=use_obs_client_write_mode) + else: + assert False, 'unexpected mode: %r' % mode + return fileobj + + +def create_obs_client(client_config: dict, + decrypt_ak_sk: bool = False, + scc_lib_path: Optional[str] = None, + scc_conf_path: Optional[str] = None) -> obs.ObsClient: + """Initializes the ObsClient. + """ + if not decrypt_ak_sk: + return obs.ObsClient(**client_config) + + decrypted_config = _decrypt_ak_sk(client_config=client_config, + scc_lib_path=scc_lib_path, + scc_conf_path=scc_conf_path) + + set_defaults(decrypted_config, client_config) + return obs.ObsClient(**decrypted_config) + + +def _decrypt_ak_sk(client_config: dict, + scc_lib_path: Optional[str] = None, + scc_conf_path: Optional[str] = None) -> dict: + crypto_provider = CryptoProvider(scc_lib_path=scc_lib_path, + scc_conf_path=scc_conf_path) + + if 'access_key_id' in client_config: + access_key_id = client_config.get('access_key_id') + secret_access_key = client_config.get('secret_access_key') + security_token = client_config.get('security_token', None) + else: + tokens = get_token(security_providers=loadtoken.ENV) + access_key_id = tokens.get('accessKey') + secret_access_key = tokens.get('secretKey') + security_token = tokens.get('securityToken') + + return { + access_key_id: crypto_provider.decrypt(access_key_id), + secret_access_key: crypto_provider.decrypt(secret_access_key), + security_token: crypto_provider.decrypt(security_token), + } + + +class ObsReader(io.RawIOBase): + """Read an OBS Object. + """ + + def __init__(self, + bucket_id: str, + object_key: str, + client: obs.ObsClient, + headers: Optional[obs.GetObjectHeader] = None, + buffer_size: int = DEFAULT_CHUNK_SIZE): + self.name = object_key + self.bucket_id = bucket_id + self.object_key = object_key + self.buffer_size = buffer_size + self._client = client + self._buffer = smart_open.bytebuffer.ByteBuffer(buffer_size) + self._resp = self._client.getObject(bucketName=bucket_id, + objectKey=object_key, + headers=headers) + if self._resp.status >= 300: + raise RuntimeError( + f'Failed to read: {self.object_key}! ' + f'errorCode: {self._resp.errorCode}, ' + f'errorMessage: {self._resp.errorMessage}') + + def readinto(self, __buffer): + data = self.read(len(__buffer)) + if not data: + return 0 + __buffer[:len(data)] = data + return len(data) + + def readinto1(self, __buffer): + return self.readinto(__buffer) + + def read(self, size=-1): + if size == 0: + return b'' + + if self._resp is None: + raise RuntimeError(f'No response received while reading: {self.object_key}') + + if size > 0: + chunk = self._resp.body.response.read(size) + return chunk + else: + while True: + chunk = self._resp.body.response.read(self.buffer_size) + if not chunk: + break + self._buffer.fill(struct.unpack(str(len(chunk)) + 'c', chunk)) + return self._buffer.read() + + def read1(self, size=-1): + return self.read(size) + + def close(self): + self.__del__() + + def seekable(self): + return False + + def detach(self): + """Unsupported.""" + raise io.UnsupportedOperation + + def __del__(self): + try: + if self._client: + self._resp = None + self._client.close() + self._client = None + except Exception as ex: + logger.warning(ex) + + +class ObsWriter(io.RawIOBase): + """Write an OBS Object. + + If use_obs_client_write_mode set to False: + this class buffers all of its input in memory until its `close` method is called. + Only then the data will be written to OBS and the buffer is released. + + If use_obs_client_write_mode set to True: + `write` method of the ObsWriter will accept any readable object or path to file. + In this case will be used internal implementation in obs.ObsClient.putContent to read bytes + Write to OBS will be triggered in `close` method. + """ + + def __init__(self, + bucket_id: str, + object_key: str, + client: obs.ObsClient, + headers: Optional[obs.PutObjectHeader] = None, + use_obs_client_write_mode: bool = False + ): + self.name = object_key + self.bucket_id = bucket_id + self.object_key = object_key + self._client = client + self._headers = headers + self._content: Optional[str | io.BytesIO | io.BufferedReader] = None + self.use_obs_client_write_mode = use_obs_client_write_mode + + def write(self, __buffer): + if not __buffer: + return None + + if self.use_obs_client_write_mode: + self._content = __buffer + else: + if not self._content: + self._content = io.BytesIO() + self._content.write(__buffer) + return None + + def close(self): + if not self._content: + self._client.close() + return + + if isinstance(self._content, io.BytesIO): + self._content.seek(0) + + self._client.putContent(bucketName=self.bucket_id, + objectKey=self.object_key, + content=self._content, + headers=self._headers) + self._content = None + + def seekable(self): + return False + + def writable(self): + return self._content is not None + + def detach(self): + """Unsupported.""" + raise io.UnsupportedOperation + + +class CryptoProvider: + """Decrypt Access Key, Secret Key, Security Token. + + This class use Huawei CloudGuard CSP seccomponent to decrypt AK, SK and ST. + """ + + def __init__(self, scc_lib_path: Optional[str] = None, scc_conf_path: Optional[str] = None): + self._scc_lib_path = scc_lib_path + self._scc_conf_path = scc_conf_path + + if scc_lib_path and scc_lib_path not in sys.path: + sys.path.append(scc_lib_path) + + try: + from CryptoAPI import CryptoAPI + except ImportError: + raise RuntimeError('Failed to use CryptoAPI module. Please install CloudGuard CSP seccomponent.') + + self._api = CryptoAPI() + + if self._scc_conf_path: + self._api.initialize(self._scc_conf_path) + else: + self._api.initialize() + + def __del__(self): + if self._api: + self._api.finalize() + + def decrypt(self, encrypted: Optional[str]) -> Optional[str]: + return self._api.decrypt(encrypted) if encrypted else None diff --git a/smart_open/tests/test_obs.py b/smart_open/tests/test_obs.py new file mode 100644 index 00000000..e92c4dcd --- /dev/null +++ b/smart_open/tests/test_obs.py @@ -0,0 +1,153 @@ +import io +import os +import unittest +import uuid +from unittest.mock import patch + +import obs + +import smart_open +from smart_open.obs import ObsReader + +BUCKET_ID = 'test-smartopen-{}'.format(uuid.uuid4().hex) +OBJECT_KEY = 'hello.txt' + + +class ReadTest(unittest.TestCase): + + def setUp(self): + self.test_string = u'ветер по морю гуляет...' + + response_wrapper = obs.model.ResponseWrapper(conn=None, + connHolder=None, + result=io.BytesIO(self.test_string.encode('utf-8'))) + body = obs.model.ObjectStream(response=response_wrapper) + self.response = obs.model.GetResult(status=200, body=body) + + def test_read_never_returns_none(self): + with patch.object(obs.ObsClient, 'getObject', return_value=self.response): + reader = ObsReader(bucket_id=BUCKET_ID, object_key=OBJECT_KEY, + client=obs.ObsClient(server='server')) + + self.assertEqual(reader.read(), self.test_string.encode("utf-8")) + self.assertEqual(reader.read(), b'') + self.assertEqual(reader.read(), b'') + + +class WriteTest(unittest.TestCase): + + def setUp(self): + self.texst_text = 'ветер по морю гуляет...' + response_wrapper = obs.model.ResponseWrapper(conn=None, connHolder=None, result=io.BytesIO(b'ok')) + body = obs.model.ObjectStream(response=response_wrapper) + self.response = obs.model.GetResult(status=200, body=body) + + def test_write(self): + with patch.object(obs.ObsClient, 'putContent', return_value=self.response) as mock_method: + writer = smart_open.obs.ObsWriter(bucket_id=BUCKET_ID, + object_key=OBJECT_KEY, + client=obs.ObsClient(server='server'), + headers=obs.PutObjectHeader(contentType='text/plain')) + writer.write(u'ветер по морю '.encode('utf-8')) + writer.write(u'гуляет...'.encode('utf-8')) + writer.close() + + kwargs = mock_method.call_args.kwargs + self.assertEqual(kwargs['bucketName'], BUCKET_ID) + self.assertEqual(kwargs['objectKey'], OBJECT_KEY) + self.assertEqual(kwargs['headers']['contentType'], 'text/plain') + self.assertEqual(kwargs['content'].read(), self.texst_text.encode('utf-8')) + + def test_write_use_obs_client_write_mode(self): + test_bytes = io.BytesIO(self.texst_text.encode('utf-8')) + + with patch.object(obs.ObsClient, 'putContent', return_value=self.response) as mock_method: + writer = smart_open.obs.ObsWriter(bucket_id=BUCKET_ID, + object_key=OBJECT_KEY, + client=obs.ObsClient(server='server'), + headers=obs.PutObjectHeader(contentType='text/plain'), + use_obs_client_write_mode=True) + writer.write(test_bytes) + writer.close() + + kwargs = mock_method.call_args.kwargs + self.assertEqual(kwargs['bucketName'], BUCKET_ID) + self.assertEqual(kwargs['objectKey'], OBJECT_KEY) + self.assertEqual(kwargs['headers']['contentType'], 'text/plain') + self.assertEqual(kwargs['content'].read(), self.texst_text.encode('utf-8')) + self.assertEqual(id(kwargs['content']), id(test_bytes)) + + +class PrepareOpenKwargsTest(unittest.TestCase): + def setUp(self): + self.parsed_uri = dict( + scheme='obs', + bucket_id='bucket_id', + object_key='object_key', + server='server', + ) + + def tearDown(self): + if os.environ.get(smart_open.obs.ENV_VAR_USE_CLIENT_WRITE_MODE, None): + del os.environ[smart_open.obs.ENV_VAR_USE_CLIENT_WRITE_MODE] + if os.environ.get(smart_open.obs.ENV_VAR_DECRYPT_AK_SK, None): + del os.environ[smart_open.obs.ENV_VAR_DECRYPT_AK_SK] + if os.environ.get(smart_open.obs.ENV_VAR_SCC_LIB_PATH, None): + del os.environ[smart_open.obs.ENV_VAR_SCC_LIB_PATH] + if os.environ.get(smart_open.obs.ENV_VAR_SCC_CONF_PATH, None): + del os.environ[smart_open.obs.ENV_VAR_SCC_CONF_PATH] + + def test_prepare_open_kwargs_defaults(self): + transport_parpams = {} + actual = smart_open.obs._prepare_open_kwargs(parsed_uri=self.parsed_uri, + transport_params=transport_parpams) + + self.assertEqual(actual['decrypt_ak_sk'], False) + self.assertIsNone(actual['scc_lib_path']) + self.assertIsNone(actual['scc_conf_path']) + self.assertIsNotNone(actual.get('client', None)) + self.assertEqual(actual['client'].get('server', None), f'https://{self.parsed_uri["server"]}') + self.assertEqual(actual['client'].get('security_provider_policy', None), 'ENV') + self.assertFalse(actual.get('use_obs_client_write_mode')) + + def test_prepare_open_kwargs_override(self): + transport_parpams = { + 'decrypt_ak_sk': True, + 'scc_lib_path': 'scc_lib_path', + 'scc_conf_path': 'scc_conf_path', + 'use_obs_client_write_mode': True, + 'client': { + 'security_provider_policy': 'ECS', + 'server': 'https://server1' + } + } + + actual = smart_open.obs._prepare_open_kwargs(parsed_uri=self.parsed_uri, + transport_params=transport_parpams) + + self.assertEqual(actual['decrypt_ak_sk'], True) + self.assertEqual(actual['scc_lib_path'], 'scc_lib_path') + self.assertEqual(actual['scc_conf_path'], 'scc_conf_path') + self.assertIsNotNone(actual.get('client', None)) + self.assertEqual(actual['client'].get('server', None), f'https://server1') + self.assertEqual(actual['client'].get('security_provider_policy', None), 'ECS') + self.assertTrue(actual.get('use_obs_client_write_mode')) + + def test_prepare_open_kwargs_override_env(self): + os.environ[smart_open.obs.ENV_VAR_USE_CLIENT_WRITE_MODE] = 'True' + os.environ[smart_open.obs.ENV_VAR_DECRYPT_AK_SK] = 'True' + os.environ[smart_open.obs.ENV_VAR_SCC_LIB_PATH] = 'scc_lib_path' + os.environ[smart_open.obs.ENV_VAR_SCC_CONF_PATH] = 'scc_conf_path' + + transport_parpams = {} + + actual = smart_open.obs._prepare_open_kwargs(parsed_uri=self.parsed_uri, + transport_params=transport_parpams) + + self.assertEqual(actual['decrypt_ak_sk'], True) + self.assertEqual(actual['scc_lib_path'], 'scc_lib_path') + self.assertEqual(actual['scc_conf_path'], 'scc_conf_path') + self.assertIsNotNone(actual.get('client', None)) + self.assertEqual(actual['client'].get('server', None), f'https://{self.parsed_uri["server"]}') + self.assertEqual(actual['client'].get('security_provider_policy', None), 'ENV') + self.assertTrue(actual.get('use_obs_client_write_mode')) diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index e31f48b7..4ef05a1e 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -7,24 +7,26 @@ # import bz2 -import csv import contextlib +import csv import functools -import io import gzip import hashlib +import io import logging import os -from smart_open.compression import INFER_FROM_EXTENSION, NO_COMPRESSION import tempfile import unittest -from unittest import mock import warnings +from unittest import mock +from unittest.mock import patch import boto3 import pytest import responses +from smart_open.compression import INFER_FROM_EXTENSION, NO_COMPRESSION + # See https://github.com/piskvorky/smart_open/issues/800 # This supports moto 4 & 5 until v4 is no longer used by distros. try: @@ -37,6 +39,7 @@ from smart_open import webhdfs from smart_open.smart_open_lib import patch_pathlib, _patch_pathlib from smart_open.tests.test_s3 import patch_invalid_range_response +import obs logger = logging.getLogger(__name__) @@ -106,6 +109,7 @@ class ParseUriTest(unittest.TestCase): Test ParseUri class. """ + def test_scheme(self): """Do URIs schemes parse correctly?""" # supported schemes @@ -417,6 +421,13 @@ def test_azure_blob_uri_contains_slash(self): self.assertEqual(parsed_uri.container_id, "mycontainer") self.assertEqual(parsed_uri.blob_id, "mydir/myblob") + def test_obs_parse_uri(self): + parsed_uri = smart_open_lib._parse_uri("obs://bucketid.server.com:123/folder/file.tgz") + self.assertEqual(parsed_uri.scheme, "obs") + self.assertEqual(parsed_uri.bucket_id, "bucketid") + self.assertEqual(parsed_uri.object_key, "folder/file.tgz") + self.assertEqual(parsed_uri.server, "server.com:123") + def test_pathlib_monkeypatch(self): from smart_open.smart_open_lib import pathlib @@ -458,6 +469,7 @@ class SmartOpenHttpTest(unittest.TestCase): Test reading from HTTP connections in various ways. """ + @mock.patch('smart_open.ssh.open', return_value=open(__file__)) def test_read_ssh(self, mock_open): """Is SSH line iterator called correctly?""" @@ -918,14 +930,14 @@ def test_file(self, mock_smart_open): prefix = "file://" full_path = '/tmp/test.txt' read_mode = "rb" - smart_open_object = smart_open.open(prefix+full_path, read_mode) + smart_open_object = smart_open.open(prefix + full_path, read_mode) smart_open_object.__iter__() # called with the correct path? mock_smart_open.assert_called_with(full_path, read_mode, buffering=-1) full_path = '/tmp/test#hash##more.txt' read_mode = "rb" - smart_open_object = smart_open.open(prefix+full_path, read_mode) + smart_open_object = smart_open.open(prefix + full_path, read_mode) smart_open_object.__iter__() # called with the correct path? mock_smart_open.assert_called_with(full_path, read_mode, buffering=-1) @@ -948,7 +960,7 @@ def test_file_errors(self, mock_smart_open): short_path = "~/tmp/test.txt" full_path = os.path.expanduser(short_path) - smart_open_object = smart_open.open(prefix+short_path, read_mode, errors='strict') + smart_open_object = smart_open.open(prefix + short_path, read_mode, errors='strict') smart_open_object.__iter__() # called with the correct expanded path? mock_smart_open.assert_called_with(full_path, read_mode, buffering=-1, errors='strict') @@ -1023,13 +1035,13 @@ def test_webhdfs_read(self): def test_s3_iter_moto(self): """Are S3 files iterated over correctly?""" # a list of strings to test with - expected = [b"*" * 5 * 1024**2] + [b'0123456789'] * 1024 + [b"test"] + expected = [b"*" * 5 * 1024 ** 2] + [b'0123456789'] * 1024 + [b"test"] # create fake bucket and fake key s3 = _resource('s3') s3.create_bucket(Bucket='mybucket') - tp = dict(s3_min_part_size=5 * 1024**2) + tp = dict(s3_min_part_size=5 * 1024 ** 2) with smart_open.open("s3://mybucket/mykey", "wb", transport_params=tp) as fout: # write a single huge line (=full multipart upload) fout.write(expected[0] + b'\n') @@ -1157,6 +1169,7 @@ class SmartOpenTest(unittest.TestCase): Test reading and writing from/into files. """ + def setUp(self): self.as_text = u'куда идём мы с пятачком - большой большой секрет' self.as_bytes = self.as_text.encode('utf-8') @@ -1780,6 +1793,54 @@ def test_respects_endpoint_url_write(self, mock_open): self.assertEqual(mock_open.call_args[1]['client_kwargs']['S3.Client'], expected) +class ObsOpenTest(unittest.TestCase): + + def setUp(self): + self.test_string = u'ветер по морю гуляет...' + self.bucket_id = 'bucketId' + self.object_key = 'objectKey' + + response_wrapper = obs.model.ResponseWrapper(conn=None, + connHolder=None, + result=io.BytesIO(self.test_string.encode('utf-8'))) + body = obs.model.ObjectStream(response=response_wrapper) + self.response = obs.model.GetResult(status=200, body=body) + + def test_open(self): + with patch.object(obs.ObsClient, 'getObject', return_value=self.response): + with smart_open.open(f'obs://{self.bucket_id}.server/{self.object_key}', + 'rb', transport_params=dict(client=obs.ObsClient(server='server'))) as file: + self.assertEqual(file.read(), self.test_string.encode("utf-8")) + self.assertEqual(file.read(), b'') + self.assertEqual(file.read(), b'') + + +class ObsWriteTest(unittest.TestCase): + def setUp(self): + self.texst_text = 'ветер по морю гуляет...' + self.bucket_id = 'bucketId' + self.object_key = 'objectKey' + + response_wrapper = obs.model.ResponseWrapper(conn=None, connHolder=None, result=io.BytesIO(b'ok')) + body = obs.model.ObjectStream(response=response_wrapper) + self.response = obs.model.GetResult(status=200, body=body) + + def test_write(self): + with patch.object(obs.ObsClient, 'putContent', return_value=self.response) as mock_method: + with smart_open.open(f'obs://{self.bucket_id}.server/{self.object_key}', + 'wb', transport_params=dict(client=obs.ObsClient(server='server'), + headers=obs.PutObjectHeader( + contentType='text/plain'))) as file: + file.write(u'ветер по морю '.encode('utf-8')) + file.write(u'гуляет...'.encode('utf-8')) + + kwargs = mock_method.call_args.kwargs + self.assertEqual(kwargs['bucketName'], self.bucket_id) + self.assertEqual(kwargs['objectKey'], self.object_key) + self.assertEqual(kwargs['headers']['contentType'], 'text/plain') + self.assertEqual(kwargs['content'].read(), self.texst_text.encode('utf-8')) + + def function(a, b, c, foo='bar', baz='boz'): pass @@ -1913,12 +1974,12 @@ def test_get_binary_mode(mode, expected): @pytest.mark.parametrize( 'mode', [ - ('rw', ), - ('rwa', ), - ('rbt', ), - ('r++', ), - ('+', ), - ('x', ), + ('rw',), + ('rwa',), + ('rbt',), + ('r++',), + ('+',), + ('x',), ] ) def test_get_binary_mode_bad(mode): diff --git a/smart_open/tests/test_utils.py b/smart_open/tests/test_utils.py index c6be9a2d..37610604 100644 --- a/smart_open/tests/test_utils.py +++ b/smart_open/tests/test_utils.py @@ -59,3 +59,12 @@ def test_check_kwargs(): def test_safe_urlsplit(url, expected): actual = smart_open.utils.safe_urlsplit(url) assert actual == urllib.parse.SplitResult(*expected) + + +def test_save_defaults(): + current = {'key1': 1, 'key2': 2} + defaults = {'key1': 1, 'key3': 3} + expected = {'key1': 1, 'key2': 2, 'key3': 3} + smart_open.utils.set_defaults(current, defaults) + assert len(expected) == len(current) + assert all((current.get(k) == v for k, v in expected.items())) diff --git a/smart_open/transport.py b/smart_open/transport.py index 086ea2b0..229db169 100644 --- a/smart_open/transport.py +++ b/smart_open/transport.py @@ -104,6 +104,7 @@ def get_transport(scheme): register_transport("smart_open.s3") register_transport("smart_open.ssh") register_transport("smart_open.webhdfs") +register_transport("smart_open.obs") SUPPORTED_SCHEMES = tuple(sorted(_REGISTRY.keys())) """The transport schemes that the local installation of ``smart_open`` supports.""" diff --git a/smart_open/utils.py b/smart_open/utils.py index 2be57d19..fc5279f0 100644 --- a/smart_open/utils.py +++ b/smart_open/utils.py @@ -221,3 +221,19 @@ def __exit__(self, *args, **kwargs): def __next__(self): return self.__wrapped__.__next__() + + +def set_defaults(first: dict, second: dict): + """Sets the values in the first dictionary from the second dictionary, + preserving the existing values in the first. + + Parameters + ---------- + first: dict + The dict that will be updated. + second: dict + The dict with default values. + """ + for key, val in second.items(): + if key not in first: + first[key] = val