Skip to content

Commit

Permalink
SNOW-1043119: Add upper/lower/length/initcap support to local testing (
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jrose authored Feb 20, 2024
1 parent 2d88994 commit d0612c9
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 17 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

- Added support for an optional `date_part` argument in function `last_day`
- `SessionBuilder.app_name` will set the query_tag after the session is created.
- Added support for the following local testing functions:
- upper
- lower
- length
- initcap

### Improvements

Expand Down
53 changes: 46 additions & 7 deletions src/snowflake/snowpark/mock/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datetime
import json
import math
import string
from decimal import Decimal
from functools import partial
from numbers import Real
Expand Down Expand Up @@ -843,13 +844,6 @@ def mock_row_number(window: TableEmulator, row_idx: int):
return ColumnEmulator(data=[row_idx + 1], sf_type=ColumnType(LongType(), False))


@patch("upper")
def mock_upper(expr: ColumnEmulator):
res = expr.apply(lambda x: x.upper())
res.sf_type = ColumnType(StringType(), expr.sf_type.nullable)
return res


@patch("parse_json")
def mock_parse_json(expr: ColumnEmulator):
from snowflake.snowpark.mock import CUSTOM_JSON_DECODER
Expand Down Expand Up @@ -929,3 +923,48 @@ def mock_to_variant(expr: ColumnEmulator):
res = expr.copy()
res.sf_type = ColumnType(VariantType(), expr.sf_type.nullable)
return res


@patch("upper")
def mock_upper(expr: ColumnEmulator):
return expr.str.upper()


@patch("lower")
def mock_lower(expr: ColumnEmulator):
return expr.str.lower()


@patch("length")
def mock_length(expr: ColumnEmulator):
result = expr.str.len()
result.sf_type = ColumnType(LongType(), nullable=expr.sf_type.nullable)
return result


# See https://docs.snowflake.com/en/sql-reference/functions/initcap for list of delimiters
DEFAULT_INITCAP_DELIMITERS = set('!?@"^#$&~_,.:;+-*%/|\\[](){}<>' + string.whitespace)


def _initcap(value: Optional[str], delimiters: Optional[str]) -> str:
if value is None:
return None

delims = DEFAULT_INITCAP_DELIMITERS if delimiters is None else set(delimiters)

result = ""
cap = True
for char in value:
if cap:
result += char.upper()
else:
result += char.lower()
cap = char in delims
return result


@patch("initcap")
def mock_initcap(values: ColumnEmulator, delimiters: ColumnEmulator):
result = values.combine(delimiters, _initcap)
result.sf_type = values.sf_type
return result
75 changes: 65 additions & 10 deletions tests/integ/scala/test_function_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
from datetime import date, datetime, time
from decimal import Decimal
from functools import partial

import pytest
import pytz
Expand Down Expand Up @@ -3256,17 +3257,71 @@ def test_ascii(session, col_B):
)


@pytest.mark.parametrize("col_A", ["A", col("A")])
def test_initcap_length_lower_upper(session, col_A):
Utils.check_answer(
TestData.string2(session).select(
initcap(col_A), length(col_A), lower(col_A), upper(col_A)
@pytest.mark.localtest
@pytest.mark.parametrize(
"func,expected",
[
(
initcap,
[
Row(
"Foo-Bar;Baz",
"Qwer,Dvor>Azer",
"Lower",
"Upper",
"Chief Variable Officer",
"Lorem Ipsum Dolor Sit Amet",
)
],
),
[
Row("Asdfg", 5, "asdfg", "ASDFG"),
Row("Qqq", 3, "qqq", "QQQ"),
Row("Qw", 2, "qw", "QW"),
],
(
partial(initcap, delimiters=lit("-")),
[
Row(
"Foo-Bar;baz",
"Qwer,dvor>azer",
"Lower",
"Upper",
"Chief variable officer",
"Lorem ipsum dolor sit amet",
)
],
),
(length, [Row(11, 14, 5, 5, 22, 26)]),
(
lower,
[
Row(
"foo-bar;baz",
"qwer,dvor>azer",
"lower",
"upper",
"chief variable officer",
"lorem ipsum dolor sit amet",
)
],
),
(
upper,
[
Row(
"FOO-BAR;BAZ",
"QWER,DVOR>AZER",
"LOWER",
"UPPER",
"CHIEF VARIABLE OFFICER",
"LOREM IPSUM DOLOR SIT AMET",
)
],
),
],
)
@pytest.mark.parametrize("use_col", [True, False])
def test_initcap_length_lower_upper(func, expected, use_col, session):
df = TestData.string8(session)
Utils.check_answer(
df.select(*[func(col(c) if use_col else c) for c in df.columns]),
expected,
sort=False,
)

Expand Down
16 changes: 16 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,22 @@ def string6(cls, session: "Session") -> DataFrame:
def string7(cls, session: "Session") -> DataFrame:
return session.create_dataframe([["str", 1], [None, 2]], schema=["a", "b"])

@classmethod
def string8(cls, session: "Session") -> DataFrame:
return session.create_dataframe(
[
(
"foo-bar;baz",
"qwer,dvor>azer",
"lower",
"UPPER",
"Chief Variable Officer",
"Lorem ipsum dolor sit amet",
)
],
schema=["delim1", "delim2", "lower", "upper", "title", "sentence"],
)

@classmethod
def array1(cls, session: "Session") -> DataFrame:
return session.sql(
Expand Down

0 comments on commit d0612c9

Please sign in to comment.