Skip to content

Commit

Permalink
SNOW-984699: Add support for to_timestamp_* to local testing (#1244)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jrose authored Mar 7, 2024
1 parent 58c0acb commit 624cffb
Show file tree
Hide file tree
Showing 8 changed files with 623 additions and 138 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
### New Features

- Added support for creating vectorized UDTFs with `process` method.
- Added support for dataframe functions:
- to_timestamp_ltz
- to_timestamp_ntz
- to_timestamp_tz
- Added support for the following local testing functions:
- to_timestamp
- to_timestamp_ltz
- to_timestamp_ntz
- to_timestamp_tz
- greatest
- least
- Added support for ASOF JOIN type.
Expand Down
60 changes: 60 additions & 0 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3152,6 +3152,66 @@ def to_timestamp(e: ColumnOrName, fmt: Optional["Column"] = None) -> Column:
)


def to_timestamp_ntz(
e: ColumnOrName, fmt: Optional[ColumnOrLiteralStr] = None
) -> Column:
"""Converts an input expression into the corresponding timestamp without a timezone.
Per default fmt is set to auto, which makes Snowflake detect the format automatically. With `to_timestamp` strings
can be converted to timestamps. The format has to be specified according to the rules set forth in
<https://docs.snowflake.com/en/sql-reference/functions-conversion#date-and-time-formats-in-conversion-functions>
Example::
>>> import datetime
>>> df = session.createDataFrame([datetime.datetime(2022, 12, 25, 13, 59, 38, 467)], schema=["a"])
>>> df.select(to_timestamp_ntz(col("a"))).collect()
[Row(TO_TIMESTAMP_NTZ("A")=datetime.datetime(2022, 12, 25, 13, 59, 38, 467))]
>>> df = session.createDataFrame([datetime.date(2023, 3, 1)], schema=["a"])
>>> df.select(to_timestamp_ntz(col("a"))).collect()
[Row(TO_TIMESTAMP_NTZ("A")=datetime.datetime(2023, 3, 1, 0, 0))]
"""
c = _to_col_if_str(e, "to_timestamp_ntz")
return (
builtin("to_timestamp_ntz")(c, _to_col_if_lit(fmt, "to_timestamp_ntz"))
if fmt is not None
else builtin("to_timestamp_ntz")(c)
)


def to_timestamp_ltz(
e: ColumnOrName, fmt: Optional[ColumnOrLiteralStr] = None
) -> Column:
"""Converts an input expression into the corresponding timestamp using the local timezone.
Per default fmt is set to auto, which makes Snowflake detect the format automatically. With `to_timestamp` strings
can be converted to timestamps. The format has to be specified according to the rules set forth in
<https://docs.snowflake.com/en/sql-reference/functions-conversion#date-and-time-formats-in-conversion-functions>
"""
c = _to_col_if_str(e, "to_timestamp_ltz")
return (
builtin("to_timestamp_ltz")(c, _to_col_if_lit(fmt, "to_timestamp_ltz"))
if fmt is not None
else builtin("to_timestamp_ltz")(c)
)


def to_timestamp_tz(
e: ColumnOrName, fmt: Optional[ColumnOrLiteralStr] = None
) -> Column:
"""Converts an input expression into the corresponding timestamp with the timezone represented in each row.
Per default fmt is set to auto, which makes Snowflake detect the format automatically. With `to_timestamp` strings
can be converted to timestamps. The format has to be specified according to the rules set forth in
<https://docs.snowflake.com/en/sql-reference/functions-conversion#date-and-time-formats-in-conversion-functions>
"""
c = _to_col_if_str(e, "to_timestamp_tz")
return (
builtin("to_timestamp_tz")(c, _to_col_if_lit(fmt, "to_timestamp_tz"))
if fmt is not None
else builtin("to_timestamp_tz")(c)
)


def from_utc_timestamp(e: ColumnOrName, tz: ColumnOrLiteral) -> Column:
"""Interprets an input expression as a UTC timestamp and converts it to the given time zone.
Expand Down
155 changes: 135 additions & 20 deletions src/snowflake/snowpark/mock/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
import datetime
import json
import math
import numbers
import string
from decimal import Decimal
from functools import partial, reduce
from numbers import Real
from typing import Any, Callable, Optional, Tuple, TypeVar, Union

import pytz

from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.mock._snowflake_data_type import (
ColumnEmulator,
Expand Down Expand Up @@ -50,6 +53,30 @@
_MOCK_FUNCTION_IMPLEMENTATION_MAP = {}


class LocalTimezone:
"""
A singleton class that encapsulates conversion to the local timezone.
This class allows tests to override the local timezone in order to be consistent in different regions.
"""

LOCAL_TZ: Optional[datetime.timezone] = None

@classmethod
def set_local_timezone(cls, tz: Optional[datetime.timezone] = None) -> None:
"""Overrides the local timezone with the given value. When the local timezone is None the system timezone is used."""
cls.LOCAL_TZ = tz

@classmethod
def to_local_timezone(cls, d: datetime.datetime) -> datetime.datetime:
"""Converts an input datetime to the local timezone."""
return d.astimezone(tz=cls.LOCAL_TZ)

@classmethod
def replace_tz(cls, d: datetime.datetime) -> datetime.datetime:
"""Replaces any existing tz info with the local tz info without adjucting the time."""
return d.replace(tzinfo=cls.LOCAL_TZ)


def _register_func_implementation(
snowpark_func: Union[str, Callable], func_implementation: Callable
):
Expand All @@ -74,7 +101,7 @@ def decorator(mocking_function):
_register_func_implementation(function, mocking_function)

def wrapper(*args, **kwargs):
mocking_function(*args, **kwargs)
return mocking_function(*args, **kwargs)

return wrapper

Expand Down Expand Up @@ -502,11 +529,11 @@ def mock_to_time(
)


@patch("to_timestamp")
def mock_to_timestamp(
def _to_timestamp(
column: ColumnEmulator,
fmt: Optional[str] = None,
fmt: Optional[ColumnEmulator],
try_cast: bool = False,
add_timezone: bool = False,
):
"""
[x] For NULL input, the result will be NULL.
Expand Down Expand Up @@ -546,25 +573,46 @@ def mock_to_timestamp(
[ ] If the value is greater than or equal to 31536000000000000, then the value is treated as nanoseconds.
"""
res = []
auto_detect = bool(not fmt)
default_format = "%Y-%m-%d %H:%M:%S.%f"
(
timestamp_format,
hour_delta,
fractional_seconds,
) = convert_snowflake_datetime_format(fmt, default_format=default_format)
fmt_column = fmt if fmt is not None else [None] * len(column)

for data, format in zip(column, fmt_column):
auto_detect = bool(not format)
default_format = "%Y-%m-%d %H:%M:%S.%f"
(
timestamp_format,
hour_delta,
fractional_seconds,
) = convert_snowflake_datetime_format(format, default_format=default_format)

for data in column:
try:
if data is None:
res.append(None)
continue
if auto_detect and (
isinstance(data, int) or (isinstance(data, str) and data.isnumeric())
):
res.append(
datetime.datetime.utcfromtimestamp(process_numeric_time(data))
)

if auto_detect:
if isinstance(data, numbers.Number) or (
isinstance(data, str) and data.isnumeric()
):
parsed = datetime.datetime.utcfromtimestamp(
process_numeric_time(data)
)
# utc timestamps should be in utc timezone
if add_timezone:
parsed = parsed.replace(tzinfo=pytz.utc)
elif isinstance(data, datetime.datetime):
parsed = data
elif isinstance(data, datetime.date):
parsed = datetime.datetime.combine(data, datetime.time(0, 0, 0))
elif isinstance(data, str):
# dateutil is a pandas dependency
import dateutil.parser

try:
parsed = dateutil.parser.parse(data)
except ValueError:
parsed = None
else:
parsed = None
else:
# handle seconds fraction
try:
Expand All @@ -587,20 +635,87 @@ def mock_to_timestamp(
)
else:
raise
res.append(datetime_data + datetime.timedelta(hours=hour_delta))
parsed = datetime_data + datetime.timedelta(hours=hour_delta)

# Add the local timezone if tzinfo is missing and a tz is desired
if parsed and add_timezone and parsed.tzinfo is None:
parsed = LocalTimezone.replace_tz(parsed)

res.append(parsed)
except BaseException:
if try_cast:
res.append(None)
else:
raise
return res


@patch("to_timestamp")
def mock_to_timestamp(
column: ColumnEmulator,
fmt: Optional[ColumnEmulator] = None,
try_cast: bool = False,
):
return ColumnEmulator(
data=res,
data=_to_timestamp(column, fmt, try_cast),
sf_type=ColumnType(TimestampType(), column.sf_type.nullable),
dtype=object,
)


@patch("to_timestamp_ntz")
def mock_timestamp_ntz(
column: ColumnEmulator,
fmt: Optional[ColumnEmulator] = None,
try_cast: bool = False,
):
result = _to_timestamp(column, fmt, try_cast)
# Cast to NTZ by removing tz data if present
return ColumnEmulator(
data=[x.replace(tzinfo=None) for x in result],
sf_type=ColumnType(
TimestampType(TimestampTimeZone.NTZ), column.sf_type.nullable
),
dtype=object,
)


@patch("to_timestamp_ltz")
def mock_to_timestamp_ltz(
column: ColumnEmulator,
fmt: Optional[ColumnEmulator] = None,
try_cast: bool = False,
):
result = _to_timestamp(column, fmt, try_cast, add_timezone=True)

# Cast to ltz by providing an empty timezone when calling astimezone
# datetime will populate with the local zone
return ColumnEmulator(
data=[LocalTimezone.to_local_timezone(x) for x in result],
sf_type=ColumnType(
TimestampType(TimestampTimeZone.LTZ), column.sf_type.nullable
),
dtype=object,
)


@patch("to_timestamp_tz")
def mock_to_timestamp_tz(
column: ColumnEmulator,
fmt: Optional[ColumnEmulator] = None,
try_cast: bool = False,
):
# _to_timestamp will use the tz present in the data.
# Otherwise it adds an appropriate one by default.
return ColumnEmulator(
data=_to_timestamp(column, fmt, try_cast, add_timezone=True),
sf_type=ColumnType(
TimestampType(TimestampTimeZone.TZ), column.sf_type.nullable
),
dtype=column.dtype,
)


def try_convert(convert: Callable, try_cast: bool, val: Any):
if val is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion src/snowflake/snowpark/mock/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,12 @@ def convert_snowflake_datetime_format(format, default_format) -> Tuple[str, int,
time_fmt = time_fmt.replace("MI", "%M")
time_fmt = time_fmt.replace("SS", "%S")
time_fmt = time_fmt.replace("SS", "%S")
time_fmt = time_fmt.replace("TZHTZM", "%z")
time_fmt = time_fmt.replace("TZH", "%z")
fractional_seconds = 9
if format is not None and "FF" in format:
try:
ff_index = str(format).index("FF")
ff_index = str(time_fmt).index("FF")
# handle precision string 'FF[0-9]' which could be like FF0, FF1, ..., FF9
if str(format[ff_index + 2 : ff_index + 3]).isdigit():
fractional_seconds = int(format[ff_index + 2 : ff_index + 3])
Expand Down
9 changes: 6 additions & 3 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@
to_object,
to_time,
to_timestamp,
to_timestamp_ltz,
to_timestamp_ntz,
to_timestamp_tz,
to_variant,
)
from snowflake.snowpark.mock._analyzer import MockAnalyzer
Expand Down Expand Up @@ -2451,11 +2454,11 @@ def convert_row_to_list(
elif isinstance(field.datatype, TimestampType):
tz = field.datatype.tz
if tz == TimestampTimeZone.NTZ:
to_timestamp_func = builtin("to_timestamp_ntz")
to_timestamp_func = to_timestamp_ntz
elif tz == TimestampTimeZone.LTZ:
to_timestamp_func = builtin("to_timestamp_ltz")
to_timestamp_func = to_timestamp_ltz
elif tz == TimestampTimeZone.TZ:
to_timestamp_func = builtin("to_timestamp_tz")
to_timestamp_func = to_timestamp_tz
else:
to_timestamp_func = to_timestamp
project_columns.append(to_timestamp_func(column(name)).as_(name))
Expand Down
Loading

0 comments on commit 624cffb

Please sign in to comment.