From 5f708f2891e5fa7b6b8eaff6d65f43ac349d1132 Mon Sep 17 00:00:00 2001 From: rly Date: Fri, 20 Dec 2024 09:56:26 -0800 Subject: [PATCH] Revert "Coverage to 90% (#1198)" This reverts commit c71537a82ce9095df5a5f06a9e90d2a78cc95982. --- CHANGELOG.md | 5 +- SortedQueryTest.h5 | Bin 2128 -> 0 bytes src/hdmf/__init__.py | 26 ++- src/hdmf/array.py | 197 ++++++++++++++++++++++ src/hdmf/backends/hdf5/__init__.py | 2 +- src/hdmf/backends/hdf5/h5_utils.py | 78 ++++++++- src/hdmf/backends/hdf5/h5tools.py | 120 ++++++++++--- src/hdmf/build/classgenerator.py | 4 +- src/hdmf/build/manager.py | 16 +- src/hdmf/build/objectmapper.py | 50 ++++-- src/hdmf/common/table.py | 4 +- src/hdmf/container.py | 49 ++++++ src/hdmf/query.py | 121 ++++++++++++- src/hdmf/region.py | 91 ++++++++++ src/hdmf/utils.py | 91 ++++++++++ tests/unit/common/test_table.py | 4 +- tests/unit/test_container.py | 12 ++ tests/unit/test_query.py | 161 ++++++++++++++++++ tests/unit/utils_test/test_core_DataIO.py | 27 ++- tests/unit/utils_test/test_docval.py | 78 ++++++++- 20 files changed, 1073 insertions(+), 63 deletions(-) delete mode 100644 SortedQueryTest.h5 create mode 100644 src/hdmf/array.py create mode 100644 src/hdmf/region.py create mode 100644 tests/unit/test_query.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a1ccd1ac..028e745c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,6 @@ # HDMF Changelog -## HDMF 4.0.0 (Upcoming) - -### Deprecations -- The following classes have been deprecated and removed: Array, AbstractSortedArray, SortedArray, LinSpace, Query, RegionSlicer, ListSlicer, H5RegionSlicer, DataRegion. The following methods have been deprecated and removed: fmt_docval_args, call_docval_func, get_container_cls, add_child, set_dataio (now refactored as set_data_io). We have also removed all early evelopment for region references. @mavaylon1 [#1998](https://github.com/hdmf-dev/hdmf/pull/1198) +## HDMF 3.14.6 (Upcoming) ### Enhancements - Added support for expandable datasets of references for untyped and compound data types. @stephprince [#1188](https://github.com/hdmf-dev/hdmf/pull/1188) diff --git a/SortedQueryTest.h5 b/SortedQueryTest.h5 deleted file mode 100644 index 67f3d7d6cbd59ee6452fdaebcd21834eb191f551..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2128 zcmeHIF%H5o47Af0LM5cKAts){6F_QFC#Zmdu>-sS9swKA;3@nHlK3nk#!i(;D)lao z<0jVGz1QV>oX>Ld!&*wI2vlEhjK+#Z=epj4Fz^O^8X-*nA)3NILHu98!>+2xd1`z` zY@GteDVI!{bpPGiq<-uC;d5FJW0$z%s{sc aNH=r+3EQ=-@!NG>P{sf_1Zp1`Vg3Opl`aJU diff --git a/src/hdmf/__init__.py b/src/hdmf/__init__.py index 10305d37b..6fc72a117 100644 --- a/src/hdmf/__init__.py +++ b/src/hdmf/__init__.py @@ -1,10 +1,32 @@ from . import query -from .backends.hdf5.h5_utils import H5Dataset -from .container import Container, Data, HERDManager +from .backends.hdf5.h5_utils import H5Dataset, H5RegionSlicer +from .container import Container, Data, DataRegion, HERDManager +from .region import ListSlicer from .utils import docval, getargs from .term_set import TermSet, TermSetWrapper, TypeConfigurator +@docval( + {"name": "dataset", "type": None, "doc": "the HDF5 dataset to slice"}, + {"name": "region", "type": None, "doc": "the region reference to use to slice"}, + is_method=False, +) +def get_region_slicer(**kwargs): + import warnings # noqa: E402 + + warnings.warn( + "get_region_slicer is deprecated and will be removed in HDMF 3.0.", + DeprecationWarning, + ) + + dataset, region = getargs("dataset", "region", kwargs) + if isinstance(dataset, (list, tuple, Data)): + return ListSlicer(dataset, region) + elif isinstance(dataset, H5Dataset): + return H5RegionSlicer(dataset, region) + return None + + try: # see https://effigies.gitlab.io/posts/python-packaging-2023/ from ._version import __version__ diff --git a/src/hdmf/array.py b/src/hdmf/array.py new file mode 100644 index 000000000..a684572e4 --- /dev/null +++ b/src/hdmf/array.py @@ -0,0 +1,197 @@ +from abc import abstractmethod, ABCMeta + +import numpy as np + + +class Array: + + def __init__(self, data): + self.__data = data + if hasattr(data, 'dtype'): + self.dtype = data.dtype + else: + tmp = data + while isinstance(tmp, (list, tuple)): + tmp = tmp[0] + self.dtype = type(tmp) + + @property + def data(self): + return self.__data + + def __len__(self): + return len(self.__data) + + def get_data(self): + return self.__data + + def __getidx__(self, arg): + return self.__data[arg] + + def __sliceiter(self, arg): + return (x for x in range(*arg.indices(len(self)))) + + def __getitem__(self, arg): + if isinstance(arg, list): + idx = list() + for i in arg: + if isinstance(i, slice): + idx.extend(x for x in self.__sliceiter(i)) + else: + idx.append(i) + return np.fromiter((self.__getidx__(x) for x in idx), dtype=self.dtype) + elif isinstance(arg, slice): + return np.fromiter((self.__getidx__(x) for x in self.__sliceiter(arg)), dtype=self.dtype) + elif isinstance(arg, tuple): + return (self.__getidx__(arg[0]), self.__getidx__(arg[1])) + else: + return self.__getidx__(arg) + + +class AbstractSortedArray(Array, metaclass=ABCMeta): + ''' + An abstract class for representing sorted array + ''' + + @abstractmethod + def find_point(self, val): + pass + + def get_data(self): + return self + + def __lower(self, other): + ins = self.find_point(other) + return ins + + def __upper(self, other): + ins = self.__lower(other) + while self[ins] == other: + ins += 1 + return ins + + def __lt__(self, other): + ins = self.__lower(other) + return slice(0, ins) + + def __le__(self, other): + ins = self.__upper(other) + return slice(0, ins) + + def __gt__(self, other): + ins = self.__upper(other) + return slice(ins, len(self)) + + def __ge__(self, other): + ins = self.__lower(other) + return slice(ins, len(self)) + + @staticmethod + def __sort(a): + if isinstance(a, tuple): + return a[0] + else: + return a + + def __eq__(self, other): + if isinstance(other, list): + ret = list() + for i in other: + eq = self == i + ret.append(eq) + ret = sorted(ret, key=self.__sort) + tmp = list() + for i in range(1, len(ret)): + a, b = ret[i - 1], ret[i] + if isinstance(a, tuple): + if isinstance(b, tuple): + if a[1] >= b[0]: + b[0] = a[0] + else: + tmp.append(slice(*a)) + else: + if b > a[1]: + tmp.append(slice(*a)) + elif b == a[1]: + a[1] == b + 1 + else: + ret[i] = a + else: + if isinstance(b, tuple): + if a < b[0]: + tmp.append(a) + else: + if b - a == 1: + ret[i] = (a, b) + else: + tmp.append(a) + if isinstance(ret[-1], tuple): + tmp.append(slice(*ret[-1])) + else: + tmp.append(ret[-1]) + ret = tmp + return ret + elif isinstance(other, tuple): + ge = self >= other[0] + ge = ge.start + lt = self < other[1] + lt = lt.stop + if ge == lt: + return ge + else: + return slice(ge, lt) + else: + lower = self.__lower(other) + upper = self.__upper(other) + d = upper - lower + if d == 1: + return lower + elif d == 0: + return None + else: + return slice(lower, upper) + + def __ne__(self, other): + eq = self == other + if isinstance(eq, tuple): + return [slice(0, eq[0]), slice(eq[1], len(self))] + else: + return [slice(0, eq), slice(eq + 1, len(self))] + + +class SortedArray(AbstractSortedArray): + ''' + A class for wrapping sorted arrays. This class overrides + <,>,<=,>=,==, and != to leverage the sorted content for + efficiency. + ''' + + def __init__(self, array): + super().__init__(array) + + def find_point(self, val): + return np.searchsorted(self.data, val) + + +class LinSpace(SortedArray): + + def __init__(self, start, stop, step): + self.start = start + self.stop = stop + self.step = step + self.dtype = float if any(isinstance(s, float) for s in (start, stop, step)) else int + self.__len = int((stop - start) / step) + + def __len__(self): + return self.__len + + def find_point(self, val): + nsteps = (val - self.start) / self.step + fl = int(nsteps) + if fl == nsteps: + return int(fl) + else: + return int(fl + 1) + + def __getidx__(self, arg): + return self.start + self.step * arg diff --git a/src/hdmf/backends/hdf5/__init__.py b/src/hdmf/backends/hdf5/__init__.py index 8f76d7bcc..6abfc8c85 100644 --- a/src/hdmf/backends/hdf5/__init__.py +++ b/src/hdmf/backends/hdf5/__init__.py @@ -1,3 +1,3 @@ from . import h5_utils, h5tools -from .h5_utils import H5DataIO +from .h5_utils import H5RegionSlicer, H5DataIO from .h5tools import HDF5IO, H5SpecWriter, H5SpecReader diff --git a/src/hdmf/backends/hdf5/h5_utils.py b/src/hdmf/backends/hdf5/h5_utils.py index 878ebf089..2d7187721 100644 --- a/src/hdmf/backends/hdf5/h5_utils.py +++ b/src/hdmf/backends/hdf5/h5_utils.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from copy import copy -from h5py import Group, Dataset, Reference, special_dtype +from h5py import Group, Dataset, RegionReference, Reference, special_dtype from h5py import filters as h5py_filters import json import numpy as np @@ -16,8 +16,10 @@ import os import logging +from ...array import Array from ...data_utils import DataIO, AbstractDataChunkIterator, append_data from ...query import HDMFDataset, ReferenceResolver, ContainerResolver, BuilderResolver +from ...region import RegionSlicer from ...spec import SpecWriter, SpecReader from ...utils import docval, getargs, popargs, get_docval, get_data_shape @@ -83,7 +85,7 @@ def append(self, dataset, data): class H5Dataset(HDMFDataset): - @docval({'name': 'dataset', 'type': Dataset, 'doc': 'the HDF5 file lazily evaluate'}, + @docval({'name': 'dataset', 'type': (Dataset, Array), 'doc': 'the HDF5 file lazily evaluate'}, {'name': 'io', 'type': 'hdmf.backends.hdf5.h5tools.HDF5IO', 'doc': 'the IO object that was used to read the underlying dataset'}) def __init__(self, **kwargs): @@ -94,6 +96,10 @@ def __init__(self, **kwargs): def io(self): return self.__io + @property + def regionref(self): + return self.dataset.regionref + @property def ref(self): return self.dataset.ref @@ -183,7 +189,7 @@ def get_object(self, h5obj): class AbstractH5TableDataset(DatasetOfReferences): - @docval({'name': 'dataset', 'type': Dataset, 'doc': 'the HDF5 file lazily evaluate'}, + @docval({'name': 'dataset', 'type': (Dataset, Array), 'doc': 'the HDF5 file lazily evaluate'}, {'name': 'io', 'type': 'hdmf.backends.hdf5.h5tools.HDF5IO', 'doc': 'the IO object that was used to read the underlying dataset'}, {'name': 'types', 'type': (list, tuple), @@ -193,7 +199,9 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.__refgetters = dict() for i, t in enumerate(types): - if t is Reference: + if t is RegionReference: + self.__refgetters[i] = self.__get_regref + elif t is Reference: self.__refgetters[i] = self._get_ref elif t is str: # we need this for when we read compound data types @@ -215,6 +223,8 @@ def __init__(self, **kwargs): t = sub.metadata['ref'] if t is Reference: tmp.append('object') + elif t is RegionReference: + tmp.append('region') else: tmp.append(sub.type.__name__) self.__dtype = tmp @@ -247,6 +257,10 @@ def _get_utf(self, string): """ return string.decode('utf-8') if isinstance(string, bytes) else string + def __get_regref(self, ref): + obj = self._get_ref(ref) + return obj[ref] + def resolve(self, manager): return self[0:len(self)] @@ -269,6 +283,18 @@ def dtype(self): return 'object' +class AbstractH5RegionDataset(AbstractH5ReferenceDataset): + + def __getitem__(self, arg): + obj = super().__getitem__(arg) + ref = self.dataset[arg] + return obj[ref] + + @property + def dtype(self): + return 'region' + + class ContainerH5TableDataset(ContainerResolverMixin, AbstractH5TableDataset): """ A reference-resolving dataset for resolving references inside tables @@ -313,6 +339,28 @@ def get_inverse_class(cls): return ContainerH5ReferenceDataset +class ContainerH5RegionDataset(ContainerResolverMixin, AbstractH5RegionDataset): + """ + A reference-resolving dataset for resolving region references that returns + resolved references as Containers + """ + + @classmethod + def get_inverse_class(cls): + return BuilderH5RegionDataset + + +class BuilderH5RegionDataset(BuilderResolverMixin, AbstractH5RegionDataset): + """ + A reference-resolving dataset for resolving region references that returns + resolved references as Builders + """ + + @classmethod + def get_inverse_class(cls): + return ContainerH5RegionDataset + + class H5SpecWriter(SpecWriter): __str_type = special_dtype(vlen=str) @@ -372,6 +420,28 @@ def read_namespace(self, ns_path): return ret +class H5RegionSlicer(RegionSlicer): + + @docval({'name': 'dataset', 'type': (Dataset, H5Dataset), 'doc': 'the HDF5 dataset to slice'}, + {'name': 'region', 'type': RegionReference, 'doc': 'the region reference to use to slice'}) + def __init__(self, **kwargs): + self.__dataset = getargs('dataset', kwargs) + self.__regref = getargs('region', kwargs) + self.__len = self.__dataset.regionref.selection(self.__regref)[0] + self.__region = None + + def __read_region(self): + if self.__region is None: + self.__region = self.__dataset[self.__regref] + + def __getitem__(self, idx): + self.__read_region() + return self.__region[idx] + + def __len__(self): + return self.__len + + class H5DataIO(DataIO): """ Wrap data arrays for write via HDF5IO to customize I/O behavior, such as compression and chunking diff --git a/src/hdmf/backends/hdf5/h5tools.py b/src/hdmf/backends/hdf5/h5tools.py index f0e789d32..4fc9c258f 100644 --- a/src/hdmf/backends/hdf5/h5tools.py +++ b/src/hdmf/backends/hdf5/h5tools.py @@ -7,14 +7,14 @@ import numpy as np import h5py -from h5py import File, Group, Dataset, special_dtype, SoftLink, ExternalLink, Reference, check_dtype +from h5py import File, Group, Dataset, special_dtype, SoftLink, ExternalLink, Reference, RegionReference, check_dtype -from .h5_utils import (BuilderH5ReferenceDataset, BuilderH5TableDataset, H5DataIO, +from .h5_utils import (BuilderH5ReferenceDataset, BuilderH5RegionDataset, BuilderH5TableDataset, H5DataIO, H5SpecReader, H5SpecWriter, HDF5IODataChunkIteratorQueue) from ..io import HDMFIO from ..errors import UnsupportedOperation from ..warnings import BrokenLinkWarning -from ...build import (Builder, GroupBuilder, DatasetBuilder, LinkBuilder, BuildManager, +from ...build import (Builder, GroupBuilder, DatasetBuilder, LinkBuilder, BuildManager, RegionBuilder, ReferenceBuilder, TypeMap, ObjectMapper) from ...container import Container from ...data_utils import AbstractDataChunkIterator @@ -28,6 +28,7 @@ H5_TEXT = special_dtype(vlen=str) H5_BINARY = special_dtype(vlen=bytes) H5_REF = special_dtype(ref=Reference) +H5_REGREF = special_dtype(ref=RegionReference) RDCC_NBYTES = 32*2**20 # set raw data chunk cache size = 32 MiB @@ -692,7 +693,10 @@ def __read_dataset(self, h5obj, name=None): target = h5obj.file[scalar] target_builder = self.__read_dataset(target) self.__set_built(target.file.filename, target.id, target_builder) - d = ReferenceBuilder(target_builder) + if isinstance(scalar, RegionReference): + d = RegionBuilder(scalar, target_builder) + else: + d = ReferenceBuilder(target_builder) kwargs['data'] = d kwargs['dtype'] = d.dtype elif h5obj.dtype.kind == 'V': # scalar compound data type @@ -709,6 +713,9 @@ def __read_dataset(self, h5obj, name=None): elem1 = h5obj[tuple([0] * (h5obj.ndim - 1) + [0])] if isinstance(elem1, (str, bytes)): d = self._check_str_dtype(h5obj) + elif isinstance(elem1, RegionReference): # read list of references + d = BuilderH5RegionDataset(h5obj, self) + kwargs['dtype'] = d.dtype elif isinstance(elem1, Reference): d = BuilderH5ReferenceDataset(h5obj, self) kwargs['dtype'] = d.dtype @@ -744,7 +751,9 @@ def __read_attrs(self, h5obj): for k, v in h5obj.attrs.items(): if k == SPEC_LOC_ATTR: # ignore cached spec continue - if isinstance(v, Reference): + if isinstance(v, RegionReference): + raise ValueError("cannot read region reference attributes yet") + elif isinstance(v, Reference): ret[k] = self.__read_ref(h5obj.file[v]) else: ret[k] = v @@ -911,6 +920,7 @@ def get_type(cls, data): "ref": H5_REF, "reference": H5_REF, "object": H5_REF, + "region": H5_REGREF, "isodatetime": H5_TEXT, "datetime": H5_TEXT, } @@ -1228,12 +1238,29 @@ def _filler(): dset = self.__scalar_fill__(parent, name, data, options) else: dset = self.__list_fill__(parent, name, data, options) - # Write a dataset containing references, i.e., object reference. + # Write a dataset containing references, i.e., a region or object reference. # NOTE: we can ignore options['io_settings'] for scalar data elif self.__is_ref(options['dtype']): _dtype = self.__dtypes.get(options['dtype']) + # Write a scalar data region reference dataset + if isinstance(data, RegionBuilder): + dset = parent.require_dataset(name, shape=(), dtype=_dtype) + self.__set_written(builder) + self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing a " + "region reference. attributes: %s" + % (name, list(attributes.keys()))) + + @self.__queue_ref + def _filler(): + self.logger.debug("Resolving region reference and setting attribute on dataset '%s' " + "containing attributes: %s" + % (name, list(attributes.keys()))) + ref = self.__get_ref(data.builder, data.region) + dset = parent[name] + dset[()] = ref + self.set_attributes(dset, attributes) # Write a scalar object reference dataset - if isinstance(data, ReferenceBuilder): + elif isinstance(data, ReferenceBuilder): dset = parent.require_dataset(name, dtype=_dtype, shape=()) self.__set_written(builder) self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing an " @@ -1251,24 +1278,44 @@ def _filler(): self.set_attributes(dset, attributes) # Write an array dataset of references else: - # Write array of object references - dset = parent.require_dataset(name, shape=(len(data),), dtype=_dtype, **options['io_settings']) - self.__set_written(builder) - self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing " - "object references. attributes: %s" - % (name, list(attributes.keys()))) + # Write a array of region references + if options['dtype'] == 'region': + dset = parent.require_dataset(name, dtype=_dtype, shape=(len(data),), **options['io_settings']) + self.__set_written(builder) + self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing " + "region references. attributes: %s" + % (name, list(attributes.keys()))) - @self.__queue_ref - def _filler(): - self.logger.debug("Resolving object references and setting attribute on dataset '%s' " - "containing attributes: %s" + @self.__queue_ref + def _filler(): + self.logger.debug("Resolving region references and setting attribute on dataset '%s' " + "containing attributes: %s" + % (name, list(attributes.keys()))) + refs = list() + for item in data: + refs.append(self.__get_ref(item.builder, item.region)) + dset = parent[name] + dset[()] = refs + self.set_attributes(dset, attributes) + # Write array of object references + else: + dset = parent.require_dataset(name, shape=(len(data),), dtype=_dtype, **options['io_settings']) + self.__set_written(builder) + self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing " + "object references. attributes: %s" % (name, list(attributes.keys()))) - refs = list() - for item in data: - refs.append(self.__get_ref(item)) - dset = parent[name] - dset[()] = refs - self.set_attributes(dset, attributes) + + @self.__queue_ref + def _filler(): + self.logger.debug("Resolving object references and setting attribute on dataset '%s' " + "containing attributes: %s" + % (name, list(attributes.keys()))) + refs = list() + for item in data: + refs.append(self.__get_ref(item)) + dset = parent[name] + dset[()] = refs + self.set_attributes(dset, attributes) return # write a "regular" dataset else: @@ -1456,9 +1503,11 @@ def __list_fill__(cls, parent, name, data, options=None): @docval({'name': 'container', 'type': (Builder, Container, ReferenceBuilder), 'doc': 'the object to reference', 'default': None}, + {'name': 'region', 'type': (slice, list, tuple), 'doc': 'the region reference indexing object', + 'default': None}, returns='the reference', rtype=Reference) def __get_ref(self, **kwargs): - container = getargs('container', kwargs) + container, region = getargs('container', 'region', kwargs) if container is None: return None if isinstance(container, Builder): @@ -1476,10 +1525,20 @@ def __get_ref(self, **kwargs): path = self.__get_path(builder) self.logger.debug("Getting reference at path '%s'" % path) - return self.__file[path].ref + if isinstance(container, RegionBuilder): + region = container.region + if region is not None: + dset = self.__file[path] + if not isinstance(dset, Dataset): + raise ValueError('cannot create region reference without Dataset') + return self.__file[path].regionref[region] + else: + return self.__file[path].ref @docval({'name': 'container', 'type': (Builder, Container, ReferenceBuilder), 'doc': 'the object to reference', 'default': None}, + {'name': 'region', 'type': (slice, list, tuple), 'doc': 'the region reference indexing object', + 'default': None}, returns='the reference', rtype=Reference) def _create_ref(self, **kwargs): return self.__get_ref(**kwargs) @@ -1511,6 +1570,17 @@ def __queue_ref(self, func): # dependency self.__ref_queue.append(func) + def __rec_get_ref(self, ref_list): + ret = list() + for elem in ref_list: + if isinstance(elem, (list, tuple)): + ret.append(self.__rec_get_ref(elem)) + elif isinstance(elem, (Builder, Container)): + ret.append(self.__get_ref(elem)) + else: + ret.append(elem) + return ret + @property def mode(self): """ diff --git a/src/hdmf/build/classgenerator.py b/src/hdmf/build/classgenerator.py index 3b7d7c96e..a3336b98e 100644 --- a/src/hdmf/build/classgenerator.py +++ b/src/hdmf/build/classgenerator.py @@ -4,7 +4,7 @@ import numpy as np -from ..container import Container, Data, MultiContainerInterface +from ..container import Container, Data, DataRegion, MultiContainerInterface from ..spec import AttributeSpec, LinkSpec, RefSpec, GroupSpec from ..spec.spec import BaseStorageSpec, ZERO_OR_MANY, ONE_OR_MANY from ..utils import docval, getargs, ExtenderMeta, get_docval, popargs, AllowPositional @@ -195,7 +195,7 @@ def _ischild(cls, dtype): if isinstance(dtype, tuple): for sub in dtype: ret = ret or cls._ischild(sub) - elif isinstance(dtype, type) and issubclass(dtype, (Container, Data)): + elif isinstance(dtype, type) and issubclass(dtype, (Container, Data, DataRegion)): ret = True return ret diff --git a/src/hdmf/build/manager.py b/src/hdmf/build/manager.py index bc586013c..967c34010 100644 --- a/src/hdmf/build/manager.py +++ b/src/hdmf/build/manager.py @@ -490,6 +490,20 @@ def load_namespaces(self, **kwargs): self.register_container_type(new_ns, dt, container_cls) return deps + @docval({"name": "namespace", "type": str, "doc": "the namespace containing the data_type"}, + {"name": "data_type", "type": str, "doc": "the data type to create a AbstractContainer class for"}, + {"name": "autogen", "type": bool, "doc": "autogenerate class if one does not exist", "default": True}, + returns='the class for the given namespace and data_type', rtype=type) + def get_container_cls(self, **kwargs): + """Get the container class from data type specification. + If no class has been associated with the ``data_type`` from ``namespace``, a class will be dynamically + created and returned. + """ + # NOTE: this internally used function get_container_cls will be removed in favor of get_dt_container_cls + # Deprecated: Will be removed by HDMF 4.0 + namespace, data_type, autogen = getargs('namespace', 'data_type', 'autogen', kwargs) + return self.get_dt_container_cls(data_type, namespace, autogen) + @docval({"name": "data_type", "type": str, "doc": "the data type to create a AbstractContainer class for"}, {"name": "namespace", "type": str, "doc": "the namespace containing the data_type", "default": None}, {'name': 'post_init_method', 'type': Callable, 'default': None, @@ -501,7 +515,7 @@ def get_dt_container_cls(self, **kwargs): If no class has been associated with the ``data_type`` from ``namespace``, a class will be dynamically created and returned. - Namespace is optional. If namespace is unknown, it will be looked up from + Replaces get_container_cls but namespace is optional. If namespace is unknown, it will be looked up from all namespaces. """ namespace, data_type, post_init_method, autogen = getargs('namespace', 'data_type', diff --git a/src/hdmf/build/objectmapper.py b/src/hdmf/build/objectmapper.py index c2ef44b5f..3394ebb91 100644 --- a/src/hdmf/build/objectmapper.py +++ b/src/hdmf/build/objectmapper.py @@ -15,7 +15,7 @@ IncorrectDatasetShapeBuildWarning) from hdmf.backends.hdf5.h5_utils import H5DataIO -from ..container import AbstractContainer, Data +from ..container import AbstractContainer, Data, DataRegion from ..term_set import TermSetWrapper from ..data_utils import DataIO, AbstractDataChunkIterator from ..query import ReferenceResolver @@ -966,23 +966,41 @@ def _filler(): return _filler def __get_ref_builder(self, builder, dtype, shape, container, build_manager): - self.logger.debug("Setting object reference dataset on %s '%s' data" - % (builder.__class__.__name__, builder.name)) - if isinstance(container, Data): - self.logger.debug("Setting %s '%s' data to list of reference builders" - % (builder.__class__.__name__, builder.name)) - bldr_data = list() - for d in container.data: - target_builder = self.__get_target_builder(d, build_manager, builder) - bldr_data.append(ReferenceBuilder(target_builder)) - if isinstance(container.data, H5DataIO): - # This is here to support appending a dataset of references. - bldr_data = H5DataIO(bldr_data, **container.data.get_io_params()) + bldr_data = None + if dtype.is_region(): + if shape is None: + if not isinstance(container, DataRegion): + msg = "'container' must be of type DataRegion if spec represents region reference" + raise ValueError(msg) + self.logger.debug("Setting %s '%s' data to region reference builder" + % (builder.__class__.__name__, builder.name)) + target_builder = self.__get_target_builder(container.data, build_manager, builder) + bldr_data = RegionBuilder(container.region, target_builder) + else: + self.logger.debug("Setting %s '%s' data to list of region reference builders" + % (builder.__class__.__name__, builder.name)) + bldr_data = list() + for d in container.data: + target_builder = self.__get_target_builder(d.target, build_manager, builder) + bldr_data.append(RegionBuilder(d.slice, target_builder)) else: - self.logger.debug("Setting %s '%s' data to reference builder" + self.logger.debug("Setting object reference dataset on %s '%s' data" % (builder.__class__.__name__, builder.name)) - target_builder = self.__get_target_builder(container, build_manager, builder) - bldr_data = ReferenceBuilder(target_builder) + if isinstance(container, Data): + self.logger.debug("Setting %s '%s' data to list of reference builders" + % (builder.__class__.__name__, builder.name)) + bldr_data = list() + for d in container.data: + target_builder = self.__get_target_builder(d, build_manager, builder) + bldr_data.append(ReferenceBuilder(target_builder)) + if isinstance(container.data, H5DataIO): + # This is here to support appending a dataset of references. + bldr_data = H5DataIO(bldr_data, **container.data.get_io_params()) + else: + self.logger.debug("Setting %s '%s' data to reference builder" + % (builder.__class__.__name__, builder.name)) + target_builder = self.__get_target_builder(container, build_manager, builder) + bldr_data = ReferenceBuilder(target_builder) return bldr_data def __get_target_builder(self, container, build_manager, builder): diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index 84ac4da3b..b4530c7b7 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -775,8 +775,8 @@ def add_column(self, **kwargs): # noqa: C901 index, table, enum, col_cls, check_ragged = popargs('index', 'table', 'enum', 'col_cls', 'check_ragged', kwargs) if isinstance(index, VectorIndex): - msg = "Passing a VectorIndex may lead to unexpected behavior. This functionality is not supported." - raise ValueError(msg) + warn("Passing a VectorIndex in for index may lead to unexpected behavior. This functionality will be " + "deprecated in a future version of HDMF.", category=FutureWarning, stacklevel=3) if name in self.__colids: # column has already been added msg = "column '%s' already exists in %s '%s'" % (name, self.__class__.__name__, self.name) diff --git a/src/hdmf/container.py b/src/hdmf/container.py index ce4e8b821..f2dba6e8d 100644 --- a/src/hdmf/container.py +++ b/src/hdmf/container.py @@ -1,4 +1,5 @@ import types +from abc import abstractmethod from collections import OrderedDict from copy import deepcopy from typing import Type, Optional @@ -466,6 +467,21 @@ def set_modified(self, **kwargs): def children(self): return tuple(self.__children) + @docval({'name': 'child', 'type': 'Container', + 'doc': 'the child Container for this Container', 'default': None}) + def add_child(self, **kwargs): + warn(DeprecationWarning('add_child is deprecated. Set the parent attribute instead.')) + child = getargs('child', kwargs) + if child is not None: + # if child.parent is a Container, then the mismatch between child.parent and parent + # is used to make a soft/external link from the parent to a child elsewhere + # if child.parent is not a Container, it is either None or a Proxy and should be set to self + if not isinstance(child.parent, AbstractContainer): + # actually add the child to the parent in parent setter + child.parent = self + else: + warn('Cannot add None as child to a container %s' % self.name) + @classmethod def type_hierarchy(cls): return cls.__mro__ @@ -915,6 +931,20 @@ def shape(self): """ return get_data_shape(self.__data) + @docval({'name': 'dataio', 'type': DataIO, 'doc': 'the DataIO to apply to the data held by this Data'}) + def set_dataio(self, **kwargs): + """ + Apply DataIO object to the data held by this Data object + """ + warn( + "Data.set_dataio() is deprecated. Please use Data.set_data_io() instead.", + DeprecationWarning, + stacklevel=3, + ) + dataio = getargs('dataio', kwargs) + dataio.data = self.__data + self.__data = dataio + def set_data_io( self, data_io_class: Type[DataIO], @@ -1010,6 +1040,25 @@ def _validate_new_data_element(self, arg): pass +class DataRegion(Data): + + @property + @abstractmethod + def data(self): + ''' + The target data that this region applies to + ''' + pass + + @property + @abstractmethod + def region(self): + ''' + The region that indexes into data e.g. slice or list of indices + ''' + pass + + class MultiContainerInterface(Container): """Class that dynamically defines methods to support a Container holding multiple Containers of the same type. diff --git a/src/hdmf/query.py b/src/hdmf/query.py index abe2a93a7..9693b0b1c 100644 --- a/src/hdmf/query.py +++ b/src/hdmf/query.py @@ -2,24 +2,143 @@ import numpy as np +from .array import Array from .utils import ExtenderMeta, docval_macro, docval, getargs +class Query(metaclass=ExtenderMeta): + __operations__ = ( + '__lt__', + '__gt__', + '__le__', + '__ge__', + '__eq__', + '__ne__', + ) + + @classmethod + def __build_operation(cls, op): + def __func(self, arg): + return cls(self, op, arg) + + @ExtenderMeta.pre_init + def __make_operators(cls, name, bases, classdict): + if not isinstance(cls.__operations__, tuple): + raise TypeError("'__operations__' must be of type tuple") + # add any new operations + if len(bases) and 'Query' in globals() and issubclass(bases[-1], Query) \ + and bases[-1].__operations__ is not cls.__operations__: + new_operations = list(cls.__operations__) + new_operations[0:0] = bases[-1].__operations__ + cls.__operations__ = tuple(new_operations) + for op in cls.__operations__: + if not hasattr(cls, op): + setattr(cls, op, cls.__build_operation(op)) + + def __init__(self, obj, op, arg): + self.obj = obj + self.op = op + self.arg = arg + self.collapsed = None + self.expanded = None + + @docval({'name': 'expand', 'type': bool, 'help': 'whether or not to expand result', 'default': True}) + def evaluate(self, **kwargs): + expand = getargs('expand', kwargs) + if expand: + if self.expanded is None: + self.expanded = self.__evalhelper() + return self.expanded + else: + if self.collapsed is None: + self.collapsed = self.__collapse(self.__evalhelper()) + return self.collapsed + + def __evalhelper(self): + obj = self.obj + arg = self.arg + if isinstance(obj, Query): + obj = obj.evaluate() + elif isinstance(obj, HDMFDataset): + obj = obj.dataset + if isinstance(arg, Query): + arg = self.arg.evaluate() + return getattr(obj, self.op)(self.arg) + + def __collapse(self, result): + if isinstance(result, slice): + return (result.start, result.stop) + elif isinstance(result, list): + ret = list() + for idx in result: + if isinstance(idx, slice) and (idx.step is None or idx.step == 1): + ret.append((idx.start, idx.stop)) + else: + ret.append(idx) + return ret + else: + return result + + def __and__(self, other): + return NotImplemented + + def __or__(self, other): + return NotImplemented + + def __xor__(self, other): + return NotImplemented + + def __contains__(self, other): + return NotImplemented + + @docval_macro('array_data') class HDMFDataset(metaclass=ExtenderMeta): + __operations__ = ( + '__lt__', + '__gt__', + '__le__', + '__ge__', + '__eq__', + '__ne__', + ) + + @classmethod + def __build_operation(cls, op): + def __func(self, arg): + return Query(self, op, arg) + + setattr(__func, '__name__', op) + return __func + + @ExtenderMeta.pre_init + def __make_operators(cls, name, bases, classdict): + if not isinstance(cls.__operations__, tuple): + raise TypeError("'__operations__' must be of type tuple") + # add any new operations + if len(bases) and 'Query' in globals() and issubclass(bases[-1], Query) \ + and bases[-1].__operations__ is not cls.__operations__: + new_operations = list(cls.__operations__) + new_operations[0:0] = bases[-1].__operations__ + cls.__operations__ = tuple(new_operations) + for op in cls.__operations__: + setattr(cls, op, cls.__build_operation(op)) + def __evaluate_key(self, key): if isinstance(key, tuple) and len(key) == 0: return key if isinstance(key, (tuple, list, np.ndarray)): return list(map(self.__evaluate_key, key)) else: + if isinstance(key, Query): + return key.evaluate() return key def __getitem__(self, key): idx = self.__evaluate_key(key) return self.dataset[idx] - @docval({'name': 'dataset', 'type': 'array_data', 'doc': 'the HDF5 file lazily evaluate'}) + @docval({'name': 'dataset', 'type': ('array_data', Array), 'doc': 'the HDF5 file lazily evaluate'}) def __init__(self, **kwargs): super().__init__() self.__dataset = getargs('dataset', kwargs) diff --git a/src/hdmf/region.py b/src/hdmf/region.py new file mode 100644 index 000000000..9feeba401 --- /dev/null +++ b/src/hdmf/region.py @@ -0,0 +1,91 @@ +from abc import ABCMeta, abstractmethod +from operator import itemgetter + +from .container import Data, DataRegion +from .utils import docval, getargs + + +class RegionSlicer(DataRegion, metaclass=ABCMeta): + ''' + A abstract base class to control getting using a region + + Subclasses must implement `__getitem__` and `__len__` + ''' + + @docval({'name': 'target', 'type': None, 'doc': 'the target to slice'}, + {'name': 'slice', 'type': None, 'doc': 'the region to slice'}) + def __init__(self, **kwargs): + self.__target = getargs('target', kwargs) + self.__slice = getargs('slice', kwargs) + + @property + def data(self): + """The target data. Same as self.target""" + return self.target + + @property + def region(self): + """The selected region. Same as self.slice""" + return self.slice + + @property + def target(self): + """The target data""" + return self.__target + + @property + def slice(self): + """The selected slice""" + return self.__slice + + @property + @abstractmethod + def __getitem__(self, idx): + """Must be implemented by subclasses""" + pass + + @property + @abstractmethod + def __len__(self): + """Must be implemented by subclasses""" + pass + + +class ListSlicer(RegionSlicer): + """Implementation of RegionSlicer for slicing Lists and Data""" + + @docval({'name': 'dataset', 'type': (list, tuple, Data), 'doc': 'the dataset to slice'}, + {'name': 'region', 'type': (list, tuple, slice), 'doc': 'the region reference to use to slice'}) + def __init__(self, **kwargs): + self.__dataset, self.__region = getargs('dataset', 'region', kwargs) + super().__init__(self.__dataset, self.__region) + if isinstance(self.__region, slice): + self.__getter = itemgetter(self.__region) + self.__len = len(range(*self.__region.indices(len(self.__dataset)))) + else: + self.__getter = itemgetter(*self.__region) + self.__len = len(self.__region) + + def __read_region(self): + """ + Internal helper function used to define self._read + """ + if not hasattr(self, '_read'): + self._read = self.__getter(self.__dataset) + del self.__getter + + def __getitem__(self, idx): + """ + Get data values from selected data + """ + self.__read_region() + getter = None + if isinstance(idx, (list, tuple)): + getter = itemgetter(*idx) + else: + getter = itemgetter(idx) + return getter(self._read) + + def __len__(self): + """Number of values in the slice/region""" + return self.__len diff --git a/src/hdmf/utils.py b/src/hdmf/utils.py index c21382a2a..d05b52d93 100644 --- a/src/hdmf/utils.py +++ b/src/hdmf/utils.py @@ -382,6 +382,8 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True, for key in extras.keys(): type_errors.append("unrecognized argument: '%s'" % key) else: + # TODO: Extras get stripped out if function arguments are composed with fmt_docval_args. + # allow_extra needs to be tracked on a function so that fmt_docval_args doesn't strip them out for key in extras.keys(): ret[key] = extras[key] return {'args': ret, 'future_warnings': future_warnings, 'type_errors': type_errors, 'value_errors': value_errors, @@ -412,6 +414,95 @@ def get_docval(func, *args): return tuple() +# def docval_wrap(func, is_method=True): +# if is_method: +# @docval(*get_docval(func)) +# def method(self, **kwargs): +# +# return call_docval_args(func, kwargs) +# return method +# else: +# @docval(*get_docval(func)) +# def static_method(**kwargs): +# return call_docval_args(func, kwargs) +# return method + + +def fmt_docval_args(func, kwargs): + ''' Separate positional and keyword arguments + + Useful for methods that wrap other methods + ''' + warnings.warn("fmt_docval_args will be deprecated in a future version of HDMF. Instead of using fmt_docval_args, " + "call the function directly with the kwargs. Please note that fmt_docval_args " + "removes all arguments not accepted by the function's docval, so if you are passing kwargs that " + "includes extra arguments and the function's docval does not allow extra arguments (allow_extra=True " + "is set), then you will need to pop the extra arguments out of kwargs before calling the function.", + PendingDeprecationWarning, stacklevel=2) + func_docval = getattr(func, docval_attr_name, None) + ret_args = list() + ret_kwargs = dict() + kwargs_copy = _copy.copy(kwargs) + if func_docval: + for arg in func_docval[__docval_args_loc]: + val = kwargs_copy.pop(arg['name'], None) + if 'default' in arg: + if val is not None: + ret_kwargs[arg['name']] = val + else: + ret_args.append(val) + if func_docval['allow_extra']: + ret_kwargs.update(kwargs_copy) + else: + raise ValueError('no docval found on %s' % str(func)) + return ret_args, ret_kwargs + + +# def _remove_extra_args(func, kwargs): +# """Return a dict of only the keyword arguments that are accepted by the function's docval. +# +# If the docval specifies allow_extra=True, then the original kwargs are returned. +# """ +# # NOTE: this has the same functionality as the to-be-deprecated fmt_docval_args except that +# # kwargs are kept as kwargs instead of parsed into args and kwargs +# func_docval = getattr(func, docval_attr_name, None) +# if func_docval: +# if func_docval['allow_extra']: +# # if extra args are allowed, return all args +# return kwargs +# else: +# # save only the arguments listed in the function's docval (skip any others present in kwargs) +# ret_kwargs = dict() +# for arg in func_docval[__docval_args_loc]: +# val = kwargs.get(arg['name'], None) +# if val is not None: # do not return arguments that are not present or have value None +# ret_kwargs[arg['name']] = val +# return ret_kwargs +# else: +# raise ValueError('No docval found on %s' % str(func)) + + +def call_docval_func(func, kwargs): + """Call the function with only the keyword arguments that are accepted by the function's docval. + + Extra keyword arguments are not passed to the function unless the function's docval has allow_extra=True. + """ + warnings.warn("call_docval_func will be deprecated in a future version of HDMF. Instead of using call_docval_func, " + "call the function directly with the kwargs. Please note that call_docval_func " + "removes all arguments not accepted by the function's docval, so if you are passing kwargs that " + "includes extra arguments and the function's docval does not allow extra arguments (allow_extra=True " + "is set), then you will need to pop the extra arguments out of kwargs before calling the function.", + PendingDeprecationWarning, stacklevel=2) + with warnings.catch_warnings(record=True): + # catch and ignore only PendingDeprecationWarnings from fmt_docval_args so that two + # PendingDeprecationWarnings saying the same thing are not raised + warnings.simplefilter("ignore", UserWarning) + warnings.simplefilter("always", PendingDeprecationWarning) + fargs, fkwargs = fmt_docval_args(func, kwargs) + + return func(*fargs, **fkwargs) + + def __resolve_type(t): if t is None: return t diff --git a/tests/unit/common/test_table.py b/tests/unit/common/test_table.py index 15a0c9e91..38175b230 100644 --- a/tests/unit/common/test_table.py +++ b/tests/unit/common/test_table.py @@ -429,7 +429,9 @@ def test_add_column_vectorindex(self): table.add_column(name='qux', description='qux column') ind = VectorIndex(name='quux', data=list(), target=table['qux']) - with self.assertRaises(ValueError): + msg = ("Passing a VectorIndex in for index may lead to unexpected behavior. This functionality will be " + "deprecated in a future version of HDMF.") + with self.assertWarnsWith(FutureWarning, msg): table.add_column(name='bad', description='bad column', index=ind) def test_add_column_multi_index(self): diff --git a/tests/unit/test_container.py b/tests/unit/test_container.py index 2abe6349b..c12247de7 100644 --- a/tests/unit/test_container.py +++ b/tests/unit/test_container.py @@ -213,6 +213,18 @@ def test_all_children(self): obj = species.all_objects self.assertEqual(sorted(list(obj.keys())), sorted([species.object_id, species.id.object_id, col1.object_id])) + def test_add_child(self): + """Test that add child creates deprecation warning and also properly sets child's parent and modified + """ + parent_obj = Container('obj1') + child_obj = Container('obj2') + parent_obj.set_modified(False) + with self.assertWarnsWith(DeprecationWarning, 'add_child is deprecated. Set the parent attribute instead.'): + parent_obj.add_child(child_obj) + self.assertIs(child_obj.parent, parent_obj) + self.assertTrue(parent_obj.modified) + self.assertIs(parent_obj.children[0], child_obj) + def test_parent_set_link_warning(self): col1 = VectorData( name='col1', diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py new file mode 100644 index 000000000..b2ff267a7 --- /dev/null +++ b/tests/unit/test_query.py @@ -0,0 +1,161 @@ +import os +from abc import ABCMeta, abstractmethod + +import numpy as np +from h5py import File +from hdmf.array import SortedArray, LinSpace +from hdmf.query import HDMFDataset, Query +from hdmf.testing import TestCase + + +class AbstractQueryMixin(metaclass=ABCMeta): + + @abstractmethod + def getDataset(self): + raise NotImplementedError('Cannot run test unless getDataset is implemented') + + def setUp(self): + self.dset = self.getDataset() + self.wrapper = HDMFDataset(self.dset) + + def test_get_dataset(self): + array = self.wrapper.dataset + self.assertIsInstance(array, SortedArray) + + def test___gt__(self): + ''' + Test wrapper greater than magic method + ''' + q = self.wrapper > 5 + self.assertIsInstance(q, Query) + result = q.evaluate() + expected = [False, False, False, False, False, + False, True, True, True, True] + expected = slice(6, 10) + self.assertEqual(result, expected) + + def test___ge__(self): + ''' + Test wrapper greater than or equal magic method + ''' + q = self.wrapper >= 5 + self.assertIsInstance(q, Query) + result = q.evaluate() + expected = [False, False, False, False, False, + True, True, True, True, True] + expected = slice(5, 10) + self.assertEqual(result, expected) + + def test___lt__(self): + ''' + Test wrapper less than magic method + ''' + q = self.wrapper < 5 + self.assertIsInstance(q, Query) + result = q.evaluate() + expected = [True, True, True, True, True, + False, False, False, False, False] + expected = slice(0, 5) + self.assertEqual(result, expected) + + def test___le__(self): + ''' + Test wrapper less than or equal magic method + ''' + q = self.wrapper <= 5 + self.assertIsInstance(q, Query) + result = q.evaluate() + expected = [True, True, True, True, True, + True, False, False, False, False] + expected = slice(0, 6) + self.assertEqual(result, expected) + + def test___eq__(self): + ''' + Test wrapper equals magic method + ''' + q = self.wrapper == 5 + self.assertIsInstance(q, Query) + result = q.evaluate() + expected = [False, False, False, False, False, + True, False, False, False, False] + expected = 5 + self.assertTrue(np.array_equal(result, expected)) + + def test___ne__(self): + ''' + Test wrapper not equal magic method + ''' + q = self.wrapper != 5 + self.assertIsInstance(q, Query) + result = q.evaluate() + expected = [True, True, True, True, True, + False, True, True, True, True] + expected = [slice(0, 5), slice(6, 10)] + self.assertTrue(np.array_equal(result, expected)) + + def test___getitem__(self): + ''' + Test wrapper getitem using slice + ''' + result = self.wrapper[0:5] + expected = [0, 1, 2, 3, 4] + self.assertTrue(np.array_equal(result, expected)) + + def test___getitem__query(self): + ''' + Test wrapper getitem using query + ''' + q = self.wrapper < 5 + result = self.wrapper[q] + expected = [0, 1, 2, 3, 4] + self.assertTrue(np.array_equal(result, expected)) + + +class SortedQueryTest(AbstractQueryMixin, TestCase): + + path = 'SortedQueryTest.h5' + + def getDataset(self): + self.f = File(self.path, 'w') + self.input = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + self.d = self.f.create_dataset('dset', data=self.input) + return SortedArray(self.d) + + def tearDown(self): + self.f.close() + if os.path.exists(self.path): + os.remove(self.path) + + +class LinspaceQueryTest(AbstractQueryMixin, TestCase): + + path = 'LinspaceQueryTest.h5' + + def getDataset(self): + return LinSpace(0, 10, 1) + + +class CompoundQueryTest(TestCase): + + def getM(self): + return SortedArray(np.arange(10, 20, 1)) + + def getN(self): + return SortedArray(np.arange(10.0, 20.0, 0.5)) + + def setUp(self): + self.m = HDMFDataset(self.getM()) + self.n = HDMFDataset(self.getN()) + + # TODO: test not completed + # def test_map(self): + # q = self.m == (12, 16) # IN operation + # q.evaluate() # [2,3,4,5] + # q.evaluate(False) # RangeResult(2,6) + # r = self.m[q] # noqa: F841 + # r = self.m[q.evaluate()] # noqa: F841 + # r = self.m[q.evaluate(False)] # noqa: F841 + + def tearDown(self): + pass diff --git a/tests/unit/utils_test/test_core_DataIO.py b/tests/unit/utils_test/test_core_DataIO.py index 80518a316..4c2ffac15 100644 --- a/tests/unit/utils_test/test_core_DataIO.py +++ b/tests/unit/utils_test/test_core_DataIO.py @@ -1,8 +1,10 @@ from copy import copy, deepcopy import numpy as np +from hdmf.container import Data from hdmf.data_utils import DataIO from hdmf.testing import TestCase +import warnings class DataIOTests(TestCase): @@ -28,13 +30,34 @@ def test_dataio_slice_delegation(self): dset = DataIO(indata) self.assertTrue(np.all(dset[1:3, 5:8] == indata[1:3, 5:8])) - def test_set_data_io_data_already_set(self): + def test_set_dataio(self): + """ + Test that Data.set_dataio works as intended + """ + dataio = DataIO() + data = np.arange(30).reshape(5, 2, 3) + container = Data('wrapped_data', data) + msg = "Data.set_dataio() is deprecated. Please use Data.set_data_io() instead." + with self.assertWarnsWith(DeprecationWarning, msg): + container.set_dataio(dataio) + self.assertIs(dataio.data, data) + self.assertIs(dataio, container.data) + + def test_set_dataio_data_already_set(self): """ Test that Data.set_dataio works as intended """ dataio = DataIO(data=np.arange(30).reshape(5, 2, 3)) + data = np.arange(30).reshape(5, 2, 3) + container = Data('wrapped_data', data) with self.assertRaisesWith(ValueError, "cannot overwrite 'data' on DataIO"): - dataio.data=[1,2,3,4] + with warnings.catch_warnings(record=True): + warnings.filterwarnings( + action='ignore', + category=DeprecationWarning, + message="Data.set_dataio() is deprecated. Please use Data.set_data_io() instead.", + ) + container.set_dataio(dataio) def test_dataio_options(self): """ diff --git a/tests/unit/utils_test/test_docval.py b/tests/unit/utils_test/test_docval.py index bed5cd134..c766dcf46 100644 --- a/tests/unit/utils_test/test_docval.py +++ b/tests/unit/utils_test/test_docval.py @@ -1,7 +1,7 @@ import numpy as np from hdmf.testing import TestCase -from hdmf.utils import (docval, get_docval, getargs, popargs, AllowPositional, get_docval_macro, - docval_macro, popargs_to_dict) +from hdmf.utils import (docval, fmt_docval_args, get_docval, getargs, popargs, AllowPositional, get_docval_macro, + docval_macro, popargs_to_dict, call_docval_func) class MyTestClass(object): @@ -137,6 +137,80 @@ def method1(self, **kwargs): with self.assertRaises(ValueError): method1(self, arg1=[[1, 1, 1]]) + fmt_docval_warning_msg = ( + "fmt_docval_args will be deprecated in a future version of HDMF. Instead of using fmt_docval_args, " + "call the function directly with the kwargs. Please note that fmt_docval_args " + "removes all arguments not accepted by the function's docval, so if you are passing kwargs that " + "includes extra arguments and the function's docval does not allow extra arguments (allow_extra=True " + "is set), then you will need to pop the extra arguments out of kwargs before calling the function." + ) + + def test_fmt_docval_args(self): + """ Test that fmt_docval_args parses the args and strips extra args """ + test_kwargs = { + 'arg1': 'a string', + 'arg2': 1, + 'arg3': True, + 'hello': 'abc', + 'list': ['abc', 1, 2, 3] + } + with self.assertWarnsWith(PendingDeprecationWarning, self.fmt_docval_warning_msg): + rec_args, rec_kwargs = fmt_docval_args(self.test_obj.basic_add2_kw, test_kwargs) + exp_args = ['a string', 1] + self.assertListEqual(rec_args, exp_args) + exp_kwargs = {'arg3': True} + self.assertDictEqual(rec_kwargs, exp_kwargs) + + def test_fmt_docval_args_no_docval(self): + """ Test that fmt_docval_args raises an error when run on function without docval """ + def method1(self, **kwargs): + pass + + with self.assertRaisesRegex(ValueError, r"no docval found on .*method1.*"): + with self.assertWarnsWith(PendingDeprecationWarning, self.fmt_docval_warning_msg): + fmt_docval_args(method1, {}) + + def test_fmt_docval_args_allow_extra(self): + """ Test that fmt_docval_args works """ + test_kwargs = { + 'arg1': 'a string', + 'arg2': 1, + 'arg3': True, + 'hello': 'abc', + 'list': ['abc', 1, 2, 3] + } + with self.assertWarnsWith(PendingDeprecationWarning, self.fmt_docval_warning_msg): + rec_args, rec_kwargs = fmt_docval_args(self.test_obj.basic_add2_kw_allow_extra, test_kwargs) + exp_args = ['a string', 1] + self.assertListEqual(rec_args, exp_args) + exp_kwargs = {'arg3': True, 'hello': 'abc', 'list': ['abc', 1, 2, 3]} + self.assertDictEqual(rec_kwargs, exp_kwargs) + + def test_call_docval_func(self): + """Test that call_docval_func strips extra args and calls the function.""" + test_kwargs = { + 'arg1': 'a string', + 'arg2': 1, + 'arg3': True, + 'hello': 'abc', + 'list': ['abc', 1, 2, 3] + } + msg = ( + "call_docval_func will be deprecated in a future version of HDMF. Instead of using call_docval_func, " + "call the function directly with the kwargs. Please note that call_docval_func " + "removes all arguments not accepted by the function's docval, so if you are passing kwargs that " + "includes extra arguments and the function's docval does not allow extra arguments (allow_extra=True " + "is set), then you will need to pop the extra arguments out of kwargs before calling the function." + ) + with self.assertWarnsWith(PendingDeprecationWarning, msg): + ret_kwargs = call_docval_func(self.test_obj.basic_add2_kw, test_kwargs) + exp_kwargs = { + 'arg1': 'a string', + 'arg2': 1, + 'arg3': True + } + self.assertDictEqual(ret_kwargs, exp_kwargs) + def test_docval_add(self): """Test that docval works with a single positional argument