Skip to content

Commit

Permalink
SNOW-1063716: [Local Testing] Add support for get (#1431)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jrose authored Apr 30, 2024
1 parent 8875ddc commit c470057
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
- Added support for StringType, TimestampType and VariantType data conversion in the mocked function `to_date`.
- Added support for the following APIs:
- snowflake.snowpark.functions
- get
- concat
- concat_ws

#### Bug Fixes

- Fixed a bug that caused NaT and NaN values to not be recognized.


## 1.15.0 (2024-04-24)

### New Features
Expand Down
27 changes: 27 additions & 0 deletions src/snowflake/snowpark/mock/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,33 @@ def mock_current_database():
)


@patch("get")
def mock_get(
column_expression: ColumnEmulator, value_expression: ColumnEmulator
) -> ColumnEmulator:
def get(obj, key):
try:
if isinstance(obj, list):
return obj[key]
elif isinstance(obj, dict):
return obj.get(key, None)
else:
return None
except KeyError:
return None

# pandas.Series.combine does not work here because it will not allow Nones in int columns
result = []
for exp, k in zip(column_expression, value_expression):
result.append(get(exp, k))

return ColumnEmulator(
result,
sf_type=ColumnType(column_expression.sf_type.datatype, True),
dtype=object,
)


@patch("concat")
def mock_concat(*columns: ColumnEmulator) -> ColumnEmulator:
if len(columns) < 1:
Expand Down
9 changes: 5 additions & 4 deletions tests/integ/scala/test_function_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3779,28 +3779,29 @@ def test_get_path(session, v, k):
)


@pytest.mark.localtest
def test_get(session):
Utils.check_answer(
[Row("21"), Row(None)],
TestData.object2(session).select(get(col("obj"), col("k"))),
[Row("21"), Row(None)],
sort=False,
)
Utils.check_answer(
[Row(None), Row(None)],
TestData.object2(session).select(get(col("obj"), lit("AGE"))),
[Row(None), Row(None)],
sort=False,
)

# Same as above, but pass str instead of Column
Utils.check_answer(
[Row("21"), Row(None)],
TestData.object2(session).select(get("obj", "k")),
[Row("21"), Row(None)],
sort=False,
)

Utils.check_answer(
[Row(None), Row(None)],
TestData.object2(session).select(get("obj", lit("AGE"))),
[Row(None), Row(None)],
sort=False,
)

Expand Down

0 comments on commit c470057

Please sign in to comment.