From 0efdc59793b3374d359f8a769fd35c2aaa9a1e3e Mon Sep 17 00:00:00 2001 From: sfc-gh-fgonzalezmendez Date: Wed, 17 Jan 2024 08:29:16 -0600 Subject: [PATCH] Update create_map implementation --- src/snowflake/snowpark/functions.py | 42 ++++++++++++----------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index 28c0675accd..29e681421cd 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -792,24 +792,16 @@ def create_map(*cols: Union[ColumnOrName, Union[List[ColumnOrName], Tuple[Column ----------------------- """ - def pairwise(iterable): - while len(iterable): - a = iterable.pop(0) - b = iterable.pop(0) if len(iterable) else None - yield a, b - - col_names = _flatten_col_list(cols) # flatten any iterables to process them in pairs - has_odd_columns = len(col_names) & 1 + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + + has_odd_columns = len(cols) & 1 if has_odd_columns: raise ValueError( - f"The 'create_map' function requires an even number of parameters but the actual number is {len(col_names)}" + f"The 'create_map' function requires an even number of parameters but the actual number is {len(cols)}" ) - col_list = [] - for name, value in pairwise(col_names): - col_list.append(_to_col_if_str(name, "create_map")) - col_list.append(value) - return object_construct_keep_null(*col_list) + return object_construct_keep_null(*cols) def kurtosis(e: ColumnOrName) -> Column: @@ -2375,8 +2367,18 @@ def struct(*cols: ColumnOrName) -> Column: --------------------- """ + + def flatten_col_list(obj): + if isinstance(obj, str) or isinstance(obj, Column): + return [obj] + elif hasattr(obj, "__iter__"): + acc = [] + for innerObj in obj: + acc = acc + flatten_col_list(innerObj) + return acc + new_cols = [] - for c in _flatten_col_list(cols): + for c in flatten_col_list(cols): # first insert field_name if isinstance(c, str): new_cols.append(lit(c)) @@ -2394,16 +2396,6 @@ def struct(*cols: ColumnOrName) -> Column: return object_construct_keep_null(*new_cols) -def _flatten_col_list(obj): - if isinstance(obj, str) or isinstance(obj, Column): - return [obj] - elif hasattr(obj, "__iter__"): - acc = [] - for innerObj in obj: - acc = acc + _flatten_col_list(innerObj) - return acc - - def log( base: Union[ColumnOrName, int, float], x: Union[ColumnOrName, int, float] ) -> Column: