diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index ea23116c2c6..e5904000a21 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -58,6 +58,7 @@ Expression, FunctionExpression, InExpression, + Interval, Like, ListAgg, Literal, @@ -339,6 +340,9 @@ def analyze( sql = sql.upper() return sql + if isinstance(expr, Interval): + return expr.sql + if isinstance(expr, Attribute): assert self.alias_maps_to_use is not None name = self.alias_maps_to_use.get(expr.expr_id, expr.name) diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 236f1c5501f..de7a9d933dd 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -225,6 +225,54 @@ def __init__(self, value: Any, datatype: Optional[DataType] = None) -> None: self.datatype = infer_type(value) +class Interval(Expression): + def __init__( + self, + year: Optional[int] = None, + quarter: Optional[int] = None, + month: Optional[int] = None, + week: Optional[int] = None, + day: Optional[int] = None, + hour: Optional[int] = None, + minute: Optional[int] = None, + second: Optional[int] = None, + millisecond: Optional[int] = None, + microsecond: Optional[int] = None, + nanosecond: Optional[int] = None, + ) -> None: + super().__init__() + self.values_dict = {} + if year is not None: + self.values_dict["YEAR"] = year + if quarter is not None: + self.values_dict["QUARTER"] = quarter + if month is not None: + self.values_dict["MONTH"] = month + if week is not None: + self.values_dict["WEEK"] = week + if day is not None: + self.values_dict["DAY"] = day + if hour is not None: + self.values_dict["HOUR"] = hour + if minute is not None: + self.values_dict["MINUTE"] = minute + if second is not None: + self.values_dict["SECOND"] = second + if millisecond is not None: + self.values_dict["MILLISECOND"] = millisecond + if microsecond is not None: + self.values_dict["MICROSECOND"] = microsecond + if nanosecond is not None: + self.values_dict["NANOSECOND"] = nanosecond + + @property + def sql(self) -> str: + return f"""INTERVAL '{",".join(f"{v} {k}" for k, v in self.values_dict.items())}'""" + + def __str__(self) -> str: + return self.sql + + class Like(Expression): def __init__(self, expr: Expression, pattern: Expression) -> None: super().__init__(expr) diff --git a/tests/integ/test_column_names.py b/tests/integ/test_column_names.py index 39721d32b77..4d3bdead871 100644 --- a/tests/integ/test_column_names.py +++ b/tests/integ/test_column_names.py @@ -2,11 +2,13 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +import datetime import math import pytest -from snowflake.snowpark import Window +from snowflake.snowpark import Column, Window +from snowflake.snowpark._internal.analyzer.expression import Interval from snowflake.snowpark._internal.utils import TempObjectType, quote_name from snowflake.snowpark.functions import ( any_value, @@ -286,6 +288,41 @@ def test_literal(session): ) +def test_interval(session): + df1 = session.create_dataframe( + [ + [datetime.datetime(2010, 1, 1), datetime.datetime(2011, 1, 1)], + [datetime.datetime(2012, 1, 1), datetime.datetime(2013, 1, 1)], + ], + schema=["a", "b"], + ) + df2 = df1.select( + df1["a"] + + Column( + Interval( + quarter=1, + month=1, + week=2, + day=2, + hour=2, + minute=3, + second=3, + millisecond=3, + microsecond=4, + nanosecond=4, + ) + ) + ) + assert ( + [x.name for x in df2._output] + == df2.columns + == get_metadata_names(session, df2) + == [ + '"(""A"" + INTERVAL \'1 QUARTER,1 MONTH,2 WEEK,2 DAY,2 HOUR,3 MINUTE,3 SECOND,3 MILLISECOND,4 MICROSECOND,4 NANOSECOND\')"', + ] + ) + + @pytest.mark.localtest def test_attribute(session): df1 = session.create_dataframe([[1, 2]], schema=[" a", "a"]) diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index 8c3c85873a1..a45f6797831 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -28,7 +28,7 @@ from snowflake.connector import IntegrityError from snowflake.snowpark import Column, Row, Window from snowflake.snowpark._internal.analyzer.analyzer_utils import result_scan_statement -from snowflake.snowpark._internal.analyzer.expression import Attribute, Star +from snowflake.snowpark._internal.analyzer.expression import Attribute, Interval, Star from snowflake.snowpark._internal.utils import TempObjectType, warning_dict from snowflake.snowpark.exceptions import ( SnowparkColumnException, @@ -3440,3 +3440,47 @@ def test_drop_columns_special_names(session): Utils.check_answer(df2, [Row(1), Row(2)]) finally: Utils.drop_table(session, table_name) + + +def test_dataframe_interval_operation(session): + df = session.create_dataframe( + [ + [datetime.datetime(2010, 1, 1), datetime.datetime(2011, 1, 1)], + [datetime.datetime(2012, 1, 1), datetime.datetime(2013, 1, 1)], + ], + schema=["a", "b"], + ) + df2 = df.with_column( + "TWO_DAYS_AHEAD", + df["a"] + + Column( + Interval( + year=1, + quarter=1, + month=1, + week=2, + day=2, + hour=2, + minute=3, + second=3, + millisecond=3, + microsecond=4, + nanosecond=4, + ) + ), + ) + Utils.check_answer( + df2, + [ + Row( + datetime.datetime(2010, 1, 1, 0, 0, 0), + datetime.datetime(2011, 1, 1, 0, 0, 0), + datetime.datetime(2011, 5, 17, 2, 3, 3, 3004), + ), + Row( + datetime.datetime(2012, 1, 1, 0, 0, 0), + datetime.datetime(2013, 1, 1, 0, 0, 0), + datetime.datetime(2013, 5, 17, 2, 3, 3, 3004), + ), + ], + )