Skip to content

Commit

Permalink
Clear warning_dict before tests and add sql_count_checker
Browse files Browse the repository at this point in the history
Signed-off-by: sfc-gh-mvashishtha <[email protected]>
  • Loading branch information
sfc-gh-mvashishtha committed Apr 25, 2024
1 parent 718df63 commit 47c269a
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 63 deletions.
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import pytest

from snowflake.snowpark._internal.utils import warning_dict

logging.getLogger("snowflake.connector").setLevel(logging.ERROR)

# TODO: SNOW-1305522: Enable Modin doctests for the below frontend files
Expand Down Expand Up @@ -91,3 +93,11 @@ def cte_optimization_enabled(pytestconfig):

def pytest_sessionstart(session):
os.environ["SNOWPARK_LOCAL_TESTING_INTERNAL_TELEMETRY"] = "1"


@pytest.fixture(autouse=True)
def clear_warning_dict():
yield
# clear the warning dict so that warnings from one test don't affect
# warnings from other tests.
warning_dict.clear()
1 change: 1 addition & 0 deletions tests/integ/modin/frame/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def test_dataframe_transpose_args_warning_log(caplog, score_test_data):
)


@sql_count_checker(query_count=1, union_count=1)
def test_transpose_does_not_raise_pivot_warning_snow_1344848(caplog):
# Test transpose, which calls snowflake.snowpark.dataframe.pivot() with
# the `values` parameter as None or a Snowpark DataFrame.
Expand Down
1 change: 1 addition & 0 deletions tests/integ/modin/strings/test_get_dummies_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def test_get_dummies_series_negative(data):
)


@sql_count_checker(query_count=1)
def test_get_dummies_does_not_raise_pivot_warning_snow_1344848(caplog):
# Test get_dummies, which uses the `default_on_null` parameter of
# snowflake.snowpark.dataframe.pivot()
Expand Down
4 changes: 1 addition & 3 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from snowflake.snowpark import Column, Row, Window
from snowflake.snowpark._internal.analyzer.analyzer_utils import result_scan_statement
from snowflake.snowpark._internal.analyzer.expression import Attribute, Interval, Star
from snowflake.snowpark._internal.utils import TempObjectType, warning_dict
from snowflake.snowpark._internal.utils import TempObjectType
from snowflake.snowpark.exceptions import (
SnowparkColumnException,
SnowparkCreateDynamicTableException,
Expand Down Expand Up @@ -2719,8 +2719,6 @@ def test_write_temp_table_no_breaking_change(
Utils.assert_table_type(session, table_name, "temp")
finally:
Utils.drop_table(session, table_name)
# clear the warning dict otherwise it will affect the future tests
warning_dict.clear()


@pytest.mark.localtest
Expand Down
3 changes: 0 additions & 3 deletions tests/integ/test_pandas_to_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
TempObjectType,
is_in_stored_procedure,
random_name_for_temp_object,
warning_dict,
)
from snowflake.snowpark.exceptions import SnowparkPandasException
from tests.utils import Utils
Expand Down Expand Up @@ -304,8 +303,6 @@ def test_write_temp_table_no_breaking_change(session, table_type, caplog):
Utils.assert_table_type(session, table_name, "temp")
finally:
Utils.drop_table(session, table_name)
# clear the warning dict otherwise it will affect the future tests
warning_dict.clear()


@pytest.mark.localtest
Expand Down
19 changes: 6 additions & 13 deletions tests/integ/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@

from snowflake.connector.version import VERSION as SNOWFLAKE_CONNECTOR_VERSION
from snowflake.snowpark import Row, Session
from snowflake.snowpark._internal.utils import (
unwrap_stage_location_single_quote,
warning_dict,
)
from snowflake.snowpark._internal.utils import unwrap_stage_location_single_quote
from snowflake.snowpark.exceptions import (
SnowparkInvalidObjectNameException,
SnowparkSQLException,
Expand Down Expand Up @@ -2177,15 +2174,11 @@ def test_deprecate_call_udf_with_list(session, caplog):
return_type=IntegerType(),
input_types=[IntegerType(), IntegerType()],
)
try:
with caplog.at_level(logging.WARNING):
add_udf(["a", "b"])
assert (
"Passing arguments to a UDF with a list or tuple is deprecated"
in caplog.text
)
finally:
warning_dict.clear()
with caplog.at_level(logging.WARNING):
add_udf(["a", "b"])
assert (
"Passing arguments to a UDF with a list or tuple is deprecated" in caplog.text
)


def test_strict_udf(session):
Expand Down
76 changes: 32 additions & 44 deletions tests/unit/scala/test_utils_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
unwrap_stage_location_single_quote,
validate_object_name,
warning,
warning_dict,
zip_file_or_directory_to_stream,
)
from tests.utils import IS_WINDOWS, TestFiles
Expand Down Expand Up @@ -445,43 +444,36 @@ def test_warning(caplog):
def f():
return 1

try:
with caplog.at_level(logging.WARNING):
warning("aaa", "bbb", 2)
warning("aaa", "bbb", 2)
warning("aaa", "bbb", 2)
assert caplog.text.count("bbb") == 2
with caplog.at_level(logging.WARNING):
warning(f.__qualname__, "ccc", 2)
warning(f.__qualname__, "ccc", 2)
warning(f.__qualname__, "ccc", 2)
assert caplog.text.count("ccc") == 2
finally:
warning_dict.clear()
with caplog.at_level(logging.WARNING):
warning("aaa", "bbb", 2)
warning("aaa", "bbb", 2)
warning("aaa", "bbb", 2)
assert caplog.text.count("bbb") == 2
with caplog.at_level(logging.WARNING):
warning(f.__qualname__, "ccc", 2)
warning(f.__qualname__, "ccc", 2)
warning(f.__qualname__, "ccc", 2)
assert caplog.text.count("ccc") == 2


@pytest.mark.parametrize("decorator", [deprecated, experimental])
def test_func_decorator(caplog, decorator):
try:

@decorator(
version="1.0.0",
extra_warning_text="extra_warning_text",
extra_doc_string="extra_doc_string",
)
def f():
return 1
@decorator(
version="1.0.0",
extra_warning_text="extra_warning_text",
extra_doc_string="extra_doc_string",
)
def f():
return 1

assert "extra_doc_string" in f.__doc__
with caplog.at_level(logging.WARNING):
f()
f()
assert "extra_doc_string" in f.__doc__
with caplog.at_level(logging.WARNING):
f()
f()

assert decorator.__name__ in caplog.text
assert caplog.text.count("1.0.0") == 1
assert caplog.text.count("extra_warning_text") == 1
finally:
warning_dict.clear()
assert decorator.__name__ in caplog.text
assert caplog.text.count("1.0.0") == 1
assert caplog.text.count("extra_warning_text") == 1


def test_is_sql_select_statement():
Expand Down Expand Up @@ -564,18 +556,14 @@ def foo():
pass

caplog.clear()
warning_dict.clear()
try:
with caplog.at_level(logging.WARNING):
foo()
assert extra_doc in foo.__doc__
assert expected_warning_text in caplog.messages
caplog.clear()
with caplog.at_level(logging.WARNING):
foo()
assert expected_warning_text not in caplog.text
finally:
warning_dict.clear()
with caplog.at_level(logging.WARNING):
foo()
assert extra_doc in foo.__doc__
assert expected_warning_text in caplog.messages
caplog.clear()
with caplog.at_level(logging.WARNING):
foo()
assert expected_warning_text not in caplog.text


@pytest.mark.parametrize("function", [result_set_to_iter, result_set_to_rows])
Expand Down

0 comments on commit 47c269a

Please sign in to comment.