diff --git a/src/hdmf_ai/results_table.py b/src/hdmf_ai/results_table.py index 042f742..72f829f 100644 --- a/src/hdmf_ai/results_table.py +++ b/src/hdmf_ai/results_table.py @@ -454,3 +454,35 @@ def add_embedding(self, **kwargs): # `n_dims` kwarg is passed into `__add_col` and will be read as the length of the second dimension # of the data only if the `data` kwarg is None. return self.__add_col(EmbeddedValues, **kwargs) + + @docval( + { + "name": "data", + "type": ("array_data", "data"), + "doc": "Embedding (float) of each sample.", + "default": None, + }, + { + "name": "description", + "type": str, + "doc": "A description for this column.", + "default": "A column to store embeddings, e.g., from dimensionality reduction, for each sample.", + }, + { + "name": "n_dims", + "type": int, + "doc": ( + "The number of dimensions in the embedding, " + "used to define the shape of the column only if data is None" + ), + "default": None, + }, + ) + def add_viz_embedding(self, **kwargs): + """Add embedding (a.k.a. transformation or representation) of each sample.""" + kwargs["name"] = "viz_embedding" + kwargs["dtype"] = float + kwargs["dim2_kwarg"] = "n_dims" + # `n_dims` kwarg is passed into `__add_col` and will be read as the length of the second dimension + # of the data only if the `data` kwarg is None. + return self.__add_col(EmbeddedValues, **kwargs) diff --git a/src/hdmf_ai/schema/results_table.yaml b/src/hdmf_ai/schema/results_table.yaml index 2e6b859..9c3d135 100644 --- a/src/hdmf_ai/schema/results_table.yaml +++ b/src/hdmf_ai/schema/results_table.yaml @@ -185,4 +185,8 @@ groups: data_type_inc: EmbeddedValues doc: A column to store embeddings, e.g., from dimensionality reduction, for each sample. quantity: '?' + - name: viz_embedding + data_type_inc: EmbeddedValues + doc: A column to store embeddings meant for visualization. + quantity: '?' diff --git a/tests/test_results_table.py b/tests/test_results_table.py index a843eae..ed99d12 100644 --- a/tests/test_results_table.py +++ b/tests/test_results_table.py @@ -122,6 +122,12 @@ def test_add_embedding(self): with self.get_hdf5io() as io: io.write(rt) + def test_add_viz_embedding(self): + rt = ResultsTable(name="foo", description="a test results table") + rt.add_viz_embedding([[1.1, 2.9], [1.2, 2.8], [1.3, 2.7], [1.4, 2.6], [1.5, 2.5]]) + with self.get_hdf5io() as io: + io.write(rt) + def test_add_topk_classes(self): rt = ResultsTable(name="foo", description="a test results table") rt.add_topk_classes([[1, 2], [3, 4], [5, 6], [7, 8], [9, 0]])