Skip to content

Commit

Permalink
fix: fix ray lance sink error
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay-ju committed Dec 30, 2024
1 parent 24692b6 commit cac21f8
Showing 1 changed file with 28 additions and 34 deletions.
62 changes: 28 additions & 34 deletions python/python/lance/ray/sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pickle
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand All @@ -16,14 +17,16 @@
Union,
)

import pandas as pd
import pyarrow as pa

import lance
from lance.fragment import DEFAULT_MAX_BYTES_PER_FILE, FragmentMetadata, write_fragments

from ..dependencies import ray

if TYPE_CHECKING:
import pandas as pd

__all__ = ["LanceDatasink", "LanceFragmentWriter", "LanceCommitter", "write_lance"]


Expand All @@ -38,23 +41,22 @@ def _pd_to_arrow(
return pa.Table.from_pydict(df, schema=schema)
if _PANDAS_AVAILABLE and isinstance(df, pd.DataFrame):
tbl = pa.Table.from_pandas(df, schema=schema)
new_schema = tbl.schema.remove_metadata()
new_table = tbl.replace_schema_metadata(new_schema.metadata)
return new_table
tbl.schema = tbl.schema.remove_metadata()
return tbl
return df


def _write_fragment(
stream: Iterable[Union[pa.Table, "pd.DataFrame"]],
stream: Iterable[Union[pa.Table, "pd.Pandas"]],
uri: str,
*,
schema: Optional[pa.Schema] = None,
max_rows_per_file: int = 1024 * 1024,
max_bytes_per_file: Optional[int] = None,
max_rows_per_group: int = 1024, # Only useful for v1 writer.
data_storage_version: Optional[str] = None,
data_storage_version: str = "stable",
storage_options: Optional[Dict[str, Any]] = None,
) -> List[Tuple[FragmentMetadata, pa.Schema]]:
) -> Tuple[FragmentMetadata, pa.Schema]:
from ..dependencies import _PANDAS_AVAILABLE
from ..dependencies import pandas as pd

Expand Down Expand Up @@ -126,7 +128,7 @@ def on_write_start(self):

def on_write_complete(
self,
write_results,
write_results: List[List[Tuple[str, str]]],
):
if not write_results:
import warnings
Expand All @@ -137,22 +139,15 @@ def on_write_complete(
)
return "Empty list"

first_element = write_results[0]
if isinstance(first_element, (pa.Table, pd.DataFrame)):
write_results = [
result["write_result"].iloc[0]["result"] for result in write_results
]
if hasattr(write_results, "write_returns"):
write_results = write_results.write_returns
fragments = []
schema = None
for batch in write_results:
for fragment_str, schema_str in batch:
fragment = pickle.loads(fragment_str)
fragments.append(fragment)
schema = pickle.loads(schema_str)
# Check weather writer has fragments or not.
# Skip commit when there are no fragments.
if not schema:
return
if self.mode in set(["create", "overwrite"]):
op = lance.LanceOperation.Overwrite(schema, fragments)
elif self.mode == "append":
Expand Down Expand Up @@ -184,7 +179,7 @@ class LanceDatasink(_BaseLanceDatasink):
Choices are 'append', 'create', 'overwrite'.
max_rows_per_file : int, optional
The maximum number of rows per file. Default is 1024 * 1024.
data_storage_version: optional, str, default None
data_storage_version: optional, str, default "legacy"
The version of the data storage format to use. Newer versions are more
efficient but require newer versions of lance to read. The default is
"legacy" which will use the legacy v1 version. See the user guide
Expand All @@ -204,7 +199,7 @@ def __init__(
schema: Optional[pa.Schema] = None,
mode: Literal["create", "append", "overwrite"] = "create",
max_rows_per_file: int = 1024 * 1024,
data_storage_version: Optional[str] = None,
data_storage_version: str = "stable",
use_legacy_format: Optional[bool] = None,
storage_options: Optional[Dict[str, Any]] = None,
*args,
Expand Down Expand Up @@ -289,10 +284,11 @@ class LanceFragmentWriter:
max_rows_per_group : int, optional
The maximum number of rows per group. Default is 1024.
Only useful for v1 writer.
data_storage_version: optional, str, default None
data_storage_version: optional, str, default "legacy"
The version of the data storage format to use. Newer versions are more
efficient but require newer versions of lance to read. The default
(None) will use the 2.0 version. See the user guide for more details.
efficient but require newer versions of lance to read. The default is
"legacy" which will use the legacy v1 version. See the user guide
for more details.
use_legacy_format : optional, bool, default None
Deprecated method for setting the data storage version. Use the
`data_storage_version` parameter instead.
Expand All @@ -305,12 +301,13 @@ def __init__(
self,
uri: str,
*,
transform: Optional[Callable[[pa.Table], Union[pa.Table, Generator]]] = None,
transform: Optional[Callable[[pa.Table],
Union[pa.Table, Generator]]] = None,
schema: Optional[pa.Schema] = None,
max_rows_per_file: int = 1024 * 1024,
max_bytes_per_file: Optional[int] = None,
max_rows_per_group: Optional[int] = None, # Only useful for v1 writer.
data_storage_version: Optional[str] = None,
data_storage_version: str = "stable",
use_legacy_format: Optional[bool] = False,
storage_options: Optional[Dict[str, Any]] = None,
):
Expand Down Expand Up @@ -363,7 +360,7 @@ def __call__(self, batch: Union[pa.Table, "pd.DataFrame"]) -> Dict[str, Any]:


class LanceCommitter(_BaseLanceDatasink):
"""Lance Committer as Ray Datasink.
"""Lance Commiter as Ray Datasink.

Check warning on line 363 in python/python/lance/ray/sink.py

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"Commiter" should be "Committer".
This is used with `LanceFragmentWriter` to write large-than-memory data to
lance file.
Expand All @@ -374,7 +371,7 @@ def num_rows_per_write(self) -> int:
return 1

def get_name(self) -> str:
return f"LanceCommitter({self.mode})"
return f"LanceCommiter({self.mode})"

Check warning on line 374 in python/python/lance/ray/sink.py

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"Commiter" should be "Committer".

def write(
self,
Expand All @@ -384,10 +381,6 @@ def write(
"""Passthrough the fragments to commit phase"""
v = []
for block in blocks:
# If block is empty, skip to get "fragment" and "schema" filed
if len(block) == 0:
continue

for fragment, schema in zip(
block["fragment"].to_pylist(), block["schema"].to_pylist()
):
Expand All @@ -406,7 +399,7 @@ def write_lance(
max_rows_per_file: int = 1024 * 1024,
max_bytes_per_file: Optional[int] = None,
storage_options: Optional[Dict[str, Any]] = None,
data_storage_version: Optional[str] = None,
data_storage_version: str = "stable",
) -> None:
"""Write Ray dataset at scale.
Expand All @@ -429,10 +422,11 @@ def write_lance(
The maximum number of bytes per file. Default is 90GB.
storage_options : Dict[str, Any], optional
The storage options for the writer. Default is None.
data_storage_version: optional, str, default None
data_storage_version: optional, str, default "legacy"
The version of the data storage format to use. Newer versions are more
efficient but require newer versions of lance to read. The default
(None) will use the 2.0 version. See the user guide for more details.
efficient but require newer versions of lance to read. The default is
"legacy" which will use the legacy v1 version. See the user guide
for more details.
"""
data.map_batches(
LanceFragmentWriter(
Expand Down

0 comments on commit cac21f8

Please sign in to comment.