Skip to content

Commit

Permalink
Add column for storing visualization embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtritt committed Jan 7, 2025
1 parent b5923f7 commit 01094d9
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 0 deletions.
32 changes: 32 additions & 0 deletions src/hdmf_ai/results_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,35 @@ def add_embedding(self, **kwargs):
# `n_dims` kwarg is passed into `__add_col` and will be read as the length of the second dimension
# of the data only if the `data` kwarg is None.
return self.__add_col(EmbeddedValues, **kwargs)

@docval(
{
"name": "data",
"type": ("array_data", "data"),
"doc": "Embedding (float) of each sample.",
"default": None,
},
{
"name": "description",
"type": str,
"doc": "A description for this column.",
"default": "A column to store embeddings, e.g., from dimensionality reduction, for each sample.",
},
{
"name": "n_dims",
"type": int,
"doc": (
"The number of dimensions in the embedding, "
"used to define the shape of the column only if data is None"
),
"default": None,
},
)
def add_viz_embedding(self, **kwargs):
"""Add embedding (a.k.a. transformation or representation) of each sample."""
kwargs["name"] = "viz_embedding"
kwargs["dtype"] = float
kwargs["dim2_kwarg"] = "n_dims"
# `n_dims` kwarg is passed into `__add_col` and will be read as the length of the second dimension
# of the data only if the `data` kwarg is None.
return self.__add_col(EmbeddedValues, **kwargs)
4 changes: 4 additions & 0 deletions src/hdmf_ai/schema/results_table.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,8 @@ groups:
data_type_inc: EmbeddedValues
doc: A column to store embeddings, e.g., from dimensionality reduction, for each sample.
quantity: '?'
- name: viz_embedding
data_type_inc: EmbeddedValues
doc: A column to store embeddings meant for visualization.
quantity: '?'

6 changes: 6 additions & 0 deletions tests/test_results_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ def test_add_embedding(self):
with self.get_hdf5io() as io:
io.write(rt)

def test_add_viz_embedding(self):
rt = ResultsTable(name="foo", description="a test results table")
rt.add_viz_embedding([[1.1, 2.9], [1.2, 2.8], [1.3, 2.7], [1.4, 2.6], [1.5, 2.5]])
with self.get_hdf5io() as io:
io.write(rt)

def test_add_topk_classes(self):
rt = ResultsTable(name="foo", description="a test results table")
rt.add_topk_classes([[1, 2], [3, 4], [5, 6], [7, 8], [9, 0]])
Expand Down

0 comments on commit 01094d9

Please sign in to comment.