diff --git a/geniml/_version.py b/geniml/_version.py index f6b7e26..3d18726 100644 --- a/geniml/_version.py +++ b/geniml/_version.py @@ -1 +1 @@ -__version__ = "0.4.3" +__version__ = "0.5.0" diff --git a/geniml/bbclient/bbclient.py b/geniml/bbclient/bbclient.py index 136a742..a08241a 100644 --- a/geniml/bbclient/bbclient.py +++ b/geniml/bbclient/bbclient.py @@ -1,6 +1,7 @@ import gzip import os import shutil +from contextlib import suppress from logging import getLogger from typing import List, NoReturn, Union @@ -11,7 +12,6 @@ from botocore.exceptions import ClientError from pybiocfilecache import BiocFileCache from pybiocfilecache._exceptions import RnameExistsError -from contextlib import suppress from ubiquerg import is_url from zarr import Array from zarr.errors import PathNotFoundError diff --git a/geniml/search/backends/__init__.py b/geniml/search/backends/__init__.py index 3c2987d..41343e6 100644 --- a/geniml/search/backends/__init__.py +++ b/geniml/search/backends/__init__.py @@ -1,2 +1,3 @@ +from .bivecbackend import BiVectorBackend from .dbbackend import QdrantBackend from .filebackend import HNSWBackend diff --git a/geniml/search/backends/bivecbackend.py b/geniml/search/backends/bivecbackend.py new file mode 100644 index 0000000..cbe84a7 --- /dev/null +++ b/geniml/search/backends/bivecbackend.py @@ -0,0 +1,255 @@ +import logging +import math +from typing import Dict, List, Union + +import numpy as np + +from ...const import PKG_NAME +from .abstract import EmSearchBackend + +_LOGGER = logging.getLogger(PKG_NAME) + + +def batch_bed_vectors(matching_beds: List[Dict]) -> np.ndarray: + """ + Stack the embedding vector of bed files related to a metadata tag together for batch search + + :param matching_beds: result of BED retrieval from Qdrant Client by ids + """ + bed_vectors = [] + for bed in matching_beds: + try: + bed_vec = bed["vector"] + bed_vectors.append(bed_vec) + except KeyError: + _LOGGER.warning(f"Retrieved result missing vector: {bed}") + continue + except TypeError: + _LOGGER.warning( + f"Please check the data loading; retrieved result is not a dictionary: {bed}" + ) + continue + return np.array(bed_vectors) + + +class BiVectorBackend: + """ + Search backend that connects the embeddings of metadata tags and bed files + """ + + def __init__( + self, + metadata_backend: EmSearchBackend, + bed_backend: EmSearchBackend, + metadata_payload_matches: str = "matched_files", + ): + """ + :param metadata_backend: search backend where embedding vectors of metadata tags are stored + :param bed_backend: search backend where embedding vectors of BED files are stored + :param metadata_payload_matches: the key in metadata backend payloads to files matching to that metadata tag + """ + self.metadata_backend = metadata_backend + self.bed_backend = bed_backend + self.metadata_payload_matches = metadata_payload_matches + + def search( + self, + query: np.ndarray, + limit: int, + with_payload: bool = True, + with_vectors: bool = True, + offset: int = 0, + p: float = 1.0, + q: float = 1.0, + distance: bool = False, + rank: bool = False, + ) -> List[Dict[str, Union[int, float, Dict[str, str], List[float]]]]: + """ + :param query: query vector (embedding vector of query term) + :param limit: number of nearest neighbors to search for query vector + :param with_payload: whether payload is included in the result + :param with_vectors: whether the stored vector is included in the result + :param offset: the offset of the search results + :param p: weights to the score of metadata search, recommend 0 < p <= 1.0 + :param q: weights to the score of BED search, recommend 0 < q <= 1.0 + :param distance: whether the score is distance or similarity + :param rank: whether the result is ranked based on rank or score + :return: the search result(a list of dictionaries, + each dictionary include: storage id, vector payload (optional), vector (optional)) + """ + + # the key for the score in result: distance or score (cosine similarity) + self.score_key = "distance" if distance else "score" + + # metadata search + metadata_results = self.metadata_backend.search( + query, + limit=int(math.log(limit) * 5) if limit > 10 else 10, + with_payload=True, + offset=offset, + ) + + if not isinstance(metadata_results, list): + metadata_results = [metadata_results] + + if rank: + return self._rank_search(metadata_results, limit, with_payload, with_vectors, offset) + else: + return self._score_search( + metadata_results, limit, with_payload, with_vectors, offset, p, q + ) + + def _rank_search( + self, + metadata_results: List[Dict], + limit: int, + with_payload: bool = True, + with_vectors: bool = True, + offset: int = 0, + ) -> List[Dict[str, Union[int, float, Dict[str, str], List[float]]]]: + """ + Search based on maximum rank in results of metadata embedding and results of BED embedding + + :param metadata_results: result of metadata search + :param limit: see docstring of def search + :param with_payload: + :param with_vectors: + :param offset: + :return: the search result ranked based on maximum rank + """ + max_rank = [] + bed_results = [] + + for i in range(len(metadata_results)): + result = metadata_results[i] + + # all bed files matching the retrieved metadata tag + bed_ids = result["payload"][self.metadata_payload_matches] + matching_beds = self.bed_backend.retrieve_info(bed_ids, with_vectors=True) + + # use each single bed file as the query in the bed embedding backend + bed_vecs = batch_bed_vectors(matching_beds) + if len(bed_vecs) == 0: + continue + + retrieved_batch = self.bed_backend.search( + bed_vecs, + limit=limit * 2 if limit < 500 else 500, + with_payload=with_payload, + with_vectors=with_vectors, + offset=offset, + ) + + for retrieved_bed in retrieved_batch: + for j in range(len(retrieved_bed)): + retrieval = retrieved_bed[j] + bed_results.append(retrieval) + # collect maximum rank + max_rank.append(max(i, j)) + + return self._top_k(max_rank, bed_results, limit, True) + + def _score_search( + self, + metadata_results: List[Dict], + limit: int, + with_payload: bool = True, + with_vectors: bool = True, + offset: int = 0, + p: float = 1.0, + q: float = 1.0, + ) -> List[Dict[str, Union[int, float, Dict[str, str], List[float]]]]: + """ + Search based on weighted score from results of metadata embedding and results of BED embedding + + :param metadata_results: result of metadata search + :param limit: see docstring of def search + :param with_payload: + :param with_vectors: + :param offset: + :param p: + :param q: + :return: the search result ranked based on weighted similarity scores + """ + overall_scores = [] + bed_results = [] + for result in metadata_results: + # similarity score between query term and metadat tag + text_score = ( + 1 - result[self.score_key] + if self.score_key == "distance" + else result[self.score_key] + ) + bed_ids = result["payload"][self.metadata_payload_matches] + matching_beds = self.bed_backend.retrieve_info(bed_ids, with_vectors=True) + bed_vecs = batch_bed_vectors(matching_beds) + + if len(bed_vecs) == 0: + continue + + retrieved_batch = self.bed_backend.search( + bed_vecs, + limit=limit, + with_payload=with_payload, + with_vectors=with_vectors, + offset=offset, + ) + + for retrieved_bed in retrieved_batch: + for retrieval in retrieved_bed: + # calculate weighted score + bed_score = ( + 1 - result[self.score_key] + if self.score_key == "distance" + else result[self.score_key] + ) + bed_results.append(retrieval) + overall_scores.append(p * text_score + q * bed_score) + + return self._top_k(overall_scores, bed_results, limit, False) + + def _top_k( + self, + scales: List[Union[int, float]], + results: List[Dict[str, Union[int, float, Dict[str, str], List[float]]]], + k: int, + rank: bool = True, + ): + """ + Sort top k result and remove repetition + + :param scales: list of weighted scores or maximum rank + :param results: retrieval result + :param k: number of result to return + :param rank: whether the scale is maximum rank or not + :return: the top k selected result after rank + """ + paired_score_results = list(zip(scales, results)) + + # sort result + if not rank: + paired_score_results.sort(reverse=True, key=lambda x: x[0]) + else: + paired_score_results.sort(key=lambda x: x[0]) + + unique_result = {} + for scale, result in paired_score_results: + store_id = result["id"] + # filter out overlap + if store_id not in unique_result: + # add rank or score into the result + if not rank: + if self.score_key == "distance": + del result[self.score_key] + result["score"] = scale + else: + try: + del result["score"] + except KeyError: + del result["distance"] + + result["max_rank"] = scale + unique_result[store_id] = result + + top_k_results = list(unique_result.values())[:k] + return top_k_results diff --git a/geniml/search/backends/dbbackend.py b/geniml/search/backends/dbbackend.py index 2e0ffea..5925919 100644 --- a/geniml/search/backends/dbbackend.py +++ b/geniml/search/backends/dbbackend.py @@ -4,12 +4,14 @@ import numpy as np from qdrant_client import QdrantClient -from qdrant_client.models import PointStruct, VectorParams +from qdrant_client.http.models import SearchRequest +from qdrant_client.models import Distance, PointStruct, VectorParams from geniml.const import PKG_NAME from geniml.search.const import ( DEFAULT_COLLECTION_NAME, - DEFAULT_QDRANT_CONFIG, + DEFAULT_DIM, + DEFAULT_QDRANT_DIST, DEFAULT_QDRANT_HOST, DEFAULT_QDRANT_PORT, DEFAULT_QUANTIZATION_CONFIG, @@ -21,12 +23,68 @@ _LOGGER = logging.getLogger(PKG_NAME) +def queries_to_requests( + queries: np.ndarray, + limit: int, + with_payload: bool = True, + with_vectors: bool = True, + offset: int = 0, +) -> List[SearchRequest]: + """ + Prepare all search requests for each query vector in a batch + + :param queries: see docstring of QdrantBackend.batch_search + :param limit: + :param with_payload: + :param with_vectors: + :param offset: + """ + requests = [] + for query in queries: + if query.ndim > 1: + # that each request is from one single query vector + requests.extend(queries_to_requests(query, limit, with_payload, with_vectors, offset)) + else: + requests.append( + SearchRequest( + vector=query, + limit=limit, + with_vector=with_vectors, + with_payload=with_payload, + offset=offset, + ) + ) + return requests + + +def results_processing(search_results, with_payload: bool, with_vectors: bool) -> List[Dict]: + """ + Process the search result into unified format: list of dictionaries + + :param search_results: result of qdrant client similarity search + :type search_results: search result of qdrant client + :param with_payload: see docstring of QdrantBackend.search + :param with_vectors: + """ + output_list = [] + for result in search_results: + # build each dictionary + result_dict = {"id": result.id, "score": result.score} + if with_payload: + result_dict["payload"] = result.payload + if with_vectors: + result_dict["vector"] = result.vector + output_list.append(result_dict) + return output_list + + class QdrantBackend(EmSearchBackend): """A search backend that uses a qdrant server to store and search embeddings""" def __init__( self, - config: VectorParams = DEFAULT_QDRANT_CONFIG, + dim: int = DEFAULT_DIM, + dist: Distance = DEFAULT_QDRANT_DIST, collection: str = DEFAULT_COLLECTION_NAME, qdrant_host: str = DEFAULT_QDRANT_HOST, qdrant_port: int = DEFAULT_QDRANT_PORT, @@ -45,7 +103,7 @@ def __init__( """ super().__init__() self.collection = collection - self.config = config + self.config = VectorParams(size=dim, distance=dist) self.url = os.environ.get("QDRANT_HOST", qdrant_host) self.port = os.environ.get("QDRANT_PORT", qdrant_port) self.qd_client = QdrantClient( @@ -105,7 +163,10 @@ def search( with_payload: bool = True, with_vectors: bool = True, offset: int = 0, - ) -> List[Dict[str, Union[int, float, Dict[str, str], List[float]]]]: + ) -> Union[ + List[Dict[str, Union[int, float, Dict[str, str], List[float]]]], + List[List[Dict[str, Union[int, float, Dict[str, str], List[float]]]]], + ]: """ with a given query vector, get k nearest neighbors from vectors in the collection @@ -124,6 +185,8 @@ def search( "vector": [] } """ + if query.ndim > 1: + return self.batch_search(query, limit, with_payload, with_vectors, offset) # KNN search in qdrant client search_results = self.qd_client.search( collection_name=self.collection, @@ -135,15 +198,38 @@ def search( ) # add the results in to the output list + return results_processing(search_results, with_payload, with_vectors) + + def batch_search( + self, + queries: np.ndarray, + limit: int, + with_payload: bool = True, + with_vectors: bool = True, + offset: int = 0, + ) -> List[List[Dict[str, Union[int, float, Dict[str, str], List[float]]]]]: + """ + + :param queries: multiple search vectors, np.ndarray with shape of (n, dim) + :param limit: see docstring of def search + :type limit: + :param with_payload: + :param with_vectors: + :param offset: + :return: results of all search requests with each vector in queries + """ output_list = [] - for result in search_results: - # build each dictionary - result_dict = {"id": result.id, "score": result.score} - if with_payload: - result_dict["payload"] = result.payload - if with_vectors: - result_dict["vector"] = result.vector - output_list.append(result_dict) + # build all search requests + requests = queries_to_requests(queries, limit, with_payload, with_vectors, offset) + + search_results = self.qd_client.search_batch( + collection_name=self.collection, requests=requests + ) + + # add the results in to the output list + for batch in search_results: + batch_list = results_processing(batch, with_payload, with_vectors) + output_list.append(batch_list) return output_list def __len__(self) -> int: @@ -152,36 +238,51 @@ def __len__(self) -> int: """ return self.qd_client.get_collection(collection_name=self.collection).vectors_count - def retrieve_info(self, ids: Union[List[int], int], with_vec: bool = False) -> Union[ - Dict[str, Union[int, List[float], Dict[str, str]]], - List[Dict[str, Union[int, List[float], Dict[str, str]]]], + def retrieve_info( + self, ids: Union[List[int], int, List[str], str], with_vectors: bool = False + ) -> Union[ + Dict[str, Union[int, str, List[float], Dict[str, str]]], + List[Dict[str, Union[int, str, List[float], Dict[str, str]]]], ]: """ With a given list of storage ids, return the information of these vectors :param ids: list of ids, or a single id - :param with_vec: whether the vectors themselves will also be returned in the output + :param with_vectors: whether the vectors themselves will also be returned in the output :return: if ids is one id, a dictionary similar to the output of search() will be returned, without "score"; if ids is a list, a list of dictionaries will be returned """ if not isinstance(ids, list): # retrieve() only takes iterable input ids = [ids] + + # add hyphen to uuid if missing + for i in range(len(ids)): + id_ = ids[i] + if isinstance(id_, str): + if not "-" in id_: + ids[i] = f"{id_[:8]}-{id_[8:12]}-{id_[12:16]}-{id_[16:20]}-{id_[20:]}" + output_list = [] retrievals = self.qd_client.retrieve( collection_name=self.collection, ids=ids, with_payload=True, - with_vectors=with_vec, # no need vectors + with_vectors=with_vectors, # no need vectors ) - # retrieve() of qd client does not return result in the order of ids in the list - # sort it for convenience - sorted_retrievals = sorted(retrievals, key=lambda x: ids.index(x.id)) - # get the information - for result in sorted_retrievals: + retrieval_dict = {result.id: result for result in retrievals} + + # retrieve() of qd client does not return result in the order of ids in the list + # get the retrieval result in output by id order + for id_ in ids: + try: + result = retrieval_dict[id_] + except: + _LOGGER.warning(f"Warning: no id stored in backend matches {id_}.") + continue result_dict = {"id": result.id, "payload": result.payload} - if with_vec: + if with_vectors: result_dict["vector"] = result.vector output_list.append(result_dict) diff --git a/geniml/search/backends/filebackend.py b/geniml/search/backends/filebackend.py index b2bc8ff..49d9d92 100644 --- a/geniml/search/backends/filebackend.py +++ b/geniml/search/backends/filebackend.py @@ -118,11 +118,11 @@ def load( # increase max_elements to contain new loadings current_max = self.idx.get_max_elements() - if not ids: + if ids is None: new_max = current_max + vectors.shape[0] ids = np.arange(start=current_max, stop=new_max) else: - new_max = ids.amax() + new_max = ids.max() + 1 # check if the number of embedding vectors and labels are same verify_load_inputs(vectors, ids, payloads) @@ -146,7 +146,7 @@ def search( with_vectors: bool = True, offset: int = 0, ) -> Union[ - List[Dict[str, Union[int, float, Dict[str, str], List[float]]]], + List[Dict[str, Union[int, float, Dict[str, str], np.ndarray]]], List[List[Dict[str, Union[int, float, Dict[str, str], np.ndarray]]]], ]: """ @@ -199,14 +199,14 @@ def search( def __len__(self) -> int: return self.idx.element_count - def retrieve_info(self, ids: Union[List[int], int], with_vec: bool = False) -> Union[ + def retrieve_info(self, ids: Union[List[int], int], with_vectors: bool = False) -> Union[ Dict[str, Union[int, List[float], Dict[str, str]]], List[Dict[str, Union[int, List[float], Dict[str, str]]]], ]: """ With an id or a list of storage ids, return the information of these vectors :param ids: storage id, or a list of ids - :param with_vec: whether the stored vector is included in the result + :param with_vectors: whether the stored vector is included in the result :return: """ if not isinstance(ids, list): @@ -217,7 +217,7 @@ def retrieve_info(self, ids: Union[List[int], int], with_vec: bool = False) -> U output_dict = {"id": id_, "payload": self.payloads[id_]} output_list.append(output_dict) - if with_vec: + if with_vectors: vecs = self.idx.get_items(ids, return_type="numpy") for i in range(len(vecs)): output_list[i]["vector"] = vecs[i] diff --git a/geniml/search/const.py b/geniml/search/const.py index ad712c8..ced9e24 100644 --- a/geniml/search/const.py +++ b/geniml/search/const.py @@ -1,12 +1,12 @@ from qdrant_client.http import models -from qdrant_client.models import Distance, VectorParams +from qdrant_client.models import Distance DEFAULT_QDRANT_HOST = "localhost" DEFAULT_QDRANT_PORT = 6333 DEFAULT_COLLECTION_NAME = "embeddings" -DEFAULT_QDRANT_CONFIG = VectorParams(size=100, distance=Distance.COSINE) +DEFAULT_QDRANT_DIST = Distance.COSINE DEFAULT_INDEX_PATH = "./current_index.bin" @@ -14,6 +14,8 @@ DEFAULT_DIM = 100 +DEFAULT_TEXT_DIM = 384 + # the size of the dynamic list for the nearest neighbors # Higher ef leads to more accurate but slower search @@ -38,3 +40,5 @@ HF_INDEX = "index.bin" HF_PAYLOADS = "payloads.pkl" HF_METADATA = "metadata.json" + +TEXT_ENCODER_REPO = "databio/bivec-search-demo" diff --git a/geniml/search/hfdemo/__init__.py b/geniml/search/hfdemo/__init__.py new file mode 100644 index 0000000..aa88717 --- /dev/null +++ b/geniml/search/hfdemo/__init__.py @@ -0,0 +1 @@ +from .bivec_demo import hf_bivec_search diff --git a/geniml/search/hfdemo/bivec_demo.py b/geniml/search/hfdemo/bivec_demo.py new file mode 100644 index 0000000..d651b41 --- /dev/null +++ b/geniml/search/hfdemo/bivec_demo.py @@ -0,0 +1,121 @@ +import json +import os +import tempfile +from typing import Dict + +import numpy as np +from huggingface_hub import hf_hub_download + +from ..backends import BiVectorBackend, HNSWBackend +from ..const import TEXT_ENCODER_REPO +from ..interfaces import BiVectorSearchInterface + + +def load_json(json_path: str) -> Dict: + """ + Load metadata stored in json files + + :param json_path: path to json file + :return: dictionary stored in the json file + """ + with open(json_path, "r") as f: + result = json.load(f) + return result + + +def load_vectors(npz_path, vec_key="vectors") -> np.ndarray: + """ + Load vectors stored in .npz file + + :param npz_path: path to the npz file + :param vec_key: storage key of vector in the npz file + :return: the stored vectors + """ + data = np.load(npz_path) + return data[vec_key] + + +def hf_bivec_search(query, repo: str = TEXT_ENCODER_REPO, limit=5, p=1.0, q=1.0, rank=True): + """ + Demo using data loaded onto huggingface dataset + + :param query: free form query terms + :param repo: the huggingface repository of text encoder model + :param limit:see docstring of geniml.search.backend.BiVectorBackend + :param p: + :param q: + :param rank: + :return: the search result from demo dataset on huggingface + """ + + # download files from huggingface dataset + bed_embeddings_path = hf_hub_download(repo, "bed_embeddings.npz", repo_type="dataset") + file_id_path = hf_hub_download(repo, "file_id.json", repo_type="dataset") + metadata_path = hf_hub_download(repo, "file_key_metadata.json", repo_type="dataset") + metadata_match_path = hf_hub_download(repo, "metadata_id_match.json", repo_type="dataset") + text_embeddings_path = hf_hub_download(repo, "text_embeddings.npz", repo_type="dataset") + + # load data from downloaded files + file_id_dict = load_json(file_id_path) + metadata_dict = load_json(metadata_path) + metadata_match_dict = load_json(metadata_match_path) + + bed_data = np.load(bed_embeddings_path) + bed_embeddings = bed_data["vectors"] + bed_names = list(bed_data["names"]) + + bed_name_idx = {value: index for index, value in enumerate(bed_names)} + + text_data = np.load(text_embeddings_path) + + text_embeddings = text_data["vectors"] + text_annotations = list(text_data["texts"]) + + bed_payloads = [] + bed_vecs = [] + + # vectors and payloads for metadata backend + for i in range(len(file_id_dict)): + bed_embedding_id = bed_name_idx[file_id_dict[str(i)]] + bed_vecs.append(bed_embeddings[bed_embedding_id]) + bed_payloads.append( + {"name": file_id_dict[str(i)], "metadata": metadata_dict[file_id_dict[str(i)]]} + ) + + # payloads for bed file backend + text_payloads = [] + for annotation in text_annotations: + text_payloads.append( + {"term": annotation, "matched_files": metadata_match_dict[annotation]} + ) + + # backends in temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + # backend for BED file embedding vectors + bed_backend = HNSWBackend(local_index_path=os.path.join(temp_dir, "bed.bin"), dim=100) + bed_backend.load(vectors=np.array(bed_vecs), payloads=bed_payloads) + + # backend for metadata embedding vectors + text_backend = HNSWBackend(local_index_path=os.path.join(temp_dir, "text.bin"), dim=384) + text_backend.load(vectors=np.array(text_embeddings), payloads=text_payloads) + + # combined bi-vector search backend + search_backend = BiVectorBackend(text_backend, bed_backend) + + # search interface + search_interface = BiVectorSearchInterface( + backend=search_backend, query2vec="sentence-transformers/all-MiniLM-L6-v2" + ) + + result = search_interface.query_search( + query=query, + limit=limit, + with_payload=True, + p=p, + q=q, + with_vectors=False, + distance=True, # HNSWBackend returns result by distance instead of similarity + rank=rank, + ) + + return result diff --git a/geniml/search/interfaces/__init__.py b/geniml/search/interfaces/__init__.py index bcbbf99..ff2707c 100644 --- a/geniml/search/interfaces/__init__.py +++ b/geniml/search/interfaces/__init__.py @@ -1,2 +1,3 @@ from .bed2bed import BED2BEDSearchInterface +from .mlfree import BiVectorSearchInterface from .text2bed import Text2BEDSearchInterface diff --git a/geniml/search/interfaces/mlfree.py b/geniml/search/interfaces/mlfree.py new file mode 100644 index 0000000..918aeb4 --- /dev/null +++ b/geniml/search/interfaces/mlfree.py @@ -0,0 +1,59 @@ +from typing import Dict, List, Union + +import numpy as np + +from ..backends import BiVectorBackend +from ..query2vec import Text2Vec +from .abstract import BEDSearchInterface + + +class BiVectorSearchInterface(BEDSearchInterface): + """Search interface for ML free bi-vectors searching backend""" + + def __init__(self, backend: BiVectorBackend, query2vec: Union[str, Text2Vec]) -> None: + """ + Initiate the search interface + + :param backend: the backend where vectors are stored + :param query2vec: a Text2Vec, for details, see docstrings in geniml.search.query2vec.text2vec + """ + if isinstance(query2vec, str): + self.query2vec = Text2Vec(query2vec, v2v=None) + else: + self.query2vec = query2vec + self.backend = backend + + def query_search( + self, + query: Union[str, np.ndarray], + limit: int, + with_payload: bool = True, + with_vectors: bool = True, + offset: int = 0, + p: float = 1.0, + q: float = 1.0, + distance: bool = False, + rank: bool = False, + ) -> List[Dict]: + """ + + :param query: the natural language query string, or a vector in the embedding space of region sets + + for rest of the parameters, check the docstring of QdrantBackend.search() or HNSWBackend.search() + """ + if isinstance(query, np.ndarray): + search_vec = query + else: + search_vec = self.query2vec.forward(query) + + return self.backend.search( + query=search_vec, + limit=limit, + with_payload=with_payload, + with_vectors=with_vectors, + offset=offset, + p=p, + q=q, + distance=distance, + rank=rank, + ) diff --git a/geniml/search/query2vec/text2vec.py b/geniml/search/query2vec/text2vec.py index 3ec4b0c..3b9a854 100644 --- a/geniml/search/query2vec/text2vec.py +++ b/geniml/search/query2vec/text2vec.py @@ -14,7 +14,7 @@ class Text2Vec(Query2Vec): """Map a query string into a vector into the embedding space of region sets""" - def __init__(self, hf_repo: str, v2v: Union[str, Vec2VecFNN]): + def __init__(self, hf_repo: str, v2v: Union[str, Vec2VecFNN, None]): """ :param text_embedder: a model repository on Hugging Face :param v2v: a Vec2VecFNN (see geniml/text2bednn/text2bednn.py) or a model repository on Hugging Face @@ -25,11 +25,10 @@ def __init__(self, hf_repo: str, v2v: Union[str, Vec2VecFNN]): if isinstance(v2v, Vec2VecFNN): self.v2v = v2v elif isinstance(v2v, str): + # for bivec search (ML free) self.v2v = Vec2VecFNN(v2v) else: - _LOGGER.error( - "TypeError: Please give a Vec2VecFNN or a model repository on Hugging Face" - ) + self.v2v = None def forward(self, query: str) -> np.ndarray: """ @@ -41,5 +40,8 @@ def forward(self, query: str) -> np.ndarray: """ # embed query string query_embedding = np.array(self.text_embedder.embed_query(query)) - # map the query string embedding into the embedding space of region sets - return self.v2v.embedding_to_embedding(query_embedding) + if self.v2v is None: + return query_embedding + else: + # map the query string embedding into the embedding space of region sets + return self.v2v.encode(query_embedding) diff --git a/tests/test_search.py b/tests/test_search.py index a5f08d8..296d9c6 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,4 +1,5 @@ import os +import pprint import random from typing import Dict, List @@ -8,13 +9,16 @@ from geniml.io import RegionSet from geniml.region2vec.main import Region2VecExModel from geniml.search import BED2BEDSearchInterface, BED2Vec, Text2BEDSearchInterface, Text2Vec -from geniml.search.backends import HNSWBackend, QdrantBackend +from geniml.search.backends import BiVectorBackend, HNSWBackend, QdrantBackend from geniml.search.backends.filebackend import DEP_HNSWLIB DATA_FOLDER_PATH = os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "data" ) +random.seed(100) +np.random.seed(100) + @pytest.fixture def bed_folder(): @@ -31,62 +35,164 @@ def bed_folder(): @pytest.fixture -def embeddings(filenames): +def filenames(bed_folder): """ - mock embedding vectors for testing + list of bed file names """ - np.random.seed(100) - return np.random.random((len(filenames), 100)) + return [ + "ENCX3P", + "ENCN4Z", + "ENC7VQ", + "ENCY6R", + "ENCJ9K", + "ENCD8T", + "ENCQ1A", + "ENCM2F", + "ENCKMR", + ] @pytest.fixture -def filenames(bed_folder): +def metadata(): """ - list of bed file names + mock metadata for testing """ - return os.listdir(bed_folder) + return { + "ENCX3P": {"biosample": "HEK293", "target": "H3K27ac", "organ": ["kidney", "epithelium"]}, + "ENCN4Z": {"biosample": "HEK293", "target": "CTCF", "organ": ["kidney"]}, + "ENC7VQ": {"biosample": "HEK293", "target": "TBP", "organ": ["kidney", "epithelium"]}, + "ENCY6R": {"biosample": "A549", "target": "H3K27ac", "organ": ["epithelium", "lung"]}, + "ENCJ9K": {"biosample": "A549", "target": "CTCF", "organ": ["lung"]}, + "ENCD8T": {"biosample": "K562", "target": "TBP", "organ": ["blood"]}, + "ENCQ1A": {"biosample": "K562", "target": "H3K27ac", "organ": ["blood"]}, + "ENCM2F": {"biosample": "K562", "target": "CTCF", "organ": ["blood"]}, + "ENCKMR": {"biosample": "apple"}, + } @pytest.fixture -def metadata(): +def annotations(): + return [ + "HEK293", + "A549", + "K562", + "H3K27ac", + "CTCF", + "TBP", + "kidney", + "epithelium", + "lung", + "blood", + "apple", + ] + + +@pytest.fixture +def annotation_matches(): + return { + "HEK293": [0, 1, 2], + "A549": [3, 4], + "K562": [5, 6, 7], + "H3K27ac": [0, 3, 5], + "CTCF": [1, 4, 7], + "TBP": [2, 6], + "kidney": [0, 2], + "epithelium": [0, 2, 3], + "lung": [3, 4], + "blood": [5, 6, 7], + "apple": [8], + } + + +@pytest.fixture +def uuids(): + return [ + "7bbab414-053d-4c06-9085-d3ca894dc8b8", + "8b3fa142-8866-4b4c-9df8-7734b4ef9f2a", + "478c9b96-3b4c-41c3-af56-68e8c39de0a3", + "971c58a5-c126-433b-887c-4184184cbce6", + "b28381f8-82ce-4b19-86c0-34a2d368e3b3", + "d6f1060e-6e14-4faf-8711-f25ed5c6618e", + "920ef6f6-f821-46f9-9d11-3516119feeec", + "6ec9d4a4-a481-43dc-81f3-098953c77b0a", + "ce5345d8-84a1-4c6c-9427-145a3f207805", + ] + + +@pytest.fixture +def bed_payloads(filenames, metadata): """ - mock metadata for testing + mock list of label dictionaries for testing """ - return "This is a mock metadata, just for testing." + output_list = [] + for name in filenames: + output_list.append({"name": name, "metadata": metadata[name]}) + return output_list @pytest.fixture -def labels(filenames, metadata): +def metadata_payloads(annotations, annotation_matches): """ mock list of label dictionaries for testing """ output_list = [] - for name in filenames: - output_list.append({"name": name, "metadata": metadata}) + for tag in annotations: + output_list.append({"text": tag, "matched_files": annotation_matches[tag]}) return output_list @pytest.fixture -def collection(): +def bed_collection(): """ - collection name for qdrant client storage + Returns: bed_collection name for qdrant client storage """ return "hg38_sample" @pytest.fixture -def ids(filenames): +def metadata_collection(): + """ + Returns: bed_collection name for qdrant client storage + """ + + return "bed_metadata" + + +@pytest.fixture +def bed_embeddings(filenames): + """ + mock embedding vectors for testing + """ + + return np.random.random((len(filenames), 100)) + + +@pytest.fixture +def text_embeddings(annotations): """ - list of randomly sampled ids + mock embedding vectors for testing """ - random.seed(100) - return random.sample(range(len(filenames)), 5) + return np.random.random((len(annotations), 384)) + + +@pytest.fixture +def int_ids(filenames): + """ + list of randomly sampled integer_ids + """ + return random.sample(range(len(filenames)), 3) + + +@pytest.fixture +def ids(uuids): + ids_with_hyphen = random.sample(uuids, 3) + return [uuid.replace("-", "") for uuid in ids_with_hyphen] # @pytest.fixture @@ -97,15 +203,27 @@ def temp_data_dir(tmp_path_factory): @pytest.fixture(scope="module") -def temp_idx_path(temp_data_dir): +def temp_bed_idx_path(temp_data_dir): # temporal index path return temp_data_dir / "testing_idx.bin" @pytest.fixture(scope="module") -def hnswb(temp_idx_path): +def temp_metadata_idx_path(temp_data_dir): + # temporal index path + return temp_data_dir / "testing_metadata_idx.bin" + + +@pytest.fixture(scope="module") +def bed_hnswb(temp_bed_idx_path): # init backend - return HNSWBackend(local_index_path=str(temp_idx_path)) + return HNSWBackend(local_index_path=str(temp_bed_idx_path)) + + +@pytest.fixture(scope="module") +def metadata_hnswb(temp_metadata_idx_path): + # init backend + return HNSWBackend(local_index_path=str(temp_metadata_idx_path), dim=384) @pytest.fixture @@ -132,15 +250,6 @@ def v2v_hf_repo(): return "databio/v2v-sentencetransformers-encode" -@pytest.fixture -def collection(): - """ - Returns: collection name for qdrant client storage - """ - - return "hg38_sample" - - @pytest.fixture def query_term(): """ @@ -165,78 +274,114 @@ def query_bed(): return "./data/s1_a.bed" +def cosine_similarity(vec1: np.array, vec2: np.array) -> float: + # Ensure the vectors have shape (100,) + assert vec1.shape == (100,) and vec2.shape == (100,), "Both vectors must have shape (100,)" + + # Compute the dot product of the two vectors + dot_product = np.dot(vec1, vec2) + + # Compute the magnitude (L2 norm) of each vector + magnitude_vec1 = np.linalg.norm(vec1) + magnitude_vec2 = np.linalg.norm(vec2) + + # Compute the cosine similarity + if magnitude_vec1 == 0 or magnitude_vec2 == 0: + return 0.0 # Avoid division by zero + + cosine_sim = dot_product / (magnitude_vec1 * magnitude_vec2) + + return cosine_sim + + @pytest.mark.skipif( "not config.getoption('--qdrant')", reason="Only run when --qdrant is given", ) -def test_QdrantBackend(filenames, embeddings, labels, collection, ids): - qd_search_backend = QdrantBackend(collection=collection) +def test_QdrantBackend(filenames, bed_embeddings, bed_payloads, bed_collection, ids, uuids): + def search_results_test(search_results): + assert isinstance(search_results, list) + for result in search_results: + assert isinstance(result, dict) # only target pairs + assert isinstance(result["id"], str) + assert isinstance(result["score"], float) + + assert isinstance(result["vector"], list) + for i in result["vector"]: + assert isinstance(i, float) + assert isinstance(result["payload"], dict) + assert isinstance(result["payload"]["name"], str) + assert isinstance(result["payload"]["metadata"], dict) + + qd_search_backend = QdrantBackend(collection=bed_collection) # load data - qd_search_backend.load(embeddings, payloads=labels) + qd_search_backend.load(bed_embeddings, payloads=bed_payloads, ids=uuids) # test searching + query_vec = np.random.random( + 100, + ) search_results = qd_search_backend.search( - np.random.random( - 100, - ), + query_vec, 5, with_payload=True, with_vectors=True, ) - assert isinstance(search_results, list) - for result in search_results: - assert isinstance(result, dict) # only target pairs - assert isinstance(result["id"], int) - assert isinstance(result["score"], float) - assert isinstance(result["vector"], list) - for i in result["vector"]: - assert isinstance(i, float) - assert isinstance(result["payload"], dict) - assert isinstance(result["payload"]["name"], str) - assert isinstance(result["payload"]["metadata"], str) + + search_results_test(search_results) + assert len(qd_search_backend) == len(filenames) # test information retrieval retrieval_results = qd_search_backend.retrieve_info(ids, True) + assert len(retrieval_results) == len(ids) assert isinstance(retrieval_results, list) for i in range(len(ids)): assert ids[i] == retrieval_results[i]["id"] client_retrieval = qd_search_backend.qd_client.retrieve( - collection, [ids[i]], with_vectors=True + bed_collection, [ids[i]], with_vectors=True ) assert retrieval_results[i]["vector"] == client_retrieval[0].vector assert retrieval_results[i]["payload"] == client_retrieval[0].payload + + # test batch search + batch_query = np.random.random((6, 100)) + batch_result = qd_search_backend.search( + batch_query, limit=3, with_payload=True, with_vectors=True + ) + for batch in batch_result: + search_results_test(batch) + qd_search_backend.qd_client.delete_collection(qd_search_backend.collection) @pytest.mark.skipif( DEP_HNSWLIB == False, reason="This test require installation of hnswlib (optional)" ) -def test_HNSWBackend_load(filenames, embeddings, labels, hnswb, ids): +def test_HNSWBackend_load(filenames, bed_embeddings, bed_payloads, bed_hnswb, ids): num_upload = len(filenames) # batches to load - labels_1 = labels[: num_upload // 2] - labels_2 = labels[num_upload // 2 :] - embeddings_1 = embeddings[: num_upload // 2] - embeddings_2 = embeddings[num_upload // 2 :] + labels_1 = bed_payloads[: num_upload // 2] + labels_2 = bed_payloads[num_upload // 2 :] + embeddings_1 = bed_embeddings[: num_upload // 2] + embeddings_2 = bed_embeddings[num_upload // 2 :] # load first batch - hnswb.load(embeddings_1, payloads=labels_1) - assert len(hnswb) == num_upload // 2 + bed_hnswb.load(embeddings_1, payloads=labels_1) + assert len(bed_hnswb) == num_upload // 2 # load second batch - hnswb.load(embeddings_2, payloads=labels_2) - assert len(hnswb) == num_upload - # pytestconfig.cache.set('shared_backend', hnswb) + bed_hnswb.load(embeddings_2, payloads=labels_2) + assert len(bed_hnswb) == num_upload @pytest.mark.skipif( DEP_HNSWLIB == False, reason="This test require installation of hnswlib (optional)" ) -# @pytest.mark.dependency(depends=["test_HNSWBackend_load"]) -def test_HNSWBackend_search(filenames, hnswb, ids): +@pytest.mark.dependency(depends=["test_HNSWBackend_load"]) +def test_HNSWBackend_search(filenames, bed_hnswb, int_ids): def search_result_check(dict_list: List[Dict], backend: HNSWBackend, with_dist: bool = False): """ repeated test of the output of search / retrieve_info function of HNSWBackend to check if the result matches the content in index @@ -248,6 +393,7 @@ def search_result_check(dict_list: List[Dict], backend: HNSWBackend, with_dist: """ index = backend.idx assert isinstance(dict_list, list) + for result in dict_list: assert isinstance(result, dict) assert isinstance(result["id"], int) @@ -255,25 +401,25 @@ def search_result_check(dict_list: List[Dict], backend: HNSWBackend, with_dist: assert isinstance(result["distance"], float) assert isinstance(result["payload"], dict) assert isinstance(result["vector"], np.ndarray) - # assert result["vector"] == index.get_items([result["id"]])[0] + assert ( result["vector"] == index.get_items([result["id"]], return_type="numpy")[0] ).all() for num in result["vector"]: assert isinstance(num, np.float32) - # hnswb = pytestconfig.cache.get('shared_backend', None) - assert len(hnswb) == len(filenames) + # bed_hnswb = pytestconfig.cache.get('shared_backend', None) + assert len(bed_hnswb) == len(filenames) # test searching with one vector (np.ndarray with shape (dim,)) query_vec = np.random.random( 100, ) - single_vec_search = hnswb.search( + single_vec_search = bed_hnswb.search( query_vec, 3, ) - single_vec_search_offset = hnswb.search( + single_vec_search_offset = bed_hnswb.search( query_vec, 3, offset=2, @@ -286,32 +432,32 @@ def search_result_check(dict_list: List[Dict], backend: HNSWBackend, with_dist: single_vec_search_offset[j]["payload"]["metadata"] == single_vec_search[j]["payload"]["metadata"] ) - search_result_check(single_vec_search, hnswb, True) - search_result_check(single_vec_search_offset, hnswb, True) + search_result_check(single_vec_search, bed_hnswb, True) + search_result_check(single_vec_search_offset, bed_hnswb, True) # test searching with multiple vectors (np.ndarray with shape (n, dim)) - multiple_vecs_search = hnswb.search(np.random.random((7, 100)), 5) + multiple_vecs_search = bed_hnswb.search(np.random.random((7, 100)), 5) assert isinstance(multiple_vecs_search, list) assert len(multiple_vecs_search) == 7 for i in range(len(multiple_vecs_search)): - search_result_check(multiple_vecs_search[i], hnswb, True) + search_result_check(multiple_vecs_search[i], bed_hnswb, True) # test information retrieval / get items - retrieval_results = hnswb.retrieve_info(ids, True) - search_result_check(retrieval_results, hnswb, False) + retrieval_results = bed_hnswb.retrieve_info(int_ids, True) + search_result_check(retrieval_results, bed_hnswb, False) @pytest.mark.skipif( DEP_HNSWLIB == False, reason="This test require installation of hnswlib (optional)" ) -# @pytest.mark.dependency(depends=["test_HNSWBackend_load"]) -def test_HNSWBackend_save(filenames, hnswb, embeddings, temp_idx_path, temp_data_dir): +@pytest.mark.dependency(depends=["test_HNSWBackend_load"]) +def test_HNSWBackend_save(filenames, bed_hnswb, bed_embeddings, temp_bed_idx_path, temp_data_dir): # test saving from local - new_hnswb = HNSWBackend(local_index_path=str(temp_idx_path), payloads=hnswb.payloads) - assert new_hnswb.idx.max_elements == embeddings.shape[0] + new_hnswb = HNSWBackend(local_index_path=str(temp_bed_idx_path), payloads=bed_hnswb.payloads) + assert new_hnswb.idx.max_elements == bed_embeddings.shape[0] - for i in range(embeddings.shape[0]): - old_result = hnswb.idx.get_items([i], return_type="numpy") + for i in range(bed_embeddings.shape[0]): + old_result = bed_hnswb.idx.get_items([i], return_type="numpy") new_result = new_hnswb.idx.get_items([i], return_type="numpy") assert (old_result == new_result).all() @@ -321,6 +467,76 @@ def test_HNSWBackend_save(filenames, hnswb, embeddings, temp_idx_path, temp_data assert len(empty_hnswb.payloads) == 0 +@pytest.mark.skipif( + DEP_HNSWLIB == False, reason="This test require installation of hnswlib (optional)" +) +@pytest.mark.dependency(depends=["test_HNSWBackend_load"]) +@pytest.mark.skipif( + "not config.getoption('--qdrant')", + reason="Only run when --qdrant is given", +) +def test_QD_BiVectorBackend( + bed_hnswb, + metadata_hnswb, + bed_collection, + bed_embeddings, + bed_payloads, + metadata_collection, + text_embeddings, + metadata_payloads, +): + def bivec_test(bivec_backend, dist: bool = False, rank: bool = False): + query_vec = np.random.random( + 384, + ) + search_results = bivec_backend.search( + query_vec, 2, with_payload=True, with_vectors=True, distance=dist, rank=rank + ) + assert isinstance(search_results, list) + min_score = 100.0 + max_rank = -1 + for result in search_results: + assert isinstance(result, dict) # only target pairs + assert isinstance(result["id"], int) + + if not rank: + assert isinstance(result["score"], float) + assert result["score"] <= min_score + min_score = result["score"] + else: + assert isinstance(result["max_rank"], int) + assert result["max_rank"] >= max_rank + max_rank = result["max_rank"] + + assert isinstance(result["vector"], list) or isinstance(result["vector"], np.ndarray) + if isinstance(result["vector"], list): + for i in result["vector"]: + assert isinstance(i, float) + assert isinstance(result["payload"], dict) + assert isinstance(result["payload"]["name"], str) + assert isinstance(result["payload"]["metadata"], dict) + + # test QdrantBackend + bed_backend = QdrantBackend(collection=bed_collection) + # load data + bed_backend.load(bed_embeddings, payloads=bed_payloads) + + text_backend = QdrantBackend(collection=metadata_collection, dim=384) + text_backend.load(text_embeddings, payloads=metadata_payloads) + + bivec_qd_backend = BiVectorBackend(text_backend, bed_backend) + bivec_test(bivec_qd_backend, rank=True) + bivec_test(bivec_qd_backend, rank=False) + bivec_qd_backend.metadata_backend.qd_client.delete_collection(text_backend.collection) + bivec_qd_backend.bed_backend.qd_client.delete_collection(bed_backend.collection) + + # test HNSWBackend + metadata_hnswb.load(text_embeddings, payloads=metadata_payloads) + bivec_hnsw_backend = BiVectorBackend(metadata_hnswb, bed_hnswb) + bivec_test(bivec_hnsw_backend, dist=True, rank=True) + bivec_test(bivec_hnsw_backend, dist=True, rank=False) + + @pytest.mark.skipif( "not config.getoption('--huggingface')", reason="Only run when --huggingface is given", @@ -353,7 +569,7 @@ def test_text2bed_search_interface( r2v_hf_repo, nl_embed_repo, v2v_hf_repo, - collection, + bed_collection, query_term, tmp_path_factory, ): @@ -373,7 +589,7 @@ def test_text2bed_search_interface( vecs = np.array(vecs) - qd_search_backend = QdrantBackend(collection=collection) + qd_search_backend = QdrantBackend(collection=bed_collection) qd_search_backend.load(vectors=vecs, payloads=payloads) # # text2vec = Text2Vec(nl_embed_repo, v2v_hf_repo) @@ -389,8 +605,8 @@ def test_text2bed_search_interface( assert eval_results["Mean AUC-ROC"] > 0 assert eval_results["Average R-Precision"] > 0 - # delete testing collection - db_interface.backend.qd_client.delete_collection(collection_name=collection) + # delete testing bed_collection + db_interface.backend.qd_client.delete_collection(collection_name=bed_collection) # construct a search interface with file backend temp_data_dir = tmp_path_factory.mktemp("data") @@ -421,7 +637,7 @@ def test_text2bed_search_interface( def test_bed2bed_search_interface( bed_folder, r2v_hf_repo, - collection, + bed_collection, query_bed, tmp_path_factory, ): @@ -439,7 +655,7 @@ def test_bed2bed_search_interface( vecs = np.array(vecs) - qd_search_backend = QdrantBackend(collection=collection) + qd_search_backend = QdrantBackend(collection=bed_collection) qd_search_backend.load(vectors=vecs, payloads=payloads) bed2vec = BED2Vec(r2v_hf_repo) @@ -450,8 +666,8 @@ def test_bed2bed_search_interface( for i in range(len(db_search_result)): assert isinstance(db_search_result[i], dict) - # delete testing collection - db_interface.backend.qd_client.delete_collection(collection_name=collection) + # delete testing bed_collection + db_interface.backend.qd_client.delete_collection(collection_name=bed_collection) # construct a search interface with file backend temp_data_dir = tmp_path_factory.mktemp("data")