From 94869ebca9a3c194d097e7df9857ff03d9657c85 Mon Sep 17 00:00:00 2001 From: Luis Nino Date: Fri, 27 Sep 2024 17:58:46 -0500 Subject: [PATCH] Improved `snowflake.snowpark.functions` to support double and varchar. --- CHANGELOG.md | 1 + src/snowflake/snowpark/functions.py | 3 +- tests/integ/scala/test_function_suite.py | 97 +++++++++++++++++++++--- 3 files changed, 89 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8be1b26e59..b526efe9ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ - Added support for file writes. This feature is currently in private preview. #### Improvements +- Improved `snowflake.snowpark.functions` to support double and varchar. #### Bug Fixes diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index 27691392aa..7e744fdfbe 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -5381,7 +5381,8 @@ def array_remove(array: ColumnOrName, element: ColumnOrLiteral) -> Column: - `ARRAY `_ for more details on semi-structured arrays. """ a = _to_col_if_str(array, "array_remove") - return builtin("array_remove")(a, element) + e = lit(element).cast("VARIANT") if isinstance(element, str) else element + return builtin("array_remove")(a, e) def array_cat(array1: ColumnOrName, array2: ColumnOrName) -> Column: diff --git a/tests/integ/scala/test_function_suite.py b/tests/integ/scala/test_function_suite.py index bd667f2a8c..bffe71f3c0 100644 --- a/tests/integ/scala/test_function_suite.py +++ b/tests/integ/scala/test_function_suite.py @@ -2829,25 +2829,100 @@ def test_array_append(session): reason="array_remove is not yet supported in local testing mode.", ) def test_array_remove(session): - Utils.check_answer( + actual = session.createDataFrame([([1, 2, 4, 4, 3],), ([],)], ["data"]) + actual = actual.select(array_remove(actual.data, 4)) + expected = session.createDataFrame( [ - Row("[\n 2,\n 3\n]"), - Row("[\n 6,\n 7\n]"), + Row("[\n 1,\n 2,\n 3\n]"), + Row("[]"), ], - TestData.array1(session).select( - array_remove(array_remove(col("arr1"), lit(1)), lit(8)) - ), - sort=False, + ["data"], ) + Utils.check_answer(actual, expected) - Utils.check_answer( + actual = session.createDataFrame([(["a", "b", "c", "a", "a"],), ([],)], ["data"]) + actual = actual.select(array_remove(actual.data, "a")) + expected = session.createDataFrame( + [ + Row('[\n "b",\n "c"\n]'), + Row("[]"), + ], + ["data"], + ) + Utils.check_answer(actual, expected) + + actual = session.createDataFrame( + [(["apple", "banana", "apple", "orange"],), ([],)], ["data"] + ) + actual = actual.select(array_remove(actual.data, "apple")) + expected = session.createDataFrame( + [ + Row('[\n "banana",\n "orange"\n]'), + Row("[]"), + ], + ["data"], + ) + Utils.check_answer(actual, expected) + + actual = session.createDataFrame([([1, "2", 3.1, 1, 3],), ([],)], ["data"]) + actual = actual.select(array_remove(actual.data, 1)) + expected = session.createDataFrame( + [ + Row('[\n "2",\n 3.1,\n 3\n]'), + Row("[]"), + ], + ["data"], + ) + Utils.check_answer(actual, expected) + + actual = session.createDataFrame([(["@", ";", "3.1", 1, 5 / 3],), ([],)], ["data"]) + actual = actual.select(array_remove(actual.data, 1)) + expected = session.createDataFrame( + [ + Row('[\n "@",\n ";",\n "3.1",\n 1.6666666666666667\n]'), + Row("[]"), + ], + ["data"], + ) + Utils.check_answer(actual, expected) + + actual = session.createDataFrame([([-1, -2, -4, -4, -3],), ([],)], ["data"]) + actual = actual.select(array_remove(actual.data, 1)) + expected = session.createDataFrame( + [ + Row("[\n -1,\n -2,\n -4,\n -4,\n -3\n]"), + Row("[]"), + ], + ["data"], + ) + Utils.check_answer(actual, expected) + + actual = session.createDataFrame([([4.4, 5.5, 1.1],), ([],)], ["data"]) + actual = actual.select(array_remove(actual.data, 5.5)) + expected = session.createDataFrame( + [ + Row("[\n 4.4,\n 1.1\n]"), + Row("[]"), + ], + ["data"], + ) + Utils.check_answer(actual, expected) + + actual = TestData.array1(session).select( + array_remove(array_remove(col("arr1"), lit(1)), lit(8)) + ) + + expected = session.createDataFrame( [ Row("[\n 2,\n 3\n]"), Row("[\n 6,\n 7\n]"), ], - TestData.array1(session).select( - array_remove(array_remove(col("arr1"), 1), lit(8)) - ), + ["data"], + ) + + Utils.check_answer( + actual, + expected, sort=False, )