Skip to content

Commit

Permalink
Merge pull request #20 from hdmf-dev/enh/vizemb
Browse files Browse the repository at this point in the history
Add column for storing visualization embeddings
  • Loading branch information
ajtritt authored Jan 7, 2025
2 parents b5923f7 + d2556d1 commit a9253f8
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 8 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ jobs:
fail-fast: false
matrix:
include:
- { name: linux-python3.7-minimum , requirements: minimum, python-ver: "3.7" , os: ubuntu-latest }
- { name: linux-python3.12 , requirements: pinned , python-ver: "3.12", os: ubuntu-latest }
- { name: windows-python3.7-minimum , requirements: minimum, python-ver: "3.7" , os: windows-latest }
- { name: windows-python3.12 , requirements: pinned , python-ver: "3.12", os: windows-latest }
- { name: macos-python3.7-minimum , requirements: minimum, python-ver: "3.7" , os: macos-13 }
- { name: macos-python3.12 , requirements: pinned , python-ver: "3.12", os: macos-latest }
- { name: linux-python3.10-minimum , requirements: minimum, python-ver: "3.10", os: ubuntu-latest }
- { name: linux-python3.13 , requirements: pinned , python-ver: "3.13", os: ubuntu-latest }
- { name: windows-python3.10-minimum , requirements: minimum, python-ver: "3.10", os: windows-latest }
- { name: windows-python3.13 , requirements: pinned , python-ver: "3.13", os: windows-latest }
- { name: macos-python3.10-minimum , requirements: minimum, python-ver: "3.10", os: macos-13 }
- { name: macos-python3.13 , requirements: pinned , python-ver: "3.13", os: macos-latest }
steps:
- name: Checkout repo
uses: actions/checkout@v4
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies = [
"numpy>=1.21",
"scikit-learn>=1",
]
version = "0.2.0"
version = "0.3.0"
# dynamic = ["version"]

[project.urls]
Expand Down Expand Up @@ -108,4 +108,4 @@ exclude = [
"example.py" = ["E501", "T201"]

[tool.ruff.lint.mccabe]
max-complexity = 17
max-complexity = 17
32 changes: 32 additions & 0 deletions src/hdmf_ai/results_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,35 @@ def add_embedding(self, **kwargs):
# `n_dims` kwarg is passed into `__add_col` and will be read as the length of the second dimension
# of the data only if the `data` kwarg is None.
return self.__add_col(EmbeddedValues, **kwargs)

@docval(
{
"name": "data",
"type": ("array_data", "data"),
"doc": "Embedding (float) of each sample.",
"default": None,
},
{
"name": "description",
"type": str,
"doc": "A description for this column.",
"default": "A column to store embeddings, e.g., from dimensionality reduction, for each sample.",
},
{
"name": "n_dims",
"type": int,
"doc": (
"The number of dimensions in the embedding, "
"used to define the shape of the column only if data is None"
),
"default": None,
},
)
def add_viz_embedding(self, **kwargs):
"""Add embedding (a.k.a. transformation or representation) of each sample."""
kwargs["name"] = "viz_embedding"
kwargs["dtype"] = float
kwargs["dim2_kwarg"] = "n_dims"
# `n_dims` kwarg is passed into `__add_col` and will be read as the length of the second dimension
# of the data only if the `data` kwarg is None.
return self.__add_col(EmbeddedValues, **kwargs)
4 changes: 4 additions & 0 deletions src/hdmf_ai/schema/results_table.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,8 @@ groups:
data_type_inc: EmbeddedValues
doc: A column to store embeddings, e.g., from dimensionality reduction, for each sample.
quantity: '?'
- name: viz_embedding
data_type_inc: EmbeddedValues
doc: A column to store embeddings meant for visualization.
quantity: '?'

6 changes: 6 additions & 0 deletions tests/test_results_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ def test_add_embedding(self):
with self.get_hdf5io() as io:
io.write(rt)

def test_add_viz_embedding(self):
rt = ResultsTable(name="foo", description="a test results table")
rt.add_viz_embedding([[1.1, 2.9], [1.2, 2.8], [1.3, 2.7], [1.4, 2.6], [1.5, 2.5]])
with self.get_hdf5io() as io:
io.write(rt)

def test_add_topk_classes(self):
rt = ResultsTable(name="foo", description="a test results table")
rt.add_topk_classes([[1, 2], [3, 4], [5, 6], [7, 8], [9, 0]])
Expand Down

0 comments on commit a9253f8

Please sign in to comment.