diff --git a/ovos_plugin_manager/templates/embeddings.py b/ovos_plugin_manager/templates/embeddings.py index d941f081..4014f32b 100644 --- a/ovos_plugin_manager/templates/embeddings.py +++ b/ovos_plugin_manager/templates/embeddings.py @@ -1,27 +1,35 @@ import abc -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict, Union, Iterable -import numpy as np +# Typing helpers for readability +try: + import numpy as np + EmbeddingsArray = np.ndarray +except ImportError: + EmbeddingsArray = Iterable[Union[int, float]] +EmbeddingsTuple = Union[Tuple[str, float], Tuple[str, float, Dict]] class EmbeddingsDB: - """Base plugin for embeddings database""" + """Base class for an embeddings database that supports storage, retrieval, and querying of embeddings.""" @abc.abstractmethod - def add_embeddings(self, key: str, embedding: np.ndarray) -> np.ndarray: + def add_embeddings(self, key: str, embedding: EmbeddingsArray, + metadata: Optional[Dict[str, any]] = None) -> EmbeddingsArray: """Store 'embedding' under 'key' with associated metadata. Args: key (str): The unique key for the embedding. embedding (np.ndarray): The embedding vector to store. + metadata (Optional[Dict[str, any]]): Optional metadata associated with the embedding. Returns: np.ndarray: The stored embedding. """ - return NotImplemented + raise NotImplementedError @abc.abstractmethod - def get_embeddings(self, key: str) -> np.ndarray: + def get_embeddings(self, key: str) -> EmbeddingsArray: """Retrieve embeddings stored under 'key'. Args: @@ -30,10 +38,10 @@ def get_embeddings(self, key: str) -> np.ndarray: Returns: np.ndarray: The retrieved embedding. """ - return NotImplemented + raise NotImplementedError @abc.abstractmethod - def delete_embeddings(self, key: str) -> np.ndarray: + def delete_embeddings(self, key: str) -> EmbeddingsArray: """Delete embeddings stored under 'key'. Args: @@ -42,27 +50,29 @@ def delete_embeddings(self, key: str) -> np.ndarray: Returns: np.ndarray: The deleted embedding. """ - return NotImplemented + raise NotImplementedError @abc.abstractmethod - def query(self, embeddings: np.ndarray, top_k: int = 5) -> List[Tuple[str, float]]: - """Return top_k embeddings closest to the given 'embeddings'. + def query(self, embeddings: EmbeddingsArray, top_k: int = 5, + return_metadata: bool = False) -> List[EmbeddingsTuple]: + """Return the top_k embeddings closest to the given 'embeddings'. Args: embeddings (np.ndarray): The embedding vector to query. top_k (int, optional): The number of top results to return. Defaults to 5. + return_metadata (bool, optional): Whether to include metadata in the results. Defaults to False. Returns: - List[Tuple[str, float]]: List of tuples containing the key and distance. + List[EmbeddingsTuple]: List of tuples containing the key and distance, and optionally metadata. """ - return NotImplemented + raise NotImplementedError - def distance(self, embeddings_a: np.ndarray, embeddings_b: np.ndarray, metric: str = "cosine", + def distance(self, embeddings_a: EmbeddingsArray, embeddings_b: EmbeddingsArray, metric: str = "cosine", alpha: float = 0.5, # for alpha_divergence and tversky metrics beta: float = 0.5, # for tversky metric p: float = 3, # for minkowski and weighted_minkowski metrics - euclidean_weights: Optional[np.ndarray] = None, # required for weighted_euclidean and weighted_minkowski metrics - covariance_matrix: Optional[np.ndarray] = None # required for mahalanobis distance with user-defined covariance + euclidean_weights: Optional[EmbeddingsArray] = None, # required for weighted_euclidean and weighted_minkowski metrics + covariance_matrix: Optional[EmbeddingsArray] = None # required for mahalanobis distance with user-defined covariance ) -> float: """ Calculate the distance between two embeddings vectors using the specified distance metric. @@ -306,7 +316,7 @@ def distance(self, embeddings_a: np.ndarray, embeddings_b: np.ndarray, metric: s class TextEmbeddingsStore: - """A store for text embeddings interfacing with the embeddings database""" + """A store for text embeddings interfacing with the embeddings database.""" def __init__(self, db: EmbeddingsDB): """Initialize the text embeddings store. @@ -317,7 +327,7 @@ def __init__(self, db: EmbeddingsDB): self.db = db @abc.abstractmethod - def get_text_embeddings(self, text: str) -> np.ndarray: + def get_text_embeddings(self, text: str) -> EmbeddingsArray: """Convert text to its corresponding embeddings. Args: @@ -326,16 +336,17 @@ def get_text_embeddings(self, text: str) -> np.ndarray: Returns: np.ndarray: The resulting embeddings. """ - return NotImplemented + raise NotImplementedError - def add_document(self, document: str) -> None: + def add_document(self, document: str, metadata: Optional[Dict[str, any]] = None) -> None: """Add a document and its embeddings to the database. Args: document (str): The document to add. + metadata (Optional[Dict[str, any]]): Optional metadata associated with the document. """ embeddings = self.get_text_embeddings(document) - self.db.add_embeddings(document, embeddings) + self.db.add_embeddings(document, embeddings, metadata) def delete_document(self, document: str) -> None: """Delete a document and its embeddings from the database. @@ -369,13 +380,13 @@ def distance(self, text_a: str, text_b: str, metric: str = "cosine") -> float: Returns: float: The calculated distance. """ - emb: np.ndarray = self.get_text_embeddings(text_a) - emb2: np.ndarray = self.get_text_embeddings(text_b) - return self.db.distance(emb, emb2, metric) + emb_a = self.get_text_embeddings(text_a) + emb_b = self.get_text_embeddings(text_b) + return self.db.distance(emb_a, emb_b, metric) class FaceEmbeddingsStore: - """A store for face embeddings interfacing with the embeddings database""" + """A store for face embeddings interfacing with the embeddings database.""" def __init__(self, db: EmbeddingsDB): """Initialize the face embeddings store. @@ -386,7 +397,7 @@ def __init__(self, db: EmbeddingsDB): self.db = db @abc.abstractmethod - def get_face_embeddings(self, frame: np.ndarray) -> np.ndarray: + def get_face_embeddings(self, frame: EmbeddingsArray) -> EmbeddingsArray: """Convert an image frame to its corresponding face embeddings. Args: @@ -395,22 +406,23 @@ def get_face_embeddings(self, frame: np.ndarray) -> np.ndarray: Returns: np.ndarray: The resulting face embeddings. """ - return NotImplemented + raise NotImplementedError - def add_face(self, user_id: str, frame: np.ndarray): + def add_face(self, user_id: str, frame: EmbeddingsArray, metadata: Optional[Dict[str, any]] = None) -> EmbeddingsArray: """Add a face and its embeddings to the database. Args: user_id (str): The unique user ID. frame (np.ndarray): The image frame containing the face. + metadata (Optional[Dict[str, any]]): Optional metadata associated with the face. Returns: np.ndarray: The stored face embeddings. """ - emb: np.ndarray = self.get_face_embeddings(frame) - return self.db.add_embeddings(user_id, emb) + embeddings = self.get_face_embeddings(frame) + return self.db.add_embeddings(user_id, embeddings, metadata) - def delete_face(self, user_id: str): + def delete_face(self, user_id: str) -> EmbeddingsArray: """Delete a face and its embeddings from the database. Args: @@ -421,7 +433,7 @@ def delete_face(self, user_id: str): """ return self.db.delete_embeddings(user_id) - def predict(self, frame: np.ndarray, top_k: int = 3, thresh: float = 0.15) -> Optional[str]: + def predict(self, frame: EmbeddingsArray, top_k: int = 3, thresh: float = 0.15) -> Optional[str]: """Return the top predicted face closest to the given frame. Args: @@ -435,12 +447,12 @@ def predict(self, frame: np.ndarray, top_k: int = 3, thresh: float = 0.15) -> Op matches = self.query(frame, top_k) if not matches: return None - best = min(matches, key=lambda k: k[1]) - if best[1] > thresh: + best_match = min(matches, key=lambda k: k[1]) + if best_match[1] > thresh: return None - return best[0] + return best_match[0] - def query(self, frame: np.ndarray, top_k: int = 5) -> List[Tuple[str, float]]: + def query(self, frame: EmbeddingsArray, top_k: int = 5) -> List[Tuple[str, float]]: """Query the database for the top_k closest face embeddings to the frame. Args: @@ -450,10 +462,10 @@ def query(self, frame: np.ndarray, top_k: int = 5) -> List[Tuple[str, float]]: Returns: List[Tuple[str, float]]: List of tuples containing the user ID and distance. """ - emb = self.get_face_embeddings(frame) - return self.db.query(emb, top_k) + embeddings = self.get_face_embeddings(frame) + return self.db.query(embeddings, top_k) - def distance(self, face_a: np.ndarray, face_b: np.ndarray, metric: str = "cosine") -> float: + def distance(self, face_a: EmbeddingsArray, face_b: EmbeddingsArray, metric: str = "cosine") -> float: """Calculate the distance between embeddings of two faces. Args: @@ -464,13 +476,13 @@ def distance(self, face_a: np.ndarray, face_b: np.ndarray, metric: str = "cosine Returns: float: The calculated distance. """ - emb: np.ndarray = self.get_face_embeddings(face_a) - emb2: np.ndarray = self.get_face_embeddings(face_b) - return self.db.distance(emb, emb2, metric) + emb_a = self.get_face_embeddings(face_a) + emb_b = self.get_face_embeddings(face_b) + return self.db.distance(emb_a, emb_b, metric) class VoiceEmbeddingsStore: - """A store for voice embeddings interfacing with the embeddings database""" + """A store for voice embeddings interfacing with the embeddings database.""" def __init__(self, db: EmbeddingsDB): """Initialize the voice embeddings store. @@ -481,7 +493,7 @@ def __init__(self, db: EmbeddingsDB): self.db = db @staticmethod - def audiochunk2array(audio_bytes: bytes) -> np.ndarray: + def audiochunk2array(audio_bytes: bytes) -> EmbeddingsArray: """Convert audio buffer to a normalized float32 NumPy array. Args: @@ -494,11 +506,10 @@ def audiochunk2array(audio_bytes: bytes) -> np.ndarray: audio_as_np_float32 = audio_as_np_int16.astype(np.float32) # Normalise float32 array so that values are between -1.0 and +1.0 max_int16 = 2 ** 15 - data = audio_as_np_float32 / max_int16 - return data + return audio_as_np_float32 / max_int16 @abc.abstractmethod - def get_voice_embeddings(self, audio_data: np.ndarray) -> np.ndarray: + def get_voice_embeddings(self, audio_data: EmbeddingsArray) -> EmbeddingsArray: """Convert audio data to its corresponding voice embeddings. Args: @@ -507,22 +518,23 @@ def get_voice_embeddings(self, audio_data: np.ndarray) -> np.ndarray: Returns: np.ndarray: The resulting voice embeddings. """ - return NotImplemented + raise NotImplementedError - def add_voice(self, user_id: str, audio_data: np.ndarray): + def add_voice(self, user_id: str, audio_data: EmbeddingsArray, metadata: Optional[Dict[str, any]] = None) -> EmbeddingsArray: """Add a voice and its embeddings to the database. Args: user_id (str): The unique user ID. audio_data (np.ndarray): The input audio data. + metadata (Optional[Dict[str, any]]): Optional metadata associated with the voice. Returns: np.ndarray: The stored voice embeddings. """ - emb: np.ndarray = self.get_voice_embeddings(audio_data) - return self.db.add_embeddings(user_id, emb) + embeddings = self.get_voice_embeddings(audio_data) + return self.db.add_embeddings(user_id, embeddings, metadata) - def delete_voice(self, user_id: str): + def delete_voice(self, user_id: str) -> EmbeddingsArray: """Delete a voice and its embeddings from the database. Args: @@ -533,7 +545,7 @@ def delete_voice(self, user_id: str): """ return self.db.delete_embeddings(user_id) - def predict(self, audio_data: np.ndarray, top_k: int = 3, thresh: float = 0.75) -> Optional[str]: + def predict(self, audio_data: EmbeddingsArray, top_k: int = 3, thresh: float = 0.75) -> Optional[str]: """Return the top predicted voice closest to the given audio_data. Args: @@ -545,12 +557,14 @@ def predict(self, audio_data: np.ndarray, top_k: int = 3, thresh: float = 0.75) Optional[str]: The predicted user ID or None if the best match exceeds the threshold. """ matches = self.query(audio_data, top_k) - best = min(matches, key=lambda k: k[1]) - if best[1] > thresh: + if not matches: return None - return best[0] + best_match = min(matches, key=lambda k: k[1]) + if best_match[1] > thresh: + return None + return best_match[0] - def query(self, audio_data: np.ndarray, top_k: int = 5) -> List[Tuple[str, float]]: + def query(self, audio_data: EmbeddingsArray, top_k: int = 5) -> List[Tuple[str, float]]: """Query the database for the top_k closest voice embeddings to the audio_data. Args: @@ -560,10 +574,10 @@ def query(self, audio_data: np.ndarray, top_k: int = 5) -> List[Tuple[str, float Returns: List[Tuple[str, float]]: List of tuples containing the user ID and distance. """ - emb = self.get_voice_embeddings(audio_data) - return self.db.query(emb, top_k) + embeddings = self.get_voice_embeddings(audio_data) + return self.db.query(embeddings, top_k) - def distance(self, voice_a: np.ndarray, voice_b: np.ndarray, metric: str = "cosine") -> float: + def distance(self, voice_a: EmbeddingsArray, voice_b: EmbeddingsArray, metric: str = "cosine") -> float: """Calculate the distance between embeddings of two voices. Args: @@ -574,6 +588,6 @@ def distance(self, voice_a: np.ndarray, voice_b: np.ndarray, metric: str = "cosi Returns: float: The calculated distance. """ - emb = self.get_voice_embeddings(voice_a) - emb2 = self.get_voice_embeddings(voice_b) - return self.db.distance(emb, emb2, metric) + emb_a = self.get_voice_embeddings(voice_a) + emb_b = self.get_voice_embeddings(voice_b) + return self.db.distance(emb_a, emb_b, metric)