From 3f3f30528daabe0cb4031a229fc81457bfb8b5dd Mon Sep 17 00:00:00 2001 From: Sophie Tan Date: Wed, 4 Oct 2023 18:20:34 -0400 Subject: [PATCH 1/5] [Local Testing] SNOW-904988 Support Functions.substr and Column.substr (#1070) --- src/snowflake/snowpark/mock/analyzer.py | 6 +----- src/snowflake/snowpark/mock/functions.py | 7 +++++++ tests/integ/scala/test_function_suite.py | 1 + tests/integ/test_column.py | 5 ----- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/snowflake/snowpark/mock/analyzer.py b/src/snowflake/snowpark/mock/analyzer.py index 288b80cccd4..a7b73abe8f3 100644 --- a/src/snowflake/snowpark/mock/analyzer.py +++ b/src/snowflake/snowpark/mock/analyzer.py @@ -479,11 +479,7 @@ def unary_expression_extractor( raise NotImplementedError( f"[Local Testing] Expression {type(expr.child).__name__} is not implemented." ) - expr_str = ( - expr.name - if expr.name - else self.analyze(expr.child, expr_to_alias, parse_local_name) - ) + expr_str = self.analyze(expr.child, expr_to_alias, parse_local_name) if parse_local_name: expr_str = expr_str.upper() return expr_str diff --git a/src/snowflake/snowpark/mock/functions.py b/src/snowflake/snowpark/mock/functions.py index f195dac73d5..f0600a7189e 100644 --- a/src/snowflake/snowpark/mock/functions.py +++ b/src/snowflake/snowpark/mock/functions.py @@ -406,3 +406,10 @@ def mock_iff(condition: ColumnEmulator, expr1: ColumnEmulator, expr2: ColumnEmul res.where(condition, other=expr2, inplace=True) res.where([not x for x in condition], other=expr1, inplace=True) return res + + +@patch("substring") +def mock_substring( + base_expr: ColumnEmulator, start_expr: ColumnEmulator, length_expr: ColumnEmulator +): + return base_expr.str.slice(start=start_expr - 1, stop=start_expr - 1 + length_expr) diff --git a/tests/integ/scala/test_function_suite.py b/tests/integ/scala/test_function_suite.py index 85b2be7c9bd..5322cbc6ebe 100644 --- a/tests/integ/scala/test_function_suite.py +++ b/tests/integ/scala/test_function_suite.py @@ -540,6 +540,7 @@ def test_builtin_function(session): ) +@pytest.mark.localtest def test_sub_string(session): Utils.check_answer( TestData.string1(session).select(substring(col("A"), lit(2), lit(4))), diff --git a/tests/integ/test_column.py b/tests/integ/test_column.py index 9023fead007..73c6efcfd0d 100644 --- a/tests/integ/test_column.py +++ b/tests/integ/test_column.py @@ -147,11 +147,6 @@ def test_endswith(session): ) -@pytest.mark.xfail( - condition="config.getvalue('local_testing_mode')", - raises=NotImplementedError, - strict=True, -) def test_substring(session): Utils.check_answer( TestData.string4(session).select( From 8837989b9392912573dd5598d8ffd153d544b547 Mon Sep 17 00:00:00 2001 From: Sophie Tan Date: Thu, 5 Oct 2023 10:59:40 -0400 Subject: [PATCH 2/5] [Local Testing] SNOW-904865 Support DataFrame.toLocalIterator --- src/snowflake/snowpark/dataframe.py | 7 ------- src/snowflake/snowpark/mock/connection.py | 3 +++ tests/integ/scala/test_large_dataframe_suite.py | 5 ----- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 67c1e604d43..a6f4b9c41f4 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -689,13 +689,6 @@ def to_local_iterator( When it is ``False``, this function executes the underlying queries of the dataframe asynchronously and returns an :class:`AsyncJob`. """ - from snowflake.snowpark.mock.connection import MockServerConnection - - if isinstance(self._session._conn, MockServerConnection): - raise NotImplementedError( - "[Local Testing] `DataFrame.to_local_iterator` is currently not supported." - ) - return self._session._conn.execute( self._plan, to_iter=True, diff --git a/src/snowflake/snowpark/mock/connection.py b/src/snowflake/snowpark/mock/connection.py index d3804e8983b..cc36ae5a220 100644 --- a/src/snowflake/snowpark/mock/connection.py +++ b/src/snowflake/snowpark/mock/connection.py @@ -363,6 +363,9 @@ def execute( rows.append(row) elif isinstance(res, list): rows = res + + if to_iter: + return iter(rows) return rows @SnowflakePlan.Decorator.wrap_exception diff --git a/tests/integ/scala/test_large_dataframe_suite.py b/tests/integ/scala/test_large_dataframe_suite.py index 8f0d042d554..a943d60c02a 100644 --- a/tests/integ/scala/test_large_dataframe_suite.py +++ b/tests/integ/scala/test_large_dataframe_suite.py @@ -33,11 +33,6 @@ ) -@pytest.mark.xfail( - condition="config.getvalue('local_testing_mode')", - raises=NotImplementedError, - strict=True, -) @pytest.mark.xfail(reason="SNOW-754118 flaky test", strict=False) def test_to_local_iterator_should_not_load_all_data_at_once(session): df = ( From f994e0a81fb18a92220463f9b8ee66949ef1ca27 Mon Sep 17 00:00:00 2001 From: Sophie Tan Date: Fri, 6 Oct 2023 13:32:29 -0400 Subject: [PATCH 3/5] [Local Testing] SNOW-923345 Support functions.coalesce --- src/snowflake/snowpark/mock/functions.py | 17 +++++++++++++++++ tests/integ/scala/test_function_suite.py | 1 + tests/integ/test_function.py | 1 + 3 files changed, 19 insertions(+) diff --git a/src/snowflake/snowpark/mock/functions.py b/src/snowflake/snowpark/mock/functions.py index f0600a7189e..85903359c4d 100644 --- a/src/snowflake/snowpark/mock/functions.py +++ b/src/snowflake/snowpark/mock/functions.py @@ -408,8 +408,25 @@ def mock_iff(condition: ColumnEmulator, expr1: ColumnEmulator, expr2: ColumnEmul return res +@patch("coalesce") +def mock_coalesce(*exprs): + import pandas + + if len(exprs) < 2: + raise SnowparkSQLException( + f"not enough arguments for function [COALESCE], got {len(exprs)}, expected at least two" + ) + res = pandas.Series( + exprs[0] + ) # workaround because sf_type is not inherited properly + for expr in exprs: + res = res.combine_first(expr) + return ColumnEmulator(data=res, sf_type=exprs[0].sf_type, dtype=object) + + @patch("substring") def mock_substring( base_expr: ColumnEmulator, start_expr: ColumnEmulator, length_expr: ColumnEmulator ): return base_expr.str.slice(start=start_expr - 1, stop=start_expr - 1 + length_expr) + diff --git a/tests/integ/scala/test_function_suite.py b/tests/integ/scala/test_function_suite.py index 5322cbc6ebe..aa97e9acca2 100644 --- a/tests/integ/scala/test_function_suite.py +++ b/tests/integ/scala/test_function_suite.py @@ -398,6 +398,7 @@ def test_variance(session): ) +@pytest.mark.localtest def test_coalesce(session): Utils.check_answer( TestData.null_data2(session).select(coalesce(col("A"), col("B"), col("C"))), diff --git a/tests/integ/test_function.py b/tests/integ/test_function.py index f07656dd83f..7279bacb864 100644 --- a/tests/integ/test_function.py +++ b/tests/integ/test_function.py @@ -1073,6 +1073,7 @@ def test_to_binary(session): assert res == [Row(None), Row(None), Row(None), Row(None)] +@pytest.mark.localtest def test_coalesce(session): # Taken from FunctionSuite.scala Utils.check_answer( From 9695ab881a6321035f6479f276be0ea63c63d86a Mon Sep 17 00:00:00 2001 From: Sophie Tan Date: Fri, 6 Oct 2023 13:33:18 -0400 Subject: [PATCH 4/5] [Local Testing] SNOW-904987 Support Column.startswith and Column.endswith (#1071) --- src/snowflake/snowpark/mock/functions.py | 16 ++++++++++++++++ tests/integ/test_column.py | 12 ++---------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/snowflake/snowpark/mock/functions.py b/src/snowflake/snowpark/mock/functions.py index 85903359c4d..3555787a9c3 100644 --- a/src/snowflake/snowpark/mock/functions.py +++ b/src/snowflake/snowpark/mock/functions.py @@ -20,6 +20,7 @@ DecimalType, DoubleType, LongType, + StringType, TimestampType, TimeType, _NumericType, @@ -430,3 +431,18 @@ def mock_substring( ): return base_expr.str.slice(start=start_expr - 1, stop=start_expr - 1 + length_expr) + +@patch("startswith") +def mock_startswith(expr1: ColumnEmulator, expr2: ColumnEmulator): + res = expr1.str.startswith(expr2) + res.sf_type = ColumnType(StringType(), expr1.sf_type.nullable) + return res + + +@patch("endswith") +def mock_endswith(expr1: ColumnEmulator, expr2: ColumnEmulator): + res = expr1.str.endswith(expr2) + res.sf_type = ColumnType(StringType(), expr1.sf_type.nullable) + return res + + \ No newline at end of file diff --git a/tests/integ/test_column.py b/tests/integ/test_column.py index 73c6efcfd0d..fab7f7a2af7 100644 --- a/tests/integ/test_column.py +++ b/tests/integ/test_column.py @@ -121,11 +121,7 @@ def test_cast_array_type(session): assert json.loads(result[0][0]) == [1, 2, 3] -@pytest.mark.xfail( - condition="config.getvalue('local_testing_mode')", - raises=NotImplementedError, - strict=True, -) +@pytest.mark.localtest def test_startswith(session): Utils.check_answer( TestData.string4(session).select(col("a").startswith(lit("a"))), @@ -134,11 +130,7 @@ def test_startswith(session): ) -@pytest.mark.xfail( - condition="config.getvalue('local_testing_mode')", - raises=NotImplementedError, - strict=True, -) +@pytest.mark.localtest def test_endswith(session): Utils.check_answer( TestData.string4(session).select(col("a").endswith(lit("ana"))), From 61e0497b11af7434cfbfb53ff430654fbd53c39c Mon Sep 17 00:00:00 2001 From: Sophie Tan Date: Fri, 6 Oct 2023 14:25:49 -0400 Subject: [PATCH 5/5] [Local Testing] SNOW-904261 Support DataFrame.except_ --- .../snowpark/_internal/type_utils.py | 21 ++++++---- src/snowflake/snowpark/mock/plan.py | 42 +++++++++++++------ .../test_dataframe_set_operations_suite.py | 16 ++----- 3 files changed, 45 insertions(+), 34 deletions(-) diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index 43e7e3dc8ce..aafdfd9988b 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -96,7 +96,7 @@ def convert_sf_to_sp_type( return StringType(internal_size) elif internal_size == 0: return StringType() - raise ValueError(f"Negative value is not a valid input for StringType") + raise ValueError("Negative value is not a valid input for StringType") if column_type_name == "TIME": return TimeType() if column_type_name in ( @@ -327,7 +327,7 @@ def infer_schema( fields = [] for k, v in items: try: - fields.append(StructField(k, infer_type(v), True)) + fields.append(StructField(k, infer_type(v), v is None)) except TypeError as e: raise TypeError(f"Unable to infer the type of the field {k}.") from e return StructType(fields) @@ -347,22 +347,26 @@ def merge_type(a: DataType, b: DataType, name: Optional[str] = None) -> DataType # same type if isinstance(a, StructType): - nfs = {f.name: f.datatype for f in b.fields} + name_to_datatype_b = {f.name: f.datatype for f in b.fields} + name_to_nullable_b = {f.name: f.nullable for f in b.fields} fields = [ StructField( f.name, merge_type( f.datatype, - nfs.get(f.name, NullType()), + name_to_datatype_b.get(f.name, NullType()), name=f"field {f.name} in {name}" if name else f"field {f.name}", ), + f.nullable or name_to_nullable_b.get(f.name, True), ) for f in a.fields ] names = {f.name for f in fields} - for n in nfs: + for n in name_to_datatype_b: if n not in names: - fields.append(StructField(n, nfs[n])) + fields.append( + StructField(n, name_to_datatype_b[n], name_to_nullable_b[n]) + ) return StructType(fields) elif isinstance(a, ArrayType): @@ -404,6 +408,7 @@ def python_type_to_snow_type(tp: Union[str, Type]) -> Tuple[DataType, bool]: Returns a Snowpark type and whether it's nullable. """ from snowflake.snowpark.dataframe import DataFrame + # convert a type string to a type object if isinstance(tp, str): tp = python_type_str_to_object(tp) @@ -617,9 +622,7 @@ def get_data_type_string_object_mappings( ) # support type string format like " decimal ( 2 , 1 ) " -STRING_RE = re.compile( - r"^\s*(varchar|string|text)\s*\(\s*(\d*)\s*\)\s*$" -) +STRING_RE = re.compile(r"^\s*(varchar|string|text)\s*\(\s*(\d*)\s*\)\s*$") # support type string format like " string ( 23 ) " diff --git a/src/snowflake/snowpark/mock/plan.py b/src/snowflake/snowpark/mock/plan.py index d690b5e2745..ef3701b7ade 100644 --- a/src/snowflake/snowpark/mock/plan.py +++ b/src/snowflake/snowpark/mock/plan.py @@ -21,6 +21,7 @@ import snowflake.snowpark.mock.file_operation as mock_file_operation from snowflake.snowpark import Column, Row from snowflake.snowpark._internal.analyzer.analyzer_utils import ( + EXCEPT, UNION, UNION_ALL, quote_name, @@ -125,10 +126,9 @@ def __init__( self.expr_to_alias = expr_to_alias if expr_to_alias is not None else {} self.api_calls = [] - @cached_property + # @cached_property + @property def attributes(self) -> List[Attribute]: - # output = analyze_attributes(self.schema_query, self.session) - # self.schema_query = schema_value_statement(output) output = describe(self) return output @@ -283,16 +283,16 @@ def execute_mock_plan( for i in range(1, len(source_plan.set_operands)): operand = source_plan.set_operands[i] operator = operand.operator - if operator in (UNION, UNION_ALL): - cur_df = execute_mock_plan( - MockExecutionPlan(operand.selectable, source_plan.analyzer.session), - expr_to_alias, + cur_df = execute_mock_plan( + MockExecutionPlan(operand.selectable, source_plan.analyzer.session), + expr_to_alias, + ) + if len(res_df.columns) != len(cur_df.columns): + raise SnowparkSQLException( + f"SQL compilation error: invalid number of result columns for set operator input branches, expected {len(res_df.columns)}, got {len(cur_df.columns)} in branch {i + 1}" ) - if len(res_df.columns) != len(cur_df.columns): - raise SnowparkSQLException( - f"SQL compilation error: invalid number of result columns for set operator input branches, expected {len(res_df.columns)}, got {len(cur_df.columns)} in branch {i + 1}" - ) - cur_df.columns = res_df.columns + cur_df.columns = res_df.columns + if operator in (UNION, UNION_ALL): res_df = pd.concat([res_df, cur_df], ignore_index=True) res_df = ( res_df.drop_duplicates().reset_index(drop=True) @@ -300,6 +300,20 @@ def execute_mock_plan( else res_df ) res_df.sf_types = cur_df.sf_types + elif operator == EXCEPT: + # Dedup all none rows + if res_df.isnull().all(axis=1).where(lambda x: x).count() > 1: + res_df = res_df.drop(index=res_df.isnull().all(axis=1).index[1:]) + # If there are all none rows in cur_df, drop all none rows in res_df + if ( + cur_df.isnull().all(axis=1).any() + and res_df.isnull().all(axis=1).any() + ): + res_df = res_df[~res_df.isnull().all(axis=1).values] + # Compute NOT IS IN and drop duplicates + res_df = res_df[ + ~(res_df.isin(cur_df.values.ravel()).all(axis=1)).values + ].drop_duplicates() else: raise NotImplementedError( f"[Local Testing] SetStatement operator {operator} is currently not implemented." @@ -666,7 +680,9 @@ def describe(plan: MockExecutionPlan) -> List[Attribute]: Attribute( result[c].name, data_type, - bool(any([bool(item is None) for item in result[c]])), + result[ + c + ].sf_type.nullable, # bool(any([bool(item is None) for item in result[c]])) ) ) return ret diff --git a/tests/integ/scala/test_dataframe_set_operations_suite.py b/tests/integ/scala/test_dataframe_set_operations_suite.py index ce23b07c98d..3a3cf07080a 100644 --- a/tests/integ/scala/test_dataframe_set_operations_suite.py +++ b/tests/integ/scala/test_dataframe_set_operations_suite.py @@ -84,11 +84,7 @@ def check(new_col: Column, cfilter: Column, result: List[Row]): check(lit(2).cast(IntegerType()), col("c") != 2, list()) -@pytest.mark.xfail( - condition="config.getvalue('local_testing_mode')", - raises=NotImplementedError, - strict=True, -) +@pytest.mark.localtest def test_except(session): lower_case_data = TestData.lower_case_data(session) upper_case_data = TestData.upper_case_data(session) @@ -129,11 +125,7 @@ def test_except(session): Utils.check_answer(all_nulls.filter(lit(0) == 1).except_(all_nulls), []) -@pytest.mark.xfail( - condition="config.getvalue('local_testing_mode')", - raises=NotImplementedError, - strict=True, -) +@pytest.mark.localtest def test_except_between_two_projects_without_references_used_in_filter(session): df = session.create_dataframe(((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5))).to_df( "a", "b", "c" @@ -411,13 +403,13 @@ def test_project_should_not_be_pushed_down_through_intersect_or_except(session): assert df1.except_(df2).count() == 70 -# TODO: Fix this, `MockExecutionPlan.attributes` are ignoring nullability for now def test_except_nullability(session): - non_nullable_ints = session.create_dataframe(((11,), (3,))).to_df("a") + non_nullable_ints = session.create_dataframe(((11,), (3,))).to_df(["a"]) for attribute in non_nullable_ints.schema._to_attributes(): assert not attribute.nullable null_ints = TestData.null_ints(session) + df1 = non_nullable_ints.except_(null_ints) Utils.check_answer(df1, Row(11)) for attribute in df1.schema._to_attributes():