From 468c6d9745044d51704216a4e44275d36065fcfc Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Tue, 24 Oct 2023 18:04:42 +0200 Subject: [PATCH] fix 'DataFrame.__reduce__' Signed-off-by: Anatoly Myachev --- modin/pandas/dataframe.py | 29 ++++++++++++++++++---- modin/pandas/series.py | 4 ++- modin/pandas/test/dataframe/test_pickle.py | 29 +++++++++++++++++++++- 3 files changed, 55 insertions(+), 7 deletions(-) diff --git a/modin/pandas/dataframe.py b/modin/pandas/dataframe.py index aa1fc8841ed..a57647cfa13 100644 --- a/modin/pandas/dataframe.py +++ b/modin/pandas/dataframe.py @@ -18,6 +18,7 @@ import datetime import functools import itertools +import os import re import sys import warnings @@ -3104,7 +3105,7 @@ def _getitem(self, key): # Persistance support methods - BEGIN @classmethod - def _inflate_light(cls, query_compiler): + def _inflate_light(cls, query_compiler, source_pid): """ Re-creates the object from previously-serialized lightweight representation. @@ -3114,16 +3115,23 @@ def _inflate_light(cls, query_compiler): ---------- query_compiler : BaseQueryCompiler Query compiler to use for object re-creation. + source_pid : int + Determines whether a Modin or pandas object needs to be created. + Modin objects are created only on the main process. Returns ------- DataFrame New ``DataFrame`` based on the `query_compiler`. """ + if os.getpid() != source_pid: + return query_compiler.to_pandas() + # The current logic does not involve creating Modin objects + # and manipulation with them in worker processes return cls(query_compiler=query_compiler) @classmethod - def _inflate_full(cls, pandas_df): + def _inflate_full(cls, pandas_df, source_pid): """ Re-creates the object from previously-serialized disk-storable representation. @@ -3131,18 +3139,29 @@ def _inflate_full(cls, pandas_df): ---------- pandas_df : pandas.DataFrame Data to use for object re-creation. + source_pid : int + Determines whether a Modin or pandas object needs to be created. + Modin objects are created only on the main process. Returns ------- DataFrame New ``DataFrame`` based on the `pandas_df`. """ + if os.getpid() != source_pid: + return pandas_df + # The current logic does not involve creating Modin objects + # and manipulation with them in worker processes return cls(data=from_pandas(pandas_df)) def __reduce__(self): self._query_compiler.finalize() - if PersistentPickle.get(): - return self._inflate_full, (self._to_pandas(),) - return self._inflate_light, (self._query_compiler,) + pid = os.getpid() + if ( + PersistentPickle.get() + or not self._query_compiler.support_materialization_in_worker_process() + ): + return self._inflate_full, (self._to_pandas(), pid) + return self._inflate_light, (self._query_compiler, pid) # Persistance support methods - END diff --git a/modin/pandas/series.py b/modin/pandas/series.py index 98933d6f6f0..4964cf44622 100644 --- a/modin/pandas/series.py +++ b/modin/pandas/series.py @@ -2533,7 +2533,7 @@ def _inflate_full(cls, pandas_series, source_pid): pandas_series : pandas.Series Data to use for object re-creation. source_pid : int - Determines whether a Modin or Pandas object needs to be created. + Determines whether a Modin or pandas object needs to be created. Modin objects are created only on the main process. Returns @@ -2543,6 +2543,8 @@ def _inflate_full(cls, pandas_series, source_pid): """ if os.getpid() != source_pid: return pandas_series + # The current logic does not involve creating Modin objects + # and manipulation with them in worker processes return cls(data=pandas_series) def __reduce__(self): diff --git a/modin/pandas/test/dataframe/test_pickle.py b/modin/pandas/test/dataframe/test_pickle.py index a3ee843daa3..60dfd78f7fa 100644 --- a/modin/pandas/test/dataframe/test_pickle.py +++ b/modin/pandas/test/dataframe/test_pickle.py @@ -18,7 +18,7 @@ import modin.pandas as pd from modin.config import PersistentPickle -from modin.pandas.test.utils import df_equals +from modin.pandas.test.utils import create_test_dfs, df_equals @pytest.fixture @@ -47,6 +47,33 @@ def test_dataframe_pickle(modin_df, persistent): df_equals(modin_df, other) +def test__reduce__(): + # `Series.__reduce__` will be called implicitly when lambda expressions are + # pre-processed for the distributed engine. + dataframe_data = ["Major League Baseball", "National Basketball Association"] + abbr_md, abbr_pd = create_test_dfs(dataframe_data, index=["MLB", "NBA"]) + # breakpoint() + + dataframe_data = { + "name": ["Mariners", "Lakers"] * 500, + "league_abbreviation": ["MLB", "NBA"] * 500, + } + teams_md, teams_pd = create_test_dfs(dataframe_data) + + result_md = ( + teams_md.set_index("name") + .league_abbreviation.apply(lambda abbr: abbr_md[0].loc[abbr]) + .rename("league") + ) + + result_pd = ( + teams_pd.set_index("name") + .league_abbreviation.apply(lambda abbr: abbr_pd[0].loc[abbr]) + .rename("league") + ) + df_equals(result_md, result_pd) + + def test_column_pickle(modin_column, modin_df, persistent): dmp = pickle.dumps(modin_column) other = pickle.loads(dmp)