Skip to content

Commit

Permalink
Add the function DataFrame.write.csv (#1298)
Browse files Browse the repository at this point in the history
* Add the function DataFrame.write.csv to unload data from a DataFrame into one or more CSV files in a stage.

* Fix comment

* Fix comment and test

* Create a temp stage for test

* Add compression and single arguments.

* Use temp stage in all examples.

* Update example

* Move test to test_dataframe_writer_suite.py and fix comments

* Update tests/utils.py

Co-authored-by: Afroz Alam <[email protected]>

* Remove empty line

* Add new line

---------

Co-authored-by: Afroz Alam <[email protected]>
  • Loading branch information
sfc-gh-aherreraaguilar and sfc-gh-aalam authored Mar 21, 2024
1 parent 11eb2af commit 65db932
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- file.put_stream
- file.get
- file.get_stream
- Added the function `DataFrame.write.csv` to unload data from a ``DataFrame`` into one or more CSV files in a stage.

## 1.14.0 (2024-03-20)

Expand Down
1 change: 1 addition & 0 deletions docs/source/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Input/Output
DataFrameWriter.mode
DataFrameWriter.saveAsTable
DataFrameWriter.save_as_table
DataFrameWriter.csv
FileOperation.get
FileOperation.get_stream
FileOperation.put
Expand Down
30 changes: 30 additions & 0 deletions src/snowflake/snowpark/dataframe_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,4 +335,34 @@ def copy_into_location(
block=block,
)

def csv(self, path: str, overwrite: bool = False, compression: str = None, single: bool = True,
partition_by: ColumnOrName = None) -> Union[List[Row], AsyncJob]:
"""Executes internally a `COPY INTO <location> <https://docs.snowflake.com/en/sql-reference/sql/copy-into-location.html>`__ to unload data from a ``DataFrame`` into one or more CSV files in a stage or external stage.
Args:
path: The destination stage location.
overwrite: Specifies if it should overwrite the file if exists, the default value is ``False``.
compression: String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm. Use the options documented in the `Format Type Options <https://docs.snowflake.com/en/sql-reference/sql/copy-into-location.html#format-type-options-formattypeoptions>`__
single: Boolean that specifies whether to generate a single file or multiple files. If FALSE, a filename prefix must be included in ``<path>``
partition_by: Specifies an expression used to partition the unloaded table rows into separate files. It can be a :class:`Column`, a column name, or a SQL expression.
Returns:
A list of :class:`Row` objects containing unloading results.
Example::
>>> # save this dataframe to a parquet file on the session stage
>>> df = session.create_dataframe([["John", "Berry"], ["Rick", "Berry"], ["Anthony", "Davis"]], schema = ["FIRST_NAME", "LAST_NAME"])
>>> remote_file_path = f"{session.get_session_stage()}/names.csv"
>>> copy_result = df.write.csv(remote_file_path, overwrite=True)
>>> copy_result[0].rows_unloaded
3
"""
return self.copy_into_location(path,
file_format_type="CSV",
partition_by=partition_by,
overwrite=overwrite,
format_type_options=dict(compression=compression),
single=single)

saveAsTable = save_as_table
70 changes: 69 additions & 1 deletion tests/integ/scala/test_dataframe_writer_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
#

import copy
import os

import pytest

from snowflake.snowpark import Row
from snowflake.snowpark._internal.utils import parse_table_name
from snowflake.snowpark._internal.utils import parse_table_name, TempObjectType
from snowflake.snowpark.functions import col
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.types import (
DoubleType,
Expand Down Expand Up @@ -360,3 +362,69 @@ def create_and_append_check_answer(table_name_input):
# drop schema
Utils.drop_schema(session, schema)
Utils.drop_schema(session, double_quoted_schema)


def test_writer_csv(session, tmpdir_factory):

"""Tests for df.write.csv()."""
df = session.create_dataframe([[1, 2], [3, 4], [5, 6]], schema=["a", "b"])
ROW_NUMBER = 3
schema = StructType([StructField("a", IntegerType()), StructField("b", IntegerType())])

temp_stage = Utils.random_name_for_temp_object(TempObjectType.STAGE)
Utils.create_stage(session, temp_stage, is_temporary=True)

try:
# test default case
path1 = f"{temp_stage}/test_csv_example1/my_file.csv"
result1 = df.write.csv(path1)
assert result1[0].rows_unloaded == ROW_NUMBER
data1 = session.read.schema(schema).csv(f"@{path1}")
Utils.assert_rows_count(data1, ROW_NUMBER)

# test overwrite case
result2 = df.write.csv(path1, overwrite=True)
assert result2[0].rows_unloaded == ROW_NUMBER
data2 = session.read.schema(schema).csv(f"@{path1}")
Utils.assert_rows_count(data2, ROW_NUMBER)

# partition by testing cases
path3 = f"{temp_stage}/test_csv_example3/my_file.csv"
result3 = df.write.csv(path3, single=False, partition_by=col("a"))
assert result3[0].rows_unloaded == ROW_NUMBER
data3 = session.read.schema(schema).csv(f"@{path3}")
Utils.assert_rows_count(data3, ROW_NUMBER)

path4 = f"{temp_stage}/test_csv_example4/my_file.csv"
result4 = df.write.csv(path4, single=False, partition_by="a")
assert result4[0].rows_unloaded == ROW_NUMBER
data4 = session.read.schema(schema).csv(f"@{path4}")
Utils.assert_rows_count(data4, ROW_NUMBER)

# test single case
path5 = f"{temp_stage}/test_csv_example5/my_file.csv"
result5 = df.write.csv(path5, single=False)
assert result5[0].rows_unloaded == ROW_NUMBER
data5 = session.read.schema(schema).csv(f"@{path5}_0_0_0.csv")
Utils.assert_rows_count(data5, ROW_NUMBER)

# test compression case
path6 = f"{temp_stage}/test_csv_example6/my_file.csv.gz"
result6 = df.write.csv(path6, compression="gzip")
assert result6[0].rows_unloaded == ROW_NUMBER

directory = tmpdir_factory.mktemp("snowpark_test_target")

downloadedFile = session.file.get(
f"@{path6}",
str(directory))

downloadedFilePath = f"{directory}/{os.path.basename(path6)}"

try:
assert len(downloadedFile) == 1
assert downloadedFile[0].status == "DOWNLOADED"
finally:
os.remove(downloadedFilePath)
finally:
Utils.drop_stage(session, temp_stage)
8 changes: 8 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,14 @@ def assert_table_type(session: Session, table_name: str, table_type: str) -> Non
expected_table_kind = table_type.upper()
assert table_info[0]["kind"] == expected_table_kind

@staticmethod
def assert_rows_count(data: DataFrame, row_number: int):
row_counter = len(data.collect())

assert (
row_counter == row_number
), f"Expect {row_number} rows, Got {row_counter} instead"


class TestData:
__test__ = (
Expand Down

0 comments on commit 65db932

Please sign in to comment.