From 4c161332eb39d9542618dbf54da869ed4eb1810a Mon Sep 17 00:00:00 2001 From: tokoko Date: Tue, 6 Feb 2024 21:09:33 +0000 Subject: [PATCH] fix formatting for spark materilization engine Signed-off-by: tokoko --- .../spark/spark_materialization_engine.py | 20 ++++++++++++------- setup.py | 2 +- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py index 8bec8eab4a..798d3a8e6f 100644 --- a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py @@ -180,9 +180,8 @@ def _materialize_one( ) spark_df.mapInPandas( - lambda x: _map_by_partition(x, spark_serialized_artifacts), - "status int" - ).count() + lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int" + ).count() # dummy action to force evaluation return SparkMaterializationJob( job_id=job_id, status=MaterializationJobStatus.SUCCEEDED @@ -235,8 +234,11 @@ def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArti table = pyarrow.Table.from_pandas(pdf) - # unserialize artifacts - feature_view, online_store, repo_config = spark_serialized_artifacts.unserialize() + ( + feature_view, + online_store, + repo_config, + ) = spark_serialized_artifacts.unserialize() if feature_view.batch_source.field_mapping is not None: table = _run_pyarrow_field_mapping( @@ -248,7 +250,9 @@ def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArti for entity in feature_view.entity_columns } - rows_to_write = _convert_arrow_to_proto(table, feature_view, join_key_to_value_type) + rows_to_write = _convert_arrow_to_proto( + table, feature_view, join_key_to_value_type + ) online_store.online_write_batch( repo_config, feature_view, @@ -256,4 +260,6 @@ def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArti lambda x: None, ) - yield pd.DataFrame([pd.Series(range(1, 2))]) # dummy result because mapInPandas needs to return something + yield pd.DataFrame( + [pd.Series(range(1, 2))] + ) # dummy result because mapInPandas needs to return something diff --git a/setup.py b/setup.py index 4905a7697d..4901967329 100644 --- a/setup.py +++ b/setup.py @@ -155,7 +155,7 @@ "grpcio-testing>=1.56.2,<2", "minio==7.1.0", "mock==2.0.0", - "moto", + "moto<5", "mypy>=0.981,<0.990", "avro==1.10.0", "fsspec<2023.10.0",