diff --git a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py index ed4388aeb3..798d3a8e6f 100644 --- a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py @@ -3,6 +3,7 @@ from typing import Callable, List, Literal, Optional, Sequence, Union, cast import dill +import pandas import pandas as pd import pyarrow from tqdm import tqdm @@ -178,9 +179,9 @@ def _materialize_one( self.repo_config.batch_engine.partitions ) - spark_df.foreachPartition( - lambda x: _process_by_partition(x, spark_serialized_artifacts) - ) + spark_df.mapInPandas( + lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int" + ).count() # dummy action to force evaluation return SparkMaterializationJob( job_id=job_id, status=MaterializationJobStatus.SUCCEEDED @@ -225,38 +226,40 @@ def unserialize(self): return feature_view, online_store, repo_config -def _process_by_partition(rows, spark_serialized_artifacts: _SparkSerializedArtifacts): - """Load pandas df to online store""" - - # convert to pyarrow table - dicts = [] - for row in rows: - dicts.append(row.asDict()) +def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArtifacts): + for pdf in iterator: + if pdf.shape[0] == 0: + print("Skipping") + return - df = pd.DataFrame.from_records(dicts) - if df.shape[0] == 0: - print("Skipping") - return + table = pyarrow.Table.from_pandas(pdf) - table = pyarrow.Table.from_pandas(df) + ( + feature_view, + online_store, + repo_config, + ) = spark_serialized_artifacts.unserialize() + + if feature_view.batch_source.field_mapping is not None: + table = _run_pyarrow_field_mapping( + table, feature_view.batch_source.field_mapping + ) - # unserialize artifacts - feature_view, online_store, repo_config = spark_serialized_artifacts.unserialize() + join_key_to_value_type = { + entity.name: entity.dtype.to_value_type() + for entity in feature_view.entity_columns + } - if feature_view.batch_source.field_mapping is not None: - table = _run_pyarrow_field_mapping( - table, feature_view.batch_source.field_mapping + rows_to_write = _convert_arrow_to_proto( + table, feature_view, join_key_to_value_type + ) + online_store.online_write_batch( + repo_config, + feature_view, + rows_to_write, + lambda x: None, ) - join_key_to_value_type = { - entity.name: entity.dtype.to_value_type() - for entity in feature_view.entity_columns - } - - rows_to_write = _convert_arrow_to_proto(table, feature_view, join_key_to_value_type) - online_store.online_write_batch( - repo_config, - feature_view, - rows_to_write, - lambda x: None, - ) + yield pd.DataFrame( + [pd.Series(range(1, 2))] + ) # dummy result because mapInPandas needs to return something