Skip to content

Commit

Permalink
Merge branch 'dev/local-testing' into local/support-intersect
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-stan authored Oct 6, 2023
2 parents d3ef22a + 61e0497 commit 4e41044
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 32 deletions.
7 changes: 0 additions & 7 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,13 +689,6 @@ def to_local_iterator(
When it is ``False``, this function executes the underlying queries of the dataframe
asynchronously and returns an :class:`AsyncJob`.
"""
from snowflake.snowpark.mock.connection import MockServerConnection

if isinstance(self._session._conn, MockServerConnection):
raise NotImplementedError(
"[Local Testing] `DataFrame.to_local_iterator` is currently not supported."
)

return self._session._conn.execute(
self._plan,
to_iter=True,
Expand Down
6 changes: 1 addition & 5 deletions src/snowflake/snowpark/mock/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,11 +479,7 @@ def unary_expression_extractor(
raise NotImplementedError(
f"[Local Testing] Expression {type(expr.child).__name__} is not implemented."
)
expr_str = (
expr.name
if expr.name
else self.analyze(expr.child, expr_to_alias, parse_local_name)
)
expr_str = self.analyze(expr.child, expr_to_alias, parse_local_name)
if parse_local_name:
expr_str = expr_str.upper()
return expr_str
Expand Down
3 changes: 3 additions & 0 deletions src/snowflake/snowpark/mock/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@ def execute(
rows.append(row)
elif isinstance(res, list):
rows = res

if to_iter:
return iter(rows)
return rows

@SnowflakePlan.Decorator.wrap_exception
Expand Down
40 changes: 40 additions & 0 deletions src/snowflake/snowpark/mock/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DecimalType,
DoubleType,
LongType,
StringType,
TimestampType,
TimeType,
_NumericType,
Expand Down Expand Up @@ -406,3 +407,42 @@ def mock_iff(condition: ColumnEmulator, expr1: ColumnEmulator, expr2: ColumnEmul
res.where(condition, other=expr2, inplace=True)
res.where([not x for x in condition], other=expr1, inplace=True)
return res


@patch("coalesce")
def mock_coalesce(*exprs):
import pandas

if len(exprs) < 2:
raise SnowparkSQLException(
f"not enough arguments for function [COALESCE], got {len(exprs)}, expected at least two"
)
res = pandas.Series(
exprs[0]
) # workaround because sf_type is not inherited properly
for expr in exprs:
res = res.combine_first(expr)
return ColumnEmulator(data=res, sf_type=exprs[0].sf_type, dtype=object)


@patch("substring")
def mock_substring(
base_expr: ColumnEmulator, start_expr: ColumnEmulator, length_expr: ColumnEmulator
):
return base_expr.str.slice(start=start_expr - 1, stop=start_expr - 1 + length_expr)


@patch("startswith")
def mock_startswith(expr1: ColumnEmulator, expr2: ColumnEmulator):
res = expr1.str.startswith(expr2)
res.sf_type = ColumnType(StringType(), expr1.sf_type.nullable)
return res


@patch("endswith")
def mock_endswith(expr1: ColumnEmulator, expr2: ColumnEmulator):
res = expr1.str.endswith(expr2)
res.sf_type = ColumnType(StringType(), expr1.sf_type.nullable)
return res


2 changes: 2 additions & 0 deletions tests/integ/scala/test_function_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def test_variance(session):
)


@pytest.mark.localtest
def test_coalesce(session):
Utils.check_answer(
TestData.null_data2(session).select(coalesce(col("A"), col("B"), col("C"))),
Expand Down Expand Up @@ -540,6 +541,7 @@ def test_builtin_function(session):
)


@pytest.mark.localtest
def test_sub_string(session):
Utils.check_answer(
TestData.string1(session).select(substring(col("A"), lit(2), lit(4))),
Expand Down
5 changes: 0 additions & 5 deletions tests/integ/scala/test_large_dataframe_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@
)


@pytest.mark.xfail(
condition="config.getvalue('local_testing_mode')",
raises=NotImplementedError,
strict=True,
)
@pytest.mark.xfail(reason="SNOW-754118 flaky test", strict=False)
def test_to_local_iterator_should_not_load_all_data_at_once(session):
df = (
Expand Down
17 changes: 2 additions & 15 deletions tests/integ/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,7 @@ def test_cast_array_type(session):
assert json.loads(result[0][0]) == [1, 2, 3]


@pytest.mark.xfail(
condition="config.getvalue('local_testing_mode')",
raises=NotImplementedError,
strict=True,
)
@pytest.mark.localtest
def test_startswith(session):
Utils.check_answer(
TestData.string4(session).select(col("a").startswith(lit("a"))),
Expand All @@ -134,11 +130,7 @@ def test_startswith(session):
)


@pytest.mark.xfail(
condition="config.getvalue('local_testing_mode')",
raises=NotImplementedError,
strict=True,
)
@pytest.mark.localtest
def test_endswith(session):
Utils.check_answer(
TestData.string4(session).select(col("a").endswith(lit("ana"))),
Expand All @@ -147,11 +139,6 @@ def test_endswith(session):
)


@pytest.mark.xfail(
condition="config.getvalue('local_testing_mode')",
raises=NotImplementedError,
strict=True,
)
def test_substring(session):
Utils.check_answer(
TestData.string4(session).select(
Expand Down
1 change: 1 addition & 0 deletions tests/integ/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,7 @@ def test_to_binary(session):
assert res == [Row(None), Row(None), Row(None), Row(None)]


@pytest.mark.localtest
def test_coalesce(session):
# Taken from FunctionSuite.scala
Utils.check_answer(
Expand Down

0 comments on commit 4e41044

Please sign in to comment.