Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/embeddings_metadata_support #245

Merged
merged 2 commits into from
Jul 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 77 additions & 63 deletions ovos_plugin_manager/templates/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
import abc
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Dict, Union, Iterable

import numpy as np
# Typing helpers for readability
try:
import numpy as np
EmbeddingsArray = np.ndarray
except ImportError:
EmbeddingsArray = Iterable[Union[int, float]]
EmbeddingsTuple = Union[Tuple[str, float], Tuple[str, float, Dict]]


class EmbeddingsDB:
"""Base plugin for embeddings database"""
"""Base class for an embeddings database that supports storage, retrieval, and querying of embeddings."""

@abc.abstractmethod
def add_embeddings(self, key: str, embedding: np.ndarray) -> np.ndarray:
def add_embeddings(self, key: str, embedding: EmbeddingsArray,
metadata: Optional[Dict[str, any]] = None) -> EmbeddingsArray:
"""Store 'embedding' under 'key' with associated metadata.

Args:
key (str): The unique key for the embedding.
embedding (np.ndarray): The embedding vector to store.
metadata (Optional[Dict[str, any]]): Optional metadata associated with the embedding.

Returns:
np.ndarray: The stored embedding.
"""
return NotImplemented
raise NotImplementedError

@abc.abstractmethod
def get_embeddings(self, key: str) -> np.ndarray:
def get_embeddings(self, key: str) -> EmbeddingsArray:
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
"""Retrieve embeddings stored under 'key'.

Args:
Expand All @@ -30,10 +38,10 @@ def get_embeddings(self, key: str) -> np.ndarray:
Returns:
np.ndarray: The retrieved embedding.
"""
return NotImplemented
raise NotImplementedError

@abc.abstractmethod
def delete_embeddings(self, key: str) -> np.ndarray:
def delete_embeddings(self, key: str) -> EmbeddingsArray:
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
"""Delete embeddings stored under 'key'.

Args:
Expand All @@ -42,27 +50,29 @@ def delete_embeddings(self, key: str) -> np.ndarray:
Returns:
np.ndarray: The deleted embedding.
"""
return NotImplemented
raise NotImplementedError

@abc.abstractmethod
def query(self, embeddings: np.ndarray, top_k: int = 5) -> List[Tuple[str, float]]:
"""Return top_k embeddings closest to the given 'embeddings'.
def query(self, embeddings: EmbeddingsArray, top_k: int = 5,
return_metadata: bool = False) -> List[EmbeddingsTuple]:
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
"""Return the top_k embeddings closest to the given 'embeddings'.

Args:
embeddings (np.ndarray): The embedding vector to query.
top_k (int, optional): The number of top results to return. Defaults to 5.
return_metadata (bool, optional): Whether to include metadata in the results. Defaults to False.

Returns:
List[Tuple[str, float]]: List of tuples containing the key and distance.
List[EmbeddingsTuple]: List of tuples containing the key and distance, and optionally metadata.
"""
return NotImplemented
raise NotImplementedError

def distance(self, embeddings_a: np.ndarray, embeddings_b: np.ndarray, metric: str = "cosine",
def distance(self, embeddings_a: EmbeddingsArray, embeddings_b: EmbeddingsArray, metric: str = "cosine",
alpha: float = 0.5, # for alpha_divergence and tversky metrics
beta: float = 0.5, # for tversky metric
p: float = 3, # for minkowski and weighted_minkowski metrics
euclidean_weights: Optional[np.ndarray] = None, # required for weighted_euclidean and weighted_minkowski metrics
covariance_matrix: Optional[np.ndarray] = None # required for mahalanobis distance with user-defined covariance
euclidean_weights: Optional[EmbeddingsArray] = None, # required for weighted_euclidean and weighted_minkowski metrics
covariance_matrix: Optional[EmbeddingsArray] = None # required for mahalanobis distance with user-defined covariance
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
) -> float:
"""
Calculate the distance between two embeddings vectors using the specified distance metric.
Expand Down Expand Up @@ -306,7 +316,7 @@ def distance(self, embeddings_a: np.ndarray, embeddings_b: np.ndarray, metric: s


class TextEmbeddingsStore:
"""A store for text embeddings interfacing with the embeddings database"""
"""A store for text embeddings interfacing with the embeddings database."""

def __init__(self, db: EmbeddingsDB):
"""Initialize the text embeddings store.
Expand All @@ -317,7 +327,7 @@ def __init__(self, db: EmbeddingsDB):
self.db = db

@abc.abstractmethod
def get_text_embeddings(self, text: str) -> np.ndarray:
def get_text_embeddings(self, text: str) -> EmbeddingsArray:
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
"""Convert text to its corresponding embeddings.

Args:
Expand All @@ -326,16 +336,17 @@ def get_text_embeddings(self, text: str) -> np.ndarray:
Returns:
np.ndarray: The resulting embeddings.
"""
return NotImplemented
raise NotImplementedError

def add_document(self, document: str) -> None:
def add_document(self, document: str, metadata: Optional[Dict[str, any]] = None) -> None:
"""Add a document and its embeddings to the database.

Args:
document (str): The document to add.
metadata (Optional[Dict[str, any]]): Optional metadata associated with the document.
"""
embeddings = self.get_text_embeddings(document)
self.db.add_embeddings(document, embeddings)
self.db.add_embeddings(document, embeddings, metadata)

def delete_document(self, document: str) -> None:
"""Delete a document and its embeddings from the database.
Expand Down Expand Up @@ -369,13 +380,13 @@ def distance(self, text_a: str, text_b: str, metric: str = "cosine") -> float:
Returns:
float: The calculated distance.
"""
emb: np.ndarray = self.get_text_embeddings(text_a)
emb2: np.ndarray = self.get_text_embeddings(text_b)
return self.db.distance(emb, emb2, metric)
emb_a = self.get_text_embeddings(text_a)
emb_b = self.get_text_embeddings(text_b)
return self.db.distance(emb_a, emb_b, metric)


class FaceEmbeddingsStore:
"""A store for face embeddings interfacing with the embeddings database"""
"""A store for face embeddings interfacing with the embeddings database."""

def __init__(self, db: EmbeddingsDB):
"""Initialize the face embeddings store.
Expand All @@ -386,7 +397,7 @@ def __init__(self, db: EmbeddingsDB):
self.db = db

@abc.abstractmethod
def get_face_embeddings(self, frame: np.ndarray) -> np.ndarray:
def get_face_embeddings(self, frame: EmbeddingsArray) -> EmbeddingsArray:
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
"""Convert an image frame to its corresponding face embeddings.

Args:
Expand All @@ -395,22 +406,23 @@ def get_face_embeddings(self, frame: np.ndarray) -> np.ndarray:
Returns:
np.ndarray: The resulting face embeddings.
"""
return NotImplemented
raise NotImplementedError

def add_face(self, user_id: str, frame: np.ndarray):
def add_face(self, user_id: str, frame: EmbeddingsArray, metadata: Optional[Dict[str, any]] = None) -> EmbeddingsArray:
"""Add a face and its embeddings to the database.

Args:
user_id (str): The unique user ID.
frame (np.ndarray): The image frame containing the face.
metadata (Optional[Dict[str, any]]): Optional metadata associated with the face.

Returns:
np.ndarray: The stored face embeddings.
"""
emb: np.ndarray = self.get_face_embeddings(frame)
return self.db.add_embeddings(user_id, emb)
embeddings = self.get_face_embeddings(frame)
return self.db.add_embeddings(user_id, embeddings, metadata)

def delete_face(self, user_id: str):
def delete_face(self, user_id: str) -> EmbeddingsArray:
"""Delete a face and its embeddings from the database.

Args:
Expand All @@ -421,7 +433,7 @@ def delete_face(self, user_id: str):
"""
return self.db.delete_embeddings(user_id)

def predict(self, frame: np.ndarray, top_k: int = 3, thresh: float = 0.15) -> Optional[str]:
def predict(self, frame: EmbeddingsArray, top_k: int = 3, thresh: float = 0.15) -> Optional[str]:
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
"""Return the top predicted face closest to the given frame.

Args:
Expand All @@ -435,12 +447,12 @@ def predict(self, frame: np.ndarray, top_k: int = 3, thresh: float = 0.15) -> Op
matches = self.query(frame, top_k)
if not matches:
return None
best = min(matches, key=lambda k: k[1])
if best[1] > thresh:
best_match = min(matches, key=lambda k: k[1])
if best_match[1] > thresh:
return None
return best[0]
return best_match[0]

def query(self, frame: np.ndarray, top_k: int = 5) -> List[Tuple[str, float]]:
def query(self, frame: EmbeddingsArray, top_k: int = 5) -> List[Tuple[str, float]]:
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
"""Query the database for the top_k closest face embeddings to the frame.

Args:
Expand All @@ -450,10 +462,10 @@ def query(self, frame: np.ndarray, top_k: int = 5) -> List[Tuple[str, float]]:
Returns:
List[Tuple[str, float]]: List of tuples containing the user ID and distance.
"""
emb = self.get_face_embeddings(frame)
return self.db.query(emb, top_k)
embeddings = self.get_face_embeddings(frame)
return self.db.query(embeddings, top_k)

def distance(self, face_a: np.ndarray, face_b: np.ndarray, metric: str = "cosine") -> float:
def distance(self, face_a: EmbeddingsArray, face_b: EmbeddingsArray, metric: str = "cosine") -> float:
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
"""Calculate the distance between embeddings of two faces.

Args:
Expand All @@ -464,13 +476,13 @@ def distance(self, face_a: np.ndarray, face_b: np.ndarray, metric: str = "cosine
Returns:
float: The calculated distance.
"""
emb: np.ndarray = self.get_face_embeddings(face_a)
emb2: np.ndarray = self.get_face_embeddings(face_b)
return self.db.distance(emb, emb2, metric)
emb_a = self.get_face_embeddings(face_a)
emb_b = self.get_face_embeddings(face_b)
return self.db.distance(emb_a, emb_b, metric)


class VoiceEmbeddingsStore:
"""A store for voice embeddings interfacing with the embeddings database"""
"""A store for voice embeddings interfacing with the embeddings database."""

def __init__(self, db: EmbeddingsDB):
"""Initialize the voice embeddings store.
Expand All @@ -481,7 +493,7 @@ def __init__(self, db: EmbeddingsDB):
self.db = db

@staticmethod
def audiochunk2array(audio_bytes: bytes) -> np.ndarray:
def audiochunk2array(audio_bytes: bytes) -> EmbeddingsArray:
"""Convert audio buffer to a normalized float32 NumPy array.

Args:
Expand All @@ -494,11 +506,10 @@ def audiochunk2array(audio_bytes: bytes) -> np.ndarray:
audio_as_np_float32 = audio_as_np_int16.astype(np.float32)
# Normalise float32 array so that values are between -1.0 and +1.0
max_int16 = 2 ** 15
data = audio_as_np_float32 / max_int16
return data
return audio_as_np_float32 / max_int16

@abc.abstractmethod
def get_voice_embeddings(self, audio_data: np.ndarray) -> np.ndarray:
def get_voice_embeddings(self, audio_data: EmbeddingsArray) -> EmbeddingsArray:
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
"""Convert audio data to its corresponding voice embeddings.

Args:
Expand All @@ -507,22 +518,23 @@ def get_voice_embeddings(self, audio_data: np.ndarray) -> np.ndarray:
Returns:
np.ndarray: The resulting voice embeddings.
"""
return NotImplemented
raise NotImplementedError

def add_voice(self, user_id: str, audio_data: np.ndarray):
def add_voice(self, user_id: str, audio_data: EmbeddingsArray, metadata: Optional[Dict[str, any]] = None) -> EmbeddingsArray:
"""Add a voice and its embeddings to the database.

Args:
user_id (str): The unique user ID.
audio_data (np.ndarray): The input audio data.
metadata (Optional[Dict[str, any]]): Optional metadata associated with the voice.

Returns:
np.ndarray: The stored voice embeddings.
"""
emb: np.ndarray = self.get_voice_embeddings(audio_data)
return self.db.add_embeddings(user_id, emb)
embeddings = self.get_voice_embeddings(audio_data)
return self.db.add_embeddings(user_id, embeddings, metadata)

def delete_voice(self, user_id: str):
def delete_voice(self, user_id: str) -> EmbeddingsArray:
"""Delete a voice and its embeddings from the database.

Args:
Expand All @@ -533,7 +545,7 @@ def delete_voice(self, user_id: str):
"""
return self.db.delete_embeddings(user_id)

def predict(self, audio_data: np.ndarray, top_k: int = 3, thresh: float = 0.75) -> Optional[str]:
def predict(self, audio_data: EmbeddingsArray, top_k: int = 3, thresh: float = 0.75) -> Optional[str]:
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
"""Return the top predicted voice closest to the given audio_data.

Args:
Expand All @@ -545,12 +557,14 @@ def predict(self, audio_data: np.ndarray, top_k: int = 3, thresh: float = 0.75)
Optional[str]: The predicted user ID or None if the best match exceeds the threshold.
"""
matches = self.query(audio_data, top_k)
best = min(matches, key=lambda k: k[1])
if best[1] > thresh:
if not matches:
return None
return best[0]
best_match = min(matches, key=lambda k: k[1])
if best_match[1] > thresh:
return None
return best_match[0]

def query(self, audio_data: np.ndarray, top_k: int = 5) -> List[Tuple[str, float]]:
def query(self, audio_data: EmbeddingsArray, top_k: int = 5) -> List[Tuple[str, float]]:
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
"""Query the database for the top_k closest voice embeddings to the audio_data.

Args:
Expand All @@ -560,10 +574,10 @@ def query(self, audio_data: np.ndarray, top_k: int = 5) -> List[Tuple[str, float
Returns:
List[Tuple[str, float]]: List of tuples containing the user ID and distance.
"""
emb = self.get_voice_embeddings(audio_data)
return self.db.query(emb, top_k)
embeddings = self.get_voice_embeddings(audio_data)
return self.db.query(embeddings, top_k)

def distance(self, voice_a: np.ndarray, voice_b: np.ndarray, metric: str = "cosine") -> float:
def distance(self, voice_a: EmbeddingsArray, voice_b: EmbeddingsArray, metric: str = "cosine") -> float:
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
"""Calculate the distance between embeddings of two voices.

Args:
Expand All @@ -574,6 +588,6 @@ def distance(self, voice_a: np.ndarray, voice_b: np.ndarray, metric: str = "cosi
Returns:
float: The calculated distance.
"""
emb = self.get_voice_embeddings(voice_a)
emb2 = self.get_voice_embeddings(voice_b)
return self.db.distance(emb, emb2, metric)
emb_a = self.get_voice_embeddings(voice_a)
emb_b = self.get_voice_embeddings(voice_b)
return self.db.distance(emb_a, emb_b, metric)
Loading