Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyarrow serialize #121

Merged
merged 13 commits into from
Aug 10, 2023
16 changes: 15 additions & 1 deletion materializationengine/blueprints/client/api2.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@
"(faster), false returns json"
),
)
query_parser.add_argument(
"arrow_format",
type=inputs.boolean,
default=False,
required=False,
location="args",
help=("whether to convert dataframe to pyarrow ipc batch format"),
)
query_parser.add_argument(
"split_positions",
type=inputs.boolean,
Expand Down Expand Up @@ -121,6 +129,7 @@
not be relevant and the user might not be getting data back that they expect, but it will not error.",
)


@cached(cache=TTLCache(maxsize=64, ttl=600))
def get_relevant_datastack_info(datastack_name):
ds_info = get_datastack_info(datastack_name=datastack_name)
Expand Down Expand Up @@ -1004,7 +1013,10 @@ def post(self, datastack_name: str):
user_data["desired_resolution"] = des_res

modified_user_data, query_map, remap_warnings = remap_query(
user_data, chosen_timestamp, cg_client, allow_invalid_root_ids,
user_data,
chosen_timestamp,
cg_client,
allow_invalid_root_ids,
)

mat_df, column_names, mat_warnings = execute_materialized_query(
Expand Down Expand Up @@ -1044,6 +1056,7 @@ def post(self, datastack_name: str):
column_names=column_names,
desired_resolution=user_data["desired_resolution"],
return_pyarrow=args["return_pyarrow"],
arrow_format=args["arrow_format"],
)


Expand Down Expand Up @@ -1273,6 +1286,7 @@ def post(
column_names=column_names,
desired_resolution=data["desired_resolution"],
return_pyarrow=args["return_pyarrow"],
arrow_format=args["arrow_format"],
)


Expand Down
2 changes: 2 additions & 0 deletions materializationengine/blueprints/client/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def handle_simple_query(
column_names=column_names,
desired_resolution=data["desired_resolution"],
return_pyarrow=args["return_pyarrow"],
arrow_format=args['arrow_format']
)


Expand Down Expand Up @@ -393,4 +394,5 @@ def handle_complex_query(
column_names=column_names,
desired_resolution=data["desired_resolution"],
return_pyarrow=args["return_pyarrow"],
arrow_format=args['arrow_format']
)
35 changes: 32 additions & 3 deletions materializationengine/blueprints/client/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pyarrow as pa
from flask import Response, request
from flask import Response, request, send_file
from cloudfiles import compression
from io import BytesIO


def collect_crud_columns(column_names):
Expand Down Expand Up @@ -61,17 +62,45 @@ def update_notice_text_warnings(ann_md, warnings, table_name):


def create_query_response(
df, warnings, desired_resolution, column_names, return_pyarrow=True
df,
warnings,
desired_resolution,
column_names,
return_pyarrow=True,
arrow_format=False,
):
accept_encoding = request.headers.get("Accept-Encoding", "")

headers = add_warnings_to_headers({}, warnings)
if desired_resolution is not None:
headers["dataframe_resolution"] = desired_resolution
headers["column_names"] = column_names
if return_pyarrow:
if arrow_format:
batch = pa.RecordBatch.from_pandas(df)
sink = pa.BufferOutputStream()
if "lz4" in accept_encoding:
compression = "LZ4_FRAME"
elif "zstd" in accept_encoding:
compression = "ZSTD"
else:
compression = None
opt = pa.ipc.IpcWriteOptions(compression=compression)
with pa.ipc.new_stream(sink, batch.schema, options=opt) as writer:
writer.write_batch(batch)
response = send_file(BytesIO(sink.getvalue().to_pybytes()), "data.arrow")
response.headers.update(headers)
return after_request(response)
# headers = add_warnings_to_headers(
# headers,
# [
# "Using deprecated pyarrow serialization method, please upgrade CAVEClient with pip install --upgrade caveclient"
# ],
# )
context = pa.default_serialization_context()
serialized = context.serialize(df).to_buffer().to_pybytes()
return Response(serialized, headers=headers, mimetype="x-application/pyarrow")
response = Response(serialized, headers=headers, mimetype="x-application/pyarrow")
return after_request(response)
else:
dfjson = df.to_json(orient="records")
response = Response(dfjson, headers=headers, mimetype="application/json")
Expand Down
Loading