From 65db9321f533eaf133bd5a27ccafebee7d00df22 Mon Sep 17 00:00:00 2001 From: Arturo Herrera Aguilar <127373818+sfc-gh-aherreraaguilar@users.noreply.github.com> Date: Thu, 21 Mar 2024 12:13:12 -0600 Subject: [PATCH] Add the function DataFrame.write.csv (#1298) * Add the function DataFrame.write.csv to unload data from a DataFrame into one or more CSV files in a stage. * Fix comment * Fix comment and test * Create a temp stage for test * Add compression and single arguments. * Use temp stage in all examples. * Update example * Move test to test_dataframe_writer_suite.py and fix comments * Update tests/utils.py Co-authored-by: Afroz Alam * Remove empty line * Add new line --------- Co-authored-by: Afroz Alam --- CHANGELOG.md | 1 + docs/source/io.rst | 1 + src/snowflake/snowpark/dataframe_writer.py | 30 ++++++++ .../scala/test_dataframe_writer_suite.py | 70 ++++++++++++++++++- tests/utils.py | 8 +++ 5 files changed, 109 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6697d3fad2b..520e1933037 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - file.put_stream - file.get - file.get_stream +- Added the function `DataFrame.write.csv` to unload data from a ``DataFrame`` into one or more CSV files in a stage. ## 1.14.0 (2024-03-20) diff --git a/docs/source/io.rst b/docs/source/io.rst index f135ce335cb..e544a35632b 100644 --- a/docs/source/io.rst +++ b/docs/source/io.rst @@ -35,6 +35,7 @@ Input/Output DataFrameWriter.mode DataFrameWriter.saveAsTable DataFrameWriter.save_as_table + DataFrameWriter.csv FileOperation.get FileOperation.get_stream FileOperation.put diff --git a/src/snowflake/snowpark/dataframe_writer.py b/src/snowflake/snowpark/dataframe_writer.py index 3be54d11017..ddbd8cf4428 100644 --- a/src/snowflake/snowpark/dataframe_writer.py +++ b/src/snowflake/snowpark/dataframe_writer.py @@ -335,4 +335,34 @@ def copy_into_location( block=block, ) + def csv(self, path: str, overwrite: bool = False, compression: str = None, single: bool = True, + partition_by: ColumnOrName = None) -> Union[List[Row], AsyncJob]: + """Executes internally a `COPY INTO `__ to unload data from a ``DataFrame`` into one or more CSV files in a stage or external stage. + + Args: + path: The destination stage location. + overwrite: Specifies if it should overwrite the file if exists, the default value is ``False``. + compression: String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm. Use the options documented in the `Format Type Options `__ + single: Boolean that specifies whether to generate a single file or multiple files. If FALSE, a filename prefix must be included in ```` + partition_by: Specifies an expression used to partition the unloaded table rows into separate files. It can be a :class:`Column`, a column name, or a SQL expression. + + Returns: + A list of :class:`Row` objects containing unloading results. + + Example:: + + >>> # save this dataframe to a parquet file on the session stage + >>> df = session.create_dataframe([["John", "Berry"], ["Rick", "Berry"], ["Anthony", "Davis"]], schema = ["FIRST_NAME", "LAST_NAME"]) + >>> remote_file_path = f"{session.get_session_stage()}/names.csv" + >>> copy_result = df.write.csv(remote_file_path, overwrite=True) + >>> copy_result[0].rows_unloaded + 3 + """ + return self.copy_into_location(path, + file_format_type="CSV", + partition_by=partition_by, + overwrite=overwrite, + format_type_options=dict(compression=compression), + single=single) + saveAsTable = save_as_table diff --git a/tests/integ/scala/test_dataframe_writer_suite.py b/tests/integ/scala/test_dataframe_writer_suite.py index adc42d62528..39f34aa385c 100644 --- a/tests/integ/scala/test_dataframe_writer_suite.py +++ b/tests/integ/scala/test_dataframe_writer_suite.py @@ -3,11 +3,13 @@ # import copy +import os import pytest from snowflake.snowpark import Row -from snowflake.snowpark._internal.utils import parse_table_name +from snowflake.snowpark._internal.utils import parse_table_name, TempObjectType +from snowflake.snowpark.functions import col from snowflake.snowpark.exceptions import SnowparkSQLException from snowflake.snowpark.types import ( DoubleType, @@ -360,3 +362,69 @@ def create_and_append_check_answer(table_name_input): # drop schema Utils.drop_schema(session, schema) Utils.drop_schema(session, double_quoted_schema) + + +def test_writer_csv(session, tmpdir_factory): + + """Tests for df.write.csv().""" + df = session.create_dataframe([[1, 2], [3, 4], [5, 6]], schema=["a", "b"]) + ROW_NUMBER = 3 + schema = StructType([StructField("a", IntegerType()), StructField("b", IntegerType())]) + + temp_stage = Utils.random_name_for_temp_object(TempObjectType.STAGE) + Utils.create_stage(session, temp_stage, is_temporary=True) + + try: + # test default case + path1 = f"{temp_stage}/test_csv_example1/my_file.csv" + result1 = df.write.csv(path1) + assert result1[0].rows_unloaded == ROW_NUMBER + data1 = session.read.schema(schema).csv(f"@{path1}") + Utils.assert_rows_count(data1, ROW_NUMBER) + + # test overwrite case + result2 = df.write.csv(path1, overwrite=True) + assert result2[0].rows_unloaded == ROW_NUMBER + data2 = session.read.schema(schema).csv(f"@{path1}") + Utils.assert_rows_count(data2, ROW_NUMBER) + + # partition by testing cases + path3 = f"{temp_stage}/test_csv_example3/my_file.csv" + result3 = df.write.csv(path3, single=False, partition_by=col("a")) + assert result3[0].rows_unloaded == ROW_NUMBER + data3 = session.read.schema(schema).csv(f"@{path3}") + Utils.assert_rows_count(data3, ROW_NUMBER) + + path4 = f"{temp_stage}/test_csv_example4/my_file.csv" + result4 = df.write.csv(path4, single=False, partition_by="a") + assert result4[0].rows_unloaded == ROW_NUMBER + data4 = session.read.schema(schema).csv(f"@{path4}") + Utils.assert_rows_count(data4, ROW_NUMBER) + + # test single case + path5 = f"{temp_stage}/test_csv_example5/my_file.csv" + result5 = df.write.csv(path5, single=False) + assert result5[0].rows_unloaded == ROW_NUMBER + data5 = session.read.schema(schema).csv(f"@{path5}_0_0_0.csv") + Utils.assert_rows_count(data5, ROW_NUMBER) + + # test compression case + path6 = f"{temp_stage}/test_csv_example6/my_file.csv.gz" + result6 = df.write.csv(path6, compression="gzip") + assert result6[0].rows_unloaded == ROW_NUMBER + + directory = tmpdir_factory.mktemp("snowpark_test_target") + + downloadedFile = session.file.get( + f"@{path6}", + str(directory)) + + downloadedFilePath = f"{directory}/{os.path.basename(path6)}" + + try: + assert len(downloadedFile) == 1 + assert downloadedFile[0].status == "DOWNLOADED" + finally: + os.remove(downloadedFilePath) + finally: + Utils.drop_stage(session, temp_stage) diff --git a/tests/utils.py b/tests/utils.py index 75a966fde41..82c857be602 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -335,6 +335,14 @@ def assert_table_type(session: Session, table_name: str, table_type: str) -> Non expected_table_kind = table_type.upper() assert table_info[0]["kind"] == expected_table_kind + @staticmethod + def assert_rows_count(data: DataFrame, row_number: int): + row_counter = len(data.collect()) + + assert ( + row_counter == row_number + ), f"Expect {row_number} rows, Got {row_counter} instead" + class TestData: __test__ = (