Skip to content

Commit

Permalink
resort to Pickle codec for complex cases
Browse files Browse the repository at this point in the history
  • Loading branch information
magland committed Mar 6, 2024
1 parent b8300e8 commit 2ac5a3f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
24 changes: 23 additions & 1 deletion src/hdmf_zarr/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,13 +1050,18 @@ def write_dataset(self, **kwargs): # noqa: C901
new_dtype.append((field['name'], self.__resolve_dtype_helper__(field['dtype'])))
dtype = np.dtype(new_dtype)

object_codec = self.__codec_cls()
if not isinstance(object_codec, numcodecs.Pickle):
print(f'Resorting to Pickle codec for dataset {name} of {parent.name}')
object_codec = numcodecs.Pickle()

# cast and store compound dataset
arr = np.array(new_items, dtype=dtype)
dset = parent.require_dataset(
name,
shape=(len(arr),),
dtype=dtype,
object_codec=self.__codec_cls(),
object_codec=object_codec,
**options['io_settings']
)
dset.attrs['zarr_dtype'] = type_str
Expand Down Expand Up @@ -1268,6 +1273,23 @@ def __list_fill__(self, parent, name, data, options=None): # noqa: C901
else:
data_shape = get_data_shape(data)

# Let's check to see if we have a structured array somewhere in the data
# If we do, then we are going to resort to pickling the data and
# printing a warning.
has_structured_array = False
if dtype == object:
for c in np.ndindex(data_shape):
o = data
for i in c:
o = o[i]
if isinstance(o, np.void) and o.dtype.names is not None:
has_structured_array = True
if has_structured_array:
object_codec = io_settings.get('object_codec')
if not isinstance(object_codec, numcodecs.Pickle):
print(f'Warning: Resorting to Pickle codec for {name} of {parent.name}.')
io_settings['object_codec'] = numcodecs.Pickle()

# Create the dataset
dset = parent.require_dataset(name, shape=data_shape, dtype=dtype, **io_settings)
dset.attrs['zarr_dtype'] = type_str
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/base_tests_zarrio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# Try to import numcodecs and disable compression tests if it is not available
try:
from numcodecs import Blosc, Delta, JSON
from numcodecs import Blosc, Delta, JSON, Pickle
DISABLE_ZARR_COMPRESSION_TESTS = False
except ImportError:
DISABLE_ZARR_COMPRESSION_TESTS = True
Expand Down Expand Up @@ -491,12 +491,12 @@ def setUp(self):
# ZarrDataIO general
#############################################
def test_set_object_codec(self):
# Test that the default codec is the Pickle store
# Test that the default codec is JSON
tempIO = ZarrIO(self.store, mode='w')
self.assertEqual(tempIO.object_codec_class.__qualname__, 'Pickle')
del tempIO # also calls tempIO.close()
tempIO = ZarrIO(self.store, mode='w', object_codec_class=JSON)
self.assertEqual(tempIO.object_codec_class.__qualname__, 'JSON')
del tempIO # also calls tempIO.close()
tempIO = ZarrIO(self.store, mode='w', object_codec_class=Pickle)
self.assertEqual(tempIO.object_codec_class.__qualname__, 'Pickle')
tempIO.close()

def test_synchronizer_constructor_arg_bool(self):
Expand Down

0 comments on commit 2ac5a3f

Please sign in to comment.