diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 67c1e604d43..a6f4b9c41f4 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -689,13 +689,6 @@ def to_local_iterator( When it is ``False``, this function executes the underlying queries of the dataframe asynchronously and returns an :class:`AsyncJob`. """ - from snowflake.snowpark.mock.connection import MockServerConnection - - if isinstance(self._session._conn, MockServerConnection): - raise NotImplementedError( - "[Local Testing] `DataFrame.to_local_iterator` is currently not supported." - ) - return self._session._conn.execute( self._plan, to_iter=True, diff --git a/src/snowflake/snowpark/mock/analyzer.py b/src/snowflake/snowpark/mock/analyzer.py index 288b80cccd4..a7b73abe8f3 100644 --- a/src/snowflake/snowpark/mock/analyzer.py +++ b/src/snowflake/snowpark/mock/analyzer.py @@ -479,11 +479,7 @@ def unary_expression_extractor( raise NotImplementedError( f"[Local Testing] Expression {type(expr.child).__name__} is not implemented." ) - expr_str = ( - expr.name - if expr.name - else self.analyze(expr.child, expr_to_alias, parse_local_name) - ) + expr_str = self.analyze(expr.child, expr_to_alias, parse_local_name) if parse_local_name: expr_str = expr_str.upper() return expr_str diff --git a/src/snowflake/snowpark/mock/connection.py b/src/snowflake/snowpark/mock/connection.py index d3804e8983b..cc36ae5a220 100644 --- a/src/snowflake/snowpark/mock/connection.py +++ b/src/snowflake/snowpark/mock/connection.py @@ -363,6 +363,9 @@ def execute( rows.append(row) elif isinstance(res, list): rows = res + + if to_iter: + return iter(rows) return rows @SnowflakePlan.Decorator.wrap_exception diff --git a/src/snowflake/snowpark/mock/functions.py b/src/snowflake/snowpark/mock/functions.py index f195dac73d5..3555787a9c3 100644 --- a/src/snowflake/snowpark/mock/functions.py +++ b/src/snowflake/snowpark/mock/functions.py @@ -20,6 +20,7 @@ DecimalType, DoubleType, LongType, + StringType, TimestampType, TimeType, _NumericType, @@ -406,3 +407,42 @@ def mock_iff(condition: ColumnEmulator, expr1: ColumnEmulator, expr2: ColumnEmul res.where(condition, other=expr2, inplace=True) res.where([not x for x in condition], other=expr1, inplace=True) return res + + +@patch("coalesce") +def mock_coalesce(*exprs): + import pandas + + if len(exprs) < 2: + raise SnowparkSQLException( + f"not enough arguments for function [COALESCE], got {len(exprs)}, expected at least two" + ) + res = pandas.Series( + exprs[0] + ) # workaround because sf_type is not inherited properly + for expr in exprs: + res = res.combine_first(expr) + return ColumnEmulator(data=res, sf_type=exprs[0].sf_type, dtype=object) + + +@patch("substring") +def mock_substring( + base_expr: ColumnEmulator, start_expr: ColumnEmulator, length_expr: ColumnEmulator +): + return base_expr.str.slice(start=start_expr - 1, stop=start_expr - 1 + length_expr) + + +@patch("startswith") +def mock_startswith(expr1: ColumnEmulator, expr2: ColumnEmulator): + res = expr1.str.startswith(expr2) + res.sf_type = ColumnType(StringType(), expr1.sf_type.nullable) + return res + + +@patch("endswith") +def mock_endswith(expr1: ColumnEmulator, expr2: ColumnEmulator): + res = expr1.str.endswith(expr2) + res.sf_type = ColumnType(StringType(), expr1.sf_type.nullable) + return res + + \ No newline at end of file diff --git a/tests/integ/scala/test_function_suite.py b/tests/integ/scala/test_function_suite.py index 85b2be7c9bd..aa97e9acca2 100644 --- a/tests/integ/scala/test_function_suite.py +++ b/tests/integ/scala/test_function_suite.py @@ -398,6 +398,7 @@ def test_variance(session): ) +@pytest.mark.localtest def test_coalesce(session): Utils.check_answer( TestData.null_data2(session).select(coalesce(col("A"), col("B"), col("C"))), @@ -540,6 +541,7 @@ def test_builtin_function(session): ) +@pytest.mark.localtest def test_sub_string(session): Utils.check_answer( TestData.string1(session).select(substring(col("A"), lit(2), lit(4))), diff --git a/tests/integ/scala/test_large_dataframe_suite.py b/tests/integ/scala/test_large_dataframe_suite.py index 8f0d042d554..a943d60c02a 100644 --- a/tests/integ/scala/test_large_dataframe_suite.py +++ b/tests/integ/scala/test_large_dataframe_suite.py @@ -33,11 +33,6 @@ ) -@pytest.mark.xfail( - condition="config.getvalue('local_testing_mode')", - raises=NotImplementedError, - strict=True, -) @pytest.mark.xfail(reason="SNOW-754118 flaky test", strict=False) def test_to_local_iterator_should_not_load_all_data_at_once(session): df = ( diff --git a/tests/integ/test_column.py b/tests/integ/test_column.py index 9023fead007..fab7f7a2af7 100644 --- a/tests/integ/test_column.py +++ b/tests/integ/test_column.py @@ -121,11 +121,7 @@ def test_cast_array_type(session): assert json.loads(result[0][0]) == [1, 2, 3] -@pytest.mark.xfail( - condition="config.getvalue('local_testing_mode')", - raises=NotImplementedError, - strict=True, -) +@pytest.mark.localtest def test_startswith(session): Utils.check_answer( TestData.string4(session).select(col("a").startswith(lit("a"))), @@ -134,11 +130,7 @@ def test_startswith(session): ) -@pytest.mark.xfail( - condition="config.getvalue('local_testing_mode')", - raises=NotImplementedError, - strict=True, -) +@pytest.mark.localtest def test_endswith(session): Utils.check_answer( TestData.string4(session).select(col("a").endswith(lit("ana"))), @@ -147,11 +139,6 @@ def test_endswith(session): ) -@pytest.mark.xfail( - condition="config.getvalue('local_testing_mode')", - raises=NotImplementedError, - strict=True, -) def test_substring(session): Utils.check_answer( TestData.string4(session).select( diff --git a/tests/integ/test_function.py b/tests/integ/test_function.py index f07656dd83f..7279bacb864 100644 --- a/tests/integ/test_function.py +++ b/tests/integ/test_function.py @@ -1073,6 +1073,7 @@ def test_to_binary(session): assert res == [Row(None), Row(None), Row(None), Row(None)] +@pytest.mark.localtest def test_coalesce(session): # Taken from FunctionSuite.scala Utils.check_answer(