Skip to content

Commit

Permalink
Add Interval expression to Snowpark API (#1191)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-nkrishna authored Jan 11, 2024
1 parent 51a0c44 commit d924d78
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
Expression,
FunctionExpression,
InExpression,
Interval,
Like,
ListAgg,
Literal,
Expand Down Expand Up @@ -339,6 +340,9 @@ def analyze(
sql = sql.upper()
return sql

if isinstance(expr, Interval):
return expr.sql

if isinstance(expr, Attribute):
assert self.alias_maps_to_use is not None
name = self.alias_maps_to_use.get(expr.expr_id, expr.name)
Expand Down
48 changes: 48 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,54 @@ def __init__(self, value: Any, datatype: Optional[DataType] = None) -> None:
self.datatype = infer_type(value)


class Interval(Expression):
def __init__(
self,
year: Optional[int] = None,
quarter: Optional[int] = None,
month: Optional[int] = None,
week: Optional[int] = None,
day: Optional[int] = None,
hour: Optional[int] = None,
minute: Optional[int] = None,
second: Optional[int] = None,
millisecond: Optional[int] = None,
microsecond: Optional[int] = None,
nanosecond: Optional[int] = None,
) -> None:
super().__init__()
self.values_dict = {}
if year is not None:
self.values_dict["YEAR"] = year
if quarter is not None:
self.values_dict["QUARTER"] = quarter
if month is not None:
self.values_dict["MONTH"] = month
if week is not None:
self.values_dict["WEEK"] = week
if day is not None:
self.values_dict["DAY"] = day
if hour is not None:
self.values_dict["HOUR"] = hour
if minute is not None:
self.values_dict["MINUTE"] = minute
if second is not None:
self.values_dict["SECOND"] = second
if millisecond is not None:
self.values_dict["MILLISECOND"] = millisecond
if microsecond is not None:
self.values_dict["MICROSECOND"] = microsecond
if nanosecond is not None:
self.values_dict["NANOSECOND"] = nanosecond

@property
def sql(self) -> str:
return f"""INTERVAL '{",".join(f"{v} {k}" for k, v in self.values_dict.items())}'"""

def __str__(self) -> str:
return self.sql


class Like(Expression):
def __init__(self, expr: Expression, pattern: Expression) -> None:
super().__init__(expr)
Expand Down
39 changes: 38 additions & 1 deletion tests/integ/test_column_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

import datetime
import math

import pytest

from snowflake.snowpark import Window
from snowflake.snowpark import Column, Window
from snowflake.snowpark._internal.analyzer.expression import Interval
from snowflake.snowpark._internal.utils import TempObjectType, quote_name
from snowflake.snowpark.functions import (
any_value,
Expand Down Expand Up @@ -286,6 +288,41 @@ def test_literal(session):
)


def test_interval(session):
df1 = session.create_dataframe(
[
[datetime.datetime(2010, 1, 1), datetime.datetime(2011, 1, 1)],
[datetime.datetime(2012, 1, 1), datetime.datetime(2013, 1, 1)],
],
schema=["a", "b"],
)
df2 = df1.select(
df1["a"]
+ Column(
Interval(
quarter=1,
month=1,
week=2,
day=2,
hour=2,
minute=3,
second=3,
millisecond=3,
microsecond=4,
nanosecond=4,
)
)
)
assert (
[x.name for x in df2._output]
== df2.columns
== get_metadata_names(session, df2)
== [
'"(""A"" + INTERVAL \'1 QUARTER,1 MONTH,2 WEEK,2 DAY,2 HOUR,3 MINUTE,3 SECOND,3 MILLISECOND,4 MICROSECOND,4 NANOSECOND\')"',
]
)


@pytest.mark.localtest
def test_attribute(session):
df1 = session.create_dataframe([[1, 2]], schema=[" a", "a"])
Expand Down
46 changes: 45 additions & 1 deletion tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from snowflake.connector import IntegrityError
from snowflake.snowpark import Column, Row, Window
from snowflake.snowpark._internal.analyzer.analyzer_utils import result_scan_statement
from snowflake.snowpark._internal.analyzer.expression import Attribute, Star
from snowflake.snowpark._internal.analyzer.expression import Attribute, Interval, Star
from snowflake.snowpark._internal.utils import TempObjectType, warning_dict
from snowflake.snowpark.exceptions import (
SnowparkColumnException,
Expand Down Expand Up @@ -3440,3 +3440,47 @@ def test_drop_columns_special_names(session):
Utils.check_answer(df2, [Row(1), Row(2)])
finally:
Utils.drop_table(session, table_name)


def test_dataframe_interval_operation(session):
df = session.create_dataframe(
[
[datetime.datetime(2010, 1, 1), datetime.datetime(2011, 1, 1)],
[datetime.datetime(2012, 1, 1), datetime.datetime(2013, 1, 1)],
],
schema=["a", "b"],
)
df2 = df.with_column(
"TWO_DAYS_AHEAD",
df["a"]
+ Column(
Interval(
year=1,
quarter=1,
month=1,
week=2,
day=2,
hour=2,
minute=3,
second=3,
millisecond=3,
microsecond=4,
nanosecond=4,
)
),
)
Utils.check_answer(
df2,
[
Row(
datetime.datetime(2010, 1, 1, 0, 0, 0),
datetime.datetime(2011, 1, 1, 0, 0, 0),
datetime.datetime(2011, 5, 17, 2, 3, 3, 3004),
),
Row(
datetime.datetime(2012, 1, 1, 0, 0, 0),
datetime.datetime(2013, 1, 1, 0, 0, 0),
datetime.datetime(2013, 5, 17, 2, 3, 3, 3004),
),
],
)

0 comments on commit d924d78

Please sign in to comment.