Skip to content

Commit

Permalink
SNOW-1063723: [Local Testing] Add support for concat and concat_ws (#…
Browse files Browse the repository at this point in the history
…1413)

This PR adds support for the concat and concat_ws functions. I also
imported pandas from the connector instead of in each individual mock
function. This pandas import only works if pandas is installed and will
throw an error if used without pandas installed.
  • Loading branch information
sfc-gh-jrose authored Apr 26, 2024
1 parent 40157c5 commit cdeeb00
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 12 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
#### New Features

- Added support for StringType, TimestampType and VariantType data conversion in the mocked function `to_date`.
- Added support for the following APIs:
- snowflake.snowpark.functions
- concat
- concat_ws


## 1.15.0 (2024-04-24)
Expand Down
41 changes: 30 additions & 11 deletions src/snowflake/snowpark/mock/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytz

import snowflake.snowpark
from snowflake.connector.options import pandas
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.mock._snowflake_data_type import (
ColumnEmulator,
Expand Down Expand Up @@ -1028,8 +1029,6 @@ def mock_iff(condition: ColumnEmulator, expr1: ColumnEmulator, expr2: ColumnEmul

@patch("coalesce")
def mock_coalesce(*exprs):
import pandas

if len(exprs) < 2:
raise SnowparkSQLException(
f"not enough arguments for function [COALESCE], got {len(exprs)}, expected at least two"
Expand Down Expand Up @@ -1175,8 +1174,6 @@ def mock_to_variant(expr: ColumnEmulator):


def _object_construct(exprs, drop_nulls):
import pandas

expr_count = len(exprs)
if expr_count % 2 != 0:
raise TypeError(
Expand Down Expand Up @@ -1222,10 +1219,8 @@ def add_years(date, duration):


def add_months(scalar, date, duration):
import pandas as pd

res = (
pd.to_datetime(date) + pd.DateOffset(months=scalar * duration)
pandas.to_datetime(date) + pandas.DateOffset(months=scalar * duration)
).to_pydatetime()

if not isinstance(date, datetime.datetime):
Expand Down Expand Up @@ -1286,8 +1281,6 @@ def mock_date_part(part: str, datetime_expr: ColumnEmulator):
SNOW-1183874: Add support for relevant session parameters.
https://docs.snowflake.com/en/sql-reference/functions/date_part#usage-notes
"""
import pandas

unaliased = unalias_datetime_part(part)
datatype = datetime_expr.sf_type.datatype

Expand Down Expand Up @@ -1365,8 +1358,6 @@ def mock_date_trunc(part: str, datetime_expr: ColumnEmulator) -> ColumnEmulator:
SNOW-1183874: Add support for relevant session parameters.
https://docs.snowflake.com/en/sql-reference/functions/date_part#usage-notes
"""
import pandas

# Map snowflake time unit to pandas rounding alias
# Not all units have an alias so handle those with a special case
SUPPORTED_UNITS = {
Expand Down Expand Up @@ -1581,3 +1572,31 @@ def mock_current_database():
return ColumnEmulator(
data=session.get_current_database(), sf_type=ColumnType(StringType(), False)
)


@patch("concat")
def mock_concat(*columns: ColumnEmulator) -> ColumnEmulator:
if len(columns) < 1:
raise ValueError("concat expects one or more column(s) to be passed in.")
pdf = pandas.concat(columns, axis=1).reset_index(drop=True)
result = pdf.T.apply(
lambda c: None if c.isnull().values.any() else c.astype(str).str.cat()
)
result.sf_type = ColumnType(StringType(), result.hasnans)
return result


@patch("concat_ws")
def mock_concat_ws(*columns: ColumnEmulator) -> ColumnEmulator:
if len(columns) < 2:
raise ValueError(
"concat_ws expects a seperator column and one or more value column(s) to be passed in."
)
pdf = pandas.concat(columns, axis=1).reset_index(drop=True)
result = pdf.T.apply(
lambda c: None
if c.isnull().values.any()
else c[1:].astype(str).str.cat(sep=c[0])
)
result.sf_type = ColumnType(StringType(), result.hasnans)
return result
20 changes: 19 additions & 1 deletion tests/integ/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def test_regexp_extract(session):
assert res[0]["RES"] == "30" and res[1]["RES"] == "50"


@pytest.mark.localtest
@pytest.mark.parametrize(
"col_a, col_b, col_c", [("a", "b", "c"), (col("a"), col("b"), col("c"))]
)
Expand All @@ -282,15 +283,32 @@ def test_concat(session, col_a, col_b, col_c):
assert res[0][0] == "123"


@pytest.mark.localtest
@pytest.mark.parametrize(
"col_a, col_b, col_c", [("a", "b", "c"), (col("a"), col("b"), col("c"))]
)
def test_concat_ws(session, col_a, col_b, col_c):
df = session.create_dataframe([["1", "2", "3"]], schema=["a", "b", "c"])
res = df.select(concat_ws(lit(","), col("a"), col("b"), col("c"))).collect()
res = df.select(concat_ws(lit(","), col_a, col_b, col_c)).collect()
assert res[0][0] == "1,2,3"


@pytest.mark.localtest
def test_concat_edge_cases(session):
df = session.create_dataframe(
[[None, 1, 2, 3], [4, None, 6, 7], [8, 9, None, 11], [12, 13, 14, None]]
).to_df(["a", "b", "c", "d"])

single = df.select(concat("a")).collect()
single_ws = df.select(concat_ws(lit(","), "a")).collect()
assert single == single_ws == [Row(None), Row("4"), Row("8"), Row("12")]

nulls = df.select(concat("a", "b", "c")).collect()
nulls_ws = df.select(concat_ws(lit(","), "a", "b", "c")).collect()
assert nulls == [Row(None), Row(None), Row(None), Row("121314")]
assert nulls_ws == [Row(None), Row(None), Row(None), Row("12,13,14")]


@pytest.mark.localtest
@pytest.mark.parametrize(
"col_a",
Expand Down

0 comments on commit cdeeb00

Please sign in to comment.