From b27536d531369fafdc4ccf524866828b13610443 Mon Sep 17 00:00:00 2001 From: Tushar Goel Date: Mon, 6 Jan 2025 17:07:06 +0530 Subject: [PATCH] Add tests Signed-off-by: Tushar Goel --- vulnerabilities/models.py | 97 +++++++++++++++++++++++++++- vulnerabilities/tests/test_models.py | 62 ++++++++++++++++-- vulnerabilities/tests/test_view.py | 54 ++++++++++++++++ vulnerabilities/utils.py | 10 +++ vulnerabilities/views.py | 69 +------------------- 5 files changed, 217 insertions(+), 75 deletions(-) diff --git a/vulnerabilities/models.py b/vulnerabilities/models.py index 6248e1e47..866eb430c 100644 --- a/vulnerabilities/models.py +++ b/vulnerabilities/models.py @@ -10,12 +10,15 @@ import hashlib import json import logging -import typing from contextlib import suppress from functools import cached_property -from typing import Optional +from itertools import groupby +from operator import attrgetter from typing import Union +from cvss.exceptions import CVSS2MalformedError +from cvss.exceptions import CVSS3MalformedError +from cvss.exceptions import CVSS4MalformedError from cwe2.database import Database from django.contrib.auth import get_user_model from django.contrib.auth.models import UserManager @@ -45,6 +48,7 @@ from aboutcode import hashid from vulnerabilities import utils +from vulnerabilities.severity_systems import EPSS from vulnerabilities.severity_systems import SCORING_SYSTEMS from vulnerabilities.utils import normalize_purl from vulnerabilities.utils import purl_to_dict @@ -371,6 +375,95 @@ def get_related_purls(self): """ return [p.package_url for p in self.packages.distinct().all()] + def aggregate_fixed_and_affected_packages(self): + from vulnerabilities.views import get_purl_version_class + + sorted_fixed_by_packages = self.fixed_by_packages.filter(is_ghost=False).order_by( + "type", "namespace", "name", "qualifiers", "subpath" + ) + + sorted_affected_packages = self.affected_packages.all() + + grouped_fixed_by_packages = { + key: list(group) + for key, group in groupby( + sorted_fixed_by_packages, + key=attrgetter("type", "namespace", "name", "qualifiers", "subpath"), + ) + } + + all_affected_fixed_by_matches = [] + + for sorted_affected_package in sorted_affected_packages: + affected_fixed_by_matches = { + "affected_package": sorted_affected_package, + "matched_fixed_by_packages": [], + } + + # Build the key to find matching group + key = ( + sorted_affected_package.type, + sorted_affected_package.namespace, + sorted_affected_package.name, + sorted_affected_package.qualifiers, + sorted_affected_package.subpath, + ) + + # Get matching group from pre-grouped fixed_by_packages + matching_fixed_packages = grouped_fixed_by_packages.get(key, []) + + # Get version classes for comparison + affected_version_class = get_purl_version_class(sorted_affected_package) + affected_version = affected_version_class(sorted_affected_package.version) + + # Compare versions and filter valid matches + matched_fixed_by_packages = [ + fixed_by_package.purl + for fixed_by_package in matching_fixed_packages + if get_purl_version_class(fixed_by_package)(fixed_by_package.version) + > affected_version + ] + + affected_fixed_by_matches["matched_fixed_by_packages"] = matched_fixed_by_packages + all_affected_fixed_by_matches.append(affected_fixed_by_matches) + return sorted_fixed_by_packages, sorted_affected_packages, all_affected_fixed_by_matches + + def get_severity_vectors_and_values(self): + """ + Collect severity vectors and values, excluding EPSS scoring systems and handling errors gracefully. + """ + severity_vectors = [] + severity_values = set() + + # Exclude EPSS scoring system + base_severities = self.severities.exclude(scoring_system=EPSS.identifier) + + # QuerySet for severities with valid scoring_elements and scoring_system in SCORING_SYSTEMS + valid_scoring_severities = base_severities.filter( + scoring_elements__isnull=False, scoring_system__in=SCORING_SYSTEMS.keys() + ) + + for severity in valid_scoring_severities: + try: + vector_values = SCORING_SYSTEMS[severity.scoring_system].get( + severity.scoring_elements + ) + if vector_values: + severity_vectors.append(vector_values) + except ( + CVSS2MalformedError, + CVSS3MalformedError, + CVSS4MalformedError, + NotImplementedError, + ) as e: + logging.error(f"CVSSMalformedError for {severity.scoring_elements}: {e}") + + valid_value_severities = base_severities.filter(value__isnull=False).exclude(value="") + + severity_values.update(valid_value_severities.values_list("value", flat=True)) + + return severity_vectors, severity_values + class Weakness(models.Model): """ diff --git a/vulnerabilities/tests/test_models.py b/vulnerabilities/tests/test_models.py index 014754786..5825fbdae 100644 --- a/vulnerabilities/tests/test_models.py +++ b/vulnerabilities/tests/test_models.py @@ -9,14 +9,10 @@ import urllib.parse from datetime import datetime -from unittest import TestCase +from django.test import TestCase from unittest import mock import pytest -from django.db import transaction -from django.db.models.query import QuerySet -from django.db.utils import IntegrityError -from freezegun import freeze_time from packageurl import PackageURL from univers import versions from univers.version_range import RANGE_CLASS_BY_SCHEMES @@ -26,7 +22,6 @@ from vulnerabilities.models import Alias from vulnerabilities.models import Package from vulnerabilities.models import Vulnerability -from vulnerabilities.models import VulnerabilityQuerySet class TestVulnerabilityModel(TestCase): @@ -604,3 +599,58 @@ def test_get_fixed_by_package_versions(self): assert all_package_versions[1] == self.package_pypi_redis_4_3_6 assert all_package_versions[2] == self.package_pypi_redis_5_0_0b1 assert all_package_versions.count() == 3 + + +class TestVulnerabilityModel(TestCase): + def setUp(self): + self.vuln1 = models.Vulnerability.objects.create( + vulnerability_id="VCID-1", summary="Vuln 1" + ) + self.vuln2 = models.Vulnerability.objects.create( + vulnerability_id="VCID-2", summary="Vuln 2" + ) + self.vuln3 = models.Vulnerability.objects.create( + vulnerability_id="VCID-3", summary="Vuln 3" + ) + self.vuln4 = models.Vulnerability.objects.create( + vulnerability_id="VCID-4", summary="Vuln 4" + ) + self.vuln5 = models.Vulnerability.objects.create( + vulnerability_id="VCID-5", summary="Vuln 5" + ) + + self.package1 = models.Package.objects.create( + type="pypi", name="django", version="1.0.0" + ) + self.package2 = models.Package.objects.create( + type="pypi", name="django", version="2.0.0" + ) + self.package3 = models.Package.objects.create( + type="pypi", name="django", version="3.0.0" + ) + + models.AffectedByPackageRelatedVulnerability.objects.create( + package=self.package1, vulnerability=self.vuln1 + ) + + models.AffectedByPackageRelatedVulnerability.objects.create( + package=self.package1, vulnerability=self.vuln2 + ) + + models.AffectedByPackageRelatedVulnerability.objects.create( + package=self.package2, vulnerability=self.vuln3 + ) + + models.AffectedByPackageRelatedVulnerability.objects.create( + package=self.package2, vulnerability=self.vuln4 + ) + + # Associate fixed_by package with vuln5 + + models.FixingPackageRelatedVulnerability.objects.create( + package=self.package3, vulnerability=self.vuln5 + ) + + def test_aggregate_fixed_and_affected_packages(self): + with self.assertNumQueries(2): + self.vuln1.aggregate_fixed_and_affected_packages() \ No newline at end of file diff --git a/vulnerabilities/tests/test_view.py b/vulnerabilities/tests/test_view.py index 692305f8d..9749d75f4 100644 --- a/vulnerabilities/tests/test_view.py +++ b/vulnerabilities/tests/test_view.py @@ -15,6 +15,7 @@ from packageurl import PackageURL from univers import versions +from vulnerabilities import models from vulnerabilities.models import Alias from vulnerabilities.models import Package from vulnerabilities.models import Vulnerability @@ -273,3 +274,56 @@ class TestCustomFilters: def test_url_quote_filter(self, input_value, expected_output): filtered = url_quote_filter(input_value) assert filtered == expected_output + + +class VulnerabilitySearchTestCaseWithPackages(TestCase): + def setUp(self): + self.vuln1 = models.Vulnerability.objects.create( + vulnerability_id="VCID-1", summary="Vuln 1" + ) + self.vuln2 = models.Vulnerability.objects.create( + vulnerability_id="VCID-2", summary="Vuln 2" + ) + self.vuln3 = models.Vulnerability.objects.create( + vulnerability_id="VCID-3", summary="Vuln 3" + ) + self.vuln4 = models.Vulnerability.objects.create( + vulnerability_id="VCID-4", summary="Vuln 4" + ) + self.vuln5 = models.Vulnerability.objects.create( + vulnerability_id="VCID-5", summary="Vuln 5" + ) + + self.package1 = models.Package.objects.create(type="pypi", name="django", version="1.0.0") + self.package2 = models.Package.objects.create(type="pypi", name="django", version="2.0.0") + self.package3 = models.Package.objects.create(type="pypi", name="django", version="3.0.0") + + models.AffectedByPackageRelatedVulnerability.objects.create( + package=self.package1, vulnerability=self.vuln1 + ) + + models.AffectedByPackageRelatedVulnerability.objects.create( + package=self.package1, vulnerability=self.vuln2 + ) + + models.AffectedByPackageRelatedVulnerability.objects.create( + package=self.package2, vulnerability=self.vuln3 + ) + + models.AffectedByPackageRelatedVulnerability.objects.create( + package=self.package2, vulnerability=self.vuln4 + ) + + # Associate fixed_by package with vuln5 + + models.FixingPackageRelatedVulnerability.objects.create( + package=self.package3, vulnerability=self.vuln5 + ) + + def test_aggregate_fixed_and_affected_packages(self): + with self.assertNumQueries(11): + response = self.client.get(f"/vulnerabilities/{self.vuln1.vulnerability_id}") + self.assertEqual(response.status_code, 200) + + with self.assertNumQueries(2): + self.vuln1.aggregate_fixed_and_affected_packages() diff --git a/vulnerabilities/utils.py b/vulnerabilities/utils.py index 969a08f2f..32cfcbc02 100644 --- a/vulnerabilities/utils.py +++ b/vulnerabilities/utils.py @@ -32,6 +32,7 @@ from packageurl import PackageURL from packageurl.contrib.django.utils import without_empty_values from univers.version_range import RANGE_CLASS_BY_SCHEMES +from univers.version_range import AlpineLinuxVersionRange from univers.version_range import NginxVersionRange from univers.version_range import VersionRange @@ -536,3 +537,12 @@ def normalize_purl(purl: Union[PackageURL, str]): if isinstance(purl, PackageURL): purl = str(purl) return PackageURL.from_string(purl) + + +def get_purl_version_class(purl): + RANGE_CLASS_BY_SCHEMES["alpine"] = AlpineLinuxVersionRange + purl_version_class = None + check_version_class = RANGE_CLASS_BY_SCHEMES.get(purl.type, None) + if check_version_class: + purl_version_class = check_version_class.version_class + return purl_version_class diff --git a/vulnerabilities/views.py b/vulnerabilities/views.py index 3f791dcd6..c977d0f9d 100644 --- a/vulnerabilities/views.py +++ b/vulnerabilities/views.py @@ -7,8 +7,6 @@ # See https://aboutcode.org for more information about nexB OSS projects. # import logging -from itertools import groupby -from operator import attrgetter from cvss.exceptions import CVSS2MalformedError from cvss.exceptions import CVSS3MalformedError @@ -24,17 +22,14 @@ from django.views import generic from django.views.generic.detail import DetailView from django.views.generic.list import ListView -from univers.version_range import RANGE_CLASS_BY_SCHEMES -from univers.version_range import AlpineLinuxVersionRange from vulnerabilities import models from vulnerabilities.forms import ApiUserCreationForm from vulnerabilities.forms import PackageSearchForm from vulnerabilities.forms import VulnerabilitySearchForm -from vulnerabilities.models import VulnerabilityStatusType from vulnerabilities.severity_systems import EPSS from vulnerabilities.severity_systems import SCORING_SYSTEMS -from vulnerabilities.utils import get_severity_range +from vulnerabilities.utils import get_purl_version_class from vulnerablecode.settings import env PAGE_SIZE = 20 @@ -54,15 +49,6 @@ def purl_sort_key(purl: models.Package): return (purl.type, purl.namespace, purl.name, purl_sort_version, purl.qualifiers, purl.subpath) -def get_purl_version_class(purl: models.Package): - RANGE_CLASS_BY_SCHEMES["alpine"] = AlpineLinuxVersionRange - purl_version_class = None - check_version_class = RANGE_CLASS_BY_SCHEMES.get(purl.type, None) - if check_version_class: - purl_version_class = check_version_class.version_class - return purl_version_class - - class PackageSearch(ListView): model = models.Package template_name = "packages.html" @@ -183,7 +169,7 @@ def get_context_data(self, **kwargs): sorted_fixed_by_packages, sorted_affected_packages, all_affected_fixed_by_matches, - ) = self.aggregate_fixed_and_affected_packages() + ) = self.object.aggregate_fixed_and_affected_packages() context.update( { @@ -204,57 +190,6 @@ def get_context_data(self, **kwargs): ) return context - def aggregate_fixed_and_affected_packages(self): - sorted_fixed_by_packages = self.object.fixed_by_packages.filter(is_ghost=False).order_by( - "type", "namespace", "name", "qualifiers", "subpath" - ) - - sorted_affected_packages = self.object.affected_packages.all() - - grouped_fixed_by_packages = { - key: list(group) - for key, group in groupby( - sorted_fixed_by_packages, - key=attrgetter("type", "namespace", "name", "qualifiers", "subpath"), - ) - } - - all_affected_fixed_by_matches = [] - - for sorted_affected_package in sorted_affected_packages: - affected_fixed_by_matches = { - "affected_package": sorted_affected_package, - "matched_fixed_by_packages": [], - } - - # Build the key to find matching group - key = ( - sorted_affected_package.type, - sorted_affected_package.namespace, - sorted_affected_package.name, - sorted_affected_package.qualifiers, - sorted_affected_package.subpath, - ) - - # Get matching group from pre-grouped fixed_by_packages - matching_fixed_packages = grouped_fixed_by_packages.get(key, []) - - # Get version classes for comparison - affected_version_class = get_purl_version_class(sorted_affected_package) - affected_version = affected_version_class(sorted_affected_package.version) - - # Compare versions and filter valid matches - matched_fixed_by_packages = [ - fixed_by_package.purl - for fixed_by_package in matching_fixed_packages - if get_purl_version_class(fixed_by_package)(fixed_by_package.version) - > affected_version - ] - - affected_fixed_by_matches["matched_fixed_by_packages"] = matched_fixed_by_packages - all_affected_fixed_by_matches.append(affected_fixed_by_matches) - return sorted_fixed_by_packages, sorted_affected_packages, all_affected_fixed_by_matches - def get_severity_vectors_and_values(self): """ Collect severity vectors and values, excluding EPSS scoring systems and handling errors gracefully.