Skip to content

Commit

Permalink
add way to respect user defined schema before inferring schema
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam committed Feb 17, 2024
1 parent 638ce3d commit 2497049
Showing 1 changed file with 50 additions and 14 deletions.
64 changes: 50 additions & 14 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
from snowflake.connector import ProgrammingError, SnowflakeConnection
from snowflake.connector.options import installed_pandas, pandas
from snowflake.connector.pandas_tools import write_pandas
from snowflake.snowpark._internal.analyzer import analyzer_utils
from snowflake.snowpark._internal.analyzer.analyzer import Analyzer
from snowflake.snowpark._internal.analyzer.analyzer_utils import result_scan_statement
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
attribute_to_schema_string,
result_scan_statement,
)
from snowflake.snowpark._internal.analyzer.datatype_mapper import str_to_sql
from snowflake.snowpark._internal.analyzer.expression import Attribute
from snowflake.snowpark._internal.analyzer.select_statement import (
Expand Down Expand Up @@ -2156,6 +2158,31 @@ def write_pandas(
str(ci_output)
)

def _create_temp_table_for_given_schema(
self, temp_table_name: str, schema: StructType
) -> bool:
"""Creates a temp table for specified schema.
Args:
temp_table_name: table name
schema: user provided StructType schema
Returns:
True table was created successfully, else False
"""
try:
schema_string = attribute_to_schema_string(schema._to_attributes())
self._run_query(
f"CREATE SCOPED TEMP TABLE {temp_table_name} ({schema_string})"
)
except ProgrammingError as e:
logging.debug(
f"Cannot create temp table for specified schema, fall back to using infer"
f"schema string from select query. Exception: {str(e)}"
)
return False
return True

def create_dataframe(
self,
data: Union[List, Tuple, "pandas.DataFrame"],
Expand Down Expand Up @@ -2243,6 +2270,26 @@ def create_dataframe(
)
sf_schema = self._conn._get_current_parameter("schema", quoted=False)

if isinstance(
schema, StructType
) and self._create_temp_table_for_given_schema(temp_table_name, schema):
try:
t = self.write_pandas(
data,
temp_table_name,
database=sf_database,
schema=sf_schema,
quote_identifiers=True,
use_logical_type=self._use_logical_type_for_create_df,
)
set_api_call_source(t, "Session.create_dataframe[pandas]")
return t
except ProgrammingError as e:
logging.debug(
"Cannot create dataframe using specified schema for database."
f"Falling back to inferring schema from pandas dataframe. Exception: {e}"
)

t = self.write_pandas(
data,
temp_table_name,
Expand Down Expand Up @@ -2270,19 +2317,8 @@ def create_dataframe(
and all([field.datatype.is_primitive() for field in schema.fields])
):
temp_table_name = random_name_for_temp_object(TempObjectType.TABLE)
schema_string = analyzer_utils.attribute_to_schema_string(
schema._to_attributes()
)
try:
self._run_query(
f"CREATE SCOPED TEMP TABLE {temp_table_name} ({schema_string})"
)
if self._create_temp_table_for_given_schema(temp_table_name, schema):
schema_query = f"SELECT * FROM {self.get_current_database()}.{self.get_current_schema()}.{temp_table_name}"
except ProgrammingError as e:
logging.debug(
f"Cannot create temp table for specified non-nullable schema, fall back to using schema "
f"string from select query. Exception: {str(e)}"
)
else:
if not data:
raise ValueError("Cannot infer schema from empty data")
Expand Down

0 comments on commit 2497049

Please sign in to comment.