Skip to content

Commit

Permalink
fix: allow image extension types to be pickled in PyArrow 12 (#1955)
Browse files Browse the repository at this point in the history
Fixes #1943.
  • Loading branch information
wjones127 authored Feb 14, 2024
1 parent 8b07f90 commit 0d2083e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
24 changes: 24 additions & 0 deletions python/python/lance/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ def __arrow_ext_class__(self):
def __arrow_ext_scalar_class__(self):
return ImageURIScalar

def __reduce__(self):
# Workaround to ensure pickle works in earlier versions of PyArrow
# https://github.com/apache/arrow/issues/35599
return type(self).__arrow_ext_deserialize__, (
self.storage_type,
self.__arrow_ext_serialize__(),
)


class EncodedImageType(pa.ExtensionType):
def __init__(self, storage_type: pa.DataType = pa.binary()):
Expand All @@ -94,6 +102,14 @@ def __arrow_ext_class__(self):
def __arrow_ext_scalar_class__(self):
return EncodedImageScalar

def __reduce__(self):
# Workaround to ensure pickle works in earlier versions of PyArrow
# https://github.com/apache/arrow/issues/35599
return type(self).__arrow_ext_deserialize__, (
self.storage_type,
self.__arrow_ext_serialize__(),
)


class FixedShapeImageTensorType(pa.ExtensionType):
def __init__(self, arrow_type: pa.DataType, shape):
Expand Down Expand Up @@ -126,6 +142,14 @@ def __arrow_ext_class__(self):
def __arrow_ext_scalar_class__(self):
return FixedShapeImageTensorScalar

def __reduce__(self):
# Workaround to ensure pickle works in earlier versions of PyArrow
# https://github.com/apache/arrow/issues/35599
return type(self).__arrow_ext_deserialize__, (
self.storage_type,
self.__arrow_ext_serialize__(),
)


pa.register_extension_type(ImageURIType())
pa.register_extension_type(EncodedImageType())
Expand Down
21 changes: 21 additions & 0 deletions python/python/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import pickle
import re
from pathlib import Path

Expand All @@ -24,6 +25,7 @@
from lance.arrow import (
BFloat16,
BFloat16Array,
FixedShapeImageTensorType,
ImageArray,
ImageURIArray,
PandasBFloat16Array,
Expand Down Expand Up @@ -317,3 +319,22 @@ def test_roundtrip_image_tensor(tmp_path: Path):
tensor_image_array_2 = tbl2.take(indices).column(2)

assert tensor_image_array_2.type == tensor_image_array.type


def test_image_array_pickle(tmp_path: Path, png_uris):
# Note: this test will only fail in PyArrow 12.0.0. It was fixed in 13.0.0.
uri_array = ImageURIArray.from_uris(png_uris)
encoded_array = uri_array.read_uris()
tensor_array = pa.FixedShapeTensorArray.from_numpy_ndarray(
np.random.random((10, 1, 1, 4)).astype(np.uint8)
)
tensor_array = pa.ExtensionArray.from_storage(
FixedShapeImageTensorType(pa.uint8(), tensor_array.type.shape),
tensor_array.storage,
)

arrays = [uri_array, encoded_array, tensor_array]
for arr in arrays:
data = pickle.dumps(arr)
arr2 = pickle.loads(data)
assert arr == arr2

0 comments on commit 0d2083e

Please sign in to comment.