Skip to content

Commit

Permalink
fix: Also split on forward slashes during hive path inference on Wind…
Browse files Browse the repository at this point in the history
…ows (#19282)
  • Loading branch information
nameexhaustion authored Oct 17, 2024
1 parent 577cd62 commit 01a4e06
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 11 deletions.
19 changes: 8 additions & 11 deletions crates/polars-plan/src/plans/hive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,21 +231,18 @@ pub fn hive_partitions_from_paths(
}

/// Determine the path separator for identifying Hive partitions.
#[cfg(target_os = "windows")]
fn separator(url: &Path) -> char {
if polars_io::path_utils::is_cloud_url(url) {
'/'
fn separator(url: &Path) -> &[char] {
if cfg!(target_family = "windows") {
if polars_io::path_utils::is_cloud_url(url) {
&['/']
} else {
&['/', '\\']
}
} else {
'\\'
&['/']
}
}

/// Determine the path separator for identifying Hive partitions.
#[cfg(not(target_os = "windows"))]
fn separator(_url: &Path) -> char {
'/'
}

/// Parse a Hive partition string (e.g. "column=1.5") into a name and value part.
///
/// Returns `None` if the string is not a Hive partition string.
Expand Down
42 changes: 42 additions & 0 deletions py-polars/tests/unit/io/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,3 +779,45 @@ def test_hive_predicate_dates_14712(
)
pl.scan_parquet(tmp_path).filter(pl.col("a") != datetime(2024, 1, 1)).collect()
assert "hive partitioning: skipped 1 files" in capfd.readouterr().err


@pytest.mark.skipif(sys.platform != "win32", reason="Test is only for Windows paths")
@pytest.mark.write_disk
def test_hive_windows_splits_on_forward_slashes(tmp_path: Path) -> None:
# Note: This needs to be an absolute path.
tmp_path = tmp_path.resolve()
path = f"{tmp_path}/a=1/b=1/c=1/d=1/e=1"
Path(path).mkdir(exist_ok=True, parents=True)

df = pl.DataFrame({"x": "x"})
df.write_parquet(f"{path}/data.parquet")

expect = pl.DataFrame(
[
s.new_from_index(0, 5)
for s in pl.DataFrame(
{
"x": "x",
"a": 1,
"b": 1,
"c": 1,
"d": 1,
"e": 1,
}
)
]
)

assert_frame_equal(
pl.scan_parquet(
[
f"{tmp_path}/a=1/b=1/c=1/d=1/e=1/data.parquet",
f"{tmp_path}\\a=1\\b=1\\c=1\\d=1\\e=1\\data.parquet",
f"{tmp_path}\\a=1/b=1/c=1/d=1/**/*",
f"{tmp_path}/a=1/b=1\\c=1/d=1/**/*",
f"{tmp_path}/a=1/b=1/c=1/d=1\\e=1/*",
],
hive_partitioning=True,
).collect(),
expect,
)

0 comments on commit 01a4e06

Please sign in to comment.