Skip to content

Commit

Permalink
fix and enable tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aling committed Apr 29, 2024
1 parent d86caad commit a2f5b38
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 32 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
- concat
- concat_ws

#### Bug Fixes

- DataFrameReader.csv unable to handle quoted values containing delimiter.


## 1.15.0 (2024-04-24)

Expand Down
26 changes: 0 additions & 26 deletions src/snowflake/snowpark/mock/_snowflake_to_pandas_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,11 @@
TIME_FORMAT = "%H:%M:%S"


def _process_field_optionally_enclosed_by(
value: str, field_optionally_enclosed_by: str = None
):
if field_optionally_enclosed_by and len(field_optionally_enclosed_by) > 1:
raise SnowparkSQLException(
f"invalid value ['{field_optionally_enclosed_by}'] for parameter 'FIELD_OPTIONALLY_ENCLOSED_BY'"
)

if (
field_optionally_enclosed_by
and len(value) >= 2
and value[0] == field_optionally_enclosed_by
and value[-1] == field_optionally_enclosed_by
):
return value[1:-1]
return value


def _integer_converter(
value: str, datatype: DataType, field_optionally_enclosed_by: str = None
) -> Optional[int]:
if value is None or value == "":
return None
value = _process_field_optionally_enclosed_by(value, field_optionally_enclosed_by)
try:
return int(value)
except ValueError:
Expand All @@ -74,7 +55,6 @@ def _fraction_converter(
) -> Optional[float]:
if value is None or value == "":
return None
value = _process_field_optionally_enclosed_by(value, field_optionally_enclosed_by)
try:
return float(value)
except ValueError:
Expand All @@ -88,7 +68,6 @@ def _decimal_converter(
) -> Optional[Union[int, Decimal]]:
if value is None or value == "":
return None
value = _process_field_optionally_enclosed_by(value, field_optionally_enclosed_by)
try:
precision = datatype.precision
scale = datatype.scale
Expand Down Expand Up @@ -116,7 +95,6 @@ def _bool_converter(
) -> Optional[bool]:
if value is None or value == "":
return None
value = _process_field_optionally_enclosed_by(value, field_optionally_enclosed_by)
if value.lower() == "true":
return True
if value.lower() == "false":
Expand All @@ -135,7 +113,6 @@ def _string_converter(
) -> Optional[str]:
if value is None or value == "":
return value
value = _process_field_optionally_enclosed_by(value, field_optionally_enclosed_by)
return value


Expand All @@ -144,7 +121,6 @@ def _date_converter(
) -> Optional[datetime.date]:
if value is None or value == "":
return None
value = _process_field_optionally_enclosed_by(value, field_optionally_enclosed_by)
try:
return datetime.datetime.strptime(value, DATE_FORMAT).date()
except Exception as e:
Expand All @@ -158,7 +134,6 @@ def _timestamp_converter(
) -> Optional[datetime.datetime]:
if value is None or value == "":
return None
value = _process_field_optionally_enclosed_by(value, field_optionally_enclosed_by)
try:
return datetime.datetime.strptime(value, TIMESTAMP_FORMAT)
except Exception as e:
Expand All @@ -172,7 +147,6 @@ def _time_converter(
) -> Optional[datetime.time]:
if value is None or value == "":
return None
value = _process_field_optionally_enclosed_by(value, field_optionally_enclosed_by)
try:
return datetime.datetime.strptime(value, TIME_FORMAT).time()
except Exception as e:
Expand Down
14 changes: 13 additions & 1 deletion src/snowflake/snowpark/mock/_stage_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import csv
import glob
import json
import os
Expand Down Expand Up @@ -364,12 +365,17 @@ def read_file(
)

if file_format == "csv":
# check SNOW-1355487 for improvements
skip_header = options.get("SKIP_HEADER", 0)
skip_blank_lines = options.get("SKIP_BLANK_LINES", False)
field_delimiter = options.get("FIELD_DELIMITER", ",")
field_optionally_enclosed_by = options.get(
"FIELD_OPTIONALLY_ENCLOSED_BY", None
)
if field_optionally_enclosed_by and len(field_optionally_enclosed_by) >= 2:
raise SnowparkSQLException(
f"Invalid value ['{field_optionally_enclosed_by}'] for parameter 'FIELD_OPTIONALLY_ENCLOSED_BY'"
)
if (
field_delimiter[0]
and field_delimiter[-1] == "'"
Expand Down Expand Up @@ -445,7 +451,13 @@ def read_file(
delimiter=field_delimiter,
dtype=object,
converters=converters_dict,
quoting=3, # QUOTE_NONE
# check definition here: https://docs.python.org/3/library/csv.html#csv.QUOTE_MINIMAL
# csv.QUOTE_MINIMAL, the engine will parse the value for us using the quote value/field_optionally_enclosed_by
# csv.QUOTE_NONE, by default snowflake FIELD_OPTIONALLY_ENCLOSED_BY is None
quoting=csv.QUOTE_MINIMAL
if field_optionally_enclosed_by
else csv.QUOTE_NONE,
quotechar=field_optionally_enclosed_by,
)
# set df columns to be result_df columns such that it can be concatenated
df.columns = result_df.columns
Expand Down
52 changes: 47 additions & 5 deletions tests/integ/scala/test_dataframe_reader_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
test_file_csv_colon = "testCSVcolon.csv"
test_file_csv_header = "testCSVheader.csv"
test_file_csv_quotes = "testCSVquotes.csv"
test_file_csv_quotes_special = "testCSVquotesSpecial.csv"
test_file_json = "testJson.json"
test_file_json_same_schema = "testJsonSameSchema.json"
test_file_json_new_schema = "testJsonNewSchema.json"
Expand Down Expand Up @@ -157,6 +158,12 @@ def setup(session, resources_path, local_testing_mode):
test_files.test_file_csv_quotes,
compress=False,
)
Utils.upload_to_stage(
session,
"@" + tmp_stage_name1,
test_files.test_file_csv_quotes_special,
compress=False,
)
Utils.upload_to_stage(
session,
"@" + tmp_stage_name1,
Expand Down Expand Up @@ -235,7 +242,7 @@ def setup(session, resources_path, local_testing_mode):
session.sql(f"DROP STAGE IF EXISTS {tmp_stage_only_json_file}").collect()


# @pytest.mark.localtest
@pytest.mark.localtest
@pytest.mark.parametrize("mode", ["select", "copy"])
def test_read_csv(session, mode):
reader = get_reader(session, mode)
Expand Down Expand Up @@ -363,6 +370,7 @@ def test_read_csv_with_infer_schema(session, mode, parse_header):
Utils.check_answer(df, [Row(1, "one", 1.2), Row(2, "two", 2.2)])


@pytest.mark.localtest
@pytest.mark.parametrize("mode", ["select", "copy"])
def test_read_csv_with_infer_schema_negative(session, mode, caplog):
reader = get_reader(session, mode)
Expand All @@ -382,6 +390,7 @@ def mock_run_query(*args, **kwargs):
assert "Could not infer csv schema due to exception:" in caplog.text


@pytest.mark.localtest
@pytest.mark.parametrize("mode", ["select", "copy"])
def test_read_csv_incorrect_schema(session, mode):
reader = get_reader(session, mode)
Expand Down Expand Up @@ -451,7 +460,7 @@ def test_save_as_table_do_not_change_col_name(session):
Utils.drop_table(session, table_name)


# @pytest.mark.localtest
@pytest.mark.localtest
def test_read_csv_with_more_operations(session):
test_file_on_stage = f"@{tmp_stage_name1}/{test_file_csv}"
df1 = session.read.schema(user_schema).csv(test_file_on_stage).filter(col("a") < 2)
Expand Down Expand Up @@ -499,7 +508,7 @@ def test_read_csv_with_more_operations(session):
]


# @pytest.mark.localtest
@pytest.mark.localtest
@pytest.mark.parametrize("mode", ["select", "copy"])
def test_read_csv_with_format_type_options(session, mode, local_testing_mode):
test_file_colon = f"@{tmp_stage_name1}/{test_file_csv_colon}"
Expand Down Expand Up @@ -562,7 +571,7 @@ def test_read_csv_with_format_type_options(session, mode, local_testing_mode):
]


# @pytest.mark.localtest
@pytest.mark.localtest
@pytest.mark.parametrize("mode", ["select", "copy"])
def test_to_read_files_from_stage(session, resources_path, mode, local_testing_mode):
data_files_stage = Utils.random_stage_name()
Expand Down Expand Up @@ -597,6 +606,7 @@ def test_to_read_files_from_stage(session, resources_path, mode, local_testing_m
session.sql(f"DROP STAGE IF EXISTS {data_files_stage}")


@pytest.mark.localtest
@pytest.mark.xfail(reason="SNOW-575700 flaky test", strict=False)
@pytest.mark.parametrize("mode", ["select", "copy"])
def test_for_all_csv_compression_keywords(session, temp_schema, mode):
Expand Down Expand Up @@ -633,7 +643,7 @@ def test_for_all_csv_compression_keywords(session, temp_schema, mode):
session.sql(f"drop file format {format_name}")


# @pytest.mark.localtest
@pytest.mark.localtest
@pytest.mark.parametrize("mode", ["select", "copy"])
def test_read_csv_with_special_chars_in_format_type_options(session, mode):
schema1 = StructType(
Expand Down Expand Up @@ -715,6 +725,38 @@ def test_read_csv_with_special_chars_in_format_type_options(session, mode):
assert res == [Row('"1.234"', '"09:10:11"'), Row('"2.5"', "12:34:56")]


@pytest.mark.localtest
@pytest.mark.parametrize("mode", ["select", "copy"])
def test_read_csv_with_quotes_containing_delimiter(session, mode):
schema1 = StructType(
[
StructField("col1", StringType()),
StructField("col2", StringType()),
StructField("col3", StringType()),
]
)
test_file = f"@{tmp_stage_name1}/{test_file_csv_quotes_special}"

reader = get_reader(session, mode)

df1 = (
reader.schema(schema1)
.option("field_optionally_enclosed_by", '"')
.option("skip_header", 1)
.csv(test_file)
)
res = df1.collect()
res.sort(key=lambda x: x[0])
assert res == [
Row(
"value 1",
"value 2 with no comma",
"value3",
),
Row("value 4", "value 5, but with a comma", " value6"),
]


@pytest.mark.parametrize(
"file_format", ["csv", "json", "avro", "parquet", "xml", "orc"]
)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/scala/test_utils_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def check_zip_files_and_close_stream(input_stream, expected_files):
"resources/testCSVcolon.csv",
"resources/testCSVheader.csv",
"resources/testCSVquotes.csv",
"resources/testCSVquotesSpecial.csv",
"resources/testCSVspecialFormat.csv",
"resources/testJSONspecialFormat.json.gz",
"resources/testJson.json",
Expand Down
4 changes: 4 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,10 @@ def test_file_csv_header(self):
def test_file_csv_quotes(self):
return os.path.join(self.resources_path, "testCSVquotes.csv")

@property
def test_file_csv_quotes_special(self):
return os.path.join(self.resources_path, "testCSVquotesSpecial.csv")

@functools.cached_property
def test_file_csv_special_format(self):
return os.path.join(self.resources_path, "testCSVspecialFormat.csv")
Expand Down

0 comments on commit a2f5b38

Please sign in to comment.