diff --git a/CHANGELOG.md b/CHANGELOG.md index 22899bacf78..925e5435e35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Release History +## 1.13.0 (TBD) + +### New Features + +- Added support for an optional `date_part` argument in function `last_day` +- `SessionBuilder.app_name` will set the query_tag after the session is created. + +### Bug Fixes + +- Fixed a bug in `DataFrame.to_local_iterator` where the iterator could yield wrong results if another query is executed before the iterator finishes due to wrong isolation level. For details, please see #945. +- Fixed a bug that truncated table names in error messages while running a plan with local testing enabled. +- Fixed a bug that `Session.range` returns empty result when the range is large. + ## 1.12.1 (2024-02-08) ### Improvements @@ -28,7 +41,9 @@ - `sign`/`signum` - Added the following functions to `DataFrame.analytics`: - Added the `moving_agg` function in `DataFrame.analytics` to enable moving aggregations like sums and averages with multiple window sizes. - - Added the `cummulative_agg` function in `DataFrame.analytics` to enable moving aggregations like sums and averages with multiple window sizes. + - Added the `cummulative_agg` function in `DataFrame.analytics` to enable commulative aggregations like sums and averages on multiple columns. + - Added the `compute_lag` and `compute_lead` functions in `DataFrame.analytics` for enabling lead and lag calculations on multiple columns. + - Added the `time_series_agg` function in `DataFrame.analytics` to enable time series aggregations like sums and averages with multiple time windows. ### Bug Fixes diff --git a/docs/source/_templates/autosummary/accessor_method.rst b/docs/source/_templates/autosummary/accessor_method.rst new file mode 100644 index 00000000000..96dec5e8a99 --- /dev/null +++ b/docs/source/_templates/autosummary/accessor_method.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module + "." + objname.split(".")[0] }} + +.. automethod:: {{ ".".join(objname.split(".")[1:]) }} \ No newline at end of file diff --git a/docs/source/session.rst b/docs/source/session.rst index 468aa7a4a34..9ee2521dcf7 100644 --- a/docs/source/session.rst +++ b/docs/source/session.rst @@ -11,6 +11,19 @@ Snowpark Session Session + +.. rubric:: SessionBuilder + +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Session.SessionBuilder.app_name + Session.SessionBuilder.config + Session.SessionBuilder.configs + Session.SessionBuilder.create + Session.SessionBuilder.getOrCreate + .. rubric:: Methods .. diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 2b7a11f4d02..8ef0e7e2103 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -448,7 +448,7 @@ def sort_statement(order: List[str], child: str) -> str: def range_statement(start: int, end: int, step: int, column_name: str) -> str: range = end - start - if range * step < 0: + if (range > 0 > step) or (range < 0 < step): count = 0 else: count = math.ceil(range / step) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 37df1c8cf7e..33d19e2a34f 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -434,8 +434,16 @@ def _to_data_or_iter( results_cursor: SnowflakeCursor, to_pandas: bool = False, to_iter: bool = False, - num_statements: Optional[int] = None, ) -> Dict[str, Any]: + if ( + to_iter and not to_pandas + ): # Fix for SNOW-869536, to_pandas doesn't have this issue, SnowflakeCursor.fetch_pandas_batches already handles the isolation. + new_cursor = results_cursor.connection.cursor() + new_cursor.execute( + f"SELECT * FROM TABLE(RESULT_SCAN('{results_cursor.sfqid}'))" + ) + results_cursor = new_cursor + if to_pandas: try: data_or_iter = ( diff --git a/src/snowflake/snowpark/dataframe_analytics_functions.py b/src/snowflake/snowpark/dataframe_analytics_functions.py index 096314f9ece..ec47a985fc7 100644 --- a/src/snowflake/snowpark/dataframe_analytics_functions.py +++ b/src/snowflake/snowpark/dataframe_analytics_functions.py @@ -2,12 +2,30 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Tuple, Union import snowflake.snowpark -from snowflake.snowpark.functions import expr +from snowflake.snowpark._internal.utils import experimental +from snowflake.snowpark.column import Column, _to_col_if_str +from snowflake.snowpark.functions import ( + add_months, + col, + dateadd, + expr, + from_unixtime, + lag, + lead, + lit, + months_between, + to_timestamp, + unix_timestamp, + year, +) from snowflake.snowpark.window import Window +# "s" (seconds), "m" (minutes), "h" (hours), "d" (days), "w" (weeks), "mm" (months), "y" (years) +SUPPORTED_TIME_UNITS = ["s", "m", "h", "d", "w", "mm", "y"] + class DataFrameAnalyticsFunctions: """Provides data analytics functions for DataFrames. @@ -85,6 +103,167 @@ def _validate_formatter_argument(self, fromatter): if not callable(fromatter): raise TypeError("formatter must be a callable function") + def _compute_window_function( + self, + cols: List[Union[str, Column]], + periods: List[int], + order_by: List[str], + group_by: List[str], + col_formatter: Callable[[str, str, int], str], + window_func: Callable[[Column, int], Column], + func_name: str, + ) -> "snowflake.snowpark.dataframe.DataFrame": + """ + Generic function to create window function columns (lag or lead) for the DataFrame. + Args: + func_name: Should be either "LEAD" or "LAG". + """ + self._validate_string_list_argument(order_by, "order_by") + self._validate_string_list_argument(group_by, "group_by") + self._validate_positive_integer_list_argument(periods, func_name.lower() + "s") + self._validate_formatter_argument(col_formatter) + + window_spec = Window.partition_by(group_by).order_by(order_by) + df = self._df + col_names = [] + values = [] + for c in cols: + for period in periods: + column = _to_col_if_str(c, f"transform.compute_{func_name.lower()}") + window_col = window_func(column, period).over(window_spec) + formatted_col_name = col_formatter( + column.get_name().replace('"', ""), func_name, period + ) + col_names.append(formatted_col_name) + values.append(window_col) + + return df.with_columns(col_names, values) + + def _parse_time_string(self, time_str: str) -> Tuple[int, str]: + index = len(time_str) + for i, char in enumerate(time_str): + if not char.isdigit() and char not in ["+", "-"]: + index = i + break + + duration = int(time_str[:index]) + unit = time_str[index:].lower() + + return duration, unit + + def _validate_and_extract_time_unit( + self, time_str: str, argument_name: str, allow_negative: bool = True + ) -> Tuple[int, str]: + argument_requirements = ( + f"The '{argument_name}' argument must adhere to the following criteria: " + "1) It must not be an empty string. " + "2) The last character must be a supported time unit. " + f"Supported units are '{', '.join(SUPPORTED_TIME_UNITS)}'. " + "3) The preceding characters must represent an integer. " + "4) The integer must not be negative if allow_negative is False." + ) + if not time_str: + raise ValueError( + f"{argument_name} must not be empty. {argument_requirements}" + ) + + duration, unit = self._parse_time_string(time_str) + + if not allow_negative and duration < 0: + raise ValueError( + f"{argument_name} must not be negative. {argument_requirements}" + ) + + if unit not in SUPPORTED_TIME_UNITS: + raise ValueError( + f"Unsupported unit '{unit}'. Supported units are '{SUPPORTED_TIME_UNITS}. {argument_requirements}" + ) + + return duration, unit + + def _get_sliding_interval_start( + self, time_col: Column, unit: str, duration: int + ) -> Column: + unit_seconds = { + "s": 1, # seconds + "m": 60, # minutes + "h": 3600, # hours + "d": 86400, # days + "w": 604800, # weeks + } + + if unit == "mm": + base_date = lit("1970-01-01").cast("date") + months_since_base = months_between(time_col, base_date) + current_window_start_month = (months_since_base / duration).cast( + "long" + ) * duration + return to_timestamp(add_months(base_date, current_window_start_month)) + + elif unit == "y": + base_date = lit("1970-01-01").cast("date") + years_since_base = year(time_col) - year(base_date) + current_window_start_year = (years_since_base / duration).cast( + "long" + ) * duration + return to_timestamp(add_months(base_date, current_window_start_year * 12)) + + elif unit in unit_seconds: + # Handle seconds, minutes, hours, days, weeks + interval_seconds = unit_seconds[unit] * duration + return from_unixtime( + (unix_timestamp(time_col) / interval_seconds).cast("long") + * interval_seconds + ) + + else: + raise ValueError( + "Invalid unit. Supported units are 'S', 'M', 'H', 'D', 'W', 'MM', 'Y'." + ) + + def _perform_window_aggregations( + self, + base_df: "snowflake.snowpark.dataframe.DataFrame", + input_df: "snowflake.snowpark.dataframe.DataFrame", + aggs: Dict[str, List[str]], + group_by_cols: List[str], + col_formatter: Callable[[str, str, str], str] = None, + window: str = None, + rename_suffix: str = "", + ) -> "snowflake.snowpark.dataframe.DataFrame": + """ + Perform window-based aggregations on the given DataFrame. + + This function applies specified aggregation functions to columns of an input DataFrame, + grouped by specified columns, and joins the results back to a base DataFrame. + + Parameters: + - base_df: DataFrame to which the aggregated results will be joined. + - input_df: DataFrame on which aggregations are to be performed. + - aggs: A dictionary where keys are column names and values are lists of aggregation functions. + - group_by_cols: List of column names to group by. + - col_formatter: Optional callable to format column names of aggregated results. + - window: Optional window specification for aggregations. + - rename_suffix: Optional suffix to append to column names. + + Returns: + - DataFrame with the aggregated data joined to the base DataFrame. + """ + for column, funcs in aggs.items(): + for func in funcs: + agg_column_name = ( + col_formatter(column, func, window) + if col_formatter + else f"{column}_{func}{rename_suffix}" + ) + agg_expression = expr(f"{func}({column}{rename_suffix})").alias( + agg_column_name + ) + agg_df = input_df.group_by(group_by_cols).agg(agg_expression) + base_df = base_df.join(agg_df, on=group_by_cols, how="left") + + return base_df + def moving_agg( self, aggs: Dict[str, List[str]], @@ -248,3 +427,252 @@ def cumulative_agg( agg_df = agg_df.with_column(formatted_col_name, agg_col) return agg_df + + def compute_lag( + self, + cols: List[Union[str, Column]], + lags: List[int], + order_by: List[str], + group_by: List[str], + col_formatter: Callable[[str, str, int], str] = _default_col_formatter, + ) -> "snowflake.snowpark.dataframe.DataFrame": + """ + Creates lag columns to the specified columns of the DataFrame by grouping and ordering criteria. + + Args: + cols: List of column names or Column objects to calculate lag features. + lags: List of positive integers specifying periods to lag by. + order_by: A list of column names that specify the order in which rows are processed. + group_by: A list of column names on which the DataFrame is partitioned for separate window calculations. + col_formatter: An optional function for formatting output column names, defaulting to the format 'LAG'. + This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation, + and 'value' (int) for lag value, and returns a formatted string for the column name. + + Returns: + A Snowflake DataFrame with additional columns corresponding to each specified lag period. + + Example: + >>> sample_data = [ + ... ["2023-01-01", 101, 200], + ... ["2023-01-02", 101, 100], + ... ["2023-01-03", 101, 300], + ... ["2023-01-04", 102, 250], + ... ] + >>> df = session.create_dataframe(sample_data).to_df( + ... "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT" + ... ) + >>> res = df.analytics.compute_lag( + ... cols=["SALESAMOUNT"], + ... lags=[1, 2], + ... order_by=["ORDERDATE"], + ... group_by=["PRODUCTKEY"], + ... ) + >>> res.show() + ------------------------------------------------------------------------------------------ + |"ORDERDATE" |"PRODUCTKEY" |"SALESAMOUNT" |"SALESAMOUNT_LAG_1" |"SALESAMOUNT_LAG_2" | + ------------------------------------------------------------------------------------------ + |2023-01-04 |102 |250 |NULL |NULL | + |2023-01-01 |101 |200 |NULL |NULL | + |2023-01-02 |101 |100 |200 |NULL | + |2023-01-03 |101 |300 |100 |200 | + ------------------------------------------------------------------------------------------ + + """ + return self._compute_window_function( + cols, lags, order_by, group_by, col_formatter, lag, "LAG" + ) + + def compute_lead( + self, + cols: List[Union[str, Column]], + leads: List[int], + order_by: List[str], + group_by: List[str], + col_formatter: Callable[[str, str, int], str] = _default_col_formatter, + ) -> "snowflake.snowpark.dataframe.DataFrame": + """ + Creates lead columns to the specified columns of the DataFrame by grouping and ordering criteria. + + Args: + cols: List of column names or Column objects to calculate lead features. + leads: List of positive integers specifying periods to lead by. + order_by: A list of column names that specify the order in which rows are processed. + group_by: A list of column names on which the DataFrame is partitioned for separate window calculations. + col_formatter: An optional function for formatting output column names, defaulting to the format 'LEAD'. + This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation, + and 'value' (int) for the lead value, and returns a formatted string for the column name. + + Returns: + A Snowflake DataFrame with additional columns corresponding to each specified lead period. + + Example: + >>> sample_data = [ + ... ["2023-01-01", 101, 200], + ... ["2023-01-02", 101, 100], + ... ["2023-01-03", 101, 300], + ... ["2023-01-04", 102, 250], + ... ] + >>> df = session.create_dataframe(sample_data).to_df( + ... "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT" + ... ) + >>> res = df.analytics.compute_lead( + ... cols=["SALESAMOUNT"], + ... leads=[1, 2], + ... order_by=["ORDERDATE"], + ... group_by=["PRODUCTKEY"] + ... ) + >>> res.show() + -------------------------------------------------------------------------------------------- + |"ORDERDATE" |"PRODUCTKEY" |"SALESAMOUNT" |"SALESAMOUNT_LEAD_1" |"SALESAMOUNT_LEAD_2" | + -------------------------------------------------------------------------------------------- + |2023-01-04 |102 |250 |NULL |NULL | + |2023-01-01 |101 |200 |100 |300 | + |2023-01-02 |101 |100 |300 |NULL | + |2023-01-03 |101 |300 |NULL |NULL | + -------------------------------------------------------------------------------------------- + + """ + return self._compute_window_function( + cols, leads, order_by, group_by, col_formatter, lead, "LEAD" + ) + + @experimental(version="1.12.0") + def time_series_agg( + self, + time_col: str, + aggs: Dict[str, List[str]], + windows: List[str], + group_by: List[str], + sliding_interval: str, + col_formatter: Callable[[str, str, int], str] = _default_col_formatter, + ) -> "snowflake.snowpark.dataframe.DataFrame": + """ + Applies aggregations to the specified columns of the DataFrame over specified time windows, + and grouping criteria. + + Args: + aggs: A dictionary where keys are column names and values are lists of the desired aggregation functions. + windows: Time windows for aggregations using strings such as '7D' for 7 days, where the units are + S: Seconds, M: Minutes, H: Hours, D: Days, W: Weeks, MM: Months, Y: Years. For future-oriented analysis, use positive numbers, + and for past-oriented analysis, use negative numbers. + sliding_interval: Interval at which the window slides, specified in the same format as the windows. + group_by: A list of column names on which the DataFrame is partitioned for separate window calculations. + col_formatter: An optional function for formatting output column names, defaulting to the format '__'. + This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation, + and 'value' (int) for the window size, and returns a formatted string for the column name. + + Returns: + A Snowflake DataFrame with additional columns corresponding to each specified time window aggregation. + + Raises: + ValueError: If an unsupported value is specified in arguments. + TypeError: If an unsupported type is specified in arguments. + SnowparkSQLException: If an unsupported aggregration is specified. + + Example: + >>> sample_data = [ + ... ["2023-01-01", 101, 200], + ... ["2023-01-02", 101, 100], + ... ["2023-01-03", 101, 300], + ... ["2023-01-04", 102, 250], + ... ] + >>> df = session.create_dataframe(sample_data).to_df( + ... "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT" + ... ) + >>> df = df.with_column("ORDERDATE", to_timestamp(df["ORDERDATE"])) + >>> def custom_formatter(input_col, agg, window): + ... return f"{agg}_{input_col}_{window}" + >>> res = df.analytics.time_series_agg( + ... time_col="ORDERDATE", + ... group_by=["PRODUCTKEY"], + ... aggs={"SALESAMOUNT": ["SUM", "MAX"]}, + ... windows=["1D", "-1D"], + ... sliding_interval="12H", + ... col_formatter=custom_formatter, + ... ) + >>> res.show() + -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + |"PRODUCTKEY" |"SLIDING_POINT" |"SALESAMOUNT" |"ORDERDATE" |"SUM_SALESAMOUNT_1D" |"MAX_SALESAMOUNT_1D" |"SUM_SALESAMOUNT_-1D" |"MAX_SALESAMOUNT_-1D" | + -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + |101 |2023-01-01 00:00:00 |200 |2023-01-01 00:00:00 |300 |200 |200 |200 | + |101 |2023-01-02 00:00:00 |100 |2023-01-02 00:00:00 |400 |300 |300 |200 | + |101 |2023-01-03 00:00:00 |300 |2023-01-03 00:00:00 |300 |300 |400 |300 | + |102 |2023-01-04 00:00:00 |250 |2023-01-04 00:00:00 |250 |250 |250 |250 | + -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + + """ + self._validate_aggs_argument(aggs) + self._validate_string_list_argument(group_by, "group_by") + self._validate_formatter_argument(col_formatter) + + if not windows: + raise ValueError("windows must not be empty") + + if not sliding_interval: + raise ValueError("sliding_interval must not be empty") + + if not time_col or not isinstance(time_col, str): + raise ValueError("time_col must be a string") + + slide_duration, slide_unit = self._validate_and_extract_time_unit( + sliding_interval, "sliding_interval", allow_negative=False + ) + sliding_point_col = "sliding_point" + + agg_df = self._df + agg_df = agg_df.with_column( + sliding_point_col, + self._get_sliding_interval_start(time_col, slide_unit, slide_duration), + ) + + # Perform aggregations at sliding interval granularity. + group_by_cols = group_by + [sliding_point_col] + sliding_windows_df = self._perform_window_aggregations( + agg_df, agg_df, aggs, group_by_cols + ) + + # Perform aggregations at window intervals. + result_df = agg_df + for window in windows: + window_duration, window_unit = self._validate_and_extract_time_unit( + window, "window" + ) + # Perform self-join on DataFrame for aggregation within each group and time window. + left_df = sliding_windows_df.alias("A") + right_df = sliding_windows_df.alias("B") + + for column in right_df.columns: + if column not in group_by: + right_df = right_df.with_column_renamed(column, f"{column}B") + + self_joined_df = left_df.join(right_df, on=group_by, how="leftouter") + + window_frame = dateadd( + window_unit, lit(window_duration), f"{sliding_point_col}" + ) + + if window_duration > 0: # Future window + window_start = col(f"{sliding_point_col}") + window_end = window_frame + else: # Past window + window_start = window_frame + window_end = col(f"{sliding_point_col}") + + # Filter rows to include only those within the specified time window for aggregation. + self_joined_df = self_joined_df.filter( + col(f"{sliding_point_col}B") >= window_start + ).filter(col(f"{sliding_point_col}B") <= window_end) + + # Peform final aggregations. + group_by_cols = group_by + [sliding_point_col] + result_df = self._perform_window_aggregations( + result_df, + self_joined_df, + aggs, + group_by_cols, + col_formatter, + window, + rename_suffix="B", + ) + + return result_df diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index b97ca7178cf..459ea886ef6 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -3244,11 +3244,16 @@ def hour(e: ColumnOrName) -> Column: return builtin("hour")(c) -def last_day(e: ColumnOrName) -> Column: +def last_day(expr: ColumnOrName, part: Optional[ColumnOrName] = None) -> Column: """ Returns the last day of the specified date part for a date or timestamp. Commonly used to return the last day of the month for a date or timestamp. + Args: + expr: The array column + part: The date part used to compute the last day of the given array column, default is "MONTH". + Valid values are "YEAR", "MONTH", "QUARTER", "WEEK" or any of their supported variations. + Example:: >>> import datetime @@ -3258,9 +3263,16 @@ def last_day(e: ColumnOrName) -> Column: ... ], schema=["a"]) >>> df.select(last_day("a")).collect() [Row(LAST_DAY("A")=datetime.date(2020, 5, 31)), Row(LAST_DAY("A")=datetime.date(2020, 8, 31))] + >>> df.select(last_day("a", "YEAR")).collect() + [Row(LAST_DAY("A", "YEAR")=datetime.date(2020, 12, 31)), Row(LAST_DAY("A", "YEAR")=datetime.date(2020, 12, 31))] """ - c = _to_col_if_str(e, "last_day") - return builtin("last_day")(c) + expr_col = _to_col_if_str(expr, "last_day") + if part is None: + # Ensure we do not change the column name + return builtin("last_day")(expr_col) + + part_col = _to_col_if_str(part, "last_day") + return builtin("last_day")(expr_col, part_col) def minute(e: ColumnOrName) -> Column: diff --git a/src/snowflake/snowpark/mock/_plan.py b/src/snowflake/snowpark/mock/_plan.py index f49419255b7..c882a26e852 100644 --- a/src/snowflake/snowpark/mock/_plan.py +++ b/src/snowflake/snowpark/mock/_plan.py @@ -553,8 +553,9 @@ def execute_mock_plan( return res_df else: db_schme_table = parse_table_name(entity_name) + table = ".".join([part.strip("\"'") for part in db_schme_table[:3]]) raise SnowparkSQLException( - f"Object '{db_schme_table[0][1:-1]}.{db_schme_table[1][1:-1]}.{db_schme_table[2][1:-1]}' does not exist or not authorized." + f"Object '{table}' does not exist or not authorized." ) if isinstance(source_plan, Aggregate): child_rf = execute_mock_plan(source_plan.child) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index db90f0b0b6a..5d4e9a90051 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -300,12 +300,20 @@ class SessionBuilder: def __init__(self) -> None: self._options = {} + self._app_name = None def _remove_config(self, key: str) -> "Session.SessionBuilder": """Only used in test.""" self._options.pop(key, None) return self + def app_name(self, app_name: str) -> "Session.SessionBuilder": + """ + Adds the app name to the :class:`SessionBuilder` to set in the query_tag after session creation + """ + self._app_name = app_name + return self + def config(self, key: str, value: Union[int, str]) -> "Session.SessionBuilder": """ Adds the specified connection parameter to the :class:`SessionBuilder` configuration. @@ -334,6 +342,11 @@ def create(self) -> "Session": _add_session(session) else: session = self._create_internal(self._options.get("connection")) + + if self._app_name: + app_name_tag = f'APPNAME={self._app_name}' + session.append_query_tag(app_name_tag) + return session def getOrCreate(self) -> "Session": @@ -1627,7 +1640,7 @@ def _get_remote_query_tag(self) -> None: def append_query_tag(self, tag: str, separator: str = ",") -> None: """ - Appends a tag to the current query tag. The input tag is appended to the current sessions query tag with the given sperator. + Appends a tag to the current query tag. The input tag is appended to the current sessions query tag with the given separator. Args: tag: The tag to append to the current query tag. diff --git a/tests/integ/scala/test_dataframe_range_suite.py b/tests/integ/scala/test_dataframe_range_suite.py index bc45bc7aaf4..63341da8ba7 100644 --- a/tests/integ/scala/test_dataframe_range_suite.py +++ b/tests/integ/scala/test_dataframe_range_suite.py @@ -10,6 +10,7 @@ from snowflake.snowpark import Row from snowflake.snowpark.functions import col, count, sum as sum_ +from tests.integ.test_packaging import is_pandas_and_numpy_available @pytest.mark.localtest @@ -108,3 +109,13 @@ def test_range_with_max_and_min(session): end = MIN_VALUE + 2 assert session.range(start, end, 1).collect() == [] assert session.range(start, start, 1).collect() == [] + + +@pytest.mark.skipif(not is_pandas_and_numpy_available, reason="requires numpy") +@pytest.mark.localtest +def test_range_with_large_range_and_step(session): + import numpy as np + + ints = np.array([691200000000000], dtype="int64") + # Use a numpy int64 range with a python int step + assert session.range(0, ints[0], 86400000000000).collect() != [] diff --git a/tests/integ/scala/test_update_delete_merge_suite.py b/tests/integ/scala/test_update_delete_merge_suite.py index cc7b9f38ce1..8bb4cad6556 100644 --- a/tests/integ/scala/test_update_delete_merge_suite.py +++ b/tests/integ/scala/test_update_delete_merge_suite.py @@ -27,7 +27,7 @@ when_not_matched, ) from snowflake.snowpark.types import IntegerType, StructField, StructType -from tests.utils import TestData, Utils +from tests.utils import IS_IN_STORED_PROC, TestData, Utils table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) table_name2 = Utils.random_name_for_temp_object(TempObjectType.TABLE) @@ -74,6 +74,10 @@ def test_update_rows_in_table(session): assert "condition should also be provided if source is provided" in str(ex_info) +@pytest.mark.skipif( + IS_IN_STORED_PROC, + reason="Cannot alter session in SP", +) def test_update_rows_nondeterministic_update(session): TestData.test_data2(session).write.save_as_table( table_name, mode="overwrite", table_type="temporary" diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index b573b3c18a8..36e1f14f121 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -3673,3 +3673,20 @@ def test_dataframe_interval_operation(session): ), ], ) + + +def test_dataframe_to_local_iterator_isolation(session): + ROW_NUMBER = 10 + df = session.create_dataframe( + [[1, 2, 3] for _ in range(ROW_NUMBER)], schema=["a", "b", "c"] + ) + my_iter = df.to_local_iterator() + row_counter = 0 + for _ in my_iter: + len(df.schema.fields) # this executes a schema query internally + row_counter += 1 + + # my_iter should be iterating on df.collect()'s query's results, not the schema query (1 row) + assert ( + row_counter == ROW_NUMBER + ), f"Expect {ROW_NUMBER} rows, Got {row_counter} instead" diff --git a/tests/integ/test_df_analytics.py b/tests/integ/test_df_analytics.py index 9d6b6ad1754..983d651b0b2 100644 --- a/tests/integ/test_df_analytics.py +++ b/tests/integ/test_df_analytics.py @@ -13,7 +13,9 @@ import pytest +from snowflake.snowpark.dataframe_analytics_functions import DataFrameAnalyticsFunctions from snowflake.snowpark.exceptions import SnowparkSQLException +from snowflake.snowpark.functions import col, to_timestamp def get_sample_dataframe(session): @@ -242,7 +244,7 @@ def bad_formatter(input_col, agg): @pytest.mark.skipif(not is_pandas_available, reason="pandas is required") def test_cumulative_agg_forward_direction(session): - """Tests df.transform.cumulative_agg() with forward direction for cumulative calculations.""" + """Tests df.analytics.cumulative_agg() with forward direction for cumulative calculations.""" df = get_sample_dataframe(session) @@ -278,7 +280,7 @@ def custom_formatter(input_col, agg): @pytest.mark.skipif(not is_pandas_available, reason="pandas is required") def test_cumulative_agg_backward_direction(session): - """Tests df.transform.cumulative_agg() with backward direction for cumulative calculations.""" + """Tests df.analytics.cumulative_agg() with backward direction for cumulative calculations.""" df = get_sample_dataframe(session) @@ -310,3 +312,355 @@ def custom_formatter(input_col, agg): check_dtype=False, atol=1e-1, ) + + +@pytest.mark.skipif(not is_pandas_available, reason="pandas is required") +def test_compute_lead(session): + """Tests df.analytics.compute_lead() happy path.""" + + df = get_sample_dataframe(session) + + def custom_col_formatter(input_col, op, lead): + return f"{op}_{input_col}_{lead}" + + res = df.analytics.compute_lead( + cols=["SALESAMOUNT"], + leads=[1, 2], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + col_formatter=custom_col_formatter, + ) + + res = res.withColumn("LEAD_SALESAMOUNT_1", col("LEAD_SALESAMOUNT_1").cast("float")) + res = res.withColumn("LEAD_SALESAMOUNT_2", col("LEAD_SALESAMOUNT_2").cast("float")) + + expected_data = { + "ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"], + "PRODUCTKEY": [101, 101, 101, 102], + "SALESAMOUNT": [200, 100, 300, 250], + "LEAD_SALESAMOUNT_1": [100, 300, None, None], + "LEAD_SALESAMOUNT_2": [300, None, None, None], + } + expected_df = pd.DataFrame(expected_data) + + assert_frame_equal( + res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1 + ) + + +@pytest.mark.skipif(not is_pandas_available, reason="pandas is required") +def test_compute_lag(session): + """Tests df.analytics.compute_lag() happy path.""" + + df = get_sample_dataframe(session) + + def custom_col_formatter(input_col, op, lead): + return f"{op}_{input_col}_{lead}" + + res = df.analytics.compute_lag( + cols=["SALESAMOUNT"], + lags=[1, 2], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + col_formatter=custom_col_formatter, + ) + + res = res.withColumn("LAG_SALESAMOUNT_1", col("LAG_SALESAMOUNT_1").cast("float")) + res = res.withColumn("LAG_SALESAMOUNT_2", col("LAG_SALESAMOUNT_2").cast("float")) + + expected_data = { + "ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"], + "PRODUCTKEY": [101, 101, 101, 102], + "SALESAMOUNT": [200, 100, 300, 250], + "LAG_SALESAMOUNT_1": [None, 200, 100, None], + "LAG_SALESAMOUNT_2": [None, None, 200, None], + } + expected_df = pd.DataFrame(expected_data) + + assert_frame_equal( + res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1 + ) + + +@pytest.mark.skipif(not is_pandas_available, reason="pandas is required") +def test_lead_lag_invalid_inputs(session): + """Tests df.analytics.compute_lag() and df.analytics.compute_lead() with invalid_inputs.""" + + df = get_sample_dataframe(session) + + with pytest.raises(ValueError) as exc: + df.analytics.compute_lead( + cols=["SALESAMOUNT"], + leads=[-1, -2], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "leads must be a list of integers > 0" in str(exc) + + with pytest.raises(ValueError) as exc: + df.analytics.compute_lead( + cols=["SALESAMOUNT"], + leads=[0, 2], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "leads must be a list of integers > 0" in str(exc) + + with pytest.raises(ValueError) as exc: + df.analytics.compute_lag( + cols=["SALESAMOUNT"], + lags=[-1, -2], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "lags must be a list of integers > 0" in str(exc) + + +@pytest.mark.skipif(not is_pandas_available, reason="pandas is required") +def test_time_series_agg(session): + """Tests time_series_agg_fixed function with various window sizes.""" + + df = get_sample_dataframe(session) + df = df.withColumn("ORDERDATE", to_timestamp(df["ORDERDATE"])) + + def custom_formatter(input_col, agg, window): + return f"{agg}_{input_col}_{window}" + + res = df.analytics.time_series_agg( + time_col="ORDERDATE", + group_by=["PRODUCTKEY"], + aggs={"SALESAMOUNT": ["SUM", "MAX"]}, + windows=["1D", "-1D", "2D", "-2D"], + sliding_interval="12H", + col_formatter=custom_formatter, + ) + + # Define the expected data + expected_data = { + "PRODUCTKEY": [101, 101, 101, 102], + "SLIDING_POINT": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"], + "SALESAMOUNT": [200, 100, 300, 250], + "ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"], + "SUM_SALESAMOUNT_1D": [300, 400, 300, 250], + "MAX_SALESAMOUNT_1D": [200, 300, 300, 250], + "SUM_SALESAMOUNT_-1D": [200, 300, 400, 250], + "MAX_SALESAMOUNT_-1D": [200, 200, 300, 250], + "SUM_SALESAMOUNT_2D": [600, 400, 300, 250], + "MAX_SALESAMOUNT_2D": [300, 300, 300, 250], + "SUM_SALESAMOUNT_-2D": [200, 300, 600, 250], + "MAX_SALESAMOUNT_-2D": [200, 200, 300, 250], + } + expected_df = pd.DataFrame(expected_data) + + expected_df["ORDERDATE"] = pd.to_datetime(expected_df["ORDERDATE"]) + expected_df["SLIDING_POINT"] = pd.to_datetime(expected_df["SLIDING_POINT"]) + + # Compare the result to the expected DataFrame + assert_frame_equal( + res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1 + ) + + +@pytest.mark.skipif(not is_pandas_available, reason="pandas is required") +def test_time_series_agg_month_sliding_window(session): + """Tests time_series_agg_fixed function with month window sizes.""" + + data = [ + ["2023-01-15", 101, 100], + ["2023-02-15", 101, 200], + ["2023-03-15", 101, 300], + ["2023-04-15", 101, 400], + ["2023-01-20", 102, 150], + ["2023-02-20", 102, 250], + ["2023-03-20", 102, 350], + ["2023-04-20", 102, 450], + ] + df = session.create_dataframe(data).to_df("ORDERDATE", "PRODUCTKEY", "SALESAMOUNT") + + df = df.withColumn("ORDERDATE", to_timestamp(df["ORDERDATE"])) + + def custom_formatter(input_col, agg, window): + return f"{agg}_{input_col}_{window}" + + res = df.analytics.time_series_agg( + time_col="ORDERDATE", + group_by=["PRODUCTKEY"], + aggs={"SALESAMOUNT": ["SUM", "MAX"]}, + windows=["-2mm"], + sliding_interval="1mm", + col_formatter=custom_formatter, + ) + + expected_data = { + "PRODUCTKEY": [101, 101, 101, 101, 102, 102, 102, 102], + "SLIDING_POINT": [ + "2023-01-01", + "2023-02-01", + "2023-03-01", + "2023-04-01", + "2023-02-01", + "2023-03-01", + "2023-04-01", + "2023-05-01", + ], + "SALESAMOUNT": [100, 200, 300, 400, 150, 250, 350, 450], + "ORDERDATE": [ + "2023-01-15", + "2023-02-15", + "2023-03-15", + "2023-04-15", + "2023-01-20", + "2023-02-20", + "2023-03-20", + "2023-04-20", + ], + "SUM_SALESAMOUNT_-2mm": [100, 300, 600, 900, 150, 400, 750, 1050], + "MAX_SALESAMOUNT_-2mm": [100, 200, 300, 400, 150, 250, 350, 450], + } + expected_df = pd.DataFrame(expected_data) + expected_df["ORDERDATE"] = pd.to_datetime(expected_df["ORDERDATE"]) + expected_df["SLIDING_POINT"] = pd.to_datetime(expected_df["SLIDING_POINT"]) + expected_df = expected_df.sort_values(by=["PRODUCTKEY", "ORDERDATE"]) + + result_df = res.order_by("PRODUCTKEY", "ORDERDATE").to_pandas() + result_df = result_df.sort_values(by=["PRODUCTKEY", "ORDERDATE"]) + + assert_frame_equal(result_df, expected_df, check_dtype=False, atol=1e-1) + + +@pytest.mark.skipif(not is_pandas_available, reason="pandas is required") +def test_time_series_agg_year_sliding_window(session): + """Tests time_series_agg_fixed function with year window sizes.""" + + data = [ + ["2021-01-15", 101, 100], + ["2022-01-15", 101, 200], + ["2023-01-15", 101, 300], + ["2024-01-15", 101, 400], + ["2021-01-20", 102, 150], + ["2022-01-20", 102, 250], + ["2023-01-20", 102, 350], + ["2024-01-20", 102, 450], + ] + df = session.create_dataframe(data).to_df("ORDERDATE", "PRODUCTKEY", "SALESAMOUNT") + df = df.withColumn("ORDERDATE", to_timestamp(df["ORDERDATE"])) + + def custom_formatter(input_col, agg, window): + return f"{agg}_{input_col}_{window}" + + res = df.analytics.time_series_agg( + time_col="ORDERDATE", + group_by=["PRODUCTKEY"], + aggs={"SALESAMOUNT": ["SUM", "MAX"]}, + windows=["-1Y"], + sliding_interval="1Y", + col_formatter=custom_formatter, + ) + + # Calculated expected data for 2Y window with 1Y sliding interval + expected_data = { + "PRODUCTKEY": [101, 101, 101, 101, 102, 102, 102, 102], + "SLIDING_POINT": [ + "2021-01-01", + "2022-01-01", + "2023-01-01", + "2024-01-01", + "2021-01-01", + "2022-01-01", + "2023-01-01", + "2024-01-01", + ], + "SALESAMOUNT": [100, 200, 300, 400, 150, 250, 350, 450], + "ORDERDATE": [ + "2021-01-15", + "2022-01-15", + "2023-01-15", + "2024-01-15", + "2021-01-20", + "2022-01-20", + "2023-01-20", + "2024-01-20", + ], + "SUM_SALESAMOUNT_-1Y": [100, 300, 500, 700, 150, 400, 600, 800], + "MAX_SALESAMOUNT_-1Y": [100, 200, 300, 400, 150, 250, 350, 450], + } + expected_df = pd.DataFrame(expected_data) + expected_df["ORDERDATE"] = pd.to_datetime(expected_df["ORDERDATE"]) + expected_df["SLIDING_POINT"] = pd.to_datetime(expected_df["SLIDING_POINT"]) + expected_df = expected_df.sort_values(by=["PRODUCTKEY", "ORDERDATE"]) + + result_df = res.order_by("PRODUCTKEY", "ORDERDATE").to_pandas() + result_df = result_df.sort_values(by=["PRODUCTKEY", "ORDERDATE"]) + + assert_frame_equal(result_df, expected_df, check_dtype=False, atol=1e-1) + + +@pytest.mark.skipif(not is_pandas_available, reason="pandas is required") +def test_time_series_agg_invalid_inputs(session): + """Tests time_series_agg function with invalid inputs.""" + + df = get_sample_dataframe(session) + + # Test with invalid time_col type + with pytest.raises(ValueError) as exc: + df.analytics.time_series_agg( + time_col=123, # Invalid type + group_by=["PRODUCTKEY"], + aggs={"SALESAMOUNT": ["SUM"]}, + windows=["7D"], + sliding_interval="1D", + ).collect() + assert "time_col must be a string" in str(exc) + + # Test with empty windows list + with pytest.raises(ValueError) as exc: + df.analytics.time_series_agg( + time_col="ORDERDATE", + group_by=["PRODUCTKEY"], + aggs={"SALESAMOUNT": ["SUM"]}, + windows=[], # Empty list + sliding_interval="1D", + ).collect() + assert "windows must not be empty" in str(exc) + + # Test with invalid window format + with pytest.raises(ValueError) as exc: + df.analytics.time_series_agg( + time_col="ORDERDATE", + group_by=["PRODUCTKEY"], + aggs={"SALESAMOUNT": ["SUM"]}, + windows=["Invalid"], + sliding_interval="1D", + ).collect() + assert "invalid literal for int() with base 10" in str(exc) + + # Test with invalid window format + with pytest.raises(ValueError) as exc: + df.analytics.time_series_agg( + time_col="ORDERDATE", + group_by=["PRODUCTKEY"], + aggs={"SALESAMOUNT": ["SUM"]}, + windows=["2k"], + sliding_interval="1D", + ).collect() + assert "Unsupported unit" in str(exc) + + # Test with invalid sliding_interval format + with pytest.raises(ValueError) as exc: + df.analytics.time_series_agg( + time_col="ORDERDATE", + group_by=["PRODUCTKEY"], + aggs={"SALESAMOUNT": ["SUM"]}, + windows=["7D"], + sliding_interval="invalid", # Invalid format + ).collect() + assert "invalid literal for int() with base 10" in str(exc) + + +@pytest.mark.skipif(not is_pandas_available, reason="pandas is required") +def test_parse_time_string(session): + daf = DataFrameAnalyticsFunctions(pd.DataFrame()) + assert daf._parse_time_string("10d") == (10, "d") + assert daf._parse_time_string("-5h") == (-5, "h") + assert daf._parse_time_string("-6mm") == (-6, "mm") + assert daf._parse_time_string("-6m") == (-6, "m") diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 6c50d0a1e5b..858a24508a3 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -112,7 +112,7 @@ def test_get_or_create(session): @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") def test_get_or_create_no_previous(db_parameters, session): - # Test getOrCreate error. In this case we wan to make sure that + # Test getOrCreate error. In this case we want to make sure that # if there was not a session the session gets created sessions_backup = list(_active_sessions) _active_sessions.clear() @@ -331,6 +331,21 @@ def test_create_session_from_connection_with_noise_parameters( new_session.close() +@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") +def test_session_builder_app_name(session, db_parameters): + builder = session.builder + app_name = 'my_app' + expected_query_tag = f'APPNAME={app_name}' + same_session = builder.app_name(app_name).getOrCreate() + new_session = builder.app_name(app_name).configs(db_parameters).create() + try: + assert session == same_session + assert same_session.query_tag is None + assert new_session.query_tag == expected_query_tag + finally: + new_session.close() + + @pytest.mark.skipif( IS_IN_STORED_PROC, reason="The test creates temporary tables of which the names do not follow the rules of temp object on purposes.", diff --git a/tests/integ/test_stored_procedure.py b/tests/integ/test_stored_procedure.py index a1a1f6b1418..007ddea4c19 100644 --- a/tests/integ/test_stored_procedure.py +++ b/tests/integ/test_stored_procedure.py @@ -73,6 +73,10 @@ def setup(session, resources_path, local_testing_mode): ) +@pytest.mark.skipif( + IS_IN_STORED_PROC, + reason="Cannot create session in SP", +) @patch("snowflake.snowpark.stored_procedure.VERSION", (999, 9, 9)) @pytest.mark.parametrize( "packages,should_fail", @@ -112,6 +116,10 @@ def return1(session_): assert return1_sproc(session=new_session) == "1" +@pytest.mark.skipif( + IS_IN_STORED_PROC, + reason="Cannot create session in SP", +) @patch( "snowflake.snowpark.stored_procedure.resolve_imports_and_packages", wraps=resolve_imports_and_packages, diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0f0e53d96fe..cfe9fe273cd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -418,3 +418,63 @@ def test_connection_expiry(): ) as m: assert builder.getOrCreate() is None m.assert_called_once() + + +def test_session_builder_app_name_no_existing_query_tag(): + mocked_session = Session( + ServerConnection( + {"": ""}, + mock.Mock( + spec=SnowflakeConnection, + _telemetry=mock.Mock(), + _session_parameters=mock.Mock(), + is_closed=mock.Mock(return_value=False), + expired=False, + ), + ), + ) + + mocked_session._get_remote_query_tag = MagicMock(return_value=None) + + builder = Session.builder + + with mock.patch.object( + builder, + "_create_internal", + return_value=mocked_session) as m: + app_name = 'my_app_name' + assert builder.app_name(app_name) is builder + created_session = builder.getOrCreate() + m.assert_called_once() + assert created_session.query_tag == f'APPNAME={app_name}' + + +def test_session_builder_app_name_existing_query_tag(): + mocked_session = Session( + ServerConnection( + {"": ""}, + mock.Mock( + spec=SnowflakeConnection, + _telemetry=mock.Mock(), + _session_parameters=mock.Mock(), + is_closed=mock.Mock(return_value=False), + expired=False, + ), + ), + ) + + existing_query_tag = 'tag' + + mocked_session._get_remote_query_tag = MagicMock(return_value=existing_query_tag) + + builder = Session.builder + + with mock.patch.object( + builder, + "_create_internal", + return_value=mocked_session) as m: + app_name = 'my_app_name' + assert builder.app_name(app_name) is builder + created_session = builder.getOrCreate() + m.assert_called_once() + assert created_session.query_tag == f'tag,APPNAME={app_name}'