diff --git a/materializationengine/blueprints/client/api2.py b/materializationengine/blueprints/client/api2.py index 18a44704..46ff6d3d 100644 --- a/materializationengine/blueprints/client/api2.py +++ b/materializationengine/blueprints/client/api2.py @@ -84,6 +84,14 @@ "(faster), false returns json" ), ) +query_parser.add_argument( + "arrow_format", + type=inputs.boolean, + default=False, + required=False, + location="args", + help=("whether to convert dataframe to pyarrow ipc batch format"), +) query_parser.add_argument( "split_positions", type=inputs.boolean, @@ -121,6 +129,7 @@ not be relevant and the user might not be getting data back that they expect, but it will not error.", ) + @cached(cache=TTLCache(maxsize=64, ttl=600)) def get_relevant_datastack_info(datastack_name): ds_info = get_datastack_info(datastack_name=datastack_name) @@ -1004,7 +1013,10 @@ def post(self, datastack_name: str): user_data["desired_resolution"] = des_res modified_user_data, query_map, remap_warnings = remap_query( - user_data, chosen_timestamp, cg_client, allow_invalid_root_ids, + user_data, + chosen_timestamp, + cg_client, + allow_invalid_root_ids, ) mat_df, column_names, mat_warnings = execute_materialized_query( @@ -1044,6 +1056,7 @@ def post(self, datastack_name: str): column_names=column_names, desired_resolution=user_data["desired_resolution"], return_pyarrow=args["return_pyarrow"], + arrow_format=args["arrow_format"], ) @@ -1273,6 +1286,7 @@ def post( column_names=column_names, desired_resolution=data["desired_resolution"], return_pyarrow=args["return_pyarrow"], + arrow_format=args["arrow_format"], ) diff --git a/materializationengine/blueprints/client/common.py b/materializationengine/blueprints/client/common.py index 1db92c9d..401bb412 100644 --- a/materializationengine/blueprints/client/common.py +++ b/materializationengine/blueprints/client/common.py @@ -240,6 +240,7 @@ def handle_simple_query( column_names=column_names, desired_resolution=data["desired_resolution"], return_pyarrow=args["return_pyarrow"], + arrow_format=args['arrow_format'] ) @@ -393,4 +394,5 @@ def handle_complex_query( column_names=column_names, desired_resolution=data["desired_resolution"], return_pyarrow=args["return_pyarrow"], + arrow_format=args['arrow_format'] ) diff --git a/materializationengine/blueprints/client/utils.py b/materializationengine/blueprints/client/utils.py index 3c0aafce..1bc3c7e1 100644 --- a/materializationengine/blueprints/client/utils.py +++ b/materializationengine/blueprints/client/utils.py @@ -1,6 +1,7 @@ import pyarrow as pa -from flask import Response, request +from flask import Response, request, send_file from cloudfiles import compression +from io import BytesIO def collect_crud_columns(column_names): @@ -61,17 +62,45 @@ def update_notice_text_warnings(ann_md, warnings, table_name): def create_query_response( - df, warnings, desired_resolution, column_names, return_pyarrow=True + df, + warnings, + desired_resolution, + column_names, + return_pyarrow=True, + arrow_format=False, ): + accept_encoding = request.headers.get("Accept-Encoding", "") headers = add_warnings_to_headers({}, warnings) if desired_resolution is not None: headers["dataframe_resolution"] = desired_resolution headers["column_names"] = column_names if return_pyarrow: + if arrow_format: + batch = pa.RecordBatch.from_pandas(df) + sink = pa.BufferOutputStream() + if "lz4" in accept_encoding: + compression = "LZ4_FRAME" + elif "zstd" in accept_encoding: + compression = "ZSTD" + else: + compression = None + opt = pa.ipc.IpcWriteOptions(compression=compression) + with pa.ipc.new_stream(sink, batch.schema, options=opt) as writer: + writer.write_batch(batch) + response = send_file(BytesIO(sink.getvalue().to_pybytes()), "data.arrow") + response.headers.update(headers) + return after_request(response) + # headers = add_warnings_to_headers( + # headers, + # [ + # "Using deprecated pyarrow serialization method, please upgrade CAVEClient with pip install --upgrade caveclient" + # ], + # ) context = pa.default_serialization_context() serialized = context.serialize(df).to_buffer().to_pybytes() - return Response(serialized, headers=headers, mimetype="x-application/pyarrow") + response = Response(serialized, headers=headers, mimetype="x-application/pyarrow") + return after_request(response) else: dfjson = df.to_json(orient="records") response = Response(dfjson, headers=headers, mimetype="application/json")