From 632ecdc4bea84e5e9da31657e2e829f0116fb1b1 Mon Sep 17 00:00:00 2001 From: Samarpan Harit Date: Wed, 18 Dec 2024 01:25:40 +0530 Subject: [PATCH] Add tests for label repository --- todo/tests/fixtures/label.py | 2 + .../repositories/test_label_repository.py | 42 +++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 todo/tests/unit/repositories/test_label_repository.py diff --git a/todo/tests/fixtures/label.py b/todo/tests/fixtures/label.py index 3017688..05de140 100644 --- a/todo/tests/fixtures/label.py +++ b/todo/tests/fixtures/label.py @@ -18,3 +18,5 @@ "createdBy": "qMbT6M2GB65W7UHgJS4g", }, ] + +label_models = [LabelModel(**data) for data in label_db_data] diff --git a/todo/tests/unit/repositories/test_label_repository.py b/todo/tests/unit/repositories/test_label_repository.py new file mode 100644 index 0000000..a6ea123 --- /dev/null +++ b/todo/tests/unit/repositories/test_label_repository.py @@ -0,0 +1,42 @@ +from unittest import TestCase +from unittest.mock import patch, MagicMock +from pymongo.collection import Collection +from todo.models.label import LabelModel +from todo.repositories.label_repository import LabelRepository +from todo.tests.fixtures.label import label_db_data + + +class LabelRepositoryTests(TestCase): + def setUp(self): + self.label_ids = [label_data["_id"] for label_data in label_db_data] + self.label_data = label_db_data + + self.patcher_get_collection = patch("todo.repositories.label_repository.LabelRepository.get_collection") + self.mock_get_collection = self.patcher_get_collection.start() + self.mock_collection = MagicMock(spec=Collection) + self.mock_get_collection.return_value = self.mock_collection + + def tearDown(self): + self.patcher_get_collection.stop() + + def test_list_by_ids_returns_label_models(self): + self.mock_collection.find.return_value = self.label_data + + result = LabelRepository.list_by_ids(self.label_ids) + + self.assertEqual(len(result), len(self.label_data)) + self.assertTrue(all(isinstance(label, LabelModel) for label in result)) + + def test_list_by_ids_returns_empty_list_if_not_found(self): + self.mock_collection.find.return_value = [] + + result = LabelRepository.list_by_ids([self.label_ids[0]]) + + self.assertEqual(result, []) + + def test_list_by_ids_skips_db_call_for_empty_input(self): + result = LabelRepository.list_by_ids([]) + + self.assertEqual(result, []) + self.mock_get_collection.assert_not_called() + self.mock_collection.assert_not_called()