Skip to content

Commit

Permalink
refactor: search from storages
Browse files Browse the repository at this point in the history
  • Loading branch information
bathienle committed Aug 27, 2024
1 parent c064321 commit 570a303
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 8 deletions.
23 changes: 15 additions & 8 deletions cbir/api/searches.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Search API"""

from typing import List

from fastapi import (
APIRouter,
Depends,
Expand All @@ -10,7 +12,7 @@
)
from fastapi.responses import JSONResponse

from cbir.api.utils.utils import get_retrieval
from cbir.api.utils.utils import get_retrievals
from cbir.retrieval.retrieval import ImageRetrieval

router = APIRouter()
Expand All @@ -21,9 +23,9 @@ async def retrieve_image(
request: Request,
image: UploadFile,
nrt_neigh: int = Query(...),
storage_name: str = Query(..., alias="storage"),
storage_names: List[str] = Query(..., alias="storage"),
index_name: str = Query(default="index", alias="index"),
retrieval: ImageRetrieval = Depends(get_retrieval),
retrievals: List[ImageRetrieval] = Depends(get_retrievals),
) -> JSONResponse:
"""
Search for similar images from the index.
Expand All @@ -32,26 +34,31 @@ async def retrieve_image(
request (Request): The incoming HTTP request.
image (UploadFile): The query image.
nrt_neigh (int): The number of nearest neighbors to retrieve.
storage_name (str): The name of the storage where the index is stored.
storage_names (List[str]): The list of storage names.
index_name (str): The name of the index where the image features will be added.
retrieval (ImageRetrieval): The image retrieval object.
retrievals (List[ImageRetrieval]): The image retrieval object.
Returns:
JSONResponse: A JSON containing the list of similarities.
"""

if not storage_name:
if not storage_names:
raise HTTPException(status_code=404, detail="Storage is required")

model = request.app.state.model
content = await image.read()

similarities = retrieval.search(model, content, nrt_neigh)
similarities = [
result
for retrieval in retrievals
for result in retrieval.search(model, content, nrt_neigh)
]
similarities = sorted(similarities, key=lambda x: x[1])[:nrt_neigh]

return JSONResponse(
content={
"query": image.filename,
"storage": storage_name,
"storage": storage_names,
"index": index_name,
"similarities": similarities,
}
Expand Down
67 changes: 67 additions & 0 deletions cbir/api/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utility functions for dependency injection."""

from typing import List

from fastapi import Depends, Query
from redis import Redis # type: ignore

Expand Down Expand Up @@ -29,6 +31,25 @@ def get_store(
return Store(storage_name, redis, index_name)


def get_stores(
storage_names: List[str] = Query(..., alias="storage"),
index_name: str = Query(default="index", alias="index"),
redis: Redis = Depends(get_redis),
) -> List[Store]:
"""
Instantiate a list of Store objects based on the provided storage names.
Args:
storage_names (List[str]): The names of the storages.
index_name (str): The name of the index.
redis (Redis): An instance of the Redis client.
Returns:
List[Store]: A list of Store instances.
"""
return [Store(storage_name, redis, index_name) for storage_name in storage_names]


def get_indexer(
storage_name: str = Query(..., alias="storage"),
index_name: str = Query(default="index", alias="index"),
Expand All @@ -54,6 +75,35 @@ def get_indexer(
)


def get_indexers(
storage_names: List[str] = Query(..., alias="storage"),
index_name: str = Query(default="index", alias="index"),
settings: Settings = Depends(get_settings),
) -> List[Indexer]:
"""
Instantiate a list of Indexer objects based on the provided storage names.
Args:
storage_names (List[str]): The names of the storages.
index_name (str): The name of the index.
settings (Settings): The app settings.
Returns:
List[Indexer]: A list of Indexer instances.
"""
device = settings.device.type == "cuda"
return [
Indexer(
settings.data_path,
storage_name,
index_name,
settings.n_features,
device,
)
for storage_name in storage_names
]


def get_retrieval(
store: Store = Depends(get_store),
indexer: Indexer = Depends(get_indexer),
Expand All @@ -69,3 +119,20 @@ def get_retrieval(
ImageRetrieval: An instance of the ImageRetrieval.
"""
return ImageRetrieval(store, indexer)


def get_retrievals(
stores: List[Store] = Depends(get_stores),
indexers: List[Indexer] = Depends(get_indexers),
) -> List[ImageRetrieval]:
"""
Instantiate a list of ImageRetrieval objects, one for each store and indexer.
Args:
stores (List[Store]): A list of Store instances.
indexers (List[Indexer]): A list of Indexer instances.
Returns:
List[ImageRetrieval]: A list of ImageRetrieval instances.
"""
return [ImageRetrieval(store, indexer) for store, indexer in zip(stores, indexers)]

0 comments on commit 570a303

Please sign in to comment.