From 33818d3b599e0a26369f7009e2093c639884d9da Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Tue, 5 Dec 2023 18:44:31 +0100 Subject: [PATCH] fixes Signed-off-by: Anatoly Myachev --- .../implementations/pandas_on_dask/io/io.py | 4 +- .../dispatching/factories/factories.py | 58 +++++++++---------- .../implementations/pandas_on_ray/io/io.py | 4 +- .../pandas_on_unidist/io/io.py | 4 +- .../core/io/sql/sql_dispatcher.py | 2 +- 5 files changed, 36 insertions(+), 36 deletions(-) diff --git a/modin/core/execution/dask/implementations/pandas_on_dask/io/io.py b/modin/core/execution/dask/implementations/pandas_on_dask/io/io.py index 1c8f8e5982f..090bbcffaee 100644 --- a/modin/core/execution/dask/implementations/pandas_on_dask/io/io.py +++ b/modin/core/execution/dask/implementations/pandas_on_dask/io/io.py @@ -92,7 +92,9 @@ def __make_write(*classes, build_args=build_args): read_custom_text = __make_read( ExperimentalCustomTextParser, ExperimentalCustomTextDispatcher ) - read_sql_distributed = __make_read(ExperimentalSQLDispatcher) + read_sql_distributed = __make_read( + ExperimentalSQLDispatcher, build_args={**build_args, "base_read": read_sql} + ) del __make_read # to not pollute class namespace del __make_write # to not pollute class namespace diff --git a/modin/core/execution/dispatching/factories/factories.py b/modin/core/execution/dispatching/factories/factories.py index e37458717fd..d77dff5ecc6 100644 --- a/modin/core/execution/dispatching/factories/factories.py +++ b/modin/core/execution/dispatching/factories/factories.py @@ -329,33 +329,6 @@ def _read_pickle(cls, **kwargs): method="read_sql", ) def _read_sql(cls, **kwargs): - if IsExperimental.get(): - supported_engines = ("Ray", "Unidist", "Dask") - if Engine.get() not in supported_engines: - if "partition_column" in kwargs: - if kwargs["partition_column"] is not None: - warnings.warn( - f"Distributed read_sql() was only implemented for {', '.join(supported_engines)} engines." - ) - del kwargs["partition_column"] - if "lower_bound" in kwargs: - if kwargs["lower_bound"] is not None: - warnings.warn( - f"Distributed read_sql() was only implemented for {', '.join(supported_engines)} engines." - ) - del kwargs["lower_bound"] - if "upper_bound" in kwargs: - if kwargs["upper_bound"] is not None: - warnings.warn( - f"Distributed read_sql() was only implemented for {', '.join(supported_engines)} engines." - ) - del kwargs["upper_bound"] - if "max_sessions" in kwargs: - if kwargs["max_sessions"] is not None: - warnings.warn( - f"Distributed read_sql() was only implemented for {', '.join(supported_engines)} engines." - ) - del kwargs["max_sessions"] return cls.io_cls.read_sql(**kwargs) @classmethod @@ -490,11 +463,32 @@ def _read_pickle_distributed(cls, **kwargs): params=_doc_io_method_kwargs_params, ) def _read_sql_distributed(cls, **kwargs): - current_execution = get_current_execution() - if current_execution not in supported_execution: - raise NotImplementedError( - f"`_read_sql_distributed()` is not implemented for {current_execution} execution." - ) + supported_engines = ("Ray", "Unidist", "Dask") + if Engine.get() not in supported_engines: + if "partition_column" in kwargs: + if kwargs["partition_column"] is not None: + warnings.warn( + f"Distributed read_sql() was only implemented for {', '.join(supported_engines)} engines." + ) + del kwargs["partition_column"] + if "lower_bound" in kwargs: + if kwargs["lower_bound"] is not None: + warnings.warn( + f"Distributed read_sql() was only implemented for {', '.join(supported_engines)} engines." + ) + del kwargs["lower_bound"] + if "upper_bound" in kwargs: + if kwargs["upper_bound"] is not None: + warnings.warn( + f"Distributed read_sql() was only implemented for {', '.join(supported_engines)} engines." + ) + del kwargs["upper_bound"] + if "max_sessions" in kwargs: + if kwargs["max_sessions"] is not None: + warnings.warn( + f"Distributed read_sql() was only implemented for {', '.join(supported_engines)} engines." + ) + del kwargs["max_sessions"] return cls.io_cls.read_sql_distributed(**kwargs) @classmethod diff --git a/modin/core/execution/ray/implementations/pandas_on_ray/io/io.py b/modin/core/execution/ray/implementations/pandas_on_ray/io/io.py index 9bfa796d1fe..f08dc739716 100644 --- a/modin/core/execution/ray/implementations/pandas_on_ray/io/io.py +++ b/modin/core/execution/ray/implementations/pandas_on_ray/io/io.py @@ -94,7 +94,9 @@ def __make_write(*classes, build_args=build_args): read_custom_text = __make_read( ExperimentalCustomTextParser, ExperimentalCustomTextDispatcher ) - read_sql_distributed = __make_read(ExperimentalSQLDispatcher) + read_sql_distributed = __make_read( + ExperimentalSQLDispatcher, build_args={**build_args, "base_read": read_sql} + ) del __make_read # to not pollute class namespace del __make_write # to not pollute class namespace diff --git a/modin/core/execution/unidist/implementations/pandas_on_unidist/io/io.py b/modin/core/execution/unidist/implementations/pandas_on_unidist/io/io.py index 0c879df9089..ff5fbf05a29 100644 --- a/modin/core/execution/unidist/implementations/pandas_on_unidist/io/io.py +++ b/modin/core/execution/unidist/implementations/pandas_on_unidist/io/io.py @@ -93,7 +93,9 @@ def __make_write(*classes, build_args=build_args): read_custom_text = __make_read( ExperimentalCustomTextParser, ExperimentalCustomTextDispatcher ) - read_sql_distributed = __make_read(ExperimentalSQLDispatcher) + read_sql_distributed = __make_read( + ExperimentalSQLDispatcher, build_args={**build_args, "base_read": read_sql} + ) del __make_read # to not pollute class namespace del __make_write # to not pollute class namespace diff --git a/modin/experimental/core/io/sql/sql_dispatcher.py b/modin/experimental/core/io/sql/sql_dispatcher.py index bc6285c4107..d2fe256c131 100644 --- a/modin/experimental/core/io/sql/sql_dispatcher.py +++ b/modin/experimental/core/io/sql/sql_dispatcher.py @@ -72,7 +72,7 @@ def _read( message = "Defaulting to Modin core implementation; \ 'partition_column', 'lower_bound', 'upper_bound' must be different from None" warnings.warn(message) - return cls.base_io.read_sql( + return cls.base_read( sql, con, index_col,