diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c72205cd..473be3100 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ ### Added - Added script to check Python version support for HDMF dependencies. @rly [#1230](https://github.com/hdmf-dev/hdmf/pull/1230) +### Fixed +- Fixed issue with `DynamicTable.add_column` not allowing subclasses of `DynamicTableRegion` or `EnumData`. @rly [#1091](https://github.com/hdmf-dev/hdmf/pull/1091) + ## HDMF 3.14.6 (December 20, 2024) ### Enhancements diff --git a/src/hdmf/common/io/table.py b/src/hdmf/common/io/table.py index 50395ba24..379553c07 100644 --- a/src/hdmf/common/io/table.py +++ b/src/hdmf/common/io/table.py @@ -78,12 +78,11 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i required=field_spec.required ) dtype = cls._get_type(field_spec, type_map) + column_conf['class'] = dtype if issubclass(dtype, DynamicTableRegion): # the spec does not know which table this DTR points to # the user must specify the table attribute on the DTR after it is generated column_conf['table'] = True - else: - column_conf['class'] = dtype index_counter = 0 index_name = attr_name diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index 84ac4da3b..2f6401672 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -521,7 +521,7 @@ def _init_class_columns(self): description=col['description'], index=col.get('index', False), table=col.get('table', False), - col_cls=col.get('class', VectorData), + col_cls=col.get('class'), # Pass through extra kwargs for add_column that subclasses may have added **{k: col[k] for k in col.keys() if k not in DynamicTable.__reserved_colspec_keys}) @@ -564,10 +564,13 @@ def _set_dtr_targets(self, target_tables: dict): if not column_conf.get('table', False): raise ValueError("Column '%s' must be a DynamicTableRegion to have a target table." % colname) - self.add_column(name=column_conf['name'], - description=column_conf['description'], - index=column_conf.get('index', False), - table=True) + self.add_column( + name=column_conf['name'], + description=column_conf['description'], + index=column_conf.get('index', False), + table=True, + col_cls=column_conf.get('class'), + ) if isinstance(self[colname], VectorIndex): col = self[colname].target else: @@ -681,7 +684,7 @@ def add_row(self, **kwargs): index=col.get('index', False), table=col.get('table', False), enum=col.get('enum', False), - col_cls=col.get('class', VectorData), + col_cls=col.get('class'), # Pass through extra keyword arguments for add_column that # subclasses may have added **{k: col[k] for k in col.keys() @@ -753,7 +756,7 @@ def __eq__(self, other): 'default': False}, {'name': 'enum', 'type': (bool, 'array_data'), 'default': False, 'doc': ('whether or not this column contains data from a fixed set of elements')}, - {'name': 'col_cls', 'type': type, 'default': VectorData, + {'name': 'col_cls', 'type': type, 'default': None, 'doc': ('class to use to represent the column data. If table=True, this field is ignored and a ' 'DynamicTableRegion object is used. If enum=True, this field is ignored and a EnumData ' 'object is used.')}, @@ -805,29 +808,39 @@ def add_column(self, **kwargs): # noqa: C901 % (name, self.__class__.__name__, spec_index)) warn(msg, stacklevel=3) - spec_col_cls = self.__uninit_cols[name].get('class', VectorData) - if col_cls != spec_col_cls: - msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered " - "col_cls argument. The predefined class spec will be ignored. " - "Please ensure the new column complies with the spec. " - "This will raise an error in a future version of HDMF." - % (name, self.__class__.__name__, spec_col_cls)) - warn(msg, stacklevel=2) - ckwargs = dict(kwargs) # Add table if it's been specified if table and enum: raise ValueError("column '%s' cannot be both a table region " "and come from an enumerable set of elements" % name) + # Update col_cls if table is specified if table is not False: - col_cls = DynamicTableRegion + if col_cls is None: + col_cls = DynamicTableRegion if isinstance(table, DynamicTable): ckwargs['table'] = table + # Update col_cls if enum is specified if enum is not False: - col_cls = EnumData + if col_cls is None: + col_cls = EnumData if isinstance(enum, (list, tuple, np.ndarray, VectorData)): ckwargs['elements'] = enum + # Update col_cls to the default VectorData if col_cls is None + if col_cls is None: + col_cls = VectorData + + if name in self.__uninit_cols: # column is a predefined optional column from the spec + # check the given values against the predefined optional column spec. if they do not match, raise a warning + # and ignore the given arguments. users should not be able to override these values + spec_col_cls = self.__uninit_cols[name].get('class') + if spec_col_cls is not None and col_cls != spec_col_cls: + msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered " + "col_cls argument. The predefined class spec will be ignored. " + "Please ensure the new column complies with the spec. " + "This will raise an error in a future version of HDMF." + % (name, self.__class__.__name__, spec_col_cls)) + warn(msg, stacklevel=2) # If the user provided a list of lists that needs to be indexed, then we now need to flatten the data # We can only create the index actual VectorIndex once we have the VectorData column so we compute @@ -873,7 +886,7 @@ def add_column(self, **kwargs): # noqa: C901 if col in self.__uninit_cols: self.__uninit_cols.pop(col) - if col_cls is EnumData: + if issubclass(col_cls, EnumData): columns.append(col.elements) col.elements.parent = self diff --git a/tests/unit/common/test_generate_table.py b/tests/unit/common/test_generate_table.py index 7f7d7da40..71c15aad0 100644 --- a/tests/unit/common/test_generate_table.py +++ b/tests/unit/common/test_generate_table.py @@ -16,6 +16,13 @@ class TestDynamicDynamicTable(TestCase): def setUp(self): + + self.dtr_spec = DatasetSpec( + data_type_def='CustomDTR', + data_type_inc='DynamicTableRegion', + doc='a test DynamicTableRegion column', # this is overridden where it is used + ) + self.dt_spec = GroupSpec( 'A test extension that contains a dynamic table', data_type_def='TestTable', @@ -99,7 +106,13 @@ def setUp(self): doc='a test column', dtype='float', quantity='?', - ) + ), + DatasetSpec( + data_type_inc='CustomDTR', + name='optional_custom_dtr_col', + doc='a test DynamicTableRegion column', + quantity='?' + ), ] ) @@ -107,6 +120,7 @@ def setUp(self): writer = YAMLSpecWriter(outdir='.') self.spec_catalog = SpecCatalog() + self.spec_catalog.register_spec(self.dtr_spec, 'test.yaml') self.spec_catalog.register_spec(self.dt_spec, 'test.yaml') self.spec_catalog.register_spec(self.dt_spec2, 'test.yaml') self.namespace = SpecNamespace( @@ -124,7 +138,7 @@ def setUp(self): self.test_dir = tempfile.mkdtemp() spec_fpath = os.path.join(self.test_dir, 'test.yaml') namespace_fpath = os.path.join(self.test_dir, 'test-namespace.yaml') - writer.write_spec(dict(groups=[self.dt_spec, self.dt_spec2]), spec_fpath) + writer.write_spec(dict(datasets=[self.dtr_spec], groups=[self.dt_spec, self.dt_spec2]), spec_fpath) writer.write_namespace(self.namespace, namespace_fpath) self.namespace_catalog = NamespaceCatalog() hdmf_typemap = get_type_map() @@ -133,6 +147,7 @@ def setUp(self): self.type_map.load_namespaces(namespace_fpath) self.manager = BuildManager(self.type_map) + self.CustomDTR = self.type_map.get_dt_container_cls('CustomDTR', CORE_NAMESPACE) self.TestTable = self.type_map.get_dt_container_cls('TestTable', CORE_NAMESPACE) self.TestDTRTable = self.type_map.get_dt_container_cls('TestDTRTable', CORE_NAMESPACE) @@ -228,6 +243,22 @@ def test_dynamic_table_region_non_dtr_target(self): self.TestDTRTable(name='test_dtr_table', description='my table', target_tables={'optional_col3': test_table}) + def test_custom_dtr_class(self): + test_table = self.TestTable(name='test_table', description='my test table') + test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5) + test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5) + + test_dtr_table = self.TestDTRTable(name='test_dtr_table', description='my table', + target_tables={'optional_custom_dtr_col': test_table}) + + self.assertIsInstance(test_dtr_table['optional_custom_dtr_col'], self.CustomDTR) + self.assertEqual(test_dtr_table['optional_custom_dtr_col'].description, "a test DynamicTableRegion column") + self.assertIs(test_dtr_table['optional_custom_dtr_col'].table, test_table) + + test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_custom_dtr_col=0) + test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_custom_dtr_col=1) + self.assertEqual(test_dtr_table['optional_custom_dtr_col'].data, [0, 1]) + def test_attribute(self): test_table = self.TestTable(name='test_table', description='my test table') assert test_table.my_col is not None @@ -266,3 +297,17 @@ def test_roundtrip(self): for err in errors: raise Exception(err) self.reader.close() + + def test_add_custom_dtr_column(self): + test_table = self.TestTable(name='test_table', description='my test table') + test_table.add_column( + name='custom_dtr_column', + description='this is a custom DynamicTableRegion column', + col_cls=self.CustomDTR, + ) + self.assertIsInstance(test_table['custom_dtr_column'], self.CustomDTR) + self.assertEqual(test_table['custom_dtr_column'].description, 'this is a custom DynamicTableRegion column') + + test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], custom_dtr_column=0) + test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], custom_dtr_column=1) + self.assertEqual(test_table['custom_dtr_column'].data, [0, 1])