Skip to content

Commit

Permalink
feat(server): support time & timestamp types
Browse files Browse the repository at this point in the history
  • Loading branch information
tekumara committed Aug 25, 2024
1 parent 618bedf commit c03953a
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 31 deletions.
76 changes: 70 additions & 6 deletions fakesnow/arrow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from __future__ import annotations

from typing import cast

import pyarrow as pa
import pyarrow.compute as pc

from fakesnow.types import ColumnInfo

Expand All @@ -13,8 +18,18 @@ def with_sf_metadata(schema: pa.Schema, rowtype: list[ColumnInfo]) -> pa.Schema:
# see https://github.com/snowflakedb/snowflake-connector-python/blob/e9393a6/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.cpp#L32
# and https://github.com/snowflakedb/snowflake-connector-python/blob/e9393a6/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.cpp#L10

fms = [
schema.field(i).with_metadata(
def sf_field(field: pa.Field, c: ColumnInfo) -> pa.Field:
if isinstance(field.type, pa.TimestampType):
# snowflake uses a struct to represent timestamps, see timestamp_to_sf_struct
fields = [pa.field("epoch", pa.int64(), nullable=False), pa.field("fraction", pa.int32(), nullable=False)]
if field.type.tz:
fields.append(pa.field("timezone", nullable=False, type=pa.int32()))
field = field.with_type(pa.struct(fields))
elif isinstance(field.type, pa.Time64Type):
# this will cast values to int64 as expected by snowflake
field = field.with_type(pa.int64())

return field.with_metadata(
{
"logicalType": c["type"].upper(),
# required for FIXED type see
Expand All @@ -24,20 +39,69 @@ def with_sf_metadata(schema: pa.Schema, rowtype: list[ColumnInfo]) -> pa.Schema:
"charLength": str(c["length"] or 0),
}
)
for i, c in enumerate(rowtype)
]

fms = [sf_field(schema.field(i), c) for i, c in enumerate(rowtype)]
return pa.schema(fms)


def to_ipc(table: pa.Table, rowtype: list[ColumnInfo]) -> pa.Buffer:
def to_ipc(table: pa.Table) -> pa.Buffer:
batches = table.to_batches()
if len(batches) != 1:
raise NotImplementedError(f"{len(batches)} batches")
batch = batches[0]

sink = pa.BufferOutputStream()

with pa.ipc.new_stream(sink, with_sf_metadata(table.schema, rowtype)) as writer:
with pa.ipc.new_stream(sink, table.schema) as writer:
writer.write_batch(batch)

return sink.getvalue()


def to_sf(table: pa.Table, rowtype: list[ColumnInfo]) -> pa.Table:
def to_sf_col(col: pa.Array) -> pa.Array:
if pa.types.is_timestamp(col.type):
return timestamp_to_sf_struct(col)
elif pa.types.is_time(col.type):
# as nanoseconds
return pc.multiply(col.cast(pa.int64()), 1000) # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/44
return col

return pa.Table.from_arrays([to_sf_col(c) for c in table.columns], schema=with_sf_metadata(table.schema, rowtype))


def timestamp_to_sf_struct(ts: pa.Array | pa.ChunkedArray) -> pa.Array:
if isinstance(ts, pa.ChunkedArray):
# combine because pa.StructArray.from_arrays doesn't support ChunkedArray
ts = cast(pa.Array, ts.combine_chunks()) # see https://github.com/zen-xu/pyarrow-stubs/issues/46

if not isinstance(ts.type, pa.TimestampType):
raise ValueError(f"Expected TimestampArray, got {type(ts)}")

# Round to seconds, ie: strip subseconds
tsa_without_us = pc.floor_temporal(ts, unit="second") # type: ignore see https://github.com/zen-xu/pyarrow-stubs/issues/45
epoch = pc.divide(tsa_without_us.cast(pa.int64()), 1_000_000) # type: ignore see https://github.com/zen-xu/pyarrow-stubs/issues/44

# Calculate fractional part as nanoseconds
fraction = pc.multiply(pc.subsecond(ts), 1_000_000_000).cast(pa.int32()) # type: ignore

if ts.type.tz:
assert ts.type.tz == "UTC", f"Timezone {ts.type.tz} not yet supported"
timezone = pa.array([1440] * len(ts), type=pa.int32())

return pa.StructArray.from_arrays(
arrays=[epoch, fraction, timezone],
fields=[
pa.field("epoch", nullable=False, type=pa.int64()),
pa.field("fraction", nullable=False, type=pa.int32()),
pa.field("timezone", nullable=False, type=pa.int32()),
],
)
else:
return pa.StructArray.from_arrays(
arrays=[epoch, fraction],
fields=[
pa.field("epoch", nullable=False, type=pa.int64()),
pa.field("fraction", nullable=False, type=pa.int32()),
],
)
1 change: 1 addition & 0 deletions fakesnow/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def description(self) -> list[ResultMetadata]:
def _describe_last_sql(self) -> list:
# use a separate cursor to avoid consuming the result set on this cursor
with self._conn.cursor() as cur:
# TODO: can we replace with self._duck_conn.description?
expression = sqlglot.parse_one(f"DESCRIBE {self._last_sql}", read="duckdb")
cur._execute(expression, self._last_params) # noqa: SLF001
return cur.fetchall()
Expand Down
4 changes: 2 additions & 2 deletions fakesnow/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from starlette.responses import JSONResponse
from starlette.routing import Route

from fakesnow.arrow import to_ipc
from fakesnow.arrow import to_ipc, to_sf
from fakesnow.fakes import FakeSnowflakeConnection
from fakesnow.instance import FakeSnow
from fakesnow.types import describe_as_rowtype
Expand Down Expand Up @@ -66,7 +66,7 @@ async def query_request(request: Request) -> JSONResponse:
rowtype = describe_as_rowtype(cur._describe_last_sql()) # noqa: SLF001

if cur._arrow_table: # noqa: SLF001
batch_bytes = to_ipc(cur._arrow_table, rowtype) # noqa: SLF001
batch_bytes = to_ipc(to_sf(cur._arrow_table, rowtype)) # noqa: SLF001
rowset_b64 = b64encode(batch_bytes).decode("utf-8")
else:
rowset_b64 = ""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dev = [
# include compatible version of pandas, and secure-local-storage for token caching
"snowflake-connector-python[pandas, secure-local-storage]",
"pre-commit~=3.4",
"pyarrow-stubs",
"pyarrow-stubs==10.0.1.9",
"pytest~=8.0",
"pytest-asyncio",
"ruff~=0.5.1",
Expand Down
93 changes: 76 additions & 17 deletions tests/test_arrow.py

Large diffs are not rendered by default.

33 changes: 29 additions & 4 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
# ruff: noqa: E501

import datetime
import threading
from collections.abc import Iterator
from decimal import Decimal
from time import sleep
from typing import Callable

import pytest
import pytz
import snowflake.connector
import uvicorn
from snowflake.connector.cursor import ResultMetadata

import fakesnow.server
from tests.utils import indent


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -104,14 +107,36 @@ def test_server_types_no_result_set(sconn: snowflake.connector.SnowflakeConnecti


def test_server_types(scur: snowflake.connector.cursor.SnowflakeCursor) -> None:
scur.execute(
cur = scur
cur.execute(
# TODO: match columns names without using AS
"""
select true as TRUE, 1::int as "1::INT", 2.0::float as "2.0::FLOAT", to_decimal('12.3456', 10,2) as "TO_DECIMAL('12.3456', 10,2)",
'hello' as "'HELLO'"
select
true, 1::int, 2.0::float, to_decimal('12.3456', 10,2),
'hello', 'hello'::varchar(20),
to_date('2018-04-15'), to_time('04:15:29.123456'), to_timestamp_tz('2013-04-05 01:02:03.123456'), to_timestamp_ntz('2013-04-05 01:02:03.123456'),
/* X'41424320E29D84', ARRAY_CONSTRUCT('foo'), */ OBJECT_CONSTRUCT('k','v1'), 1.23::VARIANT
"""
)
assert scur.fetchall() == [(True, 1, 2.0, Decimal("12.35"), "hello")]
assert indent(cur.fetchall()) == [
(
True,
1,
2.0,
Decimal("12.35"),
"hello",
"hello",
datetime.date(2018, 4, 15),
datetime.time(4, 15, 29, 123456),
datetime.datetime(2013, 4, 5, 1, 2, 3, 123456, tzinfo=pytz.utc),
datetime.datetime(2013, 4, 5, 1, 2, 3, 123456),
# TODO
# bytearray(b"ABC \xe2\x9d\x84"),
# '[\n "foo"\n]',
'{\n "k": "v1"\n}',
"1.23",
)
]


def test_server_abort_request(server: dict) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def indent(rows: Sequence[tuple] | Sequence[dict]) -> list[tuple]:

def dindent(rows: Sequence[tuple] | Sequence[dict]) -> list[dict]:
# indent duckdb json strings dict values to match snowflake json strings
assert isinstance(rows[0], dict)
assert isinstance(rows[0], dict), f"{type(rows[0]).__name__} is not dict"
return [
{
k: json.dumps(json.loads(v), indent=2) if (isinstance(v, str) and v.startswith(("[", "{"))) else v
Expand Down

0 comments on commit c03953a

Please sign in to comment.