Skip to content

Commit

Permalink
Merge pull request #727 from cisagov/jd-domain-vuln-test-cases
Browse files Browse the repository at this point in the history
Add test cases for Domain and Vulnerability endpoints.
  • Loading branch information
JCantu248 authored Nov 19, 2024
2 parents 0dfc4d8 + e0ae921 commit 87d9391
Show file tree
Hide file tree
Showing 7 changed files with 1,001 additions and 119 deletions.
6 changes: 6 additions & 0 deletions backend/src/xfd_django/xfd_api/api_methods/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import csv

# Third-Party Libraries
from django.core.exceptions import ObjectDoesNotExist
from django.core.paginator import Paginator
from django.db.models import Q
from django.http import Http404
from fastapi import HTTPException

from ..auth import get_org_memberships, is_global_view_admin
Expand Down Expand Up @@ -61,6 +63,8 @@ def search_domains(domain_search: DomainSearch, current_user):
paginator = Paginator(domains, domain_search.pageSize)

return paginator.get_page(domain_search.page)
except Domain.DoesNotExist as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

Expand All @@ -74,5 +78,7 @@ def export_domains(domain_filters: DomainFilters):

# TODO: Integrate methods to generate CSV from queryset and save to S3 bucket
return domains
except Domain.DoesNotExist as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
94 changes: 86 additions & 8 deletions backend/src/xfd_django/xfd_api/api_methods/vulnerability.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,120 @@
Vulnerability API.
"""
# Standard Python Libraries
import uuid

# Third-Party Libraries
from django.core.paginator import Paginator
from django.db.models import Q
from django.shortcuts import get_object_or_404
from fastapi import HTTPException

from ..auth import get_org_memberships, is_global_view_admin
from ..helpers.filter_helpers import filter_vulnerabilities, sort_direction
from ..models import Vulnerability
from ..models import Domain, Service, Vulnerability
from ..schema_models.vulnerability import Vulnerability as VulnerabilitySchema
from ..schema_models.vulnerability import VulnerabilityFilters, VulnerabilitySearch


def get_vulnerability_by_id(vuln_id):
def is_valid_uuid(val: str) -> bool:
"""Check if the given string is a valid UUID."""
try:
uuid_obj = uuid.UUID(val, version=4)
except ValueError:
return False
return str(uuid_obj) == val


def get_vulnerability_by_id(vulnerability_id):
"""
Get vulnerability by id.
Returns:
object: a single Vulnerability object.
"""
try:
vulnerability = Vulnerability.objects.get(id=vuln_id)
vulnerability = Vulnerability.objects.get(id=vulnerability_id)
return vulnerability
except Vulnerability.DoesNotExist:
raise HTTPException(status_code=404, detail="Vulnerability not found.")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


def update_vulnerability(vuln_id, data: VulnerabilitySchema):
def update_vulnerability(
vulnerability_id, vulnerability_data: VulnerabilitySchema, current_user
):
"""
Update vulnerability by id.
Args:
vulnerability_id (UUID): The ID of the vulnerability to update.
vulnerability_data (VulnerabilitySchema): The data to update the vulnerability with.
current_user: The user performing the update (not used in this snippet, but can be used for auditing).
Returns:
object: a single vulnerability object that has been modified.
Vulnerability: The updated vulnerability object.
Raises:
HTTPException: If the vulnerability is not found or if a server error occurs.
"""
try:
vulnerability = Vulnerability.objects.get(id=vuln_id)
vulnerability = data
# Validate the vulnerability ID
if not is_valid_uuid(vulnerability_id):
raise HTTPException(status_code=404, detail="Vulnerability not found")

# Fetch the existing vulnerability
vulnerability = Vulnerability.objects.get(id=vulnerability_id)

# Create a mapping of fields to update
fields_to_update = {
"title": vulnerability_data.title,
"cve": vulnerability_data.cve,
"cwe": vulnerability_data.cwe,
"cpe": vulnerability_data.cpe,
"description": vulnerability_data.description,
"references": vulnerability_data.references,
"cvss": vulnerability_data.cvss,
"severity": vulnerability_data.severity,
"needsPopulation": vulnerability_data.needsPopulation,
"state": vulnerability_data.state,
"substate": vulnerability_data.substate,
"source": vulnerability_data.source,
"notes": vulnerability_data.notes,
"actions": vulnerability_data.actions,
"structuredData": vulnerability_data.structuredData,
"isKev": vulnerability_data.isKev,
"domain": vulnerability_data.domain_id,
"service": vulnerability_data.service_id,
}

# Update fields that are not None
for field, value in fields_to_update.items():
if value is not None:
if field == "domain":
# Handle domain ID to fetch the Domain instance
domain_instance = get_object_or_404(
Domain, id=vulnerability_data.domain_id
)
vulnerability.domain = domain_instance
elif field == "service":
# Handle service ID to fetch the Service instance
service_instance = get_object_or_404(
Service, id=vulnerability_data.service_id
)
vulnerability.service = service_instance
else:
setattr(vulnerability, field, value)

# Save the updated vulnerability object
vulnerability.save()

return vulnerability

except Vulnerability.DoesNotExist:
raise HTTPException(status_code=404, detail="Vulnerability not found.")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail="Internal Server Error")


def search_vulnerabilities(vulnerability_search: VulnerabilitySearch, current_user):
Expand Down Expand Up @@ -84,6 +158,8 @@ def search_vulnerabilities(vulnerability_search: VulnerabilitySearch, current_us

paginator = Paginator(vulnerabilities, vulnerability_search.pageSize)
return paginator.get_page(vulnerability_search.page)
except Vulnerability.DoesNotExist as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

Expand All @@ -99,5 +175,7 @@ def export_vulnerabilities(vulnerability_filters: VulnerabilityFilters):

# TODO: Integrate methods to generate CSV from queryset and save to S3 bucket
return vulnerabilities
except Vulnerability.DoesNotExist as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
154 changes: 67 additions & 87 deletions backend/src/xfd_django/xfd_api/helpers/filter_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,149 +43,129 @@ def filter_domains(domains: QuerySet, domain_filters: DomainFilters):
"domain"
)
if not services_by_port.exists():
raise Http404("No Domains found with the provided port")
raise ObjectDoesNotExist(
"Domain could not be found with provided port."
)
domains = domains.filter(id__in=services_by_port)

if domain_filters.service:
service_by_id = Service.objects.filter(id=domain_filters.service).values(
"domain"
)
if not service_by_id.exists():
raise Http404("No Domains found with the provided service")
raise Domain.DoesNotExist("No Domains found with the provided service")
domains = domains.filter(id__in=service_by_id)

if domain_filters.reverseName:
domains_by_reverse_name = Domain.objects.filter(
reverseName=domain_filters.reverseName
).values("id")
if not domains_by_reverse_name.exists():
raise Http404("No Domains found with the provided reverse name")
raise Domain.DoesNotExist(
"No Domains found with the provided reverse name"
)
domains = domains.filter(id__in=domains_by_reverse_name)

if domain_filters.ip:
domains_by_ip = Domain.objects.filter(ip=domain_filters.ip).values("id")
if not domains_by_ip.exists():
raise Http404("No Domains found with the provided ip")
raise Domain.DoesNotExist("Domain could not be found with provided Ip.")
domains = domains.filter(id__in=domains_by_ip)

if domain_filters.organization:
domains_by_org = Domain.objects.filter(
organization_id=domain_filters.organization
).values("id")
if not domains_by_org.exists():
raise Http404("No Domains found with the provided organization")
raise Domain.DoesNotExist(
"No Domains found with the provided organization"
)
domains = domains.filter(id__in=domains_by_org)

if domain_filters.organizationName:
organization_by_name = Organization.objects.filter(
name=domain_filters.organizationName
).values("id")
if not organization_by_name.exists():
raise Http404("No Domains found with the provided organization name")
raise Domain.DoesNotExist(
"No Domains found with the provided organization name"
)
domains = domains.filter(organization_id__in=organization_by_name)

if domain_filters.vulnerabilities:
vulnerabilities_by_id = Vulnerability.objects.filter(
id=domain_filters.vulnerabilities
).values("domain")
if not vulnerabilities_by_id.exists():
raise Http404("No Domains found with the provided vulnerability")
raise Domain.DoesNotExist(
"No Domains found with the provided vulnerability"
)
domains = domains.filter(id__in=vulnerabilities_by_id)
return domains
except ObjectDoesNotExist:
print("No vulnerability found with that ID.")
except Domain.DoesNotExist as e:
raise e
except Exception as e:
print(f"Error: {e}")
raise HTTPException(status_code=500, detail=str(e))


def filter_vulnerabilities(
vulnerabilities: QuerySet, vulnerability_filters: VulnerabilityFilters
):
"""
Filter vulnerabilitie
Filter vulnerabilities based on given filters.
Arguments:
vulnerabilities: A list of all vulnerabilities, sorted
vulnerability_filters: Value to filter the vulnberabilities table by
vulnerabilities: A list of all vulnerabilities, sorted.
vulnerability_filters: Value to filter the vulnerabilities table by.
Returns:
object: a list of Vulnerability objects
QuerySet: A filtered list of Vulnerability objects.
"""
try:
if vulnerability_filters.id:
vulnerability_by_id = Vulnerability.objects.values("id").get(
id=vulnerability_filters.id
)
if not vulnerability_by_id:
raise Http404("No Vulnerabilities found with the provided id")
vulnerabilities = vulnerabilities.filter(id=vulnerability_by_id)
# Initialize a query that includes all vulnerabilities
query = vulnerabilities

if vulnerability_filters.title:
vulnerabilities_by_title = Vulnerability.objects.values("id").filter(
title=vulnerability_filters.title
)
if not vulnerabilities_by_title.exists():
raise Http404("No Vulnerabilities found with the provided title")
vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_title)
# Apply filters based on the provided criteria
if vulnerability_filters.id:
query = query.filter(id=vulnerability_filters.id)

if vulnerability_filters.domain:
vulnerabilities_by_domain = Vulnerability.objects.values("id").filter(
domain=vulnerability_filters.domain
)
if not vulnerabilities_by_domain.exists():
raise Http404("No Vulnerabilities found with the provided domain")
vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_domain)
if vulnerability_filters.title:
query = query.filter(title=vulnerability_filters.title)

if vulnerability_filters.severity:
vulnerabilities_by_severity = Vulnerability.objects.values("id").filter(
severity=vulnerability_filters.severity
)
if not vulnerabilities_by_severity.exists():
raise Http404(
"No Vulnerabilities found with the provided severity level"
)
vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_severity)
if vulnerability_filters.domain:
query = query.filter(domain=vulnerability_filters.domain)

if vulnerability_filters.cpe:
vulnerabilities_by_cpe = Vulnerability.objects.values("id").filter(
cpe=vulnerability_filters.cpe
)
if not vulnerabilities_by_cpe.exists():
raise Http404("No Vulnerabilities found with the provided Cpe")
vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_cpe)
if vulnerability_filters.severity:
query = query.filter(severity=vulnerability_filters.severity)

if vulnerability_filters.state:
vulnerabilities_by_state = Vulnerability.objects.values("id").filter(
state=vulnerability_filters.state
)
if not vulnerabilities_by_state.exists():
raise Http404("No Vulnerabilities found with the provided state")
vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_state)

if vulnerability_filters.organization:
domains = Domain.objects.all()
domains_by_organization = Domain.objects.values("id").filter(
organization_id=vulnerability_filters.organization
)
if not domains_by_organization.exists():
raise Http404(
"No Organization-Domain found with the provided organization ID"
)
domains = domains.filter(id__in=domains_by_organization)
vulnerabilities_by_domain = Vulnerability.objects.values("id").filter(
id__in=domains
)
if not vulnerabilities_by_domain.exists():
raise Http404(
"No Vulnerabilities found with the provided organization ID"
)
vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_domain)
if vulnerability_filters.cpe:
query = query.filter(cpe=vulnerability_filters.cpe)

if vulnerability_filters.state:
query = query.filter(state=vulnerability_filters.state)

if vulnerability_filters.isKev:
vulnerabilities_by_is_kev = Vulnerability.objects.values("id").filter(
isKev=vulnerability_filters.isKev
if vulnerability_filters.organization:
# Fetch domains based on the organization ID
domains_by_organization = Domain.objects.filter(
organization_id=vulnerability_filters.organization
)

if not domains_by_organization.exists():
raise Vulnerability.DoesNotExist(
"No Organization-Domain found with the provided organization ID"
)
if not vulnerabilities_by_is_kev.exists():
raise Http404("No Vulnerabilities found with the provided isKev value")
vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_is_kev)
return vulnerabilities
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

# Filter vulnerabilities based on the found domains
query = query.filter(domain__in=domains_by_organization)

if (
vulnerability_filters.isKev is not None
): # Check for None to distinguish between True/False
query = query.filter(isKev=vulnerability_filters.isKev)

# If the queryset is empty, raise a not found exception (404)
if not query.exists():
raise Vulnerability.DoesNotExist(
"No Vulnerabilities found with the provided filters."
)

return query
Loading

0 comments on commit 87d9391

Please sign in to comment.