Skip to content

Commit

Permalink
Add new embedding notebook backed by sqlite database.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676929217
  • Loading branch information
Chirp Team authored and copybara-github committed Sep 20, 2024
1 parent 14e610f commit 3b72c67
Show file tree
Hide file tree
Showing 4 changed files with 340 additions and 8 deletions.
201 changes: 201 additions & 0 deletions chirp/projects/agile2/1_embed_audio_v2.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GTtVnkC-6_i7"
},
"outputs": [],
"source": [
"#@title Imports. { vertical-output: true }\n",
"\n",
"import dataclasses\n",
"import functools\n",
"import sqlite3\n",
"import os\n",
"from typing import Callable\n",
"import numpy as np\n",
"from concurrent import futures\n",
"import tqdm\n",
"import time\n",
"from scipy import stats\n",
"from matplotlib import pyplot as plt\n",
"from ml_collections import config_dict\n",
"from IPython.display import display\n",
"import ipywidgets as widgets\n",
"from chirp import audio_utils\n",
"from chirp.projects.agile2 import colab_utils\n",
"from chirp.projects.agile2 import embed\n",
"from chirp.projects.hoplite import interface"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4T4vILrO80iP"
},
"source": [
"## Embed"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c6zdGxl68vft"
},
"outputs": [],
"source": [
"#@title Configuration { vertical-output: true }\n",
"\n",
"# Configure the raw dataset location(s). The format is a mapping from a\n",
"# dataset_name to a (path, fileglob) pair. Note that the file globs are case\n",
"# sensitive. The dataset name can be anything you want.\n",
"#\n",
"# This structure allows you to move your data around without having to re-embed\n",
"# the dataset. The generated embedding database will be placed next to the\n",
"# audio files. This allows you to simply swap out the base path here if you ever\n",
"# move your dataset.\n",
"audio_globs = {\n",
" 'dataset_1':\n",
" ('/path/to/dataset/1', '*.WAV',),\n",
" 'dataset_2':\n",
" ('/path/to/dataset/2', '*/*.mp4',),\n",
"}\n",
"\n",
"# By default we only process one dataset at a time. Re-run this entire notebook\n",
"# once per dataset. The embeddings database will be located in the same\n",
"# directory as the raw audio\n",
"dataset_name = 'dataset_1' #@param\n",
"\n",
"if dataset_name not in audio_globs:\n",
" raise ValueError(f'Dataset {dataset_name} not found in audio_globs')\n",
"\n",
"globs_to_process = {dataset_name: audio_globs[dataset_name]}\n",
"\n",
"# You do not need to change this unless you want to maintain multiple distinct\n",
"# embedding databases.\n",
"db_path = None\n",
"configs = colab_utils.load_configs(globs_to_process, db_path)\n",
"configs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NN9Uyy1yqAWS"
},
"outputs": [],
"source": [
"#@title Initialize the DB { vertical-output: true }\n",
"global db\n",
"db = configs.db_config.load_db()\n",
"db.setup()\n",
"num_embeddings = db.count_embeddings()\n",
"\n",
"print('Initialized DB located at ', configs.db_config.db_config.db_path)\n",
"print('Existing DB contains datasets: ', db.get_dataset_names())\n",
"print('num embeddings: ', num_embeddings)\n",
"\n",
"def drop_and_reload_db(_) -\u003e interface.GraphSearchDBInterface:\n",
" os.unlink(configs.db_config.db_config.db_path)\n",
" print('\\n Deleted previous db at: ', configs.db_config.db_config.db_path)\n",
" db = configs.db_config.load_db()\n",
" db.setup()\n",
"\n",
"drop_existing_db = 'True' #@param['True', 'False']\n",
"\n",
"if num_embeddings \u003e 0 and drop_existing_db == 'True':\n",
" print(f'\\n\\nClick the button below to confirm you really want to drop the database at ')\n",
" print(f'{configs.db_config.db_config.db_path}\\n')\n",
" print(f'This will permanently delete all {num_embeddings} embeddings from the existing database.\\n')\n",
" print('If you do NOT want to delete this data, set `drop_existing_db` above to `False` and re-run this cell.\\n')\n",
"\n",
" button = widgets.Button(description=f'Delete database?')\n",
" button.on_click(drop_and_reload_db)\n",
" display(button)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MnGWbhc0LhiU"
},
"outputs": [],
"source": [
"#@title Run the embedding { vertical-output: true }\n",
"\n",
"# If the DB already exists, we need to make sure that the the current\n",
"# model_config is compatible with the model_config that was used previously.\n",
"colab_utils.validate_and_save_configs(configs, db)\n",
"\n",
"print(f'Datasets requested to embed: {[key for key in globs_to_process]}')\n",
"\n",
"# Avoid re-embedding datasets that are already present in the DB\n",
"# TODO(roblaber) Make this filtering more granular, ie, avoid re-embedding\n",
"# (dataset, filename) pairs\n",
"for dataset in db.get_dataset_names():\n",
" if dataset in globs_to_process:\n",
" globs_to_process.pop(dataset)\n",
" print(f'\\nDataset \\'{dataset}\\' already present in DB, not re-embedding')\n",
"\n",
"new_datasets = [key for key in globs_to_process]\n",
"\n",
"print(f'\\nNew datasets to embed: {new_datasets}')\n",
"print(f'\\nPreparing to embed {len(new_datasets)} datasets...\\n')\n",
"\n",
"worker = embed.EmbedWorker(\n",
" embed_config=configs.audio_sources_config,\n",
" db=db,\n",
" model_config=configs.model_config)\n",
"\n",
"worker.process_all()\n",
"\n",
"print('\\n\\nEmbedding complete, total embeddings: ', db.count_embeddings())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HvVuFw-somHe"
},
"outputs": [],
"source": [
"#@title Per dataset statistics { vertical-output: true }\n",
"\n",
"for dataset in db.get_dataset_names():\n",
" print(f'\\nDataset \\'{dataset}\\':')\n",
" print('\\tnum embeddings: ', db.get_embeddings_by_source(dataset, source_id=None).shape[0])"
]
}
],
"metadata": {
"colab": {
"last_runtime": {
"build_target": "",
"kind": "local"
},
"name": "v2_1_embed_unlabeled_audio.ipynb",
"private_outputs": true,
"provenance": [
{
"file_id": "1ePT3-fDB3kA3_T7trthFtu8xTJQWQBoQ",
"timestamp": 1723499538314
}
]
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
133 changes: 133 additions & 0 deletions chirp/projects/agile2/colab_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# coding=utf-8
# Copyright 2024 The Perch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions for sqlite-backed Agile modeling notebooks."""

import dataclasses

from chirp.projects.agile2 import embed
from chirp.projects.hoplite import db_loader
from chirp.projects.hoplite import interface
from etils import epath
from ml_collections import config_dict


def supported_models() -> dict[str, dict[str, int | embed.ModelConfig]]:
"""Returns the set of supported models and their corresponding configs."""
return {
'perch': {
'embedding_dim': 1280,
'model_config': embed.ModelConfig(
model_key='taxonomy_model_tf',
model_config=config_dict.ConfigDict({
'tfhub_version': 8,
'window_size_s': 5.0,
'hop_size_s': 5.0,
'model_path': '',
'sample_rate': 32000,
}),
),
},
}


@dataclasses.dataclass
class AgileConfigs:
"""Container for the various configs used in the Agile notebooks."""

# Config for the raw audio sources.
audio_sources_config: embed.EmbedConfig
# Database config for the embeddings database.
db_config: db_loader.DBConfig
# Config for the embedding model.
model_config: embed.ModelConfig

def as_config_dict(self) -> config_dict.ConfigDict:
"""Returns the configs as a ConfigDict."""
return config_dict.ConfigDict({
'audio_sources_config': self.audio_sources_config.to_config_dict(),
'db_config': self.db_config.to_config_dict(),
'model_config': self.model_config.to_config_dict(),
})


def validate_and_save_configs(
configs: AgileConfigs,
db: interface.GraphSearchDBInterface,
):
"""Validates that the model config is compatible with the DB."""

model_config = configs.model_config
db_metadata = db.get_metadata(None)
if 'model_config' in db_metadata:
if db_metadata['model_config'].model_key != model_config.model_key:
raise AssertionError(
'The configured embedding model does not match the embedding model'
' that is already in the DB. You either need to drop the database or'
" use the '%s' model confg."
% db_metadata['model_config'].model_key
)

db.insert_metadata('model_config', model_config.to_config_dict())
db.insert_metadata(
'embed_config', configs.audio_sources_config.to_config_dict()
)
db.commit()


def load_configs(
audio_globs: dict[str, tuple[str, str]],
db_path: str,
model_config_key: str = 'perch',
) -> AgileConfigs:
"""Load default configs for the notebook and return them as an AgileConfigs.
Args:
audio_globs: Mapping from dataset name to pairs of `(root directory, file
glob)`.
db_path: Location of the database. If None, the database will be created in
the same directory as the audio.
model_config_key: Name of the embedding model to use.
"""
if model_config_key not in supported_models():
# TODO(roblaber): Add support for other models.
raise ValueError(f'Unsupported model: {model_config_key}')

if db_path is None:
if len(audio_globs) > 1:
raise ValueError(
'db_path must be specified when embedding multiple datasets.'
)
# Put the DB in the same directory as the audio.
db_path = (
epath.Path(next(iter(audio_globs.values()))[0]) / 'hoplite_db.sqlite'
)

model_config = supported_models()[model_config_key]
db_config = config_dict.ConfigDict({
'db_path': db_path,
'embedding_dim': model_config['embedding_dim'],
})

audio_srcs_config = embed.EmbedConfig(
audio_globs=audio_globs,
min_audio_len_s=1.0,
)

return AgileConfigs(
audio_sources_config=audio_srcs_config,
db_config=db_loader.DBConfig('sqlite', db_config),
model_config=model_config['model_config'],
)
8 changes: 2 additions & 6 deletions chirp/projects/agile2/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,17 @@ def __init__(
self,
embed_config: EmbedConfig,
model_config: ModelConfig,
db_config: db_loader.DBConfig,
db: hoplite_interface.GraphSearchDBInterface,
embedding_model: zoo_interface.EmbeddingModel | None = None,
):
self.db = db
self.model_config = model_config
self.embed_config = embed_config
if embedding_model is None:
model_class = models.model_class_map()[model_config.model_key]
self.embedding_model = model_class.from_config(model_config.model_config)
else:
self.embedding_model = embedding_model
# TODO(tomdenton): Check that the DB's model config matches ours.
self.db = db_config.load_db()
self.db.setup()
self.db.insert_metadata('embed_config', embed_config.to_config_dict())
self.db.insert_metadata('model_config', model_config.to_config_dict())

def _log_error(self, source_id, exception, counter_name):
logging.warning(
Expand Down
6 changes: 4 additions & 2 deletions chirp/projects/agile2/tests/embed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,17 @@ def test_embed_worker(self):
model_config=placeholder_model_config,
)

db = db_config.load_db()

embed_worker = embed.EmbedWorker(
embed_config=embed_config,
model_config=model_config,
db_config=db_config,
db=db,
)
embed_worker.process_all()
# The hop size is 1.0s and each file is 6.0s, so we should get 6 embeddings
# per file. There are six files, so we should get 36 embeddings.
self.assertEqual(embed_worker.db.count_embeddings(), 36)
self.assertEqual(db.count_embeddings(), 36)


if __name__ == '__main__':
Expand Down

0 comments on commit 3b72c67

Please sign in to comment.