From 709fab7c5509001991302fd6b9262a2454609fde Mon Sep 17 00:00:00 2001 From: Ba Thien Le Date: Fri, 15 Dec 2023 11:41:59 +0100 Subject: [PATCH 1/6] feat: add remove endpoint --- cbir/api/image.py | 12 ++++++++++++ cbir/retrieval/database.py | 15 ++++++++------- tests/test_retrieval.py | 24 ++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/cbir/api/image.py b/cbir/api/image.py index 0395a82..2cd5649 100644 --- a/cbir/api/image.py +++ b/cbir/api/image.py @@ -58,6 +58,18 @@ async def index_image(request: Request, image: UploadFile = File()) -> None: ) +@router.get("/images/remove") +def remove_image(request: Request, filename: str) -> None: + """Remove an indexed image.""" + + database = request.app.state.database + + if not database.contains(filename): + raise HTTPException(status_code=404, detail="Filename not found") + + database.remove(filename) + + @router.post("/images/retrieve") async def retrieve_image( request: Request, diff --git a/cbir/retrieval/database.py b/cbir/retrieval/database.py index cc7f7ce..a92808d 100644 --- a/cbir/retrieval/database.py +++ b/cbir/retrieval/database.py @@ -54,6 +54,11 @@ def __init__( self.resources = faiss.StandardGpuResources() self.index = faiss.index_cpu_to_gpu(self.resources, 0, self.index) + def contains(self, name: str) -> bool: + """Check if a filename is in the index database.""" + + return self.redis.get(name) is not None + def save(self) -> None: """Save the index to the file.""" @@ -77,22 +82,18 @@ def add(self, images: torch.Tensor, names: List[str]) -> None: def remove(self, name: str) -> None: """Remove an image from the index database.""" + key = self.redis.get(name).decode("utf-8") label = int(key) - id_selector = faiss.IDSelectorRange(label, label + 1) - if self.gpu: - self.index = faiss.index_gpu_to_cpu(self.index) + index = faiss.index_gpu_to_cpu(self.index) if self.gpu else self.index + index.remove_ids(id_selector) - self.index.remove_ids(id_selector) self.save() self.redis.delete(key) self.redis.delete(name) - if self.gpu: - self.index = faiss.index_cpu_to_gpu(self.resources, 0, self.index) - def index_image(self, model: Model, image: torch.Tensor, filename: str) -> None: """Index an image.""" diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py index 8044f63..29e8c22 100644 --- a/tests/test_retrieval.py +++ b/tests/test_retrieval.py @@ -37,6 +37,30 @@ def test_index_image() -> None: assert os.path.isfile(database_settings.filename) is True +def test_remove_image_not_found() -> None: + """Test remove an image that do not exist in the database.""" + + with TestClient(app) as client: + response = client.get("/api/images/remove", params={"filename": "notfound.png"}) + + assert response.status_code == 404 + + +def test_remove_image() -> None: + """Test remove an image from the database.""" + + with open("tests/data/image.png", "rb") as image: + files = {"image": image.read()} + + with TestClient(app) as client: + response = client.post("/api/images/index", files=files) + + with TestClient(app) as client: + response = client.get("/api/images/remove", params={"filename": "image.png"}) + + assert response.status_code == 200 + + def test_retrieve_image() -> None: """Test image retrieval.""" From 961eefdf0cfcf4efeebc75e9f7e0653556231513 Mon Sep 17 00:00:00 2001 From: Ba Thien Le Date: Mon, 22 Jan 2024 10:32:53 +0100 Subject: [PATCH 2/6] fix: set index to gpu if needed --- cbir/retrieval/database.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cbir/retrieval/database.py b/cbir/retrieval/database.py index a92808d..f4268de 100644 --- a/cbir/retrieval/database.py +++ b/cbir/retrieval/database.py @@ -87,8 +87,12 @@ def remove(self, name: str) -> None: label = int(key) id_selector = faiss.IDSelectorRange(label, label + 1) - index = faiss.index_gpu_to_cpu(self.index) if self.gpu else self.index - index.remove_ids(id_selector) + self.index = faiss.index_gpu_to_cpu(self.index) if self.gpu else self.index + self.index.remove_ids(id_selector) + + if self.gpu: + self.resources = faiss.StandardGpuResources() + self.index = faiss.index_cpu_to_gpu(self.resources, 0, self.index) self.save() self.redis.delete(key) From 432a88572ae311e6fa41f8685abe167c54f2f52a Mon Sep 17 00:00:00 2001 From: Ba Thien Le Date: Mon, 22 Jan 2024 15:12:53 +0100 Subject: [PATCH 3/6] fix: tests for deletion --- tests/test_retrieval.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py index 29e8c22..7ccfba8 100644 --- a/tests/test_retrieval.py +++ b/tests/test_retrieval.py @@ -27,11 +27,11 @@ def test_index_image() -> None: database_settings = DatabaseSetting.get_settings() - with open("tests/data/image.png", "rb") as image: - files = {"image": image.read()} - with TestClient(app) as client: - response = client.post("/api/images/index", files=files) + response = client.post( + "/api/images/index", + files={"image": open("tests/data/image.png", "rb")}, + ) assert response.status_code == 200 assert os.path.isfile(database_settings.filename) is True @@ -41,7 +41,10 @@ def test_remove_image_not_found() -> None: """Test remove an image that do not exist in the database.""" with TestClient(app) as client: - response = client.get("/api/images/remove", params={"filename": "notfound.png"}) + response = client.delete( + "/api/images/remove", + params={"filename": "notfound.png"}, + ) assert response.status_code == 404 @@ -49,14 +52,11 @@ def test_remove_image_not_found() -> None: def test_remove_image() -> None: """Test remove an image from the database.""" - with open("tests/data/image.png", "rb") as image: - files = {"image": image.read()} - with TestClient(app) as client: - response = client.post("/api/images/index", files=files) - - with TestClient(app) as client: - response = client.get("/api/images/remove", params={"filename": "image.png"}) + response = client.delete( + "/api/images/remove", + params={"filename": "image.png"}, + ) assert response.status_code == 200 @@ -64,14 +64,11 @@ def test_remove_image() -> None: def test_retrieve_image() -> None: """Test image retrieval.""" - with open("tests/data/image.png", "rb") as image: - files = {"image": image.read()} - with TestClient(app) as client: response = client.post( "/api/images/retrieve", data={"nrt_neigh": "10"}, - files=files, + files={"image": open("tests/data/image.png", "rb")}, ) data = response.json() From a140e9b1fb9f1ddc13301d08b09fa42152c7d894 Mon Sep 17 00:00:00 2001 From: Ba Thien Le Date: Mon, 22 Jan 2024 15:14:00 +0100 Subject: [PATCH 4/6] refactor: change GET to DELETE --- cbir/api/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cbir/api/image.py b/cbir/api/image.py index 2cd5649..363b240 100644 --- a/cbir/api/image.py +++ b/cbir/api/image.py @@ -58,7 +58,7 @@ async def index_image(request: Request, image: UploadFile = File()) -> None: ) -@router.get("/images/remove") +@router.delete("/images/remove") def remove_image(request: Request, filename: str) -> None: """Remove an indexed image.""" From c356456b0762ad55da18a953d084a6b8dc1c0e64 Mon Sep 17 00:00:00 2001 From: Ba Thien Le Date: Mon, 22 Jan 2024 15:16:28 +0100 Subject: [PATCH 5/6] feat: display the filename itself --- cbir/api/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cbir/api/image.py b/cbir/api/image.py index 363b240..62545c9 100644 --- a/cbir/api/image.py +++ b/cbir/api/image.py @@ -65,7 +65,7 @@ def remove_image(request: Request, filename: str) -> None: database = request.app.state.database if not database.contains(filename): - raise HTTPException(status_code=404, detail="Filename not found") + raise HTTPException(status_code=404, detail=f"{filename} not found") database.remove(filename) From 177779aec95b7cf7b5353054c147d5ad41c7a687 Mon Sep 17 00:00:00 2001 From: Ba Thien Le Date: Mon, 22 Jan 2024 15:26:00 +0100 Subject: [PATCH 6/6] style: use with statement --- tests/test_retrieval.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py index 7ccfba8..8087a9e 100644 --- a/tests/test_retrieval.py +++ b/tests/test_retrieval.py @@ -27,10 +27,10 @@ def test_index_image() -> None: database_settings = DatabaseSetting.get_settings() - with TestClient(app) as client: + with TestClient(app) as client, open("tests/data/image.png", "rb") as image: response = client.post( "/api/images/index", - files={"image": open("tests/data/image.png", "rb")}, + files={"image": image}, ) assert response.status_code == 200 @@ -64,11 +64,11 @@ def test_remove_image() -> None: def test_retrieve_image() -> None: """Test image retrieval.""" - with TestClient(app) as client: + with TestClient(app) as client, open("tests/data/image.png", "rb") as image: response = client.post( "/api/images/retrieve", data={"nrt_neigh": "10"}, - files={"image": open("tests/data/image.png", "rb")}, + files={"image": image}, ) data = response.json()