diff --git a/CHANGELOG.md b/CHANGELOG.md index 7968f2c797d..4f632d735f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,11 @@ - Added support for an optional `date_part` argument in function `last_day` - `SessionBuilder.app_name` will set the query_tag after the session is created. +- Added support for the following local testing functions: + - upper + - lower + - length + - initcap ### Improvements diff --git a/src/snowflake/snowpark/mock/_functions.py b/src/snowflake/snowpark/mock/_functions.py index 70a6c4b179d..7a7ad195ca7 100644 --- a/src/snowflake/snowpark/mock/_functions.py +++ b/src/snowflake/snowpark/mock/_functions.py @@ -6,6 +6,7 @@ import datetime import json import math +import string from decimal import Decimal from functools import partial from numbers import Real @@ -843,13 +844,6 @@ def mock_row_number(window: TableEmulator, row_idx: int): return ColumnEmulator(data=[row_idx + 1], sf_type=ColumnType(LongType(), False)) -@patch("upper") -def mock_upper(expr: ColumnEmulator): - res = expr.apply(lambda x: x.upper()) - res.sf_type = ColumnType(StringType(), expr.sf_type.nullable) - return res - - @patch("parse_json") def mock_parse_json(expr: ColumnEmulator): from snowflake.snowpark.mock import CUSTOM_JSON_DECODER @@ -929,3 +923,48 @@ def mock_to_variant(expr: ColumnEmulator): res = expr.copy() res.sf_type = ColumnType(VariantType(), expr.sf_type.nullable) return res + + +@patch("upper") +def mock_upper(expr: ColumnEmulator): + return expr.str.upper() + + +@patch("lower") +def mock_lower(expr: ColumnEmulator): + return expr.str.lower() + + +@patch("length") +def mock_length(expr: ColumnEmulator): + result = expr.str.len() + result.sf_type = ColumnType(LongType(), nullable=expr.sf_type.nullable) + return result + + +# See https://docs.snowflake.com/en/sql-reference/functions/initcap for list of delimiters +DEFAULT_INITCAP_DELIMITERS = set('!?@"^#$&~_,.:;+-*%/|\\[](){}<>' + string.whitespace) + + +def _initcap(value: Optional[str], delimiters: Optional[str]) -> str: + if value is None: + return None + + delims = DEFAULT_INITCAP_DELIMITERS if delimiters is None else set(delimiters) + + result = "" + cap = True + for char in value: + if cap: + result += char.upper() + else: + result += char.lower() + cap = char in delims + return result + + +@patch("initcap") +def mock_initcap(values: ColumnEmulator, delimiters: ColumnEmulator): + result = values.combine(delimiters, _initcap) + result.sf_type = values.sf_type + return result diff --git a/tests/integ/scala/test_function_suite.py b/tests/integ/scala/test_function_suite.py index 5d68f40569d..2b0afb0d63a 100644 --- a/tests/integ/scala/test_function_suite.py +++ b/tests/integ/scala/test_function_suite.py @@ -6,6 +6,7 @@ import json from datetime import date, datetime, time from decimal import Decimal +from functools import partial import pytest import pytz @@ -3256,17 +3257,71 @@ def test_ascii(session, col_B): ) -@pytest.mark.parametrize("col_A", ["A", col("A")]) -def test_initcap_length_lower_upper(session, col_A): - Utils.check_answer( - TestData.string2(session).select( - initcap(col_A), length(col_A), lower(col_A), upper(col_A) +@pytest.mark.localtest +@pytest.mark.parametrize( + "func,expected", + [ + ( + initcap, + [ + Row( + "Foo-Bar;Baz", + "Qwer,Dvor>Azer", + "Lower", + "Upper", + "Chief Variable Officer", + "Lorem Ipsum Dolor Sit Amet", + ) + ], ), - [ - Row("Asdfg", 5, "asdfg", "ASDFG"), - Row("Qqq", 3, "qqq", "QQQ"), - Row("Qw", 2, "qw", "QW"), - ], + ( + partial(initcap, delimiters=lit("-")), + [ + Row( + "Foo-Bar;baz", + "Qwer,dvor>azer", + "Lower", + "Upper", + "Chief variable officer", + "Lorem ipsum dolor sit amet", + ) + ], + ), + (length, [Row(11, 14, 5, 5, 22, 26)]), + ( + lower, + [ + Row( + "foo-bar;baz", + "qwer,dvor>azer", + "lower", + "upper", + "chief variable officer", + "lorem ipsum dolor sit amet", + ) + ], + ), + ( + upper, + [ + Row( + "FOO-BAR;BAZ", + "QWER,DVOR>AZER", + "LOWER", + "UPPER", + "CHIEF VARIABLE OFFICER", + "LOREM IPSUM DOLOR SIT AMET", + ) + ], + ), + ], +) +@pytest.mark.parametrize("use_col", [True, False]) +def test_initcap_length_lower_upper(func, expected, use_col, session): + df = TestData.string8(session) + Utils.check_answer( + df.select(*[func(col(c) if use_col else c) for c in df.columns]), + expected, sort=False, ) diff --git a/tests/utils.py b/tests/utils.py index 1316c9e0785..e6132aad19f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -585,6 +585,22 @@ def string6(cls, session: "Session") -> DataFrame: def string7(cls, session: "Session") -> DataFrame: return session.create_dataframe([["str", 1], [None, 2]], schema=["a", "b"]) + @classmethod + def string8(cls, session: "Session") -> DataFrame: + return session.create_dataframe( + [ + ( + "foo-bar;baz", + "qwer,dvor>azer", + "lower", + "UPPER", + "Chief Variable Officer", + "Lorem ipsum dolor sit amet", + ) + ], + schema=["delim1", "delim2", "lower", "upper", "title", "sentence"], + ) + @classmethod def array1(cls, session: "Session") -> DataFrame: return session.sql(