Skip to content

Commit

Permalink
Arrow compression plus another http buffer tweak (#404)
Browse files Browse the repository at this point in the history
  • Loading branch information
genzgd authored Oct 4, 2024
1 parent 5d85563 commit 5b13e6b
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 12 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ release (0.9.0), unrecognized arguments/keywords for these methods of creating a
instead of being passed as ClickHouse server settings. This is in conjunction with some refactoring in Client construction.
The supported method of passing ClickHouse server settings is to prefix such arguments/query parameters with`ch_`.

## 0.8.2, 2024-10-04
### Bug Fix
- Ensure lz4 compression does not exit on an empty block. May fix https://github.com/ClickHouse/clickhouse-connect/issues/403.

### Improvement
- Compress Arrow inserts (using pyarrow compression) if compression is set to `lz4` or `zstd`. Closes https://github.com/ClickHouse/clickhouse-connect/issues/267.

## 0.8.1, 2024-09-29
### Bug Fix
- Fixed an edge case where the HTTP buffer could theoretically return empty blocks.
Expand Down
2 changes: 1 addition & 1 deletion clickhouse_connect/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = '0.8.1'
version = '0.8.2'
19 changes: 14 additions & 5 deletions clickhouse_connect/driver/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,15 @@ def __init__(self,
"""
self.query_limit = coerce_int(query_limit)
self.query_retries = coerce_int(query_retries)
if database and not database == '__default__':
self.database = database
if show_clickhouse_errors is not None:
self.show_clickhouse_errors = coerce_bool(show_clickhouse_errors)
self.server_host_name = server_host_name
self.uri = uri
self._init_common_settings(apply_server_timezone)

def _init_common_settings(self, apply_server_timezone:Optional[Union[str, bool]] ):
self.server_tz, dst_safe = pytz.UTC, True
self.server_version, server_tz = \
tuple(self.command('SELECT version(), timezone()', use_database=False))
Expand All @@ -83,8 +89,7 @@ def __init__(self,
readonly = common.get_setting('readonly')
server_settings = self.query(f'SELECT name, value, {readonly} as readonly FROM system.settings LIMIT 10000')
self.server_settings = {row['name']: SettingDef(**row) for row in server_settings.named_results()}
if database and not database == '__default__':
self.database = database

if self.min_version(CH_VERSION_WITH_PROTOCOL):
# Unfortunately we have to validate that the client protocol version is actually used by ClickHouse
# since the query parameter could be stripped off (in particular, by CHProxy)
Expand All @@ -95,7 +100,9 @@ def __init__(self,
self.protocol_version = PROTOCOL_VERSION_WITH_LOW_CARD
if self._setting_status('date_time_input_format').is_writable:
self.set_client_setting('date_time_input_format', 'best_effort')
self.uri = uri
if self._setting_status('allow_experimental_json_type').is_set:
self.set_client_setting('cast_string_to_dynamic_use_inference', '1')


def _validate_settings(self, settings: Optional[Dict[str, Any]]) -> Dict[str, str]:
"""
Expand Down Expand Up @@ -655,7 +662,8 @@ def insert_df(self, table: str = None,
settings=settings, context=context)

def insert_arrow(self, table: str,
arrow_table, database: str = None,
arrow_table,
database: str = None,
settings: Optional[Dict] = None) -> QuerySummary:
"""
Insert a PyArrow table DataFrame into ClickHouse using raw Arrow format
Expand All @@ -666,7 +674,8 @@ def insert_arrow(self, table: str,
:return: QuerySummary with summary information, throws exception if insert fails
"""
full_table = table if '.' in table or not database else f'{database}.{table}'
column_names, insert_block = arrow_buffer(arrow_table)
compression = self.write_compression if self.write_compression in ('zstd', 'lz4') else None
column_names, insert_block = arrow_buffer(arrow_table, compression)
return self.raw_insert(full_table, column_names, insert_block, settings, 'Arrow')

def create_insert_context(self,
Expand Down
3 changes: 2 additions & 1 deletion clickhouse_connect/driver/httputil.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ def buffered():
else:
chunk = chunks.popleft()
current_size -= len(chunk)
yield chunk
if chunk:
yield chunk

self.gen = buffered()

Expand Down
7 changes: 5 additions & 2 deletions clickhouse_connect/driver/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,12 @@ def to_arrow_batches(buffer: IOBase) -> StreamContext:
return StreamContext(buffer, reader)


def arrow_buffer(table) -> Tuple[Sequence[str], bytes]:
def arrow_buffer(table, compression: Optional[str] = None) -> Tuple[Sequence[str], bytes]:
pyarrow = check_arrow()
options = None
if compression in ('zstd', 'lz4'):
options = pyarrow.ipc.IpcWriteOptions(compression=pyarrow.Codec(compression=compression))
sink = pyarrow.BufferOutputStream()
with pyarrow.RecordBatchFileWriter(sink, table.schema) as writer:
with pyarrow.RecordBatchFileWriter(sink, table.schema, options=options) as writer:
writer.write(table)
return table.schema.names, sink.getvalue()
6 changes: 3 additions & 3 deletions tests/integration_tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def test_arrow(test_client: Client, table_context: Callable):
if not test_client.min_version('21'):
pytest.skip(f'PyArrow is not supported in this server version {test_client.server_version}')
with table_context('test_arrow_insert', ['animal String', 'legs Int64']):
n_legs = arrow.array([2, 4, 5, 100])
animals = arrow.array(['Flamingo', 'Horse', 'Brittle stars', 'Centipede'])
n_legs = arrow.array([2, 4, 5, 100] * 50)
animals = arrow.array(['Flamingo', 'Horse', 'Brittle stars', 'Centipede'] * 50)
names = ['legs', 'animal']
insert_table = arrow.Table.from_arrays([n_legs, animals], names=names)
test_client.insert_arrow('test_arrow_insert', insert_table)
Expand All @@ -26,7 +26,7 @@ def test_arrow(test_client: Client, table_context: Callable):
assert arrow_schema.field(1).name == 'legs'
assert arrow_schema.field(1).type == arrow.int64()
# pylint: disable=no-member
assert arrow.compute.sum(result_table['legs']).as_py() == 111
assert arrow.compute.sum(result_table['legs']).as_py() == 5550
assert len(result_table.columns) == 2

arrow_table = test_client.query_arrow('SELECT number from system.numbers LIMIT 500',
Expand Down

0 comments on commit 5b13e6b

Please sign in to comment.