diff --git a/cbir/api/image.py b/cbir/api/image.py index 0395a82..62545c9 100644 --- a/cbir/api/image.py +++ b/cbir/api/image.py @@ -58,6 +58,18 @@ async def index_image(request: Request, image: UploadFile = File()) -> None: ) +@router.delete("/images/remove") +def remove_image(request: Request, filename: str) -> None: + """Remove an indexed image.""" + + database = request.app.state.database + + if not database.contains(filename): + raise HTTPException(status_code=404, detail=f"{filename} not found") + + database.remove(filename) + + @router.post("/images/retrieve") async def retrieve_image( request: Request, diff --git a/cbir/retrieval/database.py b/cbir/retrieval/database.py index cc7f7ce..f4268de 100644 --- a/cbir/retrieval/database.py +++ b/cbir/retrieval/database.py @@ -54,6 +54,11 @@ def __init__( self.resources = faiss.StandardGpuResources() self.index = faiss.index_cpu_to_gpu(self.resources, 0, self.index) + def contains(self, name: str) -> bool: + """Check if a filename is in the index database.""" + + return self.redis.get(name) is not None + def save(self) -> None: """Save the index to the file.""" @@ -77,22 +82,22 @@ def add(self, images: torch.Tensor, names: List[str]) -> None: def remove(self, name: str) -> None: """Remove an image from the index database.""" + key = self.redis.get(name).decode("utf-8") label = int(key) - id_selector = faiss.IDSelectorRange(label, label + 1) + self.index = faiss.index_gpu_to_cpu(self.index) if self.gpu else self.index + self.index.remove_ids(id_selector) + if self.gpu: - self.index = faiss.index_gpu_to_cpu(self.index) + self.resources = faiss.StandardGpuResources() + self.index = faiss.index_cpu_to_gpu(self.resources, 0, self.index) - self.index.remove_ids(id_selector) self.save() self.redis.delete(key) self.redis.delete(name) - if self.gpu: - self.index = faiss.index_cpu_to_gpu(self.resources, 0, self.index) - def index_image(self, model: Model, image: torch.Tensor, filename: str) -> None: """Index an image.""" diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py index 8044f63..8087a9e 100644 --- a/tests/test_retrieval.py +++ b/tests/test_retrieval.py @@ -27,27 +27,48 @@ def test_index_image() -> None: database_settings = DatabaseSetting.get_settings() - with open("tests/data/image.png", "rb") as image: - files = {"image": image.read()} + with TestClient(app) as client, open("tests/data/image.png", "rb") as image: + response = client.post( + "/api/images/index", + files={"image": image}, + ) + + assert response.status_code == 200 + assert os.path.isfile(database_settings.filename) is True + + +def test_remove_image_not_found() -> None: + """Test remove an image that do not exist in the database.""" with TestClient(app) as client: - response = client.post("/api/images/index", files=files) + response = client.delete( + "/api/images/remove", + params={"filename": "notfound.png"}, + ) + + assert response.status_code == 404 + + +def test_remove_image() -> None: + """Test remove an image from the database.""" + + with TestClient(app) as client: + response = client.delete( + "/api/images/remove", + params={"filename": "image.png"}, + ) assert response.status_code == 200 - assert os.path.isfile(database_settings.filename) is True def test_retrieve_image() -> None: """Test image retrieval.""" - with open("tests/data/image.png", "rb") as image: - files = {"image": image.read()} - - with TestClient(app) as client: + with TestClient(app) as client, open("tests/data/image.png", "rb") as image: response = client.post( "/api/images/retrieve", data={"nrt_neigh": "10"}, - files=files, + files={"image": image}, ) data = response.json()