diff --git a/CHANGELOG.md b/CHANGELOG.md index c2ecd419059..64d248304bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ #### New Features - Added support for StringType, TimestampType and VariantType data conversion in the mocked function `to_date`. +- Added support for the following APIs: + - snowflake.snowpark.functions + - concat + - concat_ws ## 1.15.0 (2024-04-24) diff --git a/src/snowflake/snowpark/mock/_functions.py b/src/snowflake/snowpark/mock/_functions.py index c3cd0013cc2..2c2eb98e433 100644 --- a/src/snowflake/snowpark/mock/_functions.py +++ b/src/snowflake/snowpark/mock/_functions.py @@ -17,6 +17,7 @@ import pytz import snowflake.snowpark +from snowflake.connector.options import pandas from snowflake.snowpark.exceptions import SnowparkSQLException from snowflake.snowpark.mock._snowflake_data_type import ( ColumnEmulator, @@ -1028,8 +1029,6 @@ def mock_iff(condition: ColumnEmulator, expr1: ColumnEmulator, expr2: ColumnEmul @patch("coalesce") def mock_coalesce(*exprs): - import pandas - if len(exprs) < 2: raise SnowparkSQLException( f"not enough arguments for function [COALESCE], got {len(exprs)}, expected at least two" @@ -1175,8 +1174,6 @@ def mock_to_variant(expr: ColumnEmulator): def _object_construct(exprs, drop_nulls): - import pandas - expr_count = len(exprs) if expr_count % 2 != 0: raise TypeError( @@ -1222,10 +1219,8 @@ def add_years(date, duration): def add_months(scalar, date, duration): - import pandas as pd - res = ( - pd.to_datetime(date) + pd.DateOffset(months=scalar * duration) + pandas.to_datetime(date) + pandas.DateOffset(months=scalar * duration) ).to_pydatetime() if not isinstance(date, datetime.datetime): @@ -1286,8 +1281,6 @@ def mock_date_part(part: str, datetime_expr: ColumnEmulator): SNOW-1183874: Add support for relevant session parameters. https://docs.snowflake.com/en/sql-reference/functions/date_part#usage-notes """ - import pandas - unaliased = unalias_datetime_part(part) datatype = datetime_expr.sf_type.datatype @@ -1365,8 +1358,6 @@ def mock_date_trunc(part: str, datetime_expr: ColumnEmulator) -> ColumnEmulator: SNOW-1183874: Add support for relevant session parameters. https://docs.snowflake.com/en/sql-reference/functions/date_part#usage-notes """ - import pandas - # Map snowflake time unit to pandas rounding alias # Not all units have an alias so handle those with a special case SUPPORTED_UNITS = { @@ -1581,3 +1572,31 @@ def mock_current_database(): return ColumnEmulator( data=session.get_current_database(), sf_type=ColumnType(StringType(), False) ) + + +@patch("concat") +def mock_concat(*columns: ColumnEmulator) -> ColumnEmulator: + if len(columns) < 1: + raise ValueError("concat expects one or more column(s) to be passed in.") + pdf = pandas.concat(columns, axis=1).reset_index(drop=True) + result = pdf.T.apply( + lambda c: None if c.isnull().values.any() else c.astype(str).str.cat() + ) + result.sf_type = ColumnType(StringType(), result.hasnans) + return result + + +@patch("concat_ws") +def mock_concat_ws(*columns: ColumnEmulator) -> ColumnEmulator: + if len(columns) < 2: + raise ValueError( + "concat_ws expects a seperator column and one or more value column(s) to be passed in." + ) + pdf = pandas.concat(columns, axis=1).reset_index(drop=True) + result = pdf.T.apply( + lambda c: None + if c.isnull().values.any() + else c[1:].astype(str).str.cat(sep=c[0]) + ) + result.sf_type = ColumnType(StringType(), result.hasnans) + return result diff --git a/tests/integ/test_function.py b/tests/integ/test_function.py index c31454d4b52..4c468ebc913 100644 --- a/tests/integ/test_function.py +++ b/tests/integ/test_function.py @@ -273,6 +273,7 @@ def test_regexp_extract(session): assert res[0]["RES"] == "30" and res[1]["RES"] == "50" +@pytest.mark.localtest @pytest.mark.parametrize( "col_a, col_b, col_c", [("a", "b", "c"), (col("a"), col("b"), col("c"))] ) @@ -282,15 +283,32 @@ def test_concat(session, col_a, col_b, col_c): assert res[0][0] == "123" +@pytest.mark.localtest @pytest.mark.parametrize( "col_a, col_b, col_c", [("a", "b", "c"), (col("a"), col("b"), col("c"))] ) def test_concat_ws(session, col_a, col_b, col_c): df = session.create_dataframe([["1", "2", "3"]], schema=["a", "b", "c"]) - res = df.select(concat_ws(lit(","), col("a"), col("b"), col("c"))).collect() + res = df.select(concat_ws(lit(","), col_a, col_b, col_c)).collect() assert res[0][0] == "1,2,3" +@pytest.mark.localtest +def test_concat_edge_cases(session): + df = session.create_dataframe( + [[None, 1, 2, 3], [4, None, 6, 7], [8, 9, None, 11], [12, 13, 14, None]] + ).to_df(["a", "b", "c", "d"]) + + single = df.select(concat("a")).collect() + single_ws = df.select(concat_ws(lit(","), "a")).collect() + assert single == single_ws == [Row(None), Row("4"), Row("8"), Row("12")] + + nulls = df.select(concat("a", "b", "c")).collect() + nulls_ws = df.select(concat_ws(lit(","), "a", "b", "c")).collect() + assert nulls == [Row(None), Row(None), Row(None), Row("121314")] + assert nulls_ws == [Row(None), Row(None), Row(None), Row("12,13,14")] + + @pytest.mark.localtest @pytest.mark.parametrize( "col_a",