diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index 01b3cc77d12..daf953888b8 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -182,6 +182,16 @@ SUPPORTED_TABLE_TYPES = ["temp", "temporary", "transient"] +PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING = ( + "Calling pivot() with the `value` parameter set to None or to a Snowpark " + + "DataFrame is in private preview since v1.15.0. Do not use this feature " + + "in production." +) +PIVOT_DEFAULT_ON_NULL_WARNING = ( + "Calling pivot() with a non-None value for `default_on_null` is in " + + "private preview since v1.15.0. Do not use this feature in production." +) + class TempObjectType(Enum): TABLE = "TABLE" @@ -876,6 +886,9 @@ def escape_quotes(unescaped: str) -> str: return unescaped.replace(DOUBLE_QUOTE, DOUBLE_QUOTE + DOUBLE_QUOTE) +should_warn_dynamic_pivot_is_in_private_preview = True + + def prepare_pivot_arguments( df: "snowflake.snowpark.DataFrame", df_name: str, @@ -896,17 +909,17 @@ def prepare_pivot_arguments( """ from snowflake.snowpark.dataframe import DataFrame - if values is None or isinstance(values, DataFrame): - warning( - df_name, - "Parameter values is Optional or DataFrame is in private preview since v1.15.0. Do not use it in production.", - ) - - if default_on_null is not None: - warning( - df_name, - "Parameter default_on_null is not None is in private preview since v1.15.0. Do not use it in production.", - ) + if should_warn_dynamic_pivot_is_in_private_preview: + if values is None or isinstance(values, DataFrame): + warning( + df_name, + PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING, + ) + if default_on_null is not None: + warning( + df_name, + PIVOT_DEFAULT_ON_NULL_WARNING, + ) if values is not None and not values: raise ValueError("values cannot be empty") diff --git a/src/snowflake/snowpark/modin/plugin/__init__.py b/src/snowflake/snowpark/modin/plugin/__init__.py index bd7cab8b4d9..35042684921 100644 --- a/src/snowflake/snowpark/modin/plugin/__init__.py +++ b/src/snowflake/snowpark/modin/plugin/__init__.py @@ -7,6 +7,8 @@ from packaging import version +import snowflake.snowpark._internal.utils + if sys.version_info.major == 3 and sys.version_info.minor == 8: raise RuntimeError( "Snowpark pandas does not support Python 3.8. Please update to Python 3.9 or later." @@ -60,3 +62,14 @@ from snowflake.snowpark.modin.plugin import docstrings # isort: skip # noqa: E402 DocModule.put(docstrings.__name__) + + +# Don't warn the user about our internal usage of private preview pivot +# features. The user should have already been warned that Snowpark pandas +# is in public or private preview. They likely don't know or care that we are +# using Snowpark DataFrame pivot() internally, let alone that we are using +# private preview features of Snowpark Python. + +snowflake.snowpark._internal.utils.should_warn_dynamic_pivot_is_in_private_preview = ( + False +) diff --git a/tests/conftest.py b/tests/conftest.py index eeb0173d651..9caecda182d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,8 @@ import pytest +from snowflake.snowpark._internal.utils import warning_dict + logging.getLogger("snowflake.connector").setLevel(logging.ERROR) # TODO: SNOW-1305522: Enable Modin doctests for the below frontend files @@ -91,3 +93,11 @@ def cte_optimization_enabled(pytestconfig): def pytest_sessionstart(session): os.environ["SNOWPARK_LOCAL_TESTING_INTERNAL_TELEMETRY"] = "1" + + +@pytest.fixture(autouse=True) +def clear_warning_dict(): + yield + # clear the warning dict so that warnings from one test don't affect + # warnings from other tests. + warning_dict.clear() diff --git a/tests/integ/modin/frame/test_transpose.py b/tests/integ/modin/frame/test_transpose.py index aba9e784fd1..7e01cf0c1b7 100644 --- a/tests/integ/modin/frame/test_transpose.py +++ b/tests/integ/modin/frame/test_transpose.py @@ -10,6 +10,10 @@ import pytest import snowflake.snowpark.modin.plugin # noqa: F401 +from snowflake.snowpark._internal.utils import ( + PIVOT_DEFAULT_ON_NULL_WARNING, + PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING, +) from snowflake.snowpark.modin.plugin._internal.unpivot_utils import ( UNPIVOT_NULL_REPLACE_VALUE, ) @@ -348,3 +352,12 @@ def test_dataframe_transpose_args_warning_log(caplog, score_test_data): "Transpose ignores args in Snowpark pandas API." in [r.msg for r in caplog.records] ) + + +@sql_count_checker(query_count=1, union_count=1) +def test_transpose_does_not_raise_pivot_warning_snow_1344848(caplog): + # Test transpose, which calls snowflake.snowpark.dataframe.pivot() with + # the `values` parameter as None or a Snowpark DataFrame. + pd.DataFrame([1]).T.to_pandas() + assert PIVOT_DEFAULT_ON_NULL_WARNING not in caplog.text + assert PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING not in caplog.text diff --git a/tests/integ/modin/strings/test_get_dummies_series.py b/tests/integ/modin/strings/test_get_dummies_series.py index a6b004f577a..401e30ee338 100644 --- a/tests/integ/modin/strings/test_get_dummies_series.py +++ b/tests/integ/modin/strings/test_get_dummies_series.py @@ -6,6 +6,10 @@ import pytest import snowflake.snowpark.modin.plugin # noqa: F401 +from snowflake.snowpark._internal.utils import ( + PIVOT_DEFAULT_ON_NULL_WARNING, + PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING, +) from tests.integ.modin.sql_counter import sql_count_checker from tests.integ.modin.utils import assert_snowpark_pandas_equal_to_pandas @@ -52,3 +56,12 @@ def test_get_dummies_series_negative(data): native_pd.get_dummies(pandas_ser), check_dtype=False, ) + + +@sql_count_checker(query_count=1) +def test_get_dummies_does_not_raise_pivot_warning_snow_1344848(caplog): + # Test get_dummies, which uses the `default_on_null` parameter of + # snowflake.snowpark.dataframe.pivot() + pd.get_dummies(pd.Series(["a"])).to_pandas() + assert PIVOT_DEFAULT_ON_NULL_WARNING not in caplog.text + assert PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING not in caplog.text diff --git a/tests/integ/scala/test_dataframe_aggregate_suite.py b/tests/integ/scala/test_dataframe_aggregate_suite.py index 5d786975b3b..5bdf12a7846 100644 --- a/tests/integ/scala/test_dataframe_aggregate_suite.py +++ b/tests/integ/scala/test_dataframe_aggregate_suite.py @@ -3,6 +3,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import sys from decimal import Decimal from math import sqrt from typing import NamedTuple @@ -10,7 +11,11 @@ import pytest from snowflake.snowpark import GroupingSets, Row -from snowflake.snowpark._internal.utils import TempObjectType +from snowflake.snowpark._internal.utils import ( + PIVOT_DEFAULT_ON_NULL_WARNING, + PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING, + TempObjectType, +) from snowflake.snowpark.column import Column from snowflake.snowpark.exceptions import ( SnowparkDataframeException, @@ -97,7 +102,7 @@ def test_group_by_pivot(session): ).agg([sum(col("amount")), avg(col("amount"))]) -def test_group_by_pivot_dynamic_any(session): +def test_group_by_pivot_dynamic_any(session, caplog): Utils.check_answer( TestData.monthly_sales_with_team(session) .group_by("empid") @@ -111,6 +116,11 @@ def test_group_by_pivot_dynamic_any(session): sort=False, ) + if "snowflake.snowpark.modin.plugin" not in sys.modules: + # Snowpark pandas users don't get warnings about dynamic pivot + # features. See SNOW-1344848. + assert PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING in caplog.text + Utils.check_answer( TestData.monthly_sales_with_team(session) .group_by(["empid", "team"]) @@ -292,7 +302,7 @@ def test_pivot_dynamic_subquery_with_bad_subquery(session): assert "Pivot subquery must select single column" in str(ex_info.value) -def test_pivot_default_on_none(session): +def test_pivot_default_on_none(session, caplog): class MonthlySales(NamedTuple): empid: int amount: int @@ -326,6 +336,11 @@ class MonthlySales(NamedTuple): sort=False, ) + if "snowflake.snowpark.modin.plugin" not in sys.modules: + # Snowpark pandas users don't get warnings about dynamic pivot + # features. See SNOW-1344848. + assert PIVOT_DEFAULT_ON_NULL_WARNING in caplog.text + @pytest.mark.localtest def test_rel_grouped_dataframe_agg(session): diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index 2379f371cda..f0aaad082d6 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -30,7 +30,7 @@ from snowflake.snowpark import Column, Row, Window from snowflake.snowpark._internal.analyzer.analyzer_utils import result_scan_statement from snowflake.snowpark._internal.analyzer.expression import Attribute, Interval, Star -from snowflake.snowpark._internal.utils import TempObjectType, warning_dict +from snowflake.snowpark._internal.utils import TempObjectType from snowflake.snowpark.exceptions import ( SnowparkColumnException, SnowparkCreateDynamicTableException, @@ -2719,8 +2719,6 @@ def test_write_temp_table_no_breaking_change( Utils.assert_table_type(session, table_name, "temp") finally: Utils.drop_table(session, table_name) - # clear the warning dict otherwise it will affect the future tests - warning_dict.clear() @pytest.mark.localtest diff --git a/tests/integ/test_pandas_to_df.py b/tests/integ/test_pandas_to_df.py index 712cf1f1588..b0d6b7da81f 100644 --- a/tests/integ/test_pandas_to_df.py +++ b/tests/integ/test_pandas_to_df.py @@ -23,7 +23,6 @@ TempObjectType, is_in_stored_procedure, random_name_for_temp_object, - warning_dict, ) from snowflake.snowpark.exceptions import SnowparkPandasException from tests.utils import Utils @@ -304,8 +303,6 @@ def test_write_temp_table_no_breaking_change(session, table_type, caplog): Utils.assert_table_type(session, table_name, "temp") finally: Utils.drop_table(session, table_name) - # clear the warning dict otherwise it will affect the future tests - warning_dict.clear() @pytest.mark.localtest diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index a76a620401d..d810a67783d 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -45,10 +45,7 @@ from snowflake.connector.version import VERSION as SNOWFLAKE_CONNECTOR_VERSION from snowflake.snowpark import Row, Session -from snowflake.snowpark._internal.utils import ( - unwrap_stage_location_single_quote, - warning_dict, -) +from snowflake.snowpark._internal.utils import unwrap_stage_location_single_quote from snowflake.snowpark.exceptions import ( SnowparkInvalidObjectNameException, SnowparkSQLException, @@ -2177,15 +2174,11 @@ def test_deprecate_call_udf_with_list(session, caplog): return_type=IntegerType(), input_types=[IntegerType(), IntegerType()], ) - try: - with caplog.at_level(logging.WARNING): - add_udf(["a", "b"]) - assert ( - "Passing arguments to a UDF with a list or tuple is deprecated" - in caplog.text - ) - finally: - warning_dict.clear() + with caplog.at_level(logging.WARNING): + add_udf(["a", "b"]) + assert ( + "Passing arguments to a UDF with a list or tuple is deprecated" in caplog.text + ) def test_strict_udf(session): diff --git a/tests/unit/scala/test_utils_suite.py b/tests/unit/scala/test_utils_suite.py index f4ad5944867..e483877025c 100644 --- a/tests/unit/scala/test_utils_suite.py +++ b/tests/unit/scala/test_utils_suite.py @@ -28,7 +28,6 @@ unwrap_stage_location_single_quote, validate_object_name, warning, - warning_dict, zip_file_or_directory_to_stream, ) from tests.utils import IS_WINDOWS, TestFiles @@ -445,43 +444,36 @@ def test_warning(caplog): def f(): return 1 - try: - with caplog.at_level(logging.WARNING): - warning("aaa", "bbb", 2) - warning("aaa", "bbb", 2) - warning("aaa", "bbb", 2) - assert caplog.text.count("bbb") == 2 - with caplog.at_level(logging.WARNING): - warning(f.__qualname__, "ccc", 2) - warning(f.__qualname__, "ccc", 2) - warning(f.__qualname__, "ccc", 2) - assert caplog.text.count("ccc") == 2 - finally: - warning_dict.clear() + with caplog.at_level(logging.WARNING): + warning("aaa", "bbb", 2) + warning("aaa", "bbb", 2) + warning("aaa", "bbb", 2) + assert caplog.text.count("bbb") == 2 + with caplog.at_level(logging.WARNING): + warning(f.__qualname__, "ccc", 2) + warning(f.__qualname__, "ccc", 2) + warning(f.__qualname__, "ccc", 2) + assert caplog.text.count("ccc") == 2 @pytest.mark.parametrize("decorator", [deprecated, experimental]) def test_func_decorator(caplog, decorator): - try: - - @decorator( - version="1.0.0", - extra_warning_text="extra_warning_text", - extra_doc_string="extra_doc_string", - ) - def f(): - return 1 + @decorator( + version="1.0.0", + extra_warning_text="extra_warning_text", + extra_doc_string="extra_doc_string", + ) + def f(): + return 1 - assert "extra_doc_string" in f.__doc__ - with caplog.at_level(logging.WARNING): - f() - f() + assert "extra_doc_string" in f.__doc__ + with caplog.at_level(logging.WARNING): + f() + f() - assert decorator.__name__ in caplog.text - assert caplog.text.count("1.0.0") == 1 - assert caplog.text.count("extra_warning_text") == 1 - finally: - warning_dict.clear() + assert decorator.__name__ in caplog.text + assert caplog.text.count("1.0.0") == 1 + assert caplog.text.count("extra_warning_text") == 1 def test_is_sql_select_statement(): @@ -564,18 +556,14 @@ def foo(): pass caplog.clear() - warning_dict.clear() - try: - with caplog.at_level(logging.WARNING): - foo() - assert extra_doc in foo.__doc__ - assert expected_warning_text in caplog.messages - caplog.clear() - with caplog.at_level(logging.WARNING): - foo() - assert expected_warning_text not in caplog.text - finally: - warning_dict.clear() + with caplog.at_level(logging.WARNING): + foo() + assert extra_doc in foo.__doc__ + assert expected_warning_text in caplog.messages + caplog.clear() + with caplog.at_level(logging.WARNING): + foo() + assert expected_warning_text not in caplog.text @pytest.mark.parametrize("function", [result_set_to_iter, result_set_to_rows])