Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1344848: Suppress dynamic pivot warnings #1434

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,16 @@

SUPPORTED_TABLE_TYPES = ["temp", "temporary", "transient"]

PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING = (
"Calling pivot() with the `value` parameter set to None or to a Snowpark "
+ "DataFrame is in private preview since v1.15.0. Do not use this feature "
+ "in production."
)
PIVOT_DEFAULT_ON_NULL_WARNING = (
"Calling pivot() with a non-None value for `default_on_null` is in "
+ "private preview since v1.15.0. Do not use this feature in production."
)


class TempObjectType(Enum):
TABLE = "TABLE"
Expand Down Expand Up @@ -876,6 +886,9 @@ def escape_quotes(unescaped: str) -> str:
return unescaped.replace(DOUBLE_QUOTE, DOUBLE_QUOTE + DOUBLE_QUOTE)


should_warn_dynamic_pivot_is_in_private_preview = True


def prepare_pivot_arguments(
df: "snowflake.snowpark.DataFrame",
df_name: str,
Expand All @@ -896,17 +909,17 @@ def prepare_pivot_arguments(
"""
from snowflake.snowpark.dataframe import DataFrame

if values is None or isinstance(values, DataFrame):
warning(
df_name,
"Parameter values is Optional or DataFrame is in private preview since v1.15.0. Do not use it in production.",
)

if default_on_null is not None:
warning(
df_name,
"Parameter default_on_null is not None is in private preview since v1.15.0. Do not use it in production.",
)
if should_warn_dynamic_pivot_is_in_private_preview:
if values is None or isinstance(values, DataFrame):
warning(
df_name,
PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING,
)
if default_on_null is not None:
warning(
df_name,
PIVOT_DEFAULT_ON_NULL_WARNING,
)

if values is not None and not values:
raise ValueError("values cannot be empty")
Expand Down
13 changes: 13 additions & 0 deletions src/snowflake/snowpark/modin/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from packaging import version

import snowflake.snowpark._internal.utils

if sys.version_info.major == 3 and sys.version_info.minor == 8:
raise RuntimeError(
"Snowpark pandas does not support Python 3.8. Please update to Python 3.9 or later."
Expand Down Expand Up @@ -60,3 +62,14 @@
from snowflake.snowpark.modin.plugin import docstrings # isort: skip # noqa: E402

DocModule.put(docstrings.__name__)


# Don't warn the user about our internal usage of private preview pivot
# features. The user should have already been warned that Snowpark pandas
# is in public or private preview. They likely don't know or care that we are
# using Snowpark DataFrame pivot() internally, let alone that we are using
# private preview features of Snowpark Python.

snowflake.snowpark._internal.utils.should_warn_dynamic_pivot_is_in_private_preview = (
False
)
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import pytest

from snowflake.snowpark._internal.utils import warning_dict

logging.getLogger("snowflake.connector").setLevel(logging.ERROR)

# TODO: SNOW-1305522: Enable Modin doctests for the below frontend files
Expand Down Expand Up @@ -91,3 +93,11 @@ def cte_optimization_enabled(pytestconfig):

def pytest_sessionstart(session):
os.environ["SNOWPARK_LOCAL_TESTING_INTERNAL_TELEMETRY"] = "1"


@pytest.fixture(autouse=True)
def clear_warning_dict():
yield
# clear the warning dict so that warnings from one test don't affect
sfc-gh-mvashishtha marked this conversation as resolved.
Show resolved Hide resolved
# warnings from other tests.
warning_dict.clear()
13 changes: 13 additions & 0 deletions tests/integ/modin/frame/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
import pytest

import snowflake.snowpark.modin.plugin # noqa: F401
from snowflake.snowpark._internal.utils import (
PIVOT_DEFAULT_ON_NULL_WARNING,
PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING,
)
from snowflake.snowpark.modin.plugin._internal.unpivot_utils import (
UNPIVOT_NULL_REPLACE_VALUE,
)
Expand Down Expand Up @@ -348,3 +352,12 @@ def test_dataframe_transpose_args_warning_log(caplog, score_test_data):
"Transpose ignores args in Snowpark pandas API."
in [r.msg for r in caplog.records]
)


@sql_count_checker(query_count=1, union_count=1)
def test_transpose_does_not_raise_pivot_warning_snow_1344848(caplog):
# Test transpose, which calls snowflake.snowpark.dataframe.pivot() with
# the `values` parameter as None or a Snowpark DataFrame.
pd.DataFrame([1]).T.to_pandas()
assert PIVOT_DEFAULT_ON_NULL_WARNING not in caplog.text
assert PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING not in caplog.text
13 changes: 13 additions & 0 deletions tests/integ/modin/strings/test_get_dummies_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import pytest

import snowflake.snowpark.modin.plugin # noqa: F401
from snowflake.snowpark._internal.utils import (
PIVOT_DEFAULT_ON_NULL_WARNING,
PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING,
)
from tests.integ.modin.sql_counter import sql_count_checker
from tests.integ.modin.utils import assert_snowpark_pandas_equal_to_pandas

Expand Down Expand Up @@ -52,3 +56,12 @@ def test_get_dummies_series_negative(data):
native_pd.get_dummies(pandas_ser),
check_dtype=False,
)


@sql_count_checker(query_count=1)
def test_get_dummies_does_not_raise_pivot_warning_snow_1344848(caplog):
# Test get_dummies, which uses the `default_on_null` parameter of
# snowflake.snowpark.dataframe.pivot()
pd.get_dummies(pd.Series(["a"])).to_pandas()
assert PIVOT_DEFAULT_ON_NULL_WARNING not in caplog.text
assert PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING not in caplog.text
21 changes: 18 additions & 3 deletions tests/integ/scala/test_dataframe_aggregate_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

import sys
from decimal import Decimal
from math import sqrt
from typing import NamedTuple

import pytest

from snowflake.snowpark import GroupingSets, Row
from snowflake.snowpark._internal.utils import TempObjectType
from snowflake.snowpark._internal.utils import (
PIVOT_DEFAULT_ON_NULL_WARNING,
PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING,
TempObjectType,
)
from snowflake.snowpark.column import Column
from snowflake.snowpark.exceptions import (
SnowparkDataframeException,
Expand Down Expand Up @@ -97,7 +102,7 @@ def test_group_by_pivot(session):
).agg([sum(col("amount")), avg(col("amount"))])


def test_group_by_pivot_dynamic_any(session):
def test_group_by_pivot_dynamic_any(session, caplog):
Utils.check_answer(
TestData.monthly_sales_with_team(session)
.group_by("empid")
Expand All @@ -111,6 +116,11 @@ def test_group_by_pivot_dynamic_any(session):
sort=False,
)

if "snowflake.snowpark.modin.plugin" not in sys.modules:
# Snowpark pandas users don't get warnings about dynamic pivot
# features. See SNOW-1344848.
assert PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING in caplog.text

Utils.check_answer(
TestData.monthly_sales_with_team(session)
.group_by(["empid", "team"])
Expand Down Expand Up @@ -292,7 +302,7 @@ def test_pivot_dynamic_subquery_with_bad_subquery(session):
assert "Pivot subquery must select single column" in str(ex_info.value)


def test_pivot_default_on_none(session):
def test_pivot_default_on_none(session, caplog):
class MonthlySales(NamedTuple):
empid: int
amount: int
Expand Down Expand Up @@ -326,6 +336,11 @@ class MonthlySales(NamedTuple):
sort=False,
)

if "snowflake.snowpark.modin.plugin" not in sys.modules:
# Snowpark pandas users don't get warnings about dynamic pivot
# features. See SNOW-1344848.
assert PIVOT_DEFAULT_ON_NULL_WARNING in caplog.text


@pytest.mark.localtest
def test_rel_grouped_dataframe_agg(session):
Expand Down
4 changes: 1 addition & 3 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from snowflake.snowpark import Column, Row, Window
from snowflake.snowpark._internal.analyzer.analyzer_utils import result_scan_statement
from snowflake.snowpark._internal.analyzer.expression import Attribute, Interval, Star
from snowflake.snowpark._internal.utils import TempObjectType, warning_dict
from snowflake.snowpark._internal.utils import TempObjectType
from snowflake.snowpark.exceptions import (
SnowparkColumnException,
SnowparkCreateDynamicTableException,
Expand Down Expand Up @@ -2719,8 +2719,6 @@ def test_write_temp_table_no_breaking_change(
Utils.assert_table_type(session, table_name, "temp")
finally:
Utils.drop_table(session, table_name)
# clear the warning dict otherwise it will affect the future tests
warning_dict.clear()


@pytest.mark.localtest
Expand Down
3 changes: 0 additions & 3 deletions tests/integ/test_pandas_to_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
TempObjectType,
is_in_stored_procedure,
random_name_for_temp_object,
warning_dict,
)
from snowflake.snowpark.exceptions import SnowparkPandasException
from tests.utils import Utils
Expand Down Expand Up @@ -304,8 +303,6 @@ def test_write_temp_table_no_breaking_change(session, table_type, caplog):
Utils.assert_table_type(session, table_name, "temp")
finally:
Utils.drop_table(session, table_name)
# clear the warning dict otherwise it will affect the future tests
warning_dict.clear()


@pytest.mark.localtest
Expand Down
19 changes: 6 additions & 13 deletions tests/integ/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@

from snowflake.connector.version import VERSION as SNOWFLAKE_CONNECTOR_VERSION
from snowflake.snowpark import Row, Session
from snowflake.snowpark._internal.utils import (
unwrap_stage_location_single_quote,
warning_dict,
)
from snowflake.snowpark._internal.utils import unwrap_stage_location_single_quote
from snowflake.snowpark.exceptions import (
SnowparkInvalidObjectNameException,
SnowparkSQLException,
Expand Down Expand Up @@ -2177,15 +2174,11 @@ def test_deprecate_call_udf_with_list(session, caplog):
return_type=IntegerType(),
input_types=[IntegerType(), IntegerType()],
)
try:
with caplog.at_level(logging.WARNING):
add_udf(["a", "b"])
assert (
"Passing arguments to a UDF with a list or tuple is deprecated"
in caplog.text
)
finally:
warning_dict.clear()
with caplog.at_level(logging.WARNING):
add_udf(["a", "b"])
assert (
"Passing arguments to a UDF with a list or tuple is deprecated" in caplog.text
)


def test_strict_udf(session):
Expand Down
76 changes: 32 additions & 44 deletions tests/unit/scala/test_utils_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
unwrap_stage_location_single_quote,
validate_object_name,
warning,
warning_dict,
zip_file_or_directory_to_stream,
)
from tests.utils import IS_WINDOWS, TestFiles
Expand Down Expand Up @@ -445,43 +444,36 @@ def test_warning(caplog):
def f():
return 1

try:
with caplog.at_level(logging.WARNING):
warning("aaa", "bbb", 2)
warning("aaa", "bbb", 2)
warning("aaa", "bbb", 2)
assert caplog.text.count("bbb") == 2
with caplog.at_level(logging.WARNING):
warning(f.__qualname__, "ccc", 2)
warning(f.__qualname__, "ccc", 2)
warning(f.__qualname__, "ccc", 2)
assert caplog.text.count("ccc") == 2
finally:
warning_dict.clear()
with caplog.at_level(logging.WARNING):
warning("aaa", "bbb", 2)
warning("aaa", "bbb", 2)
warning("aaa", "bbb", 2)
assert caplog.text.count("bbb") == 2
with caplog.at_level(logging.WARNING):
warning(f.__qualname__, "ccc", 2)
warning(f.__qualname__, "ccc", 2)
warning(f.__qualname__, "ccc", 2)
assert caplog.text.count("ccc") == 2


@pytest.mark.parametrize("decorator", [deprecated, experimental])
def test_func_decorator(caplog, decorator):
try:

@decorator(
version="1.0.0",
extra_warning_text="extra_warning_text",
extra_doc_string="extra_doc_string",
)
def f():
return 1
@decorator(
version="1.0.0",
extra_warning_text="extra_warning_text",
extra_doc_string="extra_doc_string",
)
def f():
return 1

assert "extra_doc_string" in f.__doc__
with caplog.at_level(logging.WARNING):
f()
f()
assert "extra_doc_string" in f.__doc__
with caplog.at_level(logging.WARNING):
f()
f()

assert decorator.__name__ in caplog.text
assert caplog.text.count("1.0.0") == 1
assert caplog.text.count("extra_warning_text") == 1
finally:
warning_dict.clear()
assert decorator.__name__ in caplog.text
assert caplog.text.count("1.0.0") == 1
assert caplog.text.count("extra_warning_text") == 1


def test_is_sql_select_statement():
Expand Down Expand Up @@ -564,18 +556,14 @@ def foo():
pass

caplog.clear()
warning_dict.clear()
try:
with caplog.at_level(logging.WARNING):
foo()
assert extra_doc in foo.__doc__
assert expected_warning_text in caplog.messages
caplog.clear()
with caplog.at_level(logging.WARNING):
foo()
assert expected_warning_text not in caplog.text
finally:
warning_dict.clear()
with caplog.at_level(logging.WARNING):
foo()
assert extra_doc in foo.__doc__
assert expected_warning_text in caplog.messages
caplog.clear()
with caplog.at_level(logging.WARNING):
foo()
assert expected_warning_text not in caplog.text


@pytest.mark.parametrize("function", [result_set_to_iter, result_set_to_rows])
Expand Down
Loading