From bd20a7658d03d372c28cf5b45b5e93d1ee576081 Mon Sep 17 00:00:00 2001 From: JCantu248 Date: Wed, 13 Nov 2024 16:34:27 -0600 Subject: [PATCH 1/9] Add test cases for domains and fixed issues in domains and vulnerabilities. --- .../xfd_django/xfd_api/api_methods/domain.py | 12 +- .../xfd_api/api_methods/vulnerability.py | 9 - .../xfd_api/helpers/filter_helpers.py | 30 ++- .../xfd_django/xfd_api/tests/test_domain.py | 251 ++++++++++++++++++ backend/src/xfd_django/xfd_api/views.py | 17 +- 5 files changed, 278 insertions(+), 41 deletions(-) create mode 100644 backend/src/xfd_django/xfd_api/tests/test_domain.py diff --git a/backend/src/xfd_django/xfd_api/api_methods/domain.py b/backend/src/xfd_django/xfd_api/api_methods/domain.py index 70e85f2c..d83d35d8 100644 --- a/backend/src/xfd_django/xfd_api/api_methods/domain.py +++ b/backend/src/xfd_django/xfd_api/api_methods/domain.py @@ -7,8 +7,10 @@ import csv # Third-Party Libraries +from django.core.exceptions import ObjectDoesNotExist from django.core.paginator import Paginator from django.db.models import Q +from django.http import Http404 from fastapi import HTTPException from ..auth import get_org_memberships, is_global_view_admin @@ -45,14 +47,6 @@ def search_domains(domain_search: DomainSearch, current_user): sort_direction(domain_search.sort, domain_search.order) ) - # Apply global filters based on user permissions - if not is_global_view_admin(current_user): - orgs = get_org_memberships(current_user) - if not orgs: - # No organization memberships, return empty result - return [], 0 - domains = domains.filter(organization__id__in=orgs) - # Add a filter to restrict based on FCEB and CIDR criteria domains = domains.filter(Q(isFceb=True) | Q(isFceb=False, fromCidr=True)) @@ -61,6 +55,8 @@ def search_domains(domain_search: DomainSearch, current_user): paginator = Paginator(domains, domain_search.pageSize) return paginator.get_page(domain_search.page) + except Domain.DoesNotExist as e: + raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py b/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py index 66905577..c29b58c6 100644 --- a/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py +++ b/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py @@ -58,15 +58,6 @@ def search_vulnerabilities(vulnerability_search: VulnerabilitySearch, current_us sort_direction(vulnerability_search.sort, vulnerability_search.order) ) - # Permissions check - if not is_global_view_admin(current_user): - org_ids = get_org_memberships(current_user) - if not org_ids: - return [], 0 # User has no accessible organizations - vulnerabilities = vulnerabilities.filter( - domain__organization_id__in=org_ids - ) - # Apply custom FCEB and CIDR filter vulnerabilities = vulnerabilities.filter( Q(domain__isFceb=True) | Q(domain__isFceb=False, domain__fromCidr=True) diff --git a/backend/src/xfd_django/xfd_api/helpers/filter_helpers.py b/backend/src/xfd_django/xfd_api/helpers/filter_helpers.py index cd5474f4..9e47f64a 100644 --- a/backend/src/xfd_django/xfd_api/helpers/filter_helpers.py +++ b/backend/src/xfd_django/xfd_api/helpers/filter_helpers.py @@ -43,7 +43,9 @@ def filter_domains(domains: QuerySet, domain_filters: DomainFilters): "domain" ) if not services_by_port.exists(): - raise Http404("No Domains found with the provided port") + raise ObjectDoesNotExist( + "Domain could not be found with provided port." + ) domains = domains.filter(id__in=services_by_port) if domain_filters.service: @@ -51,7 +53,7 @@ def filter_domains(domains: QuerySet, domain_filters: DomainFilters): "domain" ) if not service_by_id.exists(): - raise Http404("No Domains found with the provided service") + raise Domain.DoesNotExist("No Domains found with the provided service") domains = domains.filter(id__in=service_by_id) if domain_filters.reverseName: @@ -59,13 +61,15 @@ def filter_domains(domains: QuerySet, domain_filters: DomainFilters): reverseName=domain_filters.reverseName ).values("id") if not domains_by_reverse_name.exists(): - raise Http404("No Domains found with the provided reverse name") + raise Domain.DoesNotExist( + "No Domains found with the provided reverse name" + ) domains = domains.filter(id__in=domains_by_reverse_name) if domain_filters.ip: domains_by_ip = Domain.objects.filter(ip=domain_filters.ip).values("id") if not domains_by_ip.exists(): - raise Http404("No Domains found with the provided ip") + raise Domain.DoesNotExist("Domain could not be found with provided Ip.") domains = domains.filter(id__in=domains_by_ip) if domain_filters.organization: @@ -73,7 +77,9 @@ def filter_domains(domains: QuerySet, domain_filters: DomainFilters): organization_id=domain_filters.organization ).values("id") if not domains_by_org.exists(): - raise Http404("No Domains found with the provided organization") + raise Domain.DoesNotExist( + "No Domains found with the provided organization" + ) domains = domains.filter(id__in=domains_by_org) if domain_filters.organizationName: @@ -81,7 +87,9 @@ def filter_domains(domains: QuerySet, domain_filters: DomainFilters): name=domain_filters.organizationName ).values("id") if not organization_by_name.exists(): - raise Http404("No Domains found with the provided organization name") + raise Domain.DoesNotExist( + "No Domains found with the provided organization name" + ) domains = domains.filter(organization_id__in=organization_by_name) if domain_filters.vulnerabilities: @@ -89,13 +97,13 @@ def filter_domains(domains: QuerySet, domain_filters: DomainFilters): id=domain_filters.vulnerabilities ).values("domain") if not vulnerabilities_by_id.exists(): - raise Http404("No Domains found with the provided vulnerability") + raise Domain.DoesNotExist( + "No Domains found with the provided vulnerability" + ) domains = domains.filter(id__in=vulnerabilities_by_id) return domains - except ObjectDoesNotExist: - print("No vulnerability found with that ID.") - except Exception as e: - print(f"Error: {e}") + except Domain.DoesNotExist as e: + raise e def filter_vulnerabilities( diff --git a/backend/src/xfd_django/xfd_api/tests/test_domain.py b/backend/src/xfd_django/xfd_api/tests/test_domain.py new file mode 100644 index 00000000..da6cf9ca --- /dev/null +++ b/backend/src/xfd_django/xfd_api/tests/test_domain.py @@ -0,0 +1,251 @@ +# Standard Python Libraries +from datetime import datetime +import logging +import secrets + +# Configure logging +logging.basicConfig(level=logging.DEBUG) # Set the logging level to DEBUG +logger = logging.getLogger(__name__) + + +# Third-Party Libraries +from fastapi.testclient import TestClient +import pytest +from xfd_api.auth import create_jwt_token +from xfd_api.models import User, UserType +from xfd_django.asgi import app + +client = TestClient(app) + + +test_id = "960b7db7-f3af-411d-a247-33371739050b" +filters = { + "ports": "80", + "service": "6d9ecf5a-db5d-4b77-9752-a88a5d247631", + "reverseName": "local.crossfeed.quizzical-wing", + "ip": "127.116.195.151", + "organization": "5ef69132-d3ab-43d2-bbe4-a1c79962af9c", + "organizationName": "Wizardly Agency", + "vulnerabilities": "c0effe93-3647-475a-a0c5-0b629c348588", + "tag": "", +} + + +@pytest.mark.django_db(transaction=True) +def test_get_domain_by_id(): + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.get( + f"/domain/{test_id}", + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + data = response.json() + + assert response.status_code == 200 + + +@pytest.mark.django_db(transaction=True) +def test_filter_domain_by_ip(capfd): + # Filter domains by ip + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/domain/search", + json={"page": 1, "filters": {"ip": filters["ip"]}, "pageSize": 25}, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + assert response.status_code == 200 + data = response.json() + for domain in data: + assert domain["ip"] == filters["ip"] + + +@pytest.mark.django_db(transaction=True) +def test_filter_domain_by_port(): + # Test filter domains by port + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + response = client.post( + "/domain/search", + json={"page": 1, "filters": {"ports": filters["ports"]}, "pageSize": 25}, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + assert response.status_code == 200 + data = response.json() + assert data is not None + for domain in data: + assert domain["id"] != "" + + +@pytest.mark.django_db(transaction=True) +def test_filter_domain_by_service(): + # Test filter domains by service_id + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/domain/search", + json={"page": 1, "filters": {"service": filters["service"]}, "pageSize": 25}, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + assert response.status_code == 200 + data = response.json() + assert data is not None + for domain in data: + assert domain["id"] != "" + + +@pytest.mark.django_db(transaction=True) +def test_filter_domain_by_organization(): + # Test filter domains by organization + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + response = client.post( + "/domain/search", + json={ + "page": 1, + "filters": {"organization": filters["organization"]}, + "pageSize": 25, + }, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + assert response.status_code == 200 + data = response.json() + for domain in data: + assert domain["organization_id"] == filters["organization"] + + +@pytest.mark.django_db(transaction=True) +def test_filter_domain_by_organization_name(): + # Test filter domains by organization + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/domain/search", + json={ + "page": 1, + "filters": {"organizationName": filters["organizationName"]}, + "pageSize": 25, + }, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + assert response.status_code == 200 + data = response.json() + assert data is not None + for domain in data: + assert domain["id"] != "" + + +@pytest.mark.django_db(transaction=True) +def test_filter_domain_by_vulnerabilities(): + # Test filter domains by vulnerabilities + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/domain/search", + json={ + "page": 1, + "filters": {"vulnerabilities": filters["vulnerabilities"]}, + "pageSize": 25, + }, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + assert response.status_code == 200 + data = response.json() + assert data is not None + for domain in data: + assert domain["id"] != "" + + +@pytest.mark.django_db(transaction=True) +def test_filter_domains_multiple_criteria(): + # Test filter domains by multiple criteria + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/domain/search", + json={ + "page": 1, + "filters": {"ip": filters["ip"], "ports": filters["ports"]}, + "pageSize": 25, + }, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + assert response.status_code == 200 + data = response.json() + for domain in data: + assert domain["ip"] == filters["ip"] + + +@pytest.mark.django_db(transaction=True) +def test_filter_domains_does_not_exist(): + # Test filter domains if record does not exist + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/domain/search", + json={"page": 1, "filters": {"ip": "Does not exist"}, "pageSize": 25}, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + + assert response.status_code == 404 diff --git a/backend/src/xfd_django/xfd_api/views.py b/backend/src/xfd_django/xfd_api/views.py index 1ca646e6..8bdfe89a 100644 --- a/backend/src/xfd_django/xfd_api/views.py +++ b/backend/src/xfd_django/xfd_api/views.py @@ -33,7 +33,7 @@ ) from .auth import get_current_active_user from .login_gov import callback, login -from .models import Assessment, User +from .models import Assessment, Domain, User, Vulnerability from .schema_models import organization as OrganizationSchema from .schema_models import scan as scanSchema from .schema_models import scan_tasks as scanTaskSchema @@ -221,10 +221,7 @@ async def call_get_cves_by_name(cve_name): async def call_search_domains( domain_search: DomainSearch, current_user: User = Depends(get_current_active_user) ): - try: - return search_domains(domain_search, current_user) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + return search_domains(domain_search, current_user) @api_router.post( @@ -233,10 +230,7 @@ async def call_search_domains( tags=["Domains"], ) async def call_export_domains(domain_search: DomainSearch): - try: - return export_domains(domain_search) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + return export_domains(domain_search) @api_router.get( @@ -264,10 +258,7 @@ async def call_search_vulnerabilities( vulnerability_search: VulnerabilitySearch, current_user: User = Depends(get_current_active_user), ): - try: - return search_vulnerabilities(vulnerability_search, current_user) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + return search_vulnerabilities(vulnerability_search, current_user) @api_router.post("/vulnerabilities/export") From 9e09de41d4c0d38529aad778797376fc8bfd5e77 Mon Sep 17 00:00:00 2001 From: JCantu248 Date: Thu, 14 Nov 2024 12:09:42 -0600 Subject: [PATCH 2/9] WIP. --- backend/src/xfd_django/xfd_api/views.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/src/xfd_django/xfd_api/views.py b/backend/src/xfd_django/xfd_api/views.py index 8bdfe89a..b5c3c9e6 100644 --- a/backend/src/xfd_django/xfd_api/views.py +++ b/backend/src/xfd_django/xfd_api/views.py @@ -271,11 +271,11 @@ async def export_vulnerabilities(): @api_router.get( "/vulnerabilities/{vulnerabilityId}", - # dependencies=[Depends(get_current_active_user)], + dependencies=[Depends(get_current_active_user)], response_model=VulnerabilitySchema, - tags="Get vulnerability by id", + tags=["Get vulnerability by id"], ) -async def call_get_vulnerability_by_id(vuln_id): +async def call_get_vulnerability_by_id(vuln_id: str): """ Get vulnerability by id. Returns: From 68335fef0062582bf687db275421070f94cb2ceb Mon Sep 17 00:00:00 2001 From: JCantu248 Date: Thu, 14 Nov 2024 12:10:58 -0600 Subject: [PATCH 3/9] WIP. Adding vulnerability test file and update related api files prior to merging AL-python-serverless changes. --- .../xfd_django/xfd_api/api_methods/domain.py | 2 ++ .../xfd_api/api_methods/vulnerability.py | 6 ++++ .../xfd_api/helpers/filter_helpers.py | 34 ++++++++++++++----- .../xfd_api/tests/test_vulnerability.py | 18 ++++++++++ 4 files changed, 51 insertions(+), 9 deletions(-) create mode 100644 backend/src/xfd_django/xfd_api/tests/test_vulnerability.py diff --git a/backend/src/xfd_django/xfd_api/api_methods/domain.py b/backend/src/xfd_django/xfd_api/api_methods/domain.py index d83d35d8..40b28fde 100644 --- a/backend/src/xfd_django/xfd_api/api_methods/domain.py +++ b/backend/src/xfd_django/xfd_api/api_methods/domain.py @@ -70,5 +70,7 @@ def export_domains(domain_filters: DomainFilters): # TODO: Integrate methods to generate CSV from queryset and save to S3 bucket return domains + except Domain.DoesNotExist as e: + raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py b/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py index c29b58c6..49dee9c7 100644 --- a/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py +++ b/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py @@ -24,6 +24,8 @@ def get_vulnerability_by_id(vuln_id): try: vulnerability = Vulnerability.objects.get(id=vuln_id) return vulnerability + except Vulnerability.DoesNotExist: + raise HTTPException(status_code=404, detail="Vulnerability not found.") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -75,6 +77,8 @@ def search_vulnerabilities(vulnerability_search: VulnerabilitySearch, current_us paginator = Paginator(vulnerabilities, vulnerability_search.pageSize) return paginator.get_page(vulnerability_search.page) + except Vulnerability.DoesNotExist as e: + raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -90,5 +94,7 @@ def export_vulnerabilities(vulnerability_filters: VulnerabilityFilters): # TODO: Integrate methods to generate CSV from queryset and save to S3 bucket return vulnerabilities + except Vulnerability.DoesNotExist as e: + raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/src/xfd_django/xfd_api/helpers/filter_helpers.py b/backend/src/xfd_django/xfd_api/helpers/filter_helpers.py index 9e47f64a..39c03652 100644 --- a/backend/src/xfd_django/xfd_api/helpers/filter_helpers.py +++ b/backend/src/xfd_django/xfd_api/helpers/filter_helpers.py @@ -104,6 +104,8 @@ def filter_domains(domains: QuerySet, domain_filters: DomainFilters): return domains except Domain.DoesNotExist as e: raise e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) def filter_vulnerabilities( @@ -123,7 +125,9 @@ def filter_vulnerabilities( id=vulnerability_filters.id ) if not vulnerability_by_id: - raise Http404("No Vulnerabilities found with the provided id") + raise Vulnerability.DoesNotExist( + "No Vulnerabilities found with the provided id" + ) vulnerabilities = vulnerabilities.filter(id=vulnerability_by_id) if vulnerability_filters.title: @@ -131,7 +135,9 @@ def filter_vulnerabilities( title=vulnerability_filters.title ) if not vulnerabilities_by_title.exists(): - raise Http404("No Vulnerabilities found with the provided title") + raise Vulnerability.DoesNotExist( + "No Vulnerabilities found with the provided title" + ) vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_title) if vulnerability_filters.domain: @@ -139,7 +145,9 @@ def filter_vulnerabilities( domain=vulnerability_filters.domain ) if not vulnerabilities_by_domain.exists(): - raise Http404("No Vulnerabilities found with the provided domain") + raise Vulnerability.DoesNotExist( + "No Vulnerabilities found with the provided domain" + ) vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_domain) if vulnerability_filters.severity: @@ -147,7 +155,7 @@ def filter_vulnerabilities( severity=vulnerability_filters.severity ) if not vulnerabilities_by_severity.exists(): - raise Http404( + raise Vulnerability.DoesNotExist( "No Vulnerabilities found with the provided severity level" ) vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_severity) @@ -157,7 +165,9 @@ def filter_vulnerabilities( cpe=vulnerability_filters.cpe ) if not vulnerabilities_by_cpe.exists(): - raise Http404("No Vulnerabilities found with the provided Cpe") + raise Vulnerability.DoesNotExist( + "No Vulnerabilities found with the provided Cpe" + ) vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_cpe) if vulnerability_filters.state: @@ -165,7 +175,9 @@ def filter_vulnerabilities( state=vulnerability_filters.state ) if not vulnerabilities_by_state.exists(): - raise Http404("No Vulnerabilities found with the provided state") + raise Vulnerability.DoesNotExist( + "No Vulnerabilities found with the provided state" + ) vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_state) if vulnerability_filters.organization: @@ -174,7 +186,7 @@ def filter_vulnerabilities( organization_id=vulnerability_filters.organization ) if not domains_by_organization.exists(): - raise Http404( + raise Vulnerability.DoesNotExist( "No Organization-Domain found with the provided organization ID" ) domains = domains.filter(id__in=domains_by_organization) @@ -182,7 +194,7 @@ def filter_vulnerabilities( id__in=domains ) if not vulnerabilities_by_domain.exists(): - raise Http404( + raise Vulnerability.DoesNotExist( "No Vulnerabilities found with the provided organization ID" ) vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_domain) @@ -192,8 +204,12 @@ def filter_vulnerabilities( isKev=vulnerability_filters.isKev ) if not vulnerabilities_by_is_kev.exists(): - raise Http404("No Vulnerabilities found with the provided isKev value") + raise Vulnerability.DoesNotExist( + "No Vulnerabilities found with the provided isKev value" + ) vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_is_kev) return vulnerabilities + except Domain.DoesNotExist as e: + raise e except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/src/xfd_django/xfd_api/tests/test_vulnerability.py b/backend/src/xfd_django/xfd_api/tests/test_vulnerability.py new file mode 100644 index 00000000..d7226d9a --- /dev/null +++ b/backend/src/xfd_django/xfd_api/tests/test_vulnerability.py @@ -0,0 +1,18 @@ +# Standard Python Libraries +from datetime import datetime +import logging +import secrets + +# Configure logging +logging.basicConfig(level=logging.DEBUG) # Set the logging level to DEBUG +logger = logging.getLogger(__name__) + + +# Third-Party Libraries +from fastapi.testclient import TestClient +import pytest +from xfd_api.auth import create_jwt_token +from xfd_api.models import User, UserType +from xfd_django.asgi import app + +client = TestClient(app) From c09850a20bfffe8eea2bb1208bccef5ba75d03ef Mon Sep 17 00:00:00 2001 From: JCantu248 Date: Fri, 15 Nov 2024 10:56:01 -0600 Subject: [PATCH 4/9] Vulnerabilities test cases work, refactor for improved clarity on filter_helpers.filter_vulnerabilities() --- .../xfd_api/helpers/filter_helpers.py | 136 +++---- .../xfd_django/xfd_api/tests/test_domain.py | 40 +- .../xfd_api/tests/test_vulnerabilities.py | 341 ++++++++++++++++++ .../xfd_api/tests/test_vulnerability.py | 18 - backend/src/xfd_django/xfd_api/views.py | 2 +- 5 files changed, 406 insertions(+), 131 deletions(-) create mode 100644 backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py delete mode 100644 backend/src/xfd_django/xfd_api/tests/test_vulnerability.py diff --git a/backend/src/xfd_django/xfd_api/helpers/filter_helpers.py b/backend/src/xfd_django/xfd_api/helpers/filter_helpers.py index 39c03652..02bb4897 100644 --- a/backend/src/xfd_django/xfd_api/helpers/filter_helpers.py +++ b/backend/src/xfd_django/xfd_api/helpers/filter_helpers.py @@ -112,104 +112,60 @@ def filter_vulnerabilities( vulnerabilities: QuerySet, vulnerability_filters: VulnerabilityFilters ): """ - Filter vulnerabilitie + Filter vulnerabilities based on given filters. + Arguments: - vulnerabilities: A list of all vulnerabilities, sorted - vulnerability_filters: Value to filter the vulnberabilities table by + vulnerabilities: A list of all vulnerabilities, sorted. + vulnerability_filters: Value to filter the vulnerabilities table by. + Returns: - object: a list of Vulnerability objects + QuerySet: A filtered list of Vulnerability objects. """ - try: - if vulnerability_filters.id: - vulnerability_by_id = Vulnerability.objects.values("id").get( - id=vulnerability_filters.id - ) - if not vulnerability_by_id: - raise Vulnerability.DoesNotExist( - "No Vulnerabilities found with the provided id" - ) - vulnerabilities = vulnerabilities.filter(id=vulnerability_by_id) + # Initialize a query that includes all vulnerabilities + query = vulnerabilities - if vulnerability_filters.title: - vulnerabilities_by_title = Vulnerability.objects.values("id").filter( - title=vulnerability_filters.title - ) - if not vulnerabilities_by_title.exists(): - raise Vulnerability.DoesNotExist( - "No Vulnerabilities found with the provided title" - ) - vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_title) + # Apply filters based on the provided criteria + if vulnerability_filters.id: + query = query.filter(id=vulnerability_filters.id) - if vulnerability_filters.domain: - vulnerabilities_by_domain = Vulnerability.objects.values("id").filter( - domain=vulnerability_filters.domain - ) - if not vulnerabilities_by_domain.exists(): - raise Vulnerability.DoesNotExist( - "No Vulnerabilities found with the provided domain" - ) - vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_domain) + if vulnerability_filters.title: + query = query.filter(title=vulnerability_filters.title) - if vulnerability_filters.severity: - vulnerabilities_by_severity = Vulnerability.objects.values("id").filter( - severity=vulnerability_filters.severity - ) - if not vulnerabilities_by_severity.exists(): - raise Vulnerability.DoesNotExist( - "No Vulnerabilities found with the provided severity level" - ) - vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_severity) + if vulnerability_filters.domain: + query = query.filter(domain=vulnerability_filters.domain) - if vulnerability_filters.cpe: - vulnerabilities_by_cpe = Vulnerability.objects.values("id").filter( - cpe=vulnerability_filters.cpe - ) - if not vulnerabilities_by_cpe.exists(): - raise Vulnerability.DoesNotExist( - "No Vulnerabilities found with the provided Cpe" - ) - vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_cpe) + if vulnerability_filters.severity: + query = query.filter(severity=vulnerability_filters.severity) - if vulnerability_filters.state: - vulnerabilities_by_state = Vulnerability.objects.values("id").filter( - state=vulnerability_filters.state - ) - if not vulnerabilities_by_state.exists(): - raise Vulnerability.DoesNotExist( - "No Vulnerabilities found with the provided state" - ) - vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_state) + if vulnerability_filters.cpe: + query = query.filter(cpe=vulnerability_filters.cpe) - if vulnerability_filters.organization: - domains = Domain.objects.all() - domains_by_organization = Domain.objects.values("id").filter( - organization_id=vulnerability_filters.organization - ) - if not domains_by_organization.exists(): - raise Vulnerability.DoesNotExist( - "No Organization-Domain found with the provided organization ID" - ) - domains = domains.filter(id__in=domains_by_organization) - vulnerabilities_by_domain = Vulnerability.objects.values("id").filter( - id__in=domains - ) - if not vulnerabilities_by_domain.exists(): - raise Vulnerability.DoesNotExist( - "No Vulnerabilities found with the provided organization ID" - ) - vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_domain) + if vulnerability_filters.state: + query = query.filter(state=vulnerability_filters.state) - if vulnerability_filters.isKev: - vulnerabilities_by_is_kev = Vulnerability.objects.values("id").filter( - isKev=vulnerability_filters.isKev + if vulnerability_filters.organization: + # Fetch domains based on the organization ID + domains_by_organization = Domain.objects.filter( + organization_id=vulnerability_filters.organization + ) + + if not domains_by_organization.exists(): + raise Vulnerability.DoesNotExist( + "No Organization-Domain found with the provided organization ID" ) - if not vulnerabilities_by_is_kev.exists(): - raise Vulnerability.DoesNotExist( - "No Vulnerabilities found with the provided isKev value" - ) - vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_is_kev) - return vulnerabilities - except Domain.DoesNotExist as e: - raise e - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + + # Filter vulnerabilities based on the found domains + query = query.filter(domain__in=domains_by_organization) + + if ( + vulnerability_filters.isKev is not None + ): # Check for None to distinguish between True/False + query = query.filter(isKev=vulnerability_filters.isKev) + + # If the queryset is empty, raise a not found exception (404) + if not query.exists(): + raise Vulnerability.DoesNotExist( + "No Vulnerabilities found with the provided filters." + ) + + return query diff --git a/backend/src/xfd_django/xfd_api/tests/test_domain.py b/backend/src/xfd_django/xfd_api/tests/test_domain.py index da6cf9ca..30955d14 100644 --- a/backend/src/xfd_django/xfd_api/tests/test_domain.py +++ b/backend/src/xfd_django/xfd_api/tests/test_domain.py @@ -1,13 +1,7 @@ # Standard Python Libraries from datetime import datetime -import logging import secrets -# Configure logging -logging.basicConfig(level=logging.DEBUG) # Set the logging level to DEBUG -logger = logging.getLogger(__name__) - - # Third-Party Libraries from fastapi.testclient import TestClient import pytest @@ -33,6 +27,7 @@ @pytest.mark.django_db(transaction=True) def test_get_domain_by_id(): + # Get domain by Id. user = User.objects.create( firstName="", lastName="", @@ -49,11 +44,12 @@ def test_get_domain_by_id(): data = response.json() assert response.status_code == 200 + assert data["id"] == test_id @pytest.mark.django_db(transaction=True) -def test_filter_domain_by_ip(capfd): - # Filter domains by ip +def test_search_domain_by_ip(): + # Search domains by ip user = User.objects.create( firstName="", lastName="", @@ -75,8 +71,8 @@ def test_filter_domain_by_ip(capfd): @pytest.mark.django_db(transaction=True) -def test_filter_domain_by_port(): - # Test filter domains by port +def test_search_domain_by_port(): + # Test search domains by port user = User.objects.create( firstName="", lastName="", @@ -98,8 +94,8 @@ def test_filter_domain_by_port(): @pytest.mark.django_db(transaction=True) -def test_filter_domain_by_service(): - # Test filter domains by service_id +def test_search_domain_by_service(): + # Test search domains by service_id user = User.objects.create( firstName="", lastName="", @@ -122,8 +118,8 @@ def test_filter_domain_by_service(): @pytest.mark.django_db(transaction=True) -def test_filter_domain_by_organization(): - # Test filter domains by organization +def test_search_domain_by_organization(): + # Test search domains by organization user = User.objects.create( firstName="", lastName="", @@ -148,8 +144,8 @@ def test_filter_domain_by_organization(): @pytest.mark.django_db(transaction=True) -def test_filter_domain_by_organization_name(): - # Test filter domains by organization +def test_search_domain_by_organization_name(): + # Test search domains by organization user = User.objects.create( firstName="", lastName="", @@ -176,8 +172,8 @@ def test_filter_domain_by_organization_name(): @pytest.mark.django_db(transaction=True) -def test_filter_domain_by_vulnerabilities(): - # Test filter domains by vulnerabilities +def test_search_domain_by_vulnerabilities(): + # Test search domains by vulnerabilities user = User.objects.create( firstName="", lastName="", @@ -204,8 +200,8 @@ def test_filter_domain_by_vulnerabilities(): @pytest.mark.django_db(transaction=True) -def test_filter_domains_multiple_criteria(): - # Test filter domains by multiple criteria +def test_search_domains_multiple_criteria(): + # Test search domains by multiple criteria user = User.objects.create( firstName="", lastName="", @@ -231,8 +227,8 @@ def test_filter_domains_multiple_criteria(): @pytest.mark.django_db(transaction=True) -def test_filter_domains_does_not_exist(): - # Test filter domains if record does not exist +def test_search_domains_does_not_exist(): + # Test search domains if record does not exist user = User.objects.create( firstName="", lastName="", diff --git a/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py b/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py new file mode 100644 index 00000000..959d349e --- /dev/null +++ b/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py @@ -0,0 +1,341 @@ +# Standard Python Libraries +from datetime import datetime +import logging +import secrets + +# Configure logging +logging.basicConfig(level=logging.DEBUG) # Set the logging level to DEBUG +logger = logging.getLogger(__name__) + + +# Third-Party Libraries +from fastapi.testclient import TestClient +import pytest +from xfd_api.auth import create_jwt_token +from xfd_api.models import User, UserType +from xfd_django.asgi import app + +client = TestClient(app) + +test_id = "c0effe93-3647-475a-a0c5-0b629c348588" +filters = { + "id": "d39a8536-0b64-45b6-b621-5d954329221c", + "title": "DNS Twist Domains", + "cpe": "cpe:/a:openbsd:openssh:7.4", + "severity": "Low", + "domain": "84313a29-0009-45dc-8a2d-1ff7e0ba0030", + "state": "open", + "substate": "unconfirmed", + "organization": "fff159cb-efc8-4ea8-be51-e6b65e38d3e9", + "isKev": False, +} + + +@pytest.mark.django_db(transaction=True) +def test_get_vulnerability_by_id(): + # Get vulnerability by Id. + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.get( + f"/vulnerabilities/{test_id}", + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + data = response.json() + + assert response.status_code == 200 + assert data["id"] == test_id + + +@pytest.mark.django_db(transaction=True) +def test_search_vulnerabilities_id(): + # Search vulnerabilities by ip. + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/vulnerabilities/search", + json={"page": 1, "filters": {"id": filters["id"]}, "pageSize": 25}, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + + assert response.status_code == 200 + data = response.json() + assert data is not None + for vulnerability in data: + assert vulnerability["id"] == filters["id"] + + +@pytest.mark.django_db(transaction=True) +def test_search_vulnerabilities_by_title(): + # Test search vulnerabilities by title + + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/vulnerabilities/search", + json={"page": 1, "filters": {"title": filters["title"]}, "pageSize": 25}, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + + assert response.status_code == 200 + data = response.json() + assert data is not None + for vulnerability in data: + assert vulnerability["title"] == filters["title"] + + +@pytest.mark.django_db(transaction=True) +def test_search_vulnerabilities_by_cpe(): + # Test search vulnerabilities by cpe + + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/vulnerabilities/search", + json={"page": 1, "filters": {"cpe": filters["cpe"]}, "pageSize": 25}, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + + assert response.status_code == 200 + data = response.json() + assert data is not None + for vulnerability in data: + assert vulnerability["cpe"] == filters["cpe"] + + +@pytest.mark.django_db(transaction=True) +def test_search_vulnerabilities_by_severity(): + # Test search vulnerabilities by severity + + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/vulnerabilities/search", + json={"page": 1, "filters": {"severity": filters["severity"]}, "pageSize": 25}, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + + assert response.status_code == 200 + data = response.json() + assert data is not None + for vulnerability in data: + assert vulnerability["severity"] == filters["severity"] + + +@pytest.mark.django_db(transaction=True) +def test_search_vulnerabilities_by_domain_id(): + # Test search vulnerabilities by domain id + + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/vulnerabilities/search", + json={"page": 1, "filters": {"domain": filters["domain"]}, "pageSize": 25}, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + + assert response.status_code == 200 + data = response.json() + assert data is not None + for vulnerability in data: + assert vulnerability["domain_id"] == filters["domain"] + + +@pytest.mark.django_db(transaction=True) +def test_search_vulnerabilities_by_state(): + # Test search vulnerabilities by state + + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/vulnerabilities/search", + json={"page": 1, "filters": {"state": filters["state"]}, "pageSize": 25}, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + + assert response.status_code == 200 + data = response.json() + assert data is not None + for vulnerability in data: + assert vulnerability["state"] == filters["state"] + + +@pytest.mark.django_db(transaction=True) +def test_search_vulnerabilities_by_substate(): + # Test search vulnerabilities by state + + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/vulnerabilities/search", + json={"page": 1, "filters": {"substate": filters["substate"]}, "pageSize": 25}, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + + assert response.status_code == 200 + data = response.json() + assert data is not None + for vulnerability in data: + assert vulnerability["substate"] == filters["substate"] + + +@pytest.mark.django_db(transaction=True) +def test_search_vulnerabilities_by_organization_id(): + # Test search vulnerabilities by state + + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/vulnerabilities/search", + json={ + "page": 1, + "filters": {"organization": filters["organization"]}, + "pageSize": 25, + }, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + + assert response.status_code == 200 + data = response.json() + assert data is not None + + +@pytest.mark.django_db(transaction=True) +def test_search_vulnerabilities_by_is_kev(): + # Test search vulnerabilities by state + + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/vulnerabilities/search", + json={"page": 1, "filters": {"isKev": filters["isKev"]}, "pageSize": 25}, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + + assert response.status_code == 200 + data = response.json() + assert data is not None + for vulnerability in data: + assert vulnerability["isKev"] == filters["isKev"] + + +@pytest.mark.django_db(transaction=True) +def test_search_vulnerabilities_by_multiple_criteria(): + # Test search vulnerabilities by state + + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/vulnerabilities/search", + json={ + "page": 1, + "filters": {"state": filters["state"], "substate": filters["substate"]}, + "pageSize": 25, + }, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + + assert response.status_code == 200 + data = response.json() + assert data is not None + for vulnerability in data: + assert vulnerability["state"] == filters["state"] + assert vulnerability["substate"] == filters["substate"] + + +@pytest.mark.django_db(transaction=True) +def test_search_vulnerabilities_does_not_exist(): + # Test search vulnerabilities by state + + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.post( + "/vulnerabilities/search", + json={"page": 1, "filters": {"title": "Does Not Exist"}, "pageSize": 25}, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + + assert response.status_code == 404 diff --git a/backend/src/xfd_django/xfd_api/tests/test_vulnerability.py b/backend/src/xfd_django/xfd_api/tests/test_vulnerability.py deleted file mode 100644 index d7226d9a..00000000 --- a/backend/src/xfd_django/xfd_api/tests/test_vulnerability.py +++ /dev/null @@ -1,18 +0,0 @@ -# Standard Python Libraries -from datetime import datetime -import logging -import secrets - -# Configure logging -logging.basicConfig(level=logging.DEBUG) # Set the logging level to DEBUG -logger = logging.getLogger(__name__) - - -# Third-Party Libraries -from fastapi.testclient import TestClient -import pytest -from xfd_api.auth import create_jwt_token -from xfd_api.models import User, UserType -from xfd_django.asgi import app - -client = TestClient(app) diff --git a/backend/src/xfd_django/xfd_api/views.py b/backend/src/xfd_django/xfd_api/views.py index 18f5de41..6f66ec36 100644 --- a/backend/src/xfd_django/xfd_api/views.py +++ b/backend/src/xfd_django/xfd_api/views.py @@ -229,7 +229,7 @@ async def export_vulnerabilities(): @api_router.get( - "/vulnerabilities/{vulnerabilityId}", + "/vulnerabilities/{vuln_id}", dependencies=[Depends(get_current_active_user)], response_model=VulnerabilitySchema, tags=["Get vulnerability by id"], From 32400cd1225c9958bd62c52f0a07bd868721f7f5 Mon Sep 17 00:00:00 2001 From: JCantu248 Date: Fri, 15 Nov 2024 11:10:54 -0600 Subject: [PATCH 5/9] Add permissions check back in now that authentication works. --- backend/src/xfd_django/xfd_api/api_methods/domain.py | 8 ++++++++ .../src/xfd_django/xfd_api/api_methods/vulnerability.py | 9 +++++++++ 2 files changed, 17 insertions(+) diff --git a/backend/src/xfd_django/xfd_api/api_methods/domain.py b/backend/src/xfd_django/xfd_api/api_methods/domain.py index 40b28fde..68f35f4f 100644 --- a/backend/src/xfd_django/xfd_api/api_methods/domain.py +++ b/backend/src/xfd_django/xfd_api/api_methods/domain.py @@ -47,6 +47,14 @@ def search_domains(domain_search: DomainSearch, current_user): sort_direction(domain_search.sort, domain_search.order) ) + # Apply global filters based on user permissions + if not is_global_view_admin(current_user): + orgs = get_org_memberships(current_user) + if not orgs: + # No organization memberships, return empty result + return [], 0 + domains = domains.filter(organization__id__in=orgs) + # Add a filter to restrict based on FCEB and CIDR criteria domains = domains.filter(Q(isFceb=True) | Q(isFceb=False, fromCidr=True)) diff --git a/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py b/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py index 49dee9c7..2a3f5c74 100644 --- a/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py +++ b/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py @@ -60,6 +60,15 @@ def search_vulnerabilities(vulnerability_search: VulnerabilitySearch, current_us sort_direction(vulnerability_search.sort, vulnerability_search.order) ) + # Permissions check + if not is_global_view_admin(current_user): + org_ids = get_org_memberships(current_user) + if not org_ids: + return [], 0 # User has no accessible organizations + vulnerabilities = vulnerabilities.filter( + domain__organization_id__in=org_ids + ) + # Apply custom FCEB and CIDR filter vulnerabilities = vulnerabilities.filter( Q(domain__isFceb=True) | Q(domain__isFceb=False, domain__fromCidr=True) From c4b30d4445f809b1600902ecae44d2c5cac03557 Mon Sep 17 00:00:00 2001 From: JCantu248 Date: Fri, 15 Nov 2024 17:01:14 -0600 Subject: [PATCH 6/9] Add test case for get vulnerability when no record is found, resulting in a 404. --- .../xfd_api/tests/test_vulnerabilities.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py b/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py index 959d349e..69509c6d 100644 --- a/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py +++ b/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py @@ -18,6 +18,7 @@ client = TestClient(app) test_id = "c0effe93-3647-475a-a0c5-0b629c348588" +bad_id = "c0effe93-3647-475a-a0c5-0b629c348590" filters = { "id": "d39a8536-0b64-45b6-b621-5d954329221c", "title": "DNS Twist Domains", @@ -53,6 +54,27 @@ def test_get_vulnerability_by_id(): assert data["id"] == test_id +@pytest.mark.django_db(transaction=True) +def test_get_vulnerability_by_id_not_found(): + # Get error 404 if vulnerability does not exist + + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.get( + f"/vulnerabilities/{bad_id}", + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + + assert response.status_code == 404 + + @pytest.mark.django_db(transaction=True) def test_search_vulnerabilities_id(): # Search vulnerabilities by ip. From e122508a35c35bec083739b850c3ad11e7fcafd7 Mon Sep 17 00:00:00 2001 From: JCantu248 Date: Mon, 18 Nov 2024 09:50:15 -0600 Subject: [PATCH 7/9] Fixed vulnerability_update test cases, fix vulnerability_update method. --- .../xfd_api/api_methods/vulnerability.py | 88 +++++++- .../xfd_api/schema_models/vulnerability.py | 4 +- .../xfd_api/tests/test_vulnerabilities.py | 197 +++++++++++++++++- backend/src/xfd_django/xfd_api/views.py | 18 +- 4 files changed, 287 insertions(+), 20 deletions(-) diff --git a/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py b/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py index 2a3f5c74..c7909b4c 100644 --- a/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py +++ b/backend/src/xfd_django/xfd_api/api_methods/vulnerability.py @@ -2,27 +2,39 @@ Vulnerability API. """ +# Standard Python Libraries +import uuid # Third-Party Libraries from django.core.paginator import Paginator from django.db.models import Q +from django.shortcuts import get_object_or_404 from fastapi import HTTPException from ..auth import get_org_memberships, is_global_view_admin from ..helpers.filter_helpers import filter_vulnerabilities, sort_direction -from ..models import Vulnerability +from ..models import Domain, Service, Vulnerability from ..schema_models.vulnerability import Vulnerability as VulnerabilitySchema from ..schema_models.vulnerability import VulnerabilityFilters, VulnerabilitySearch -def get_vulnerability_by_id(vuln_id): +def is_valid_uuid(val: str) -> bool: + """Check if the given string is a valid UUID.""" + try: + uuid_obj = uuid.UUID(val, version=4) + except ValueError: + return False + return str(uuid_obj) == val + + +def get_vulnerability_by_id(vulnerability_id): """ Get vulnerability by id. Returns: object: a single Vulnerability object. """ try: - vulnerability = Vulnerability.objects.get(id=vuln_id) + vulnerability = Vulnerability.objects.get(id=vulnerability_id) return vulnerability except Vulnerability.DoesNotExist: raise HTTPException(status_code=404, detail="Vulnerability not found.") @@ -30,20 +42,80 @@ def get_vulnerability_by_id(vuln_id): raise HTTPException(status_code=500, detail=str(e)) -def update_vulnerability(vuln_id, data: VulnerabilitySchema): +def update_vulnerability( + vulnerability_id, vulnerability_data: VulnerabilitySchema, current_user +): """ Update vulnerability by id. + Args: + vulnerability_id (UUID): The ID of the vulnerability to update. + vulnerability_data (VulnerabilitySchema): The data to update the vulnerability with. + current_user: The user performing the update (not used in this snippet, but can be used for auditing). + Returns: - object: a single vulnerability object that has been modified. + Vulnerability: The updated vulnerability object. + + Raises: + HTTPException: If the vulnerability is not found or if a server error occurs. """ try: - vulnerability = Vulnerability.objects.get(id=vuln_id) - vulnerability = data + # Validate the vulnerability ID + if not is_valid_uuid(vulnerability_id): + raise HTTPException(status_code=404, detail="Vulnerability not found") + + # Fetch the existing vulnerability + vulnerability = Vulnerability.objects.get(id=vulnerability_id) + + # Create a mapping of fields to update + fields_to_update = { + "title": vulnerability_data.title, + "cve": vulnerability_data.cve, + "cwe": vulnerability_data.cwe, + "cpe": vulnerability_data.cpe, + "description": vulnerability_data.description, + "references": vulnerability_data.references, + "cvss": vulnerability_data.cvss, + "severity": vulnerability_data.severity, + "needsPopulation": vulnerability_data.needsPopulation, + "state": vulnerability_data.state, + "substate": vulnerability_data.substate, + "source": vulnerability_data.source, + "notes": vulnerability_data.notes, + "actions": vulnerability_data.actions, + "structuredData": vulnerability_data.structuredData, + "isKev": vulnerability_data.isKev, + "domain": vulnerability_data.domain_id, + "service": vulnerability_data.service_id, + } + + # Update fields that are not None + for field, value in fields_to_update.items(): + if value is not None: + if field == "domain": + # Handle domain ID to fetch the Domain instance + domain_instance = get_object_or_404( + Domain, id=vulnerability_data.domain_id + ) + vulnerability.domain = domain_instance + elif field == "service": + # Handle service ID to fetch the Service instance + service_instance = get_object_or_404( + Service, id=vulnerability_data.service_id + ) + vulnerability.service = service_instance + else: + setattr(vulnerability, field, value) + + # Save the updated vulnerability object vulnerability.save() + return vulnerability + + except Vulnerability.DoesNotExist: + raise HTTPException(status_code=404, detail="Vulnerability not found.") except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail="Internal Server Error") def search_vulnerabilities(vulnerability_search: VulnerabilitySearch, current_user): diff --git a/backend/src/xfd_django/xfd_api/schema_models/vulnerability.py b/backend/src/xfd_django/xfd_api/schema_models/vulnerability.py index 0a6719f7..1eab67c8 100644 --- a/backend/src/xfd_django/xfd_api/schema_models/vulnerability.py +++ b/backend/src/xfd_django/xfd_api/schema_models/vulnerability.py @@ -33,8 +33,8 @@ class Vulnerability(BaseModel): actions: Optional[Any] structuredData: Optional[Any] isKev: bool - domain_id: UUID - service_id: UUID + domain_id: Optional[UUID] + service_id: Optional[UUID] class Config: from_attributes = True diff --git a/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py b/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py index 69509c6d..d5cdaa2a 100644 --- a/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py +++ b/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py @@ -12,7 +12,7 @@ from fastapi.testclient import TestClient import pytest from xfd_api.auth import create_jwt_token -from xfd_api.models import User, UserType +from xfd_api.models import User, UserType, Vulnerability from xfd_django.asgi import app client = TestClient(app) @@ -55,9 +55,8 @@ def test_get_vulnerability_by_id(): @pytest.mark.django_db(transaction=True) -def test_get_vulnerability_by_id_not_found(): +def test_get_vulnerability_by_id_fails_404(): # Get error 404 if vulnerability does not exist - user = User.objects.create( firstName="", lastName="", @@ -75,6 +74,198 @@ def test_get_vulnerability_by_id_not_found(): assert response.status_code == 404 +@pytest.mark.django_db(transaction=True) +def test_update_vulnerability(): + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + vulnerability = Vulnerability.objects.create( + title="Old Vulnerability", + description="Old description.", + severity="Medium", + cvss=5.0, + createdAt=datetime.now(), + updatedAt=datetime.now(), + needsPopulation=True, + source="source1", + notes="old notes", + actions=[], + structuredData={}, + isKev=False, + kevResults={}, + domain_id="", + service_id="", + ) + + new_data = { + "id": str(vulnerability.id), + "createdAt": str(vulnerability.createdAt), + "updatedAt": str(datetime.now()), + "lastSeen": str(datetime.now()), + "title": "Updated Vulnerability", + "cve": vulnerability.cve, + "cwe": vulnerability.cwe, + "cpe": vulnerability.cpe, + "description": "Updated description.", + "references": None, + "severity": "High", + "cvss": 7.5, + "needsPopulation": False, + "state": vulnerability.state, + "substate": vulnerability.substate, + "source": "source2", + "notes": "updated notes", + "actions": ["action1"], + "structuredData": {"key": "value"}, + "isKev": True, + "domain_id": None, + "service_id": None, + } + + response = client.put( + f"/vulnerabilities/{vulnerability.id}", + json=new_data, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + + assert response.status_code == 200 + + vulnerability.refresh_from_db() + assert vulnerability.title == new_data["title"] + assert vulnerability.description == new_data["description"] + assert vulnerability.needsPopulation == new_data["needsPopulation"] + assert vulnerability.source == new_data["source"] + assert vulnerability.notes == new_data["notes"] + + assert vulnerability.id == vulnerability.id + + +@pytest.mark.django_db(transaction=True) +def test_update_vulnerability_fails_404(): + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + vulnerability = Vulnerability.objects.create( + title="Old Vulnerability", + description="Old description.", + severity="Medium", + cvss=5.0, + createdAt=datetime.now(), + updatedAt=datetime.now(), + needsPopulation=True, + source="source1", + notes="old notes", + actions=[], + structuredData={}, + isKev=False, + kevResults={}, + domain_id="", + service_id="", + ) + + new_data = { + "id": str(vulnerability.id), + "createdAt": str(vulnerability.createdAt), + "updatedAt": str(datetime.now()), + "lastSeen": str(datetime.now()), + "title": "Updated Vulnerability", + "cve": vulnerability.cve, + "cwe": vulnerability.cwe, + "cpe": vulnerability.cpe, + "description": "Updated description.", + "references": None, + "severity": "High", + "cvss": 7.5, + "needsPopulation": False, + "state": vulnerability.state, + "substate": vulnerability.substate, + "source": "source2", + "notes": "updated notes", + "actions": ["action1"], + "structuredData": {"key": "value"}, + "isKev": True, + "domain_id": None, + "service_id": None, + } + + response = client.put( + f"/vulnerabilities/{bad_id}", + json=new_data, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + assert response.status_code == 404 + + +@pytest.mark.django_db(transaction=True) +def test_update_vulnerability_fails_422(): + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + vulnerability = Vulnerability.objects.create( + title="Old Vulnerability", + description="Old description.", + severity="Medium", + cvss=5.0, + createdAt=datetime.now(), + updatedAt=datetime.now(), + needsPopulation=True, + source="source1", + notes="old notes", + actions=[], + structuredData={}, + isKev=False, + kevResults={}, + domain_id="", + service_id="", + ) + + new_data = { + "title": "Updated Vulnerability", + "cve": vulnerability.cve, + "cwe": vulnerability.cwe, + "cpe": vulnerability.cpe, + "description": "Updated description.", + "references": None, + "severity": "High", + "cvss": 7.5, + "needsPopulation": False, + "state": vulnerability.state, + "substate": vulnerability.substate, + "source": "source2", + "notes": "updated notes", + "actions": ["action1"], + "structuredData": {"key": "value"}, + "isKev": True, + "domain_id": None, + "service_id": None, + } + + response = client.put( + f"/vulnerabilities/{vulnerability.id}", + json=new_data, + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + assert response.status_code == 422 + + @pytest.mark.django_db(transaction=True) def test_search_vulnerabilities_id(): # Search vulnerabilities by ip. diff --git a/backend/src/xfd_django/xfd_api/views.py b/backend/src/xfd_django/xfd_api/views.py index 6f66ec36..eaec5acb 100644 --- a/backend/src/xfd_django/xfd_api/views.py +++ b/backend/src/xfd_django/xfd_api/views.py @@ -229,34 +229,38 @@ async def export_vulnerabilities(): @api_router.get( - "/vulnerabilities/{vuln_id}", + "/vulnerabilities/{vulnerability_id}", dependencies=[Depends(get_current_active_user)], response_model=VulnerabilitySchema, tags=["Get vulnerability by id"], ) -async def call_get_vulnerability_by_id(vuln_id: str): +async def call_get_vulnerability_by_id(vulnerability_id: str): """ Get vulnerability by id. Returns: object: a single Vulnerability object. """ - return get_vulnerability_by_id(vuln_id) + return get_vulnerability_by_id(vulnerability_id) @api_router.put( - "/vulnerabilities/{vulnerabilityId}", - # dependencies=[Depends(get_current_active_user)], + "/vulnerabilities/{vulnerability_id}", + dependencies=[Depends(get_current_active_user)], response_model=VulnerabilitySchema, tags="Update vulnerability", ) -async def call_update_vulnerability(vuln_id, data: VulnerabilitySchema): +async def call_update_vulnerability( + vulnerability_id, + data: VulnerabilitySchema, + current_user: User = Depends(get_current_active_user), +): """ Update vulnerability by id. Returns: object: a single vulnerability object that has been modified. """ - return update_vulnerability(vuln_id, data) + return update_vulnerability(vulnerability_id, data, current_user) # ======================================== From e0ae921eca54e5ca267549c8b673a85d5b9d82fb Mon Sep 17 00:00:00 2001 From: JCantu248 Date: Tue, 19 Nov 2024 09:52:18 -0600 Subject: [PATCH 8/9] Add test case for GET domain not found ends in a 404. --- .../xfd_django/xfd_api/tests/test_domain.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/backend/src/xfd_django/xfd_api/tests/test_domain.py b/backend/src/xfd_django/xfd_api/tests/test_domain.py index 30955d14..ff775580 100644 --- a/backend/src/xfd_django/xfd_api/tests/test_domain.py +++ b/backend/src/xfd_django/xfd_api/tests/test_domain.py @@ -13,6 +13,7 @@ test_id = "960b7db7-f3af-411d-a247-33371739050b" +bad_id = "960b7db7-f3af-411d-a247-333717390999" filters = { "ports": "80", "service": "6d9ecf5a-db5d-4b77-9752-a88a5d247631", @@ -47,6 +48,27 @@ def test_get_domain_by_id(): assert data["id"] == test_id +@pytest.mark.django_db(transaction=True) +def test_get_domain_by_id_fails_404(): + # Get domain by Id. + user = User.objects.create( + firstName="", + lastName="", + email=f"{secrets.token_hex(4)}@example.com", + userType=UserType.GLOBAL_ADMIN, + createdAt=datetime.now(), + updatedAt=datetime.now(), + ) + + response = client.get( + f"/domain/{bad_id}", + headers={"Authorization": "Bearer " + create_jwt_token(user)}, + ) + data = response.json() + + assert response.status_code == 404 + + @pytest.mark.django_db(transaction=True) def test_search_domain_by_ip(): # Search domains by ip From 7794f927b524bf1fc6a8ce327524d69899183979 Mon Sep 17 00:00:00 2001 From: JCantu248 Date: Tue, 19 Nov 2024 14:01:14 -0600 Subject: [PATCH 9/9] Add test fixtures to create users and vulnerabilities to test_domain, test_vulnerabilities. --- .../xfd_django/xfd_api/tests/test_domain.py | 108 ++----- .../xfd_api/tests/test_vulnerabilities.py | 263 ++++-------------- 2 files changed, 66 insertions(+), 305 deletions(-) diff --git a/backend/src/xfd_django/xfd_api/tests/test_domain.py b/backend/src/xfd_django/xfd_api/tests/test_domain.py index ff775580..05b511c2 100644 --- a/backend/src/xfd_django/xfd_api/tests/test_domain.py +++ b/backend/src/xfd_django/xfd_api/tests/test_domain.py @@ -26,9 +26,8 @@ } -@pytest.mark.django_db(transaction=True) -def test_get_domain_by_id(): - # Get domain by Id. +@pytest.fixture +def user(): user = User.objects.create( firstName="", lastName="", @@ -37,7 +36,13 @@ def test_get_domain_by_id(): createdAt=datetime.now(), updatedAt=datetime.now(), ) + yield user + user.delete() # Clean up after the test + +@pytest.mark.django_db(transaction=True) +def test_get_domain_by_id(user): + # Get domain by Id. response = client.get( f"/domain/{test_id}", headers={"Authorization": "Bearer " + create_jwt_token(user)}, @@ -49,17 +54,8 @@ def test_get_domain_by_id(): @pytest.mark.django_db(transaction=True) -def test_get_domain_by_id_fails_404(): +def test_get_domain_by_id_fails_404(user): # Get domain by Id. - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.get( f"/domain/{bad_id}", headers={"Authorization": "Bearer " + create_jwt_token(user)}, @@ -70,17 +66,8 @@ def test_get_domain_by_id_fails_404(): @pytest.mark.django_db(transaction=True) -def test_search_domain_by_ip(): +def test_search_domain_by_ip(user): # Search domains by ip - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/domain/search", json={"page": 1, "filters": {"ip": filters["ip"]}, "pageSize": 25}, @@ -93,16 +80,8 @@ def test_search_domain_by_ip(): @pytest.mark.django_db(transaction=True) -def test_search_domain_by_port(): +def test_search_domain_by_port(user): # Test search domains by port - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) response = client.post( "/domain/search", json={"page": 1, "filters": {"ports": filters["ports"]}, "pageSize": 25}, @@ -116,17 +95,8 @@ def test_search_domain_by_port(): @pytest.mark.django_db(transaction=True) -def test_search_domain_by_service(): +def test_search_domain_by_service(user): # Test search domains by service_id - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/domain/search", json={"page": 1, "filters": {"service": filters["service"]}, "pageSize": 25}, @@ -140,16 +110,8 @@ def test_search_domain_by_service(): @pytest.mark.django_db(transaction=True) -def test_search_domain_by_organization(): +def test_search_domain_by_organization(user): # Test search domains by organization - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) response = client.post( "/domain/search", json={ @@ -166,17 +128,8 @@ def test_search_domain_by_organization(): @pytest.mark.django_db(transaction=True) -def test_search_domain_by_organization_name(): +def test_search_domain_by_organization_name(user): # Test search domains by organization - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/domain/search", json={ @@ -194,17 +147,8 @@ def test_search_domain_by_organization_name(): @pytest.mark.django_db(transaction=True) -def test_search_domain_by_vulnerabilities(): +def test_search_domain_by_vulnerabilities(user): # Test search domains by vulnerabilities - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/domain/search", json={ @@ -222,17 +166,8 @@ def test_search_domain_by_vulnerabilities(): @pytest.mark.django_db(transaction=True) -def test_search_domains_multiple_criteria(): +def test_search_domains_multiple_criteria(user): # Test search domains by multiple criteria - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/domain/search", json={ @@ -249,17 +184,8 @@ def test_search_domains_multiple_criteria(): @pytest.mark.django_db(transaction=True) -def test_search_domains_does_not_exist(): +def test_search_domains_does_not_exist(user): # Test search domains if record does not exist - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/domain/search", json={"page": 1, "filters": {"ip": "Does not exist"}, "pageSize": 25}, diff --git a/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py b/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py index d5cdaa2a..c1323566 100644 --- a/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py +++ b/backend/src/xfd_django/xfd_api/tests/test_vulnerabilities.py @@ -32,9 +32,8 @@ } -@pytest.mark.django_db(transaction=True) -def test_get_vulnerability_by_id(): - # Get vulnerability by Id. +@pytest.fixture +def user(): user = User.objects.create( firstName="", lastName="", @@ -43,7 +42,36 @@ def test_get_vulnerability_by_id(): createdAt=datetime.now(), updatedAt=datetime.now(), ) + yield user + user.delete() # Clean up after the test + + +@pytest.fixture +def create_vulnerability(): + vulnerability = Vulnerability.objects.create( + title="Old Vulnerability", + description="Old description.", + severity="Medium", + cvss=5.0, + createdAt=datetime.now(), + updatedAt=datetime.now(), + needsPopulation=True, + source="source1", + notes="old notes", + actions=[], + structuredData={}, + isKev=False, + kevResults={}, + domain_id="", + service_id="", + ) + yield vulnerability + vulnerability.delete() + +@pytest.mark.django_db(transaction=True) +def test_get_vulnerability_by_id(user): + # Get vulnerability by Id. response = client.get( f"/vulnerabilities/{test_id}", headers={"Authorization": "Bearer " + create_jwt_token(user)}, @@ -55,17 +83,8 @@ def test_get_vulnerability_by_id(): @pytest.mark.django_db(transaction=True) -def test_get_vulnerability_by_id_fails_404(): +def test_get_vulnerability_by_id_fails_404(user): # Get error 404 if vulnerability does not exist - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.get( f"/vulnerabilities/{bad_id}", headers={"Authorization": "Bearer " + create_jwt_token(user)}, @@ -75,33 +94,8 @@ def test_get_vulnerability_by_id_fails_404(): @pytest.mark.django_db(transaction=True) -def test_update_vulnerability(): - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - - vulnerability = Vulnerability.objects.create( - title="Old Vulnerability", - description="Old description.", - severity="Medium", - cvss=5.0, - createdAt=datetime.now(), - updatedAt=datetime.now(), - needsPopulation=True, - source="source1", - notes="old notes", - actions=[], - structuredData={}, - isKev=False, - kevResults={}, - domain_id="", - service_id="", - ) +def test_update_vulnerability(user, create_vulnerability): + vulnerability = create_vulnerability new_data = { "id": str(vulnerability.id), @@ -147,34 +141,8 @@ def test_update_vulnerability(): @pytest.mark.django_db(transaction=True) -def test_update_vulnerability_fails_404(): - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - - vulnerability = Vulnerability.objects.create( - title="Old Vulnerability", - description="Old description.", - severity="Medium", - cvss=5.0, - createdAt=datetime.now(), - updatedAt=datetime.now(), - needsPopulation=True, - source="source1", - notes="old notes", - actions=[], - structuredData={}, - isKev=False, - kevResults={}, - domain_id="", - service_id="", - ) - +def test_update_vulnerability_fails_404(user, create_vulnerability): + vulnerability = create_vulnerability new_data = { "id": str(vulnerability.id), "createdAt": str(vulnerability.createdAt), @@ -209,33 +177,8 @@ def test_update_vulnerability_fails_404(): @pytest.mark.django_db(transaction=True) -def test_update_vulnerability_fails_422(): - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - - vulnerability = Vulnerability.objects.create( - title="Old Vulnerability", - description="Old description.", - severity="Medium", - cvss=5.0, - createdAt=datetime.now(), - updatedAt=datetime.now(), - needsPopulation=True, - source="source1", - notes="old notes", - actions=[], - structuredData={}, - isKev=False, - kevResults={}, - domain_id="", - service_id="", - ) +def test_update_vulnerability_fails_422(user, create_vulnerability): + vulnerability = create_vulnerability new_data = { "title": "Updated Vulnerability", @@ -267,17 +210,8 @@ def test_update_vulnerability_fails_422(): @pytest.mark.django_db(transaction=True) -def test_search_vulnerabilities_id(): +def test_search_vulnerabilities_id(user): # Search vulnerabilities by ip. - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/vulnerabilities/search", json={"page": 1, "filters": {"id": filters["id"]}, "pageSize": 25}, @@ -292,18 +226,9 @@ def test_search_vulnerabilities_id(): @pytest.mark.django_db(transaction=True) -def test_search_vulnerabilities_by_title(): +def test_search_vulnerabilities_by_title(user): # Test search vulnerabilities by title - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/vulnerabilities/search", json={"page": 1, "filters": {"title": filters["title"]}, "pageSize": 25}, @@ -318,18 +243,8 @@ def test_search_vulnerabilities_by_title(): @pytest.mark.django_db(transaction=True) -def test_search_vulnerabilities_by_cpe(): +def test_search_vulnerabilities_by_cpe(user): # Test search vulnerabilities by cpe - - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/vulnerabilities/search", json={"page": 1, "filters": {"cpe": filters["cpe"]}, "pageSize": 25}, @@ -344,18 +259,8 @@ def test_search_vulnerabilities_by_cpe(): @pytest.mark.django_db(transaction=True) -def test_search_vulnerabilities_by_severity(): +def test_search_vulnerabilities_by_severity(user): # Test search vulnerabilities by severity - - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/vulnerabilities/search", json={"page": 1, "filters": {"severity": filters["severity"]}, "pageSize": 25}, @@ -370,18 +275,8 @@ def test_search_vulnerabilities_by_severity(): @pytest.mark.django_db(transaction=True) -def test_search_vulnerabilities_by_domain_id(): +def test_search_vulnerabilities_by_domain_id(user): # Test search vulnerabilities by domain id - - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/vulnerabilities/search", json={"page": 1, "filters": {"domain": filters["domain"]}, "pageSize": 25}, @@ -396,18 +291,8 @@ def test_search_vulnerabilities_by_domain_id(): @pytest.mark.django_db(transaction=True) -def test_search_vulnerabilities_by_state(): +def test_search_vulnerabilities_by_state(user): # Test search vulnerabilities by state - - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/vulnerabilities/search", json={"page": 1, "filters": {"state": filters["state"]}, "pageSize": 25}, @@ -422,18 +307,8 @@ def test_search_vulnerabilities_by_state(): @pytest.mark.django_db(transaction=True) -def test_search_vulnerabilities_by_substate(): +def test_search_vulnerabilities_by_substate(user): # Test search vulnerabilities by state - - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/vulnerabilities/search", json={"page": 1, "filters": {"substate": filters["substate"]}, "pageSize": 25}, @@ -448,18 +323,8 @@ def test_search_vulnerabilities_by_substate(): @pytest.mark.django_db(transaction=True) -def test_search_vulnerabilities_by_organization_id(): +def test_search_vulnerabilities_by_organization_id(user): # Test search vulnerabilities by state - - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/vulnerabilities/search", json={ @@ -476,18 +341,8 @@ def test_search_vulnerabilities_by_organization_id(): @pytest.mark.django_db(transaction=True) -def test_search_vulnerabilities_by_is_kev(): +def test_search_vulnerabilities_by_is_kev(user): # Test search vulnerabilities by state - - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/vulnerabilities/search", json={"page": 1, "filters": {"isKev": filters["isKev"]}, "pageSize": 25}, @@ -502,18 +357,8 @@ def test_search_vulnerabilities_by_is_kev(): @pytest.mark.django_db(transaction=True) -def test_search_vulnerabilities_by_multiple_criteria(): +def test_search_vulnerabilities_by_multiple_criteria(user): # Test search vulnerabilities by state - - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/vulnerabilities/search", json={ @@ -533,18 +378,8 @@ def test_search_vulnerabilities_by_multiple_criteria(): @pytest.mark.django_db(transaction=True) -def test_search_vulnerabilities_does_not_exist(): +def test_search_vulnerabilities_does_not_exist(user): # Test search vulnerabilities by state - - user = User.objects.create( - firstName="", - lastName="", - email=f"{secrets.token_hex(4)}@example.com", - userType=UserType.GLOBAL_ADMIN, - createdAt=datetime.now(), - updatedAt=datetime.now(), - ) - response = client.post( "/vulnerabilities/search", json={"page": 1, "filters": {"title": "Does Not Exist"}, "pageSize": 25},