Skip to content

Commit

Permalink
fix: lance ray sink crash when fields contain none (#3322)
Browse files Browse the repository at this point in the history
fix #3308
  • Loading branch information
Jay-ju authored Jan 1, 2025
1 parent 8767c10 commit 783bc12
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
25 changes: 22 additions & 3 deletions python/python/lance/ray/sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

__all__ = ["LanceDatasink", "LanceFragmentWriter", "LanceCommitter", "write_lance"]

NONE_ARROW_STR = "None"


def _pd_to_arrow(
df: Union[pa.Table, "pd.DataFrame", Dict], schema: Optional[pa.Schema]
Expand All @@ -39,10 +41,27 @@ def _pd_to_arrow(

if isinstance(df, dict):
return pa.Table.from_pydict(df, schema=schema)
if _PANDAS_AVAILABLE and isinstance(df, pd.DataFrame):
elif _PANDAS_AVAILABLE and isinstance(df, pd.DataFrame):
tbl = pa.Table.from_pandas(df, schema=schema)
new_schema = tbl.schema.remove_metadata()
new_table = tbl.replace_schema_metadata(new_schema.metadata)
tbl.schema = tbl.schema.remove_metadata()
return tbl
elif isinstance(df, pa.Table):
fields = df.schema.names
new_columns = []
new_fields = []
for field in fields:
col = df[field]
new_field = pa.field(field, col.type)
if (
pa.types.is_null(col.type)
and schema.field_by_name(field).type == pa.string()
):
new_field = pa.field(field, pa.string())
col = pa.compute.if_else(pa.compute.is_null(col), NONE_ARROW_STR, col)
new_columns.append(col)
new_fields.append(new_field)
new_schema = pa.schema(fields=new_fields)
new_table = pa.Table.from_arrays(new_columns, schema=new_schema)
return new_table
return df

Expand Down
22 changes: 22 additions & 0 deletions python/python/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,25 @@ def test_ray_empty_write_lance(tmp_path: Path):
# empty write would not generate dataset.
with pytest.raises(ValueError):
lance.dataset(tmp_path)


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_ray_write_lance_none_str(tmp_path: Path):
def f(row):
return {
"id": row["id"],
"str": None,
}

schema = pa.schema([pa.field("id", pa.int64()), pa.field("str", pa.string())])
(ray.data.range(10).map(f).write_lance(tmp_path, schema=schema))

ds = lance.dataset(tmp_path)
ds.count_rows() == 10
assert ds.schema == schema

tbl = ds.to_table()
pylist = tbl["str"].to_pylist()
assert len(pylist) == 10
for item in pylist:
assert item is None

0 comments on commit 783bc12

Please sign in to comment.