diff --git a/vulnerabilities/api_v2.py b/vulnerabilities/api_v2.py index 2ab782d59..b570570ed 100644 --- a/vulnerabilities/api_v2.py +++ b/vulnerabilities/api_v2.py @@ -8,6 +8,7 @@ # +from django.db.models import Prefetch from django_filters import rest_framework as filters from drf_spectacular.utils import OpenApiParameter from drf_spectacular.utils import extend_schema @@ -20,8 +21,6 @@ from rest_framework.response import Response from rest_framework.reverse import reverse -from vulnerabilities.api import PackageFilterSet -from vulnerabilities.api import VulnerabilitySeveritySerializer from vulnerabilities.models import Package from vulnerabilities.models import Vulnerability from vulnerabilities.models import VulnerabilityReference @@ -195,7 +194,20 @@ class Meta: ] def get_affected_by_vulnerabilities(self, obj): - return [vuln.vulnerability_id for vuln in obj.affected_by_vulnerabilities.all()] + """ + Return a dictionary with vulnerabilities as keys and their details, including fixed_by_packages. + """ + result = {} + for vuln in getattr(obj, "prefetched_affected_vulnerabilities", []): + fixed_by_package = vuln.fixed_by_packages.first() + purl = None + if fixed_by_package: + purl = fixed_by_package.package_url + result[vuln.vulnerability_id] = { + "vulnerability_id": vuln.vulnerability_id, + "fixed_by_packages": purl, + } + return result def get_fixing_vulnerabilities(self, obj): # Ghost package should not fix any vulnerability. @@ -233,7 +245,13 @@ class PackageV2FilterSet(filters.FilterSet): class PackageV2ViewSet(viewsets.ReadOnlyModelViewSet): - queryset = Package.objects.all() + queryset = Package.objects.all().prefetch_related( + Prefetch( + "affected_by_vulnerabilities", + queryset=Vulnerability.objects.prefetch_related("fixed_by_packages"), + to_attr="prefetched_affected_vulnerabilities", + ) + ) serializer_class = PackageV2Serializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = PackageV2FilterSet diff --git a/vulnerabilities/tests/test_api_v2.py b/vulnerabilities/tests/test_api_v2.py index fa3b7773c..af4dc47c8 100644 --- a/vulnerabilities/tests/test_api_v2.py +++ b/vulnerabilities/tests/test_api_v2.py @@ -7,6 +7,7 @@ # See https://aboutcode.org for more information about nexB OSS projects. # +from django.db.models import Prefetch from django.urls import reverse from packageurl import PackageURL from rest_framework import status @@ -67,6 +68,8 @@ def test_list_vulnerabilities(self): """ url = reverse("vulnerability-v2-list") response = self.client.get(url, format="json") + with self.assertNumQueries(5): + response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("results", response.data) self.assertIn("vulnerabilities", response.data["results"]) @@ -80,7 +83,8 @@ def test_retrieve_vulnerability_detail(self): Test retrieving vulnerability details by vulnerability_id. """ url = reverse("vulnerability-v2-detail", kwargs={"vulnerability_id": "VCID-1234"}) - response = self.client.get(url, format="json") + with self.assertNumQueries(8): + response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["vulnerability_id"], "VCID-1234") self.assertEqual(response.data["summary"], "Test vulnerability 1") @@ -93,7 +97,8 @@ def test_filter_vulnerability_by_vulnerability_id(self): Test filtering vulnerabilities by vulnerability_id. """ url = reverse("vulnerability-v2-list") - response = self.client.get(url, {"vulnerability_id": "VCID-1234"}, format="json") + with self.assertNumQueries(4): + response = self.client.get(url, {"vulnerability_id": "VCID-1234"}, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["vulnerability_id"], "VCID-1234") @@ -102,7 +107,8 @@ def test_filter_vulnerability_by_alias(self): Test filtering vulnerabilities by alias. """ url = reverse("vulnerability-v2-list") - response = self.client.get(url, {"alias": "CVE-2021-5678"}, format="json") + with self.assertNumQueries(5): + response = self.client.get(url, {"alias": "CVE-2021-5678"}, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("results", response.data) self.assertIn("vulnerabilities", response.data["results"]) @@ -116,9 +122,10 @@ def test_filter_vulnerabilities_multiple_ids(self): Test filtering vulnerabilities by multiple vulnerability_ids. """ url = reverse("vulnerability-v2-list") - response = self.client.get( - url, {"vulnerability_id": ["VCID-1234", "VCID-5678"]}, format="json" - ) + with self.assertNumQueries(5): + response = self.client.get( + url, {"vulnerability_id": ["VCID-1234", "VCID-5678"]}, format="json" + ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]["vulnerabilities"]), 2) @@ -127,9 +134,10 @@ def test_filter_vulnerabilities_multiple_aliases(self): Test filtering vulnerabilities by multiple aliases. """ url = reverse("vulnerability-v2-list") - response = self.client.get( - url, {"alias": ["CVE-2021-1234", "CVE-2021-5678"]}, format="json" - ) + with self.assertNumQueries(5): + response = self.client.get( + url, {"alias": ["CVE-2021-1234", "CVE-2021-5678"]}, format="json" + ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]["vulnerabilities"]), 2) @@ -139,7 +147,8 @@ def test_invalid_vulnerability_id(self): Should return 404 Not Found. """ url = reverse("vulnerability-v2-detail", kwargs={"vulnerability_id": "VCID-9999"}) - response = self.client.get(url, format="json") + with self.assertNumQueries(5): + response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_get_url_in_serializer(self): @@ -207,7 +216,8 @@ def test_list_packages(self): Should return a list of packages with their details and associated vulnerabilities. """ url = reverse("package-v2-list") - response = self.client.get(url, format="json") + with self.assertNumQueries(31): + response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("results", response.data) self.assertIn("packages", response.data["results"]) @@ -228,7 +238,8 @@ def test_filter_packages_by_purl(self): Test filtering packages by one or more PURLs. """ url = reverse("package-v2-list") - response = self.client.get(url, {"purl": "pkg:pypi/django@3.2"}, format="json") + with self.assertNumQueries(19): + response = self.client.get(url, {"purl": "pkg:pypi/django@3.2"}, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]["packages"]), 1) self.assertEqual(response.data["results"]["packages"][0]["purl"], "pkg:pypi/django@3.2") @@ -238,7 +249,10 @@ def test_filter_packages_by_affected_vulnerability(self): Test filtering packages by affected_by_vulnerability. """ url = reverse("package-v2-list") - response = self.client.get(url, {"affected_by_vulnerability": "VCID-1234"}, format="json") + with self.assertNumQueries(19): + response = self.client.get( + url, {"affected_by_vulnerability": "VCID-1234"}, format="json" + ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]["packages"]), 1) self.assertEqual(response.data["results"]["packages"][0]["purl"], "pkg:pypi/django@3.2") @@ -248,26 +262,59 @@ def test_filter_packages_by_fixing_vulnerability(self): Test filtering packages by fixing_vulnerability. """ url = reverse("package-v2-list") - response = self.client.get(url, {"fixing_vulnerability": "VCID-5678"}, format="json") + with self.assertNumQueries(18): + response = self.client.get(url, {"fixing_vulnerability": "VCID-5678"}, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]["packages"]), 1) self.assertEqual(response.data["results"]["packages"][0]["purl"], "pkg:npm/lodash@4.17.20") def test_package_serializer_fields(self): """ - Test that the PackageV2Serializer returns the correct fields. + Test that the PackageV2Serializer returns the correct fields and formats them correctly. """ + # Fetch the package package = Package.objects.get(package_url="pkg:pypi/django@3.2") + + # Ensure prefetched data is available for the serializer + package = ( + Package.objects.filter(package_url="pkg:pypi/django@3.2") + .prefetch_related( + Prefetch( + "affected_by_vulnerabilities", + queryset=Vulnerability.objects.prefetch_related("fixed_by_packages"), + to_attr="prefetched_affected_vulnerabilities", + ) + ) + .first() + ) + + # Serialize the package serializer = PackageV2Serializer(package) data = serializer.data + + # Verify the presence of required fields self.assertIn("purl", data) self.assertIn("affected_by_vulnerabilities", data) self.assertIn("fixing_vulnerabilities", data) self.assertIn("next_non_vulnerable_version", data) self.assertIn("latest_non_vulnerable_version", data) + self.assertIn("risk_score", data) + + # Verify field values self.assertEqual(data["purl"], "pkg:pypi/django@3.2") - self.assertEqual(data["affected_by_vulnerabilities"], ["VCID-1234"]) - self.assertEqual(data["fixing_vulnerabilities"], []) + self.assertEqual(data["next_non_vulnerable_version"], None) + self.assertEqual(data["latest_non_vulnerable_version"], None) + self.assertEqual(data["risk_score"], None) + + # Verify affected_by_vulnerabilities structure + expected_affected_by_vulnerabilities = { + "VCID-1234": {"vulnerability_id": "VCID-1234", "fixed_by_packages": None} + } + self.assertEqual(data["affected_by_vulnerabilities"], expected_affected_by_vulnerabilities) + + # Verify fixing_vulnerabilities structure + expected_fixing_vulnerabilities = [] + self.assertEqual(data["fixing_vulnerabilities"], expected_fixing_vulnerabilities) def test_list_packages_pagination(self): """ @@ -300,7 +347,10 @@ def test_invalid_vulnerability_filter(self): Should return an empty list. """ url = reverse("package-v2-list") - response = self.client.get(url, {"affected_by_vulnerability": "VCID-9999"}, format="json") + with self.assertNumQueries(4): + response = self.client.get( + url, {"affected_by_vulnerability": "VCID-9999"}, format="json" + ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]["packages"]), 0) @@ -310,7 +360,10 @@ def test_invalid_purl_filter(self): Should return an empty list. """ url = reverse("package-v2-list") - response = self.client.get(url, {"purl": "pkg:nonexistent/package@1.0.0"}, format="json") + with self.assertNumQueries(4): + response = self.client.get( + url, {"purl": "pkg:nonexistent/package@1.0.0"}, format="json" + ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]["packages"]), 0) @@ -318,10 +371,24 @@ def test_get_affected_by_vulnerabilities(self): """ Test the get_affected_by_vulnerabilities method in the serializer. """ - package = Package.objects.get(package_url="pkg:pypi/django@3.2") + package = ( + Package.objects.filter(package_url="pkg:pypi/django@3.2") + .prefetch_related( + Prefetch( + "affected_by_vulnerabilities", + queryset=Vulnerability.objects.prefetch_related("fixed_by_packages"), + to_attr="prefetched_affected_vulnerabilities", + ) + ) + .first() + ) + serializer = PackageV2Serializer() vulnerabilities = serializer.get_affected_by_vulnerabilities(package) - self.assertEqual(vulnerabilities, ["VCID-1234"]) + self.assertEqual( + vulnerabilities, + {"VCID-1234": {"vulnerability_id": "VCID-1234", "fixed_by_packages": None}}, + ) def test_get_fixing_vulnerabilities(self): """ @@ -339,7 +406,8 @@ def test_bulk_lookup_with_valid_purls(self): """ url = reverse("package-v2-bulk-lookup") data = {"purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"]} - response = self.client.post(url, data, format="json") + with self.assertNumQueries(28): + response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("packages", response.data) self.assertIn("vulnerabilities", response.data) @@ -363,7 +431,8 @@ def test_bulk_lookup_with_invalid_purls(self): """ url = reverse("package-v2-bulk-lookup") data = {"purls": ["pkg:pypi/nonexistent@1.0.0", "pkg:npm/unknown@0.0.1"]} - response = self.client.post(url, data, format="json") + with self.assertNumQueries(4): + response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Since the packages don't exist, the response should be empty self.assertEqual(len(response.data["packages"]), 0) @@ -376,7 +445,8 @@ def test_bulk_lookup_with_empty_purls(self): """ url = reverse("package-v2-bulk-lookup") data = {"purls": []} - response = self.client.post(url, data, format="json") + with self.assertNumQueries(3): + response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertIn("error", response.data) self.assertIn("message", response.data) @@ -389,7 +459,8 @@ def test_bulk_search_with_valid_purls(self): """ url = reverse("package-v2-bulk-search") data = {"purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"]} - response = self.client.post(url, data, format="json") + with self.assertNumQueries(28): + response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("packages", response.data) self.assertIn("vulnerabilities", response.data) @@ -416,7 +487,8 @@ def test_bulk_search_with_purl_only_true(self): "purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"], "purl_only": True, } - response = self.client.post(url, data, format="json") + with self.assertNumQueries(17): + response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Since purl_only=True, response should be a list of PURLs self.assertIsInstance(response.data, list) @@ -442,7 +514,8 @@ def test_bulk_search_with_plain_purl_true(self): "purls": ["pkg:pypi/django@3.2", "pkg:pypi/django@3.2?extension=tar.gz"], "plain_purl": True, } - response = self.client.post(url, data, format="json") + with self.assertNumQueries(16): + response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("packages", response.data) self.assertIn("vulnerabilities", response.data) @@ -462,7 +535,8 @@ def test_bulk_search_with_purl_only_and_plain_purl_true(self): "purl_only": True, "plain_purl": True, } - response = self.client.post(url, data, format="json") + with self.assertNumQueries(11): + response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Response should be a list of plain PURLs self.assertIsInstance(response.data, list) @@ -477,7 +551,8 @@ def test_bulk_search_with_invalid_purls(self): """ url = reverse("package-v2-bulk-search") data = {"purls": ["pkg:pypi/nonexistent@1.0.0", "pkg:npm/unknown@0.0.1"]} - response = self.client.post(url, data, format="json") + with self.assertNumQueries(4): + response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Since the packages don't exist, the response should be empty self.assertEqual(len(response.data["packages"]), 0) @@ -490,7 +565,8 @@ def test_bulk_search_with_empty_purls(self): """ url = reverse("package-v2-bulk-search") data = {"purls": []} - response = self.client.post(url, data, format="json") + with self.assertNumQueries(3): + response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertIn("error", response.data) self.assertIn("message", response.data) @@ -501,7 +577,8 @@ def test_all_vulnerable_packages(self): Test the 'all' endpoint that returns all vulnerable package URLs. """ url = reverse("package-v2-all") - response = self.client.get(url, format="json") + with self.assertNumQueries(4): + response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Since package1 is vulnerable, it should be returned expected_purls = ["pkg:pypi/django@3.2"] @@ -514,7 +591,8 @@ def test_lookup_with_valid_purl(self): """ url = reverse("package-v2-lookup") data = {"purl": "pkg:pypi/django@3.2"} - response = self.client.post(url, data, format="json") + with self.assertNumQueries(12): + response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(1, len(response.data)) self.assertIn("purl", response.data[0]) @@ -523,7 +601,10 @@ def test_lookup_with_valid_purl(self): self.assertIn("next_non_vulnerable_version", response.data[0]) self.assertIn("latest_non_vulnerable_version", response.data[0]) self.assertEqual(response.data[0]["purl"], "pkg:pypi/django@3.2") - self.assertEqual(response.data[0]["affected_by_vulnerabilities"], ["VCID-1234"]) + self.assertEqual( + response.data[0]["affected_by_vulnerabilities"], + {"VCID-1234": {"vulnerability_id": "VCID-1234", "fixed_by_packages": None}}, + ) self.assertEqual(response.data[0]["fixing_vulnerabilities"], []) def test_lookup_with_invalid_purl(self): @@ -533,7 +614,8 @@ def test_lookup_with_invalid_purl(self): """ url = reverse("package-v2-lookup") data = {"purl": "pkg:pypi/nonexistent@1.0.0"} - response = self.client.post(url, data, format="json") + with self.assertNumQueries(4): + response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # No packages or vulnerabilities should be returned self.assertEqual(len(response.data), 0) @@ -545,7 +627,8 @@ def test_lookup_with_missing_purl(self): """ url = reverse("package-v2-lookup") data = {} - response = self.client.post(url, data, format="json") + with self.assertNumQueries(3): + response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertIn("error", response.data) self.assertIn("message", response.data) @@ -558,7 +641,8 @@ def test_lookup_with_invalid_purl_format(self): """ url = reverse("package-v2-lookup") data = {"purl": "invalid_purl_format"} - response = self.client.post(url, data, format="json") + with self.assertNumQueries(4): + response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # No packages or vulnerabilities should be returned self.assertEqual(len(response.data), 0) diff --git a/vulnerabilities/tests/test_view.py b/vulnerabilities/tests/test_view.py index 692305f8d..fd62e94a1 100644 --- a/vulnerabilities/tests/test_view.py +++ b/vulnerabilities/tests/test_view.py @@ -8,6 +8,7 @@ # import os +import time import pytest from django.test import Client @@ -15,9 +16,13 @@ from packageurl import PackageURL from univers import versions +from vulnerabilities import models +from vulnerabilities.models import AffectedByPackageRelatedVulnerability from vulnerabilities.models import Alias +from vulnerabilities.models import FixingPackageRelatedVulnerability from vulnerabilities.models import Package from vulnerabilities.models import Vulnerability +from vulnerabilities.models import VulnerabilitySeverity from vulnerabilities.templatetags.url_filters import url_quote_filter from vulnerabilities.views import PackageDetails from vulnerabilities.views import PackageSearch @@ -273,3 +278,56 @@ class TestCustomFilters: def test_url_quote_filter(self, input_value, expected_output): filtered = url_quote_filter(input_value) assert filtered == expected_output + + +class VulnerabilitySearchTestCaseWithPackages(TestCase): + def setUp(self): + self.vuln1 = Vulnerability.objects.create(vulnerability_id="VCID-1", summary="Vuln 1") + self.vuln2 = Vulnerability.objects.create(vulnerability_id="VCID-2", summary="Vuln 2") + self.vuln3 = Vulnerability.objects.create(vulnerability_id="VCID-3", summary="Vuln 3") + self.vuln4 = Vulnerability.objects.create(vulnerability_id="VCID-4", summary="Vuln 4") + self.vuln5 = Vulnerability.objects.create(vulnerability_id="VCID-5", summary="Vuln 5") + + self.package1 = Package.objects.create(type="pypi", name="django", version="1.0.0") + self.package2 = Package.objects.create(type="pypi", name="django", version="2.0.0") + self.package3 = Package.objects.create(type="pypi", name="django", version="3.0.0") + + AffectedByPackageRelatedVulnerability.objects.create( + package=self.package1, vulnerability=self.vuln1 + ) + AffectedByPackageRelatedVulnerability.objects.create( + package=self.package1, vulnerability=self.vuln2 + ) + AffectedByPackageRelatedVulnerability.objects.create( + package=self.package2, vulnerability=self.vuln3 + ) + AffectedByPackageRelatedVulnerability.objects.create( + package=self.package2, vulnerability=self.vuln4 + ) + + FixingPackageRelatedVulnerability.objects.create( + package=self.package3, vulnerability=self.vuln5 + ) + + self.severity1 = VulnerabilitySeverity.objects.create( + scoring_system="CVSSv3", + value="9.8", + scoring_elements="AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H", + ) + self.severity2 = VulnerabilitySeverity.objects.create( + scoring_system="CVSSv3", + value="7.5", + scoring_elements="AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H", + ) + + self.vuln1.severities.add(self.severity1) + self.vuln1.severities.add(self.severity2) + self.vuln1.save() + + def test_aggregate_fixed_and_affected_packages(self): + with self.assertNumQueries(11): + start_time = time.time() + response = self.client.get(f"/vulnerabilities/{self.vuln1.vulnerability_id}") + end_time = time.time() + assert end_time - start_time < 0.05 + self.assertEqual(response.status_code, 200)