-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add new embedding notebook backed by sqlite database.
PiperOrigin-RevId: 676929217
- Loading branch information
1 parent
14e610f
commit 3b72c67
Showing
4 changed files
with
340 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "GTtVnkC-6_i7" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Imports. { vertical-output: true }\n", | ||
"\n", | ||
"import dataclasses\n", | ||
"import functools\n", | ||
"import sqlite3\n", | ||
"import os\n", | ||
"from typing import Callable\n", | ||
"import numpy as np\n", | ||
"from concurrent import futures\n", | ||
"import tqdm\n", | ||
"import time\n", | ||
"from scipy import stats\n", | ||
"from matplotlib import pyplot as plt\n", | ||
"from ml_collections import config_dict\n", | ||
"from IPython.display import display\n", | ||
"import ipywidgets as widgets\n", | ||
"from chirp import audio_utils\n", | ||
"from chirp.projects.agile2 import colab_utils\n", | ||
"from chirp.projects.agile2 import embed\n", | ||
"from chirp.projects.hoplite import interface" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "4T4vILrO80iP" | ||
}, | ||
"source": [ | ||
"## Embed" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "c6zdGxl68vft" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Configuration { vertical-output: true }\n", | ||
"\n", | ||
"# Configure the raw dataset location(s). The format is a mapping from a\n", | ||
"# dataset_name to a (path, fileglob) pair. Note that the file globs are case\n", | ||
"# sensitive. The dataset name can be anything you want.\n", | ||
"#\n", | ||
"# This structure allows you to move your data around without having to re-embed\n", | ||
"# the dataset. The generated embedding database will be placed next to the\n", | ||
"# audio files. This allows you to simply swap out the base path here if you ever\n", | ||
"# move your dataset.\n", | ||
"audio_globs = {\n", | ||
" 'dataset_1':\n", | ||
" ('/path/to/dataset/1', '*.WAV',),\n", | ||
" 'dataset_2':\n", | ||
" ('/path/to/dataset/2', '*/*.mp4',),\n", | ||
"}\n", | ||
"\n", | ||
"# By default we only process one dataset at a time. Re-run this entire notebook\n", | ||
"# once per dataset. The embeddings database will be located in the same\n", | ||
"# directory as the raw audio\n", | ||
"dataset_name = 'dataset_1' #@param\n", | ||
"\n", | ||
"if dataset_name not in audio_globs:\n", | ||
" raise ValueError(f'Dataset {dataset_name} not found in audio_globs')\n", | ||
"\n", | ||
"globs_to_process = {dataset_name: audio_globs[dataset_name]}\n", | ||
"\n", | ||
"# You do not need to change this unless you want to maintain multiple distinct\n", | ||
"# embedding databases.\n", | ||
"db_path = None\n", | ||
"configs = colab_utils.load_configs(globs_to_process, db_path)\n", | ||
"configs" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "NN9Uyy1yqAWS" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Initialize the DB { vertical-output: true }\n", | ||
"global db\n", | ||
"db = configs.db_config.load_db()\n", | ||
"db.setup()\n", | ||
"num_embeddings = db.count_embeddings()\n", | ||
"\n", | ||
"print('Initialized DB located at ', configs.db_config.db_config.db_path)\n", | ||
"print('Existing DB contains datasets: ', db.get_dataset_names())\n", | ||
"print('num embeddings: ', num_embeddings)\n", | ||
"\n", | ||
"def drop_and_reload_db(_) -\u003e interface.GraphSearchDBInterface:\n", | ||
" os.unlink(configs.db_config.db_config.db_path)\n", | ||
" print('\\n Deleted previous db at: ', configs.db_config.db_config.db_path)\n", | ||
" db = configs.db_config.load_db()\n", | ||
" db.setup()\n", | ||
"\n", | ||
"drop_existing_db = 'True' #@param['True', 'False']\n", | ||
"\n", | ||
"if num_embeddings \u003e 0 and drop_existing_db == 'True':\n", | ||
" print(f'\\n\\nClick the button below to confirm you really want to drop the database at ')\n", | ||
" print(f'{configs.db_config.db_config.db_path}\\n')\n", | ||
" print(f'This will permanently delete all {num_embeddings} embeddings from the existing database.\\n')\n", | ||
" print('If you do NOT want to delete this data, set `drop_existing_db` above to `False` and re-run this cell.\\n')\n", | ||
"\n", | ||
" button = widgets.Button(description=f'Delete database?')\n", | ||
" button.on_click(drop_and_reload_db)\n", | ||
" display(button)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "MnGWbhc0LhiU" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Run the embedding { vertical-output: true }\n", | ||
"\n", | ||
"# If the DB already exists, we need to make sure that the the current\n", | ||
"# model_config is compatible with the model_config that was used previously.\n", | ||
"colab_utils.validate_and_save_configs(configs, db)\n", | ||
"\n", | ||
"print(f'Datasets requested to embed: {[key for key in globs_to_process]}')\n", | ||
"\n", | ||
"# Avoid re-embedding datasets that are already present in the DB\n", | ||
"# TODO(roblaber) Make this filtering more granular, ie, avoid re-embedding\n", | ||
"# (dataset, filename) pairs\n", | ||
"for dataset in db.get_dataset_names():\n", | ||
" if dataset in globs_to_process:\n", | ||
" globs_to_process.pop(dataset)\n", | ||
" print(f'\\nDataset \\'{dataset}\\' already present in DB, not re-embedding')\n", | ||
"\n", | ||
"new_datasets = [key for key in globs_to_process]\n", | ||
"\n", | ||
"print(f'\\nNew datasets to embed: {new_datasets}')\n", | ||
"print(f'\\nPreparing to embed {len(new_datasets)} datasets...\\n')\n", | ||
"\n", | ||
"worker = embed.EmbedWorker(\n", | ||
" embed_config=configs.audio_sources_config,\n", | ||
" db=db,\n", | ||
" model_config=configs.model_config)\n", | ||
"\n", | ||
"worker.process_all()\n", | ||
"\n", | ||
"print('\\n\\nEmbedding complete, total embeddings: ', db.count_embeddings())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "HvVuFw-somHe" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Per dataset statistics { vertical-output: true }\n", | ||
"\n", | ||
"for dataset in db.get_dataset_names():\n", | ||
" print(f'\\nDataset \\'{dataset}\\':')\n", | ||
" print('\\tnum embeddings: ', db.get_embeddings_by_source(dataset, source_id=None).shape[0])" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"last_runtime": { | ||
"build_target": "", | ||
"kind": "local" | ||
}, | ||
"name": "v2_1_embed_unlabeled_audio.ipynb", | ||
"private_outputs": true, | ||
"provenance": [ | ||
{ | ||
"file_id": "1ePT3-fDB3kA3_T7trthFtu8xTJQWQBoQ", | ||
"timestamp": 1723499538314 | ||
} | ||
] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 The Perch Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Utility functions for sqlite-backed Agile modeling notebooks.""" | ||
|
||
import dataclasses | ||
|
||
from chirp.projects.agile2 import embed | ||
from chirp.projects.hoplite import db_loader | ||
from chirp.projects.hoplite import interface | ||
from etils import epath | ||
from ml_collections import config_dict | ||
|
||
|
||
def supported_models() -> dict[str, dict[str, int | embed.ModelConfig]]: | ||
"""Returns the set of supported models and their corresponding configs.""" | ||
return { | ||
'perch': { | ||
'embedding_dim': 1280, | ||
'model_config': embed.ModelConfig( | ||
model_key='taxonomy_model_tf', | ||
model_config=config_dict.ConfigDict({ | ||
'tfhub_version': 8, | ||
'window_size_s': 5.0, | ||
'hop_size_s': 5.0, | ||
'model_path': '', | ||
'sample_rate': 32000, | ||
}), | ||
), | ||
}, | ||
} | ||
|
||
|
||
@dataclasses.dataclass | ||
class AgileConfigs: | ||
"""Container for the various configs used in the Agile notebooks.""" | ||
|
||
# Config for the raw audio sources. | ||
audio_sources_config: embed.EmbedConfig | ||
# Database config for the embeddings database. | ||
db_config: db_loader.DBConfig | ||
# Config for the embedding model. | ||
model_config: embed.ModelConfig | ||
|
||
def as_config_dict(self) -> config_dict.ConfigDict: | ||
"""Returns the configs as a ConfigDict.""" | ||
return config_dict.ConfigDict({ | ||
'audio_sources_config': self.audio_sources_config.to_config_dict(), | ||
'db_config': self.db_config.to_config_dict(), | ||
'model_config': self.model_config.to_config_dict(), | ||
}) | ||
|
||
|
||
def validate_and_save_configs( | ||
configs: AgileConfigs, | ||
db: interface.GraphSearchDBInterface, | ||
): | ||
"""Validates that the model config is compatible with the DB.""" | ||
|
||
model_config = configs.model_config | ||
db_metadata = db.get_metadata(None) | ||
if 'model_config' in db_metadata: | ||
if db_metadata['model_config'].model_key != model_config.model_key: | ||
raise AssertionError( | ||
'The configured embedding model does not match the embedding model' | ||
' that is already in the DB. You either need to drop the database or' | ||
" use the '%s' model confg." | ||
% db_metadata['model_config'].model_key | ||
) | ||
|
||
db.insert_metadata('model_config', model_config.to_config_dict()) | ||
db.insert_metadata( | ||
'embed_config', configs.audio_sources_config.to_config_dict() | ||
) | ||
db.commit() | ||
|
||
|
||
def load_configs( | ||
audio_globs: dict[str, tuple[str, str]], | ||
db_path: str, | ||
model_config_key: str = 'perch', | ||
) -> AgileConfigs: | ||
"""Load default configs for the notebook and return them as an AgileConfigs. | ||
Args: | ||
audio_globs: Mapping from dataset name to pairs of `(root directory, file | ||
glob)`. | ||
db_path: Location of the database. If None, the database will be created in | ||
the same directory as the audio. | ||
model_config_key: Name of the embedding model to use. | ||
""" | ||
if model_config_key not in supported_models(): | ||
# TODO(roblaber): Add support for other models. | ||
raise ValueError(f'Unsupported model: {model_config_key}') | ||
|
||
if db_path is None: | ||
if len(audio_globs) > 1: | ||
raise ValueError( | ||
'db_path must be specified when embedding multiple datasets.' | ||
) | ||
# Put the DB in the same directory as the audio. | ||
db_path = ( | ||
epath.Path(next(iter(audio_globs.values()))[0]) / 'hoplite_db.sqlite' | ||
) | ||
|
||
model_config = supported_models()[model_config_key] | ||
db_config = config_dict.ConfigDict({ | ||
'db_path': db_path, | ||
'embedding_dim': model_config['embedding_dim'], | ||
}) | ||
|
||
audio_srcs_config = embed.EmbedConfig( | ||
audio_globs=audio_globs, | ||
min_audio_len_s=1.0, | ||
) | ||
|
||
return AgileConfigs( | ||
audio_sources_config=audio_srcs_config, | ||
db_config=db_loader.DBConfig('sqlite', db_config), | ||
model_config=model_config['model_config'], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters