Skip to content

Commit

Permalink
feat: enable list of paths for read_csv (#824)
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored Aug 22, 2024
1 parent b2982ec commit 805183b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
7 changes: 5 additions & 2 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ def read_json(

def read_csv(
self,
path: str | pathlib.Path,
path: str | pathlib.Path | list[str] | list[pathlib.Path],
schema: pyarrow.Schema | None = None,
has_header: bool = True,
delimiter: str = ",",
Expand Down Expand Up @@ -914,9 +914,12 @@ def read_csv(
"""
if table_partition_cols is None:
table_partition_cols = []

path = [str(p) for p in path] if isinstance(path, list) else str(path)

return DataFrame(
self.ctx.read_csv(
str(path),
path,
schema,
has_header,
delimiter,
Expand Down
16 changes: 16 additions & 0 deletions python/datafusion/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,22 @@ def test_read_csv(ctx):
csv_df.select(column("c1")).show()


def test_read_csv_list(ctx):
csv_df = ctx.read_csv(path=["testing/data/csv/aggregate_test_100.csv"])
expected = csv_df.count() * 2

double_csv_df = ctx.read_csv(
path=[
"testing/data/csv/aggregate_test_100.csv",
"testing/data/csv/aggregate_test_100.csv",
]
)
actual = double_csv_df.count()

double_csv_df.select(column("c1")).show()
assert actual == expected


def test_read_csv_compressed(ctx, tmp_path):
test_data_path = "testing/data/csv/aggregate_test_100.csv"

Expand Down
15 changes: 7 additions & 8 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ impl PySessionContext {
file_compression_type=None))]
pub fn read_csv(
&self,
path: PathBuf,
path: &Bound<'_, PyAny>,
schema: Option<PyArrowType<Schema>>,
has_header: bool,
delimiter: &str,
Expand All @@ -815,10 +815,6 @@ impl PySessionContext {
file_compression_type: Option<String>,
py: Python,
) -> PyResult<PyDataFrame> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;

let delimiter = delimiter.as_bytes();
if delimiter.len() != 1 {
return Err(PyValueError::new_err(
Expand All @@ -833,13 +829,16 @@ impl PySessionContext {
.file_extension(file_extension)
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.file_compression_type(parse_file_compression_type(file_compression_type)?);
options.schema = schema.as_ref().map(|x| &x.0);

if let Some(py_schema) = schema {
options.schema = Some(&py_schema.0);
let result = self.ctx.read_csv(path, options);
if path.is_instance_of::<PyList>() {
let paths = path.extract::<Vec<String>>()?;
let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
let result = self.ctx.read_csv(paths, options);
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
Ok(df)
} else {
let path = path.extract::<String>()?;
let result = self.ctx.read_csv(path, options);
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
Ok(df)
Expand Down

0 comments on commit 805183b

Please sign in to comment.