Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SCT-7891: Add support for create_map function in snowflake.snowpark.functions #1204

Merged
merged 4 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- `SessionBuilder.getOrCreate` will now attempt to replace the singleton it returns when token expiration has been detected.
- Added support for new function(s) in `snowflake.snowpark.functions`:
- `array_except`
- `create_map`
- `sign`/`signum`

### Bug Fixes
Expand Down
1 change: 1 addition & 0 deletions docs/source/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Functions
count_distinct
covar_pop
covar_samp
create_map
cume_dist
current_available_roles
current_database
Expand Down
54 changes: 54 additions & 0 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,60 @@ def covar_samp(column1: ColumnOrName, column2: ColumnOrName) -> Column:
return builtin("covar_samp")(col1, col2)


def create_map(*cols: Union[ColumnOrName, Iterable[ColumnOrName]]) -> Column:
"""Transforms multiple column pairs into a single map :class:`~snowflake.snowpark.Column` where each pair of
columns is treated as a key-value pair in the resulting map.

Args:
*cols: A variable number of column names or :class:`~snowflake.snowpark.Column` objects that can also be
expressed as a list of columns.
The function expects an even number of arguments, where each pair of arguments represents a key-value
pair for the map.

Returns:
A :class:`~snowflake.snowpark.Column` where each row contains a map created from the provided column pairs.

Example:
>>> from snowflake.snowpark.functions import create_map
>>> df = session.create_dataframe([("Paris", "France"), ("Tokyo", "Japan")], ("city", "country"))
>>> df.select(create_map("city", "country").alias("map")).show()
-----------------------
|"MAP" |
-----------------------
|{ |
| "Paris": "France" |
|} |
|{ |
| "Tokyo": "Japan" |
|} |
-----------------------
<BLANKLINE>

>>> df.select(create_map([df.city, df.country]).alias("map")).show()
-----------------------
|"MAP" |
-----------------------
|{ |
| "Paris": "France" |
|} |
|{ |
| "Tokyo": "Japan" |
|} |
-----------------------
<BLANKLINE>
"""
if len(cols) == 1 and isinstance(cols[0], (list, set)):
cols = cols[0]

has_odd_columns = len(cols) & 1
if has_odd_columns:
raise ValueError(
f"The 'create_map' function requires an even number of parameters but the actual number is {len(cols)}"
)

return object_construct_keep_null(*cols)


def kurtosis(e: ColumnOrName) -> Column:
"""
Returns the population excess kurtosis of non-NULL records. If all records
Expand Down
116 changes: 116 additions & 0 deletions tests/integ/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import decimal
import json
import re
from itertools import chain

import pytest

Expand Down Expand Up @@ -69,6 +70,7 @@
concat_ws,
contains,
count_distinct,
create_map,
current_date,
current_time,
current_timestamp,
Expand Down Expand Up @@ -1643,3 +1645,117 @@ def _result_str2lst(result):
assert (
result_list == expected_result
), f"Unexpected result: {result_list}, expected: {expected_result}"


def test_create_map(session):
df = session.create_dataframe(
[("Sales", 6500, "USA"), ("Legal", 3000, None)],
("department", "salary", "location")
)

# Case 1: create_map with column names
Utils.check_answer(
df.select(create_map("department", "salary").alias("map")),
[
Row(MAP='{\n "Sales": 6500\n}'),
Row(MAP='{\n "Legal": 3000\n}')
],
sort=False,
)

# Case 2: create_map with column objects
Utils.check_answer(
df.select(create_map(df.department, df.salary).alias("map")),
[
Row(MAP='{\n "Sales": 6500\n}'),
Row(MAP='{\n "Legal": 3000\n}')
],
sort=False,
)

# Case 3: create_map with a list of column names
Utils.check_answer(
df.select(create_map(["department", "salary"]).alias("map")),
[
Row(MAP='{\n "Sales": 6500\n}'),
Row(MAP='{\n "Legal": 3000\n}')
],
sort=False,
)

# Case 4: create_map with a list of column objects
Utils.check_answer(
df.select(create_map([df.department, df.salary]).alias("map")),
[
Row(MAP='{\n "Sales": 6500\n}'),
Row(MAP='{\n "Legal": 3000\n}')
],
sort=False,
)

# Case 5: create_map with constant values
Utils.check_answer(
df.select(create_map(lit("department"), col("department"), lit("salary"), col("salary")).alias("map")),
[
Row(MAP='{\n "department": "Sales",\n "salary": 6500\n}'),
Row(MAP='{\n "department": "Legal",\n "salary": 3000\n}')
],
sort=False,
)

# Case 6: create_map with a nested map
Utils.check_answer(
df.select(create_map(col("department"), create_map(lit("salary"), col("salary"))).alias("map")),
[
Row(MAP='{\n "Sales": {\n "salary": 6500\n }\n}'),
Row(MAP='{\n "Legal": {\n "salary": 3000\n }\n}')
],
sort=False,
)

# Case 7: create_map with None values
Utils.check_answer(
df.select(create_map("department", "location").alias("map")),
[
Row(MAP='{\n "Sales": "USA"\n}'),
Row(MAP='{\n "Legal": null\n}')
],
sort=False,
)

# Case 8: create_map dynamic creation
Utils.check_answer(
df.select(create_map(list(chain(*((lit(name), col(name)) for name in df.columns)))).alias("map")),
[
Row(MAP='{\n "DEPARTMENT": "Sales",\n "LOCATION": "USA",\n "SALARY": 6500\n}'),
Row(MAP='{\n "DEPARTMENT": "Legal",\n "LOCATION": null,\n "SALARY": 3000\n}')
],
sort=False,
)

# Case 9: create_map without columns
Utils.check_answer(
df.select(create_map().alias("map")),
[
Row(MAP='{}'),
Row(MAP='{}')
],
sort=False,
)


def test_create_map_negative(session):
df = session.create_dataframe(
[("Sales", 6500, "USA"), ("Legal", 3000, None)],
("department", "salary", "location")
)

# Case 1: create_map with odd number of columns
with pytest.raises(ValueError) as ex_info:
df.select(create_map("department").alias("map"))
assert "The 'create_map' function requires an even number of parameters but the actual number is 1" in str(ex_info)

# Case 2: create_map with odd number of columns (list)
with pytest.raises(ValueError) as ex_info:
df.select(create_map([df.department, df.salary, df.location]).alias("map"))
assert "The 'create_map' function requires an even number of parameters but the actual number is 3" in str(ex_info)
Loading