From c9dc7970d45746f44aa9b4a4f053ed7147cf04f6 Mon Sep 17 00:00:00 2001 From: Samarpan Harit Date: Wed, 18 Dec 2024 01:25:26 +0530 Subject: [PATCH] Add tests for task repository --- todo/tests/fixtures/task.py | 2 + .../unit/repositories/test_task_repository.py | 51 +++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 todo/tests/unit/repositories/test_task_repository.py diff --git a/todo/tests/fixtures/task.py b/todo/tests/fixtures/task.py index ff87579..81a4a17 100644 --- a/todo/tests/fixtures/task.py +++ b/todo/tests/fixtures/task.py @@ -35,3 +35,5 @@ "updatedBy": "qMbT6M2GB65W7UHgJS4g", }, ] + +tasks_models = [TaskModel(**data) for data in tasks_db_data] diff --git a/todo/tests/unit/repositories/test_task_repository.py b/todo/tests/unit/repositories/test_task_repository.py new file mode 100644 index 0000000..e8c5f28 --- /dev/null +++ b/todo/tests/unit/repositories/test_task_repository.py @@ -0,0 +1,51 @@ +from unittest import TestCase +from unittest.mock import patch, MagicMock +from pymongo.collection import Collection +from todo.models.task import TaskModel +from todo.repositories.task_repository import TaskRepository +from todo.tests.fixtures.task import tasks_db_data + + +class TaskRepositoryTests(TestCase): + def setUp(self): + self.task_data = tasks_db_data + + self.patcher_get_collection = patch("todo.repositories.task_repository.TaskRepository.get_collection") + self.mock_get_collection = self.patcher_get_collection.start() + self.mock_collection = MagicMock(spec=Collection) + self.mock_get_collection.return_value = self.mock_collection + + def tearDown(self): + self.patcher_get_collection.stop() + + def test_list_applies_pagination_correctly(self): + self.mock_collection.find.return_value.skip.return_value.limit.return_value = self.task_data + + page = 1 + limit = 10 + result = TaskRepository.list(page, limit) + + self.assertEqual(len(result), len(self.task_data)) + self.assertTrue(all(isinstance(task, TaskModel) for task in result)) + + self.mock_collection.find.assert_called_once() + self.mock_collection.find.return_value.skip.assert_called_once_with(0) + self.mock_collection.find.return_value.skip.return_value.limit.assert_called_once_with(limit) + + def test_list_returns_empty_list_for_no_tasks(self): + self.mock_collection.find.return_value.skip.return_value.limit.return_value = [] + + result = TaskRepository.list(2, 10) + + self.assertEqual(result, []) + self.mock_collection.find.assert_called_once() + self.mock_collection.find.return_value.skip.assert_called_once_with(10) + self.mock_collection.find.return_value.skip.return_value.limit.assert_called_once_with(10) + + def test_count_returns_total_task_count(self): + self.mock_collection.count_documents.return_value = 42 + + result = TaskRepository.count() + + self.assertEqual(result, 42) + self.mock_collection.count_documents.assert_called_once_with({})