Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1344848: Suppress dynamic pivot warnings #1434

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@

SUPPORTED_TABLE_TYPES = ["temp", "temporary", "transient"]

PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING = "Parameter values is Optional or DataFrame is in private preview since v1.15.0. Do not use it in production."
sfc-gh-mvashishtha marked this conversation as resolved.
Show resolved Hide resolved
PIVOT_DEFAULT_ON_NULL_WARNING = "Parameter default_on_null is not None is in private preview since v1.15.0. Do not use it in production."


class TempObjectType(Enum):
TABLE = "TABLE"
Expand Down Expand Up @@ -899,13 +902,13 @@ def prepare_pivot_arguments(
if values is None or isinstance(values, DataFrame):
warning(
df_name,
"Parameter values is Optional or DataFrame is in private preview since v1.15.0. Do not use it in production.",
PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING,
)

if default_on_null is not None:
warning(
df_name,
"Parameter default_on_null is not None is in private preview since v1.15.0. Do not use it in production.",
PIVOT_DEFAULT_ON_NULL_WARNING,
)

if values is not None and not values:
Expand Down
45 changes: 37 additions & 8 deletions src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
ColumnOrSqlExpr,
LiteralType,
)
from snowflake.snowpark._internal.utils import parse_positional_args_to_list
from snowflake.snowpark._internal.utils import (
PIVOT_DEFAULT_ON_NULL_WARNING,
PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING,
parse_positional_args_to_list,
)
from snowflake.snowpark.column import Column
from snowflake.snowpark.dataframe import DataFrame as SnowparkDataFrame
from snowflake.snowpark.dataframe_writer import DataFrameWriter
Expand Down Expand Up @@ -818,16 +822,41 @@ def pivot(
See detailed docstring in Snowpark DataFrame's pivot.
"""
snowpark_dataframe = self.to_projected_snowpark_dataframe()

# Don't warn the user about our internal usage of private preview
sfc-gh-mvashishtha marked this conversation as resolved.
Show resolved Hide resolved
# pivot features. The user should have already been warned that
# Snowpark pandas is in public or private preview. They likely don't
# know or care that we are using Snowpark DataFrame pivot() internally,
# let alone that we are using private preview features of Snowpark
# Python.

class NoPivotWarningFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
message = record.getMessage()
return (
PIVOT_DEFAULT_ON_NULL_WARNING not in message
and PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING not in message
)

filter = NoPivotWarningFilter()

class NoPivotWarningContext:
def __enter__(self):
logging.getLogger("snowflake.snowpark").addFilter(filter)

def __exit__(self, exc_type, exc_val, exc_tb):
logging.getLogger("snowflake.snowpark").removeFilter(filter)

with NoPivotWarningContext():
pivoted_snowpark_dataframe = snowpark_dataframe.pivot(
pivot_col=pivot_col,
values=values,
default_on_null=default_on_null,
)
return OrderedDataFrame(
# the pivot result columns for dynamic pivot are data dependent, a schema call is required
# to know all the quoted identifiers for the pivot result.
DataFrameReference(
snowpark_dataframe.pivot(
pivot_col=pivot_col,
values=values,
default_on_null=default_on_null,
).agg(*agg_exprs)
)
DataFrameReference(pivoted_snowpark_dataframe.agg(*agg_exprs))
)

def unpivot(
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import pytest

from snowflake.snowpark._internal.utils import warning_dict

logging.getLogger("snowflake.connector").setLevel(logging.ERROR)

# TODO: SNOW-1305522: Enable Modin doctests for the below frontend files
Expand Down Expand Up @@ -91,3 +93,11 @@ def cte_optimization_enabled(pytestconfig):

def pytest_sessionstart(session):
os.environ["SNOWPARK_LOCAL_TESTING_INTERNAL_TELEMETRY"] = "1"


@pytest.fixture(autouse=True)
def clear_warning_dict():
yield
# clear the warning dict so that warnings from one test don't affect
sfc-gh-mvashishtha marked this conversation as resolved.
Show resolved Hide resolved
# warnings from other tests.
warning_dict.clear()
13 changes: 13 additions & 0 deletions tests/integ/modin/frame/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
import pytest

import snowflake.snowpark.modin.plugin # noqa: F401
from snowflake.snowpark._internal.utils import (
PIVOT_DEFAULT_ON_NULL_WARNING,
PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING,
)
from snowflake.snowpark.modin.plugin._internal.unpivot_utils import (
UNPIVOT_NULL_REPLACE_VALUE,
)
Expand Down Expand Up @@ -348,3 +352,12 @@ def test_dataframe_transpose_args_warning_log(caplog, score_test_data):
"Transpose ignores args in Snowpark pandas API."
in [r.msg for r in caplog.records]
)


@sql_count_checker(query_count=1, union_count=1)
def test_transpose_does_not_raise_pivot_warning_snow_1344848(caplog):
# Test transpose, which calls snowflake.snowpark.dataframe.pivot() with
# the `values` parameter as None or a Snowpark DataFrame.
pd.DataFrame([1]).T.to_pandas()
assert PIVOT_DEFAULT_ON_NULL_WARNING not in caplog.text
assert PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING not in caplog.text
13 changes: 13 additions & 0 deletions tests/integ/modin/strings/test_get_dummies_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import pytest

import snowflake.snowpark.modin.plugin # noqa: F401
from snowflake.snowpark._internal.utils import (
PIVOT_DEFAULT_ON_NULL_WARNING,
PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING,
)
from tests.integ.modin.sql_counter import sql_count_checker
from tests.integ.modin.utils import assert_snowpark_pandas_equal_to_pandas

Expand Down Expand Up @@ -52,3 +56,12 @@ def test_get_dummies_series_negative(data):
native_pd.get_dummies(pandas_ser),
check_dtype=False,
)


@sql_count_checker(query_count=1)
def test_get_dummies_does_not_raise_pivot_warning_snow_1344848(caplog):
# Test get_dummies, which uses the `default_on_null` parameter of
# snowflake.snowpark.dataframe.pivot()
pd.get_dummies(pd.Series(["a"])).to_pandas()
assert PIVOT_DEFAULT_ON_NULL_WARNING not in caplog.text
assert PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING not in caplog.text
14 changes: 11 additions & 3 deletions tests/integ/scala/test_dataframe_aggregate_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import pytest

from snowflake.snowpark import GroupingSets, Row
from snowflake.snowpark._internal.utils import TempObjectType
from snowflake.snowpark._internal.utils import (
PIVOT_DEFAULT_ON_NULL_WARNING,
PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING,
TempObjectType,
)
from snowflake.snowpark.column import Column
from snowflake.snowpark.exceptions import (
SnowparkDataframeException,
Expand Down Expand Up @@ -97,7 +101,7 @@ def test_group_by_pivot(session):
).agg([sum(col("amount")), avg(col("amount"))])


def test_group_by_pivot_dynamic_any(session):
def test_group_by_pivot_dynamic_any(session, caplog):
Utils.check_answer(
TestData.monthly_sales_with_team(session)
.group_by("empid")
Expand All @@ -111,6 +115,8 @@ def test_group_by_pivot_dynamic_any(session):
sort=False,
)

assert PIVOT_VALUES_NONE_OR_DATAFRAME_WARNING in caplog.text

Utils.check_answer(
TestData.monthly_sales_with_team(session)
.group_by(["empid", "team"])
Expand Down Expand Up @@ -292,7 +298,7 @@ def test_pivot_dynamic_subquery_with_bad_subquery(session):
assert "Pivot subquery must select single column" in str(ex_info.value)


def test_pivot_default_on_none(session):
def test_pivot_default_on_none(session, caplog):
class MonthlySales(NamedTuple):
empid: int
amount: int
Expand Down Expand Up @@ -326,6 +332,8 @@ class MonthlySales(NamedTuple):
sort=False,
)

assert PIVOT_DEFAULT_ON_NULL_WARNING in caplog.text


@pytest.mark.localtest
def test_rel_grouped_dataframe_agg(session):
Expand Down
4 changes: 1 addition & 3 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from snowflake.snowpark import Column, Row, Window
from snowflake.snowpark._internal.analyzer.analyzer_utils import result_scan_statement
from snowflake.snowpark._internal.analyzer.expression import Attribute, Interval, Star
from snowflake.snowpark._internal.utils import TempObjectType, warning_dict
from snowflake.snowpark._internal.utils import TempObjectType
from snowflake.snowpark.exceptions import (
SnowparkColumnException,
SnowparkCreateDynamicTableException,
Expand Down Expand Up @@ -2719,8 +2719,6 @@ def test_write_temp_table_no_breaking_change(
Utils.assert_table_type(session, table_name, "temp")
finally:
Utils.drop_table(session, table_name)
# clear the warning dict otherwise it will affect the future tests
warning_dict.clear()


@pytest.mark.localtest
Expand Down
3 changes: 0 additions & 3 deletions tests/integ/test_pandas_to_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
TempObjectType,
is_in_stored_procedure,
random_name_for_temp_object,
warning_dict,
)
from snowflake.snowpark.exceptions import SnowparkPandasException
from tests.utils import Utils
Expand Down Expand Up @@ -304,8 +303,6 @@ def test_write_temp_table_no_breaking_change(session, table_type, caplog):
Utils.assert_table_type(session, table_name, "temp")
finally:
Utils.drop_table(session, table_name)
# clear the warning dict otherwise it will affect the future tests
warning_dict.clear()


@pytest.mark.localtest
Expand Down
19 changes: 6 additions & 13 deletions tests/integ/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@

from snowflake.connector.version import VERSION as SNOWFLAKE_CONNECTOR_VERSION
from snowflake.snowpark import Row, Session
from snowflake.snowpark._internal.utils import (
unwrap_stage_location_single_quote,
warning_dict,
)
from snowflake.snowpark._internal.utils import unwrap_stage_location_single_quote
from snowflake.snowpark.exceptions import (
SnowparkInvalidObjectNameException,
SnowparkSQLException,
Expand Down Expand Up @@ -2177,15 +2174,11 @@ def test_deprecate_call_udf_with_list(session, caplog):
return_type=IntegerType(),
input_types=[IntegerType(), IntegerType()],
)
try:
with caplog.at_level(logging.WARNING):
add_udf(["a", "b"])
assert (
"Passing arguments to a UDF with a list or tuple is deprecated"
in caplog.text
)
finally:
warning_dict.clear()
with caplog.at_level(logging.WARNING):
add_udf(["a", "b"])
assert (
"Passing arguments to a UDF with a list or tuple is deprecated" in caplog.text
)


def test_strict_udf(session):
Expand Down
76 changes: 32 additions & 44 deletions tests/unit/scala/test_utils_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
unwrap_stage_location_single_quote,
validate_object_name,
warning,
warning_dict,
zip_file_or_directory_to_stream,
)
from tests.utils import IS_WINDOWS, TestFiles
Expand Down Expand Up @@ -445,43 +444,36 @@ def test_warning(caplog):
def f():
return 1

try:
with caplog.at_level(logging.WARNING):
warning("aaa", "bbb", 2)
warning("aaa", "bbb", 2)
warning("aaa", "bbb", 2)
assert caplog.text.count("bbb") == 2
with caplog.at_level(logging.WARNING):
warning(f.__qualname__, "ccc", 2)
warning(f.__qualname__, "ccc", 2)
warning(f.__qualname__, "ccc", 2)
assert caplog.text.count("ccc") == 2
finally:
warning_dict.clear()
with caplog.at_level(logging.WARNING):
warning("aaa", "bbb", 2)
warning("aaa", "bbb", 2)
warning("aaa", "bbb", 2)
assert caplog.text.count("bbb") == 2
with caplog.at_level(logging.WARNING):
warning(f.__qualname__, "ccc", 2)
warning(f.__qualname__, "ccc", 2)
warning(f.__qualname__, "ccc", 2)
assert caplog.text.count("ccc") == 2


@pytest.mark.parametrize("decorator", [deprecated, experimental])
def test_func_decorator(caplog, decorator):
try:

@decorator(
version="1.0.0",
extra_warning_text="extra_warning_text",
extra_doc_string="extra_doc_string",
)
def f():
return 1
@decorator(
version="1.0.0",
extra_warning_text="extra_warning_text",
extra_doc_string="extra_doc_string",
)
def f():
return 1

assert "extra_doc_string" in f.__doc__
with caplog.at_level(logging.WARNING):
f()
f()
assert "extra_doc_string" in f.__doc__
with caplog.at_level(logging.WARNING):
f()
f()

assert decorator.__name__ in caplog.text
assert caplog.text.count("1.0.0") == 1
assert caplog.text.count("extra_warning_text") == 1
finally:
warning_dict.clear()
assert decorator.__name__ in caplog.text
assert caplog.text.count("1.0.0") == 1
assert caplog.text.count("extra_warning_text") == 1


def test_is_sql_select_statement():
Expand Down Expand Up @@ -564,18 +556,14 @@ def foo():
pass

caplog.clear()
warning_dict.clear()
try:
with caplog.at_level(logging.WARNING):
foo()
assert extra_doc in foo.__doc__
assert expected_warning_text in caplog.messages
caplog.clear()
with caplog.at_level(logging.WARNING):
foo()
assert expected_warning_text not in caplog.text
finally:
warning_dict.clear()
with caplog.at_level(logging.WARNING):
foo()
assert extra_doc in foo.__doc__
assert expected_warning_text in caplog.messages
caplog.clear()
with caplog.at_level(logging.WARNING):
foo()
assert expected_warning_text not in caplog.text


@pytest.mark.parametrize("function", [result_set_to_iter, result_set_to_rows])
Expand Down
Loading