Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1000284: Add schema support for structure types #1323

Merged
merged 7 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### New Features
- Support stored procedure register with packages given as Python modules.
- Added support for structured type schema parsing.

## 1.15.0 (2024-04-24)

Expand Down
52 changes: 49 additions & 3 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,37 @@ def convert_metadata_to_sp_type(
raise ValueError(
f"Invalid result metadata for vector type: invalid element type: {element_type_name}"
)
elif column_type_name in {"ARRAY", "MAP", "OBJECT"} and getattr(
metadata, "fields", None
):
# If fields is not defined or empty then the legacy type can be returned instead
if column_type_name == "ARRAY":
assert (
len(metadata.fields) == 1
), "ArrayType columns should have one metadata field."
return ArrayType(
convert_metadata_to_sp_type(metadata.fields[0]), structured=True
)
elif column_type_name == "MAP":
assert (
len(metadata.fields) == 2
), "MapType columns should have two metadata fields."
return MapType(
convert_metadata_to_sp_type(metadata.fields[0]),
convert_metadata_to_sp_type(metadata.fields[1]),
structured=True,
)
else:
assert all(
getattr(field, "name", None) for field in metadata.fields
), "All fields of a StructType should be named."
return StructType(
[
StructField(field.name, convert_metadata_to_sp_type(field))
for field in metadata.fields
],
structured=True,
)
else:
return convert_sf_to_sp_type(
column_type_name,
Expand All @@ -142,7 +173,7 @@ def convert_sf_to_sp_type(
return ArrayType(StringType())
if column_type_name == "VARIANT":
return VariantType()
if column_type_name == "OBJECT":
if column_type_name in {"OBJECT", "MAP"}:
return MapType(StringType(), StringType())
if column_type_name == "GEOGRAPHY":
return GeographyType()
Expand Down Expand Up @@ -235,9 +266,24 @@ def convert_sp_to_sf_type(datatype: DataType) -> str:
if isinstance(datatype, BinaryType):
return "BINARY"
if isinstance(datatype, ArrayType):
return "ARRAY"
if datatype.structured:
return f"ARRAY({convert_sp_to_sf_type(datatype.element_type)})"
else:
return "ARRAY"
if isinstance(datatype, MapType):
return "OBJECT"
if datatype.structured:
return f"MAP({convert_sp_to_sf_type(datatype.key_type)}, {convert_sp_to_sf_type(datatype.value_type)})"
else:
return "OBJECT"
if isinstance(datatype, StructType):
if datatype.structured:
fields = ", ".join(
f"{field.name.upper()} {convert_sp_to_sf_type(field.datatype)}"
for field in datatype.fields
)
return f"OBJECT({fields})"
else:
return "OBJECT"
if isinstance(datatype, VariantType):
return "VARIANT"
if isinstance(datatype, GeographyType):
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ def create_python_udf_or_sp(

if replace and if_not_exists:
raise ValueError("options replace and if_not_exists are incompatible")
if isinstance(return_type, StructType):
if isinstance(return_type, StructType) and not return_type.structured:
return_sql = f'RETURNS TABLE ({",".join(f"{field.name} {convert_sp_to_sf_type(field.datatype)}" for field in return_type.fields)})'
elif installed_pandas and isinstance(return_type, PandasDataFrameType):
return_sql = f'RETURNS TABLE ({",".join(f"{name} {convert_sp_to_sf_type(datatype)}" for name, datatype in zip(return_type.col_names, return_type.col_types))})'
Expand Down
20 changes: 15 additions & 5 deletions src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,10 @@ def __repr__(self) -> str:
class ArrayType(DataType):
"""Array data type. This maps to the ARRAY data type in Snowflake."""

def __init__(self, element_type: Optional[DataType] = None) -> None:
def __init__(
self, element_type: Optional[DataType] = None, structured: bool = False
) -> None:
self.structured = structured
self.element_type = element_type if element_type else StringType()

def __repr__(self) -> str:
Expand All @@ -228,11 +231,15 @@ def is_primitive(self):


class MapType(DataType):
"""Map data type. This maps to the OBJECT data type in Snowflake."""
"""Map data type. This maps to the OBJECT data type in Snowflake if key and value types are not defined otherwise MAP."""

def __init__(
self, key_type: Optional[DataType] = None, value_type: Optional[DataType] = None
self,
key_type: Optional[DataType] = None,
value_type: Optional[DataType] = None,
structured: bool = False,
) -> None:
self.structured = structured
self.key_type = key_type if key_type else StringType()
self.value_type = value_type if value_type else StringType()

Expand Down Expand Up @@ -366,9 +373,12 @@ def __eq__(self, other):


class StructType(DataType):
"""Represents a table schema. Contains :class:`StructField` for each column."""
"""Represents a table schema or structured column. Contains :class:`StructField` for each field."""

def __init__(self, fields: Optional[List["StructField"]] = None) -> None:
def __init__(
self, fields: Optional[List["StructField"]] = None, structured=False
) -> None:
self.structured = structured
if fields is None:
fields = []
self.fields = fields
Expand Down
233 changes: 231 additions & 2 deletions tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

import uuid

# Many of the tests have been moved to unit/scala/test_datattype_suite.py
from decimal import Decimal

import pytest

from snowflake.snowpark import Row
from snowflake.snowpark.functions import lit
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.functions import col, lit, udf
from snowflake.snowpark.types import (
ArrayType,
BinaryType,
Expand All @@ -33,7 +36,69 @@
VariantType,
VectorType,
)
from tests.utils import Utils
from tests.utils import IS_ICEBERG_SUPPORTED, IS_STRUCTURED_TYPES_SUPPORTED, Utils

# Map of structured type enabled state to test params
STRUCTURED_TYPES_EXAMPLES = {
True: pytest.param(
"""
select
object_construct('k1', 1) :: map(varchar, int) as map,
object_construct('A', 'foo', 'B', 0.05) :: object(A varchar, B float) as obj,
[1.0, 3.1, 4.5] :: array(float) as arr
""",
[
("MAP", "map<string(16777216),bigint>"),
("OBJ", "struct<string(16777216),double>"),
("ARR", "array<double>"),
],
StructType(
[
StructField(
"MAP",
MapType(StringType(16777216), LongType(), structured=True),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to explicit use the length 16777216? This number may change soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without a BCR we have to use the length for now. I expect that whenever I make the lob change I'll need to update this test and others that depend on this constant.

nullable=True,
),
StructField(
"OBJ",
StructType(
[
StructField("A", StringType(16777216), nullable=True),
StructField("B", DoubleType(), nullable=True),
],
structured=True,
),
nullable=True,
),
StructField(
"ARR", ArrayType(DoubleType(), structured=True), nullable=True
),
]
),
id="structured-types-enabled",
),
False: pytest.param(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we finally will remove this example once all deployments supporting struct types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, at some point in the future we will make a BCR that turns structured types on by default. That will probably be the best time to remove this.

"""
select
object_construct('k1', 1) :: map(varchar, int) as map,
object_construct('a', 'foo', 'b', 0.05) :: object(a varchar, b float) as obj,
[1.0, 3.1, 4.5] :: array(float) as arr
""",
[
("MAP", "map<string,string>"),
("OBJ", "map<string,string>"),
("ARR", "array<string>"),
],
StructType(
[
StructField("MAP", MapType(StringType(), StringType()), nullable=True),
StructField("OBJ", MapType(StringType(), StringType()), nullable=True),
StructField("ARR", ArrayType(StringType()), nullable=True),
]
),
id="legacy",
),
}


def test_verify_datatypes_reference(session):
Expand Down Expand Up @@ -229,6 +294,170 @@ def test_dtypes(session):
]


@pytest.mark.parametrize(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also skip these tests in stored proc? Does stored proc connector have the corresponding change for struct type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the time I merge this change it should have the corresponding change.

"query,expected_dtypes,expected_schema",
[STRUCTURED_TYPES_EXAMPLES[IS_STRUCTURED_TYPES_SUPPORTED]],
)
def test_structured_dtypes(session, query, expected_dtypes, expected_schema):
df = session.sql(query)
assert df.schema == expected_schema
assert df.dtypes == expected_dtypes


@pytest.mark.parametrize(
"query,expected_dtypes,expected_schema",
[STRUCTURED_TYPES_EXAMPLES[IS_STRUCTURED_TYPES_SUPPORTED]],
)
def test_structured_dtypes_select(session, query, expected_dtypes, expected_schema):
df = session.sql(query)
flattened_df = df.select(
df.map["k1"].alias("value1"),
df.obj["a"].alias("a"),
col("obj")["b"].alias("b"),
df.arr[0].alias("value2"),
df.arr[1].alias("value3"),
col("arr")[2].alias("value4"),
)
assert flattened_df.schema == StructType(
[
StructField("VALUE1", LongType(), nullable=True),
StructField("A", StringType(16777216), nullable=True),
StructField("B", DoubleType(), nullable=True),
StructField("VALUE2", DoubleType(), nullable=True),
StructField("VALUE3", DoubleType(), nullable=True),
StructField("VALUE4", DoubleType(), nullable=True),
]
)
assert flattened_df.dtypes == [
("VALUE1", "bigint"),
("A", "string(16777216)"),
("B", "double"),
("VALUE2", "double"),
("VALUE3", "double"),
("VALUE4", "double"),
]
assert flattened_df.collect() == [
Row(VALUE1=1, A="foo", B=0.05, VALUE2=1.0, VALUE3=3.1, VALUE4=4.5)
]


@pytest.mark.parametrize(
"query,expected_dtypes,expected_schema",
[STRUCTURED_TYPES_EXAMPLES[IS_STRUCTURED_TYPES_SUPPORTED]],
)
def test_structured_dtypes_pandas(session, query, expected_dtypes, expected_schema):
pdf = session.sql(query).to_pandas()
if IS_STRUCTURED_TYPES_SUPPORTED:
assert (
pdf.to_json()
== '{"MAP":{"0":{"k1":1.0}},"OBJ":{"0":{"a":"foo","b":0.05}},"ARR":{"0":[1.0,3.1,4.5]}}'
)
else:
assert (
pdf.to_json()
== '{"MAP":{"0":"{\\n \\"k1\\": 1\\n}"},"OBJ":{"0":"{\\n \\"a\\": \\"foo\\",\\n \\"b\\": 5.000000000000000e-02\\n}"},"ARR":{"0":"[\\n 1.000000000000000e+00,\\n 3.100000000000000e+00,\\n 4.500000000000000e+00\\n]"}}'
)


@pytest.mark.skipif(
not (IS_STRUCTURED_TYPES_SUPPORTED and IS_ICEBERG_SUPPORTED),
reason="Test requires iceberg support and structured type support.",
)
@pytest.mark.parametrize(
"query,expected_dtypes,expected_schema",
[STRUCTURED_TYPES_EXAMPLES[IS_STRUCTURED_TYPES_SUPPORTED]],
)
def test_structured_dtypes_iceberg(session, query, expected_dtypes, expected_schema):
table_name = f"snowpark_structured_dtypes_{uuid.uuid4().hex[:5]}"
try:
session.sql(
f"""
create iceberg table if not exists {table_name} (
map map(varchar, int),
obj object(a varchar, b float),
arr array(float)
)
CATALOG = 'SNOWFLAKE'
EXTERNAL_VOLUME = 'python_connector_iceberg_exvol'
BASE_LOCATION = 'python_connector_merge_gate';
"""
).collect()
session.sql(
f"""
insert into {table_name}
{query}
"""
).collect()
df = session.table(table_name)
assert df.schema == expected_schema
assert df.dtypes == expected_dtypes
finally:
session.sql(f"drop table if exists {table_name}")


@pytest.mark.skipif(
not (IS_STRUCTURED_TYPES_SUPPORTED and IS_ICEBERG_SUPPORTED),
reason="Test requires iceberg support and structured type support.",
)
@pytest.mark.parametrize(
"query,expected_dtypes,expected_schema",
[STRUCTURED_TYPES_EXAMPLES[IS_STRUCTURED_TYPES_SUPPORTED]],
)
def test_structured_dtypes_iceberg_udf(
session, query, expected_dtypes, expected_schema
):
table_name = f"snowpark_structured_dtypes_udf_test{uuid.uuid4().hex[:5]}"

def nop(x):
return x

(map_type, object_type, array_type) = expected_schema
nop_map_udf = udf(
nop, return_type=map_type.datatype, input_types=[map_type.datatype]
)
nop_object_udf = udf(
nop, return_type=object_type.datatype, input_types=[object_type.datatype]
)
nop_array_udf = udf(
nop, return_type=array_type.datatype, input_types=[array_type.datatype]
)

try:
session.sql(
f"""
create iceberg table if not exists {table_name} (
map map(varchar, int),
obj object(A varchar, B float),
arr array(float)
)
CATALOG = 'SNOWFLAKE'
EXTERNAL_VOLUME = 'python_connector_iceberg_exvol'
BASE_LOCATION = 'python_connector_merge_gate';
"""
).collect()
session.sql(
f"""
insert into {table_name}
{query}
"""
).collect()

df = session.table(table_name)
working = df.select(
nop_object_udf(col("obj")).alias("obj"),
nop_array_udf(col("arr")).alias("arr"),
)
assert working.schema == StructType([object_type, array_type])

with pytest.raises(SnowparkSQLException):
# SNOW-XXXXXXX: Map not supported as a udf return type.
df.select(
nop_map_udf(col("map")).alias("map"),
).collect()
finally:
session.sql(f"drop table if exists {table_name}")


@pytest.mark.xfail(reason="SNOW-974852 vectors are not yet rolled out", strict=False)
def test_dtypes_vector(session):
schema = StructType(
Expand Down
Loading
Loading