Skip to content

Commit

Permalink
Update create_map implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-fgonzalezmendez committed Jan 19, 2024
1 parent f7293c9 commit 0efdc59
Showing 1 changed file with 17 additions and 25 deletions.
42 changes: 17 additions & 25 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,24 +792,16 @@ def create_map(*cols: Union[ColumnOrName, Union[List[ColumnOrName], Tuple[Column
-----------------------
<BLANKLINE>
"""
def pairwise(iterable):
while len(iterable):
a = iterable.pop(0)
b = iterable.pop(0) if len(iterable) else None
yield a, b

col_names = _flatten_col_list(cols) # flatten any iterables to process them in pairs
has_odd_columns = len(col_names) & 1
if len(cols) == 1 and isinstance(cols[0], (list, set)):
cols = cols[0]

has_odd_columns = len(cols) & 1
if has_odd_columns:
raise ValueError(
f"The 'create_map' function requires an even number of parameters but the actual number is {len(col_names)}"
f"The 'create_map' function requires an even number of parameters but the actual number is {len(cols)}"
)

col_list = []
for name, value in pairwise(col_names):
col_list.append(_to_col_if_str(name, "create_map"))
col_list.append(value)
return object_construct_keep_null(*col_list)
return object_construct_keep_null(*cols)


def kurtosis(e: ColumnOrName) -> Column:
Expand Down Expand Up @@ -2375,8 +2367,18 @@ def struct(*cols: ColumnOrName) -> Column:
---------------------
<BLANKLINE>
"""

def flatten_col_list(obj):
if isinstance(obj, str) or isinstance(obj, Column):
return [obj]
elif hasattr(obj, "__iter__"):
acc = []
for innerObj in obj:
acc = acc + flatten_col_list(innerObj)
return acc

new_cols = []
for c in _flatten_col_list(cols):
for c in flatten_col_list(cols):
# first insert field_name
if isinstance(c, str):
new_cols.append(lit(c))
Expand All @@ -2394,16 +2396,6 @@ def struct(*cols: ColumnOrName) -> Column:
return object_construct_keep_null(*new_cols)


def _flatten_col_list(obj):
if isinstance(obj, str) or isinstance(obj, Column):
return [obj]
elif hasattr(obj, "__iter__"):
acc = []
for innerObj in obj:
acc = acc + _flatten_col_list(innerObj)
return acc


def log(
base: Union[ColumnOrName, int, float], x: Union[ColumnOrName, int, float]
) -> Column:
Expand Down

0 comments on commit 0efdc59

Please sign in to comment.