diff --git a/CHANGELOG.md b/CHANGELOG.md index b3848309..aace28d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,13 @@ release (0.9.0), unrecognized arguments/keywords for these methods of creating a instead of being passed as ClickHouse server settings. This is in conjunction with some refactoring in Client construction. The supported method of passing ClickHouse server settings is to prefix such arguments/query parameters with`ch_`. +## 0.8.2, 2024-10-04 +### Bug Fix +- Ensure lz4 compression does not exit on an empty block. May fix https://github.com/ClickHouse/clickhouse-connect/issues/403. + +### Improvement +- Compress Arrow inserts (using pyarrow compression) if compression is set to `lz4` or `zstd`. Closes https://github.com/ClickHouse/clickhouse-connect/issues/267. + ## 0.8.1, 2024-09-29 ### Bug Fix - Fixed an edge case where the HTTP buffer could theoretically return empty blocks. diff --git a/clickhouse_connect/__version__.py b/clickhouse_connect/__version__.py index 398cfc4c..c400d681 100644 --- a/clickhouse_connect/__version__.py +++ b/clickhouse_connect/__version__.py @@ -1 +1 @@ -version = '0.8.1' +version = '0.8.2' diff --git a/clickhouse_connect/driver/client.py b/clickhouse_connect/driver/client.py index fe11c278..d6b84885 100644 --- a/clickhouse_connect/driver/client.py +++ b/clickhouse_connect/driver/client.py @@ -59,9 +59,15 @@ def __init__(self, """ self.query_limit = coerce_int(query_limit) self.query_retries = coerce_int(query_retries) + if database and not database == '__default__': + self.database = database if show_clickhouse_errors is not None: self.show_clickhouse_errors = coerce_bool(show_clickhouse_errors) self.server_host_name = server_host_name + self.uri = uri + self._init_common_settings(apply_server_timezone) + + def _init_common_settings(self, apply_server_timezone:Optional[Union[str, bool]] ): self.server_tz, dst_safe = pytz.UTC, True self.server_version, server_tz = \ tuple(self.command('SELECT version(), timezone()', use_database=False)) @@ -83,8 +89,7 @@ def __init__(self, readonly = common.get_setting('readonly') server_settings = self.query(f'SELECT name, value, {readonly} as readonly FROM system.settings LIMIT 10000') self.server_settings = {row['name']: SettingDef(**row) for row in server_settings.named_results()} - if database and not database == '__default__': - self.database = database + if self.min_version(CH_VERSION_WITH_PROTOCOL): # Unfortunately we have to validate that the client protocol version is actually used by ClickHouse # since the query parameter could be stripped off (in particular, by CHProxy) @@ -95,7 +100,9 @@ def __init__(self, self.protocol_version = PROTOCOL_VERSION_WITH_LOW_CARD if self._setting_status('date_time_input_format').is_writable: self.set_client_setting('date_time_input_format', 'best_effort') - self.uri = uri + if self._setting_status('allow_experimental_json_type').is_set: + self.set_client_setting('cast_string_to_dynamic_use_inference', '1') + def _validate_settings(self, settings: Optional[Dict[str, Any]]) -> Dict[str, str]: """ @@ -655,7 +662,8 @@ def insert_df(self, table: str = None, settings=settings, context=context) def insert_arrow(self, table: str, - arrow_table, database: str = None, + arrow_table, + database: str = None, settings: Optional[Dict] = None) -> QuerySummary: """ Insert a PyArrow table DataFrame into ClickHouse using raw Arrow format @@ -666,7 +674,8 @@ def insert_arrow(self, table: str, :return: QuerySummary with summary information, throws exception if insert fails """ full_table = table if '.' in table or not database else f'{database}.{table}' - column_names, insert_block = arrow_buffer(arrow_table) + compression = self.write_compression if self.write_compression in ('zstd', 'lz4') else None + column_names, insert_block = arrow_buffer(arrow_table, compression) return self.raw_insert(full_table, column_names, insert_block, settings, 'Arrow') def create_insert_context(self, diff --git a/clickhouse_connect/driver/httputil.py b/clickhouse_connect/driver/httputil.py index 58b5460a..558d66f6 100644 --- a/clickhouse_connect/driver/httputil.py +++ b/clickhouse_connect/driver/httputil.py @@ -244,7 +244,8 @@ def buffered(): else: chunk = chunks.popleft() current_size -= len(chunk) - yield chunk + if chunk: + yield chunk self.gen = buffered() diff --git a/clickhouse_connect/driver/query.py b/clickhouse_connect/driver/query.py index 54edbeff..bd10270e 100644 --- a/clickhouse_connect/driver/query.py +++ b/clickhouse_connect/driver/query.py @@ -374,9 +374,12 @@ def to_arrow_batches(buffer: IOBase) -> StreamContext: return StreamContext(buffer, reader) -def arrow_buffer(table) -> Tuple[Sequence[str], bytes]: +def arrow_buffer(table, compression: Optional[str] = None) -> Tuple[Sequence[str], bytes]: pyarrow = check_arrow() + options = None + if compression in ('zstd', 'lz4'): + options = pyarrow.ipc.IpcWriteOptions(compression=pyarrow.Codec(compression=compression)) sink = pyarrow.BufferOutputStream() - with pyarrow.RecordBatchFileWriter(sink, table.schema) as writer: + with pyarrow.RecordBatchFileWriter(sink, table.schema, options=options) as writer: writer.write(table) return table.schema.names, sink.getvalue() diff --git a/tests/integration_tests/test_arrow.py b/tests/integration_tests/test_arrow.py index c7f80ff5..aa681cbe 100644 --- a/tests/integration_tests/test_arrow.py +++ b/tests/integration_tests/test_arrow.py @@ -14,8 +14,8 @@ def test_arrow(test_client: Client, table_context: Callable): if not test_client.min_version('21'): pytest.skip(f'PyArrow is not supported in this server version {test_client.server_version}') with table_context('test_arrow_insert', ['animal String', 'legs Int64']): - n_legs = arrow.array([2, 4, 5, 100]) - animals = arrow.array(['Flamingo', 'Horse', 'Brittle stars', 'Centipede']) + n_legs = arrow.array([2, 4, 5, 100] * 50) + animals = arrow.array(['Flamingo', 'Horse', 'Brittle stars', 'Centipede'] * 50) names = ['legs', 'animal'] insert_table = arrow.Table.from_arrays([n_legs, animals], names=names) test_client.insert_arrow('test_arrow_insert', insert_table) @@ -26,7 +26,7 @@ def test_arrow(test_client: Client, table_context: Callable): assert arrow_schema.field(1).name == 'legs' assert arrow_schema.field(1).type == arrow.int64() # pylint: disable=no-member - assert arrow.compute.sum(result_table['legs']).as_py() == 111 + assert arrow.compute.sum(result_table['legs']).as_py() == 5550 assert len(result_table.columns) == 2 arrow_table = test_client.query_arrow('SELECT number from system.numbers LIMIT 500',