Skip to content

Commit

Permalink
Pyarrow serialize (#121)
Browse files Browse the repository at this point in the history
* adding arrow format option

* adding arrow format option

* adding compression

* removing 64bit

* modifying compression

* making compression optional

* adding zstd option and fixing uncompressed

* fixing headers

* formatting

* fixing uncompressed

* fixing headers

* fixing bug in compression

* adding gzip compression
  • Loading branch information
fcollman authored Aug 10, 2023
1 parent 17c60d9 commit 887da5b
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 4 deletions.
16 changes: 15 additions & 1 deletion materializationengine/blueprints/client/api2.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@
"(faster), false returns json"
),
)
query_parser.add_argument(
"arrow_format",
type=inputs.boolean,
default=False,
required=False,
location="args",
help=("whether to convert dataframe to pyarrow ipc batch format"),
)
query_parser.add_argument(
"split_positions",
type=inputs.boolean,
Expand Down Expand Up @@ -121,6 +129,7 @@
not be relevant and the user might not be getting data back that they expect, but it will not error.",
)


@cached(cache=TTLCache(maxsize=64, ttl=600))
def get_relevant_datastack_info(datastack_name):
ds_info = get_datastack_info(datastack_name=datastack_name)
Expand Down Expand Up @@ -1004,7 +1013,10 @@ def post(self, datastack_name: str):
user_data["desired_resolution"] = des_res

modified_user_data, query_map, remap_warnings = remap_query(
user_data, chosen_timestamp, cg_client, allow_invalid_root_ids,
user_data,
chosen_timestamp,
cg_client,
allow_invalid_root_ids,
)

mat_df, column_names, mat_warnings = execute_materialized_query(
Expand Down Expand Up @@ -1044,6 +1056,7 @@ def post(self, datastack_name: str):
column_names=column_names,
desired_resolution=user_data["desired_resolution"],
return_pyarrow=args["return_pyarrow"],
arrow_format=args["arrow_format"],
)


Expand Down Expand Up @@ -1273,6 +1286,7 @@ def post(
column_names=column_names,
desired_resolution=data["desired_resolution"],
return_pyarrow=args["return_pyarrow"],
arrow_format=args["arrow_format"],
)


Expand Down
2 changes: 2 additions & 0 deletions materializationengine/blueprints/client/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def handle_simple_query(
column_names=column_names,
desired_resolution=data["desired_resolution"],
return_pyarrow=args["return_pyarrow"],
arrow_format=args['arrow_format']
)


Expand Down Expand Up @@ -393,4 +394,5 @@ def handle_complex_query(
column_names=column_names,
desired_resolution=data["desired_resolution"],
return_pyarrow=args["return_pyarrow"],
arrow_format=args['arrow_format']
)
35 changes: 32 additions & 3 deletions materializationengine/blueprints/client/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pyarrow as pa
from flask import Response, request
from flask import Response, request, send_file
from cloudfiles import compression
from io import BytesIO


def collect_crud_columns(column_names):
Expand Down Expand Up @@ -61,17 +62,45 @@ def update_notice_text_warnings(ann_md, warnings, table_name):


def create_query_response(
df, warnings, desired_resolution, column_names, return_pyarrow=True
df,
warnings,
desired_resolution,
column_names,
return_pyarrow=True,
arrow_format=False,
):
accept_encoding = request.headers.get("Accept-Encoding", "")

headers = add_warnings_to_headers({}, warnings)
if desired_resolution is not None:
headers["dataframe_resolution"] = desired_resolution
headers["column_names"] = column_names
if return_pyarrow:
if arrow_format:
batch = pa.RecordBatch.from_pandas(df)
sink = pa.BufferOutputStream()
if "lz4" in accept_encoding:
compression = "LZ4_FRAME"
elif "zstd" in accept_encoding:
compression = "ZSTD"
else:
compression = None
opt = pa.ipc.IpcWriteOptions(compression=compression)
with pa.ipc.new_stream(sink, batch.schema, options=opt) as writer:
writer.write_batch(batch)
response = send_file(BytesIO(sink.getvalue().to_pybytes()), "data.arrow")
response.headers.update(headers)
return after_request(response)
# headers = add_warnings_to_headers(
# headers,
# [
# "Using deprecated pyarrow serialization method, please upgrade CAVEClient with pip install --upgrade caveclient"
# ],
# )
context = pa.default_serialization_context()
serialized = context.serialize(df).to_buffer().to_pybytes()
return Response(serialized, headers=headers, mimetype="x-application/pyarrow")
response = Response(serialized, headers=headers, mimetype="x-application/pyarrow")
return after_request(response)
else:
dfjson = df.to_json(orient="records")
response = Response(dfjson, headers=headers, mimetype="application/json")
Expand Down

0 comments on commit 887da5b

Please sign in to comment.