Skip to content

Commit

Permalink
fix: removed nvd api 1.0 code (intel#3599)
Browse files Browse the repository at this point in the history
* Fixes intel#3583
Co-authored-by: Terri Oda <[email protected]>
  • Loading branch information
mastersans authored Jan 23, 2024
1 parent dbe0659 commit debdf8e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 59 deletions.
2 changes: 1 addition & 1 deletion cve_bin_tool/config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def config_generator(args, types):
{first_char}cve_data_download{last_char}
#set your nvd api key
nvd_api_key {sign} {coma}{args["nvd_api_key"]}{coma}
# choose method for getting CVE lists from NVD (default: api) other option available api2, json
# choose method for getting CVE lists from NVD (default: api2) other option available json
nvd {sign} {coma}{args["nvd"]}{coma}
# update schedule for data sources and exploits database (default: daily)
update {sign} {coma}{args["update"]}{coma}
Expand Down
12 changes: 5 additions & 7 deletions cve_bin_tool/data_sources/nvd_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ def __init__(
self.nvd_api_key = nvd_api_key

async def get_cve_data(self):
"""Retrieves the CVE data from the data source."""
await self.fetch_cves()

if self.nvd_type == "api":
return self.format_data(self.all_cve_entries), self.source_name
elif self.nvd_type == "api2":
if self.nvd_type == "api2":
return self.format_data_api2(self.all_cve_entries), self.source_name

else:
severity_data = []
affected_data = []
Expand Down Expand Up @@ -310,6 +310,7 @@ def parse_node_api2(
return affects_list

async def fetch_cves(self):
"""Fetches CVEs from the NVD data source."""
if not self.session:
connector = aiohttp.TCPConnector(limit_per_host=19)
self.session = RateLimiter(
Expand All @@ -320,7 +321,7 @@ async def fetch_cves(self):

tasks = []
LOGGER.info("Getting NVD CVE data...")
if self.nvd_type in ["api", "api2"]:
if self.nvd_type == "api2":
self.all_cve_entries = await asyncio.create_task(
self.nist_fetch_using_api(),
)
Expand Down Expand Up @@ -363,9 +364,6 @@ async def nist_fetch_using_api(self) -> list:
if self.nvd_type == "api2":
LOGGER.info("[Using NVD API 2.0]")
api_version = "2.0"
else:
LOGGER.info("[Using NVD API]")
api_version = "1.0"

# Can only do incremental update if database exists
if not db.dbpath.exists():
Expand Down
62 changes: 11 additions & 51 deletions cve_bin_tool/nvd_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from __future__ import annotations

import asyncio
import json
import math
import time
from datetime import datetime, timedelta, timezone
Expand Down Expand Up @@ -48,7 +47,7 @@ def __init__(
error_mode: ErrorMode = ErrorMode.TruncTrace,
incremental_update=False,
api_key: str = "",
api_version: str = "1.0",
api_version: str = "2.0",
max_hosts: int = MAX_HOSTS,
):
self.logger = logger or LOGGER.getChild(self.__class__.__name__)
Expand All @@ -69,25 +68,16 @@ def __init__(
self.feed = f"{feed}{self.api_version}"
self.api_key = api_key
if self.api_key != "":
if self.api_version == "1.0":
# API key is passed as URL parameter
self.params["apiKey"] = self.api_key
self.header = None
else:
if self.api_version == "2.0":
# API key is passed as part of header
self.header = HTTP_HEADERS
self.header["apiKey"] = self.api_key

else:
# Because of rate limiting
self.max_hosts = 1
self.header = HTTP_HEADERS

@staticmethod
def convert_date_to_nvd_date(date: datetime) -> str:
"""Returns a datetime string of NVD recognized date format"""
utc_date = date.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S:%f")[:-3]
return f"{utc_date} UTC-00:00"

@staticmethod
async def nvd_count_metadata(session):
"""Returns CVE Status count from NVD"""
Expand All @@ -106,18 +96,6 @@ async def nvd_count_metadata(session):
cve_count[key["name"]] = int(key["count"])
return cve_count

@staticmethod
def get_reject_count(fetched_data: dict) -> int:
"""Returns total rejected CVE count"""
all_cve_list = fetched_data["result"]["CVE_Items"]
reject_count = 0
for cve_item in all_cve_list:
if cve_item["cve"]["description"]["description_data"][0][
"value"
].startswith("** REJECT **"):
reject_count += 1
return reject_count

@staticmethod
def convert_date_to_nvd_date_api2(date: datetime) -> str:
"""Returns a datetime string of NVD recognized date format"""
Expand Down Expand Up @@ -179,18 +157,7 @@ async def get_nvd_params(
else:
if time_of_last_update:
# Fetch all the updated CVE entries from the modified date. Subtracting 2-minute offset for updating cve entries
if self.api_version == "1.0":
self.params["modStartDate"] = self.convert_date_to_nvd_date(
time_of_last_update - timedelta(minutes=2)
)
self.params["modEndDate"] = self.convert_date_to_nvd_date(
datetime.now()
)
self.params["includeMatchStringChange"] = json.dumps(True)
self.logger.info(
f'Fetching updated CVE entries after {self.params["modStartDate"]}'
)
else:
if self.api_version == "2.0":
self.params[
"lastModStartDate"
] = self.convert_date_to_nvd_date_api2(
Expand All @@ -202,6 +169,7 @@ async def get_nvd_params(
self.logger.info(
f'Fetching updated CVE entries after {self.params["lastModStartDate"]}'
)

# Check modified strings inside CVEs as well
with Progress() as progress:
task = progress.add_task(
Expand Down Expand Up @@ -246,7 +214,7 @@ async def validate_nvd_api(self):
# If the API key provided is invalid, delete from params
# list and try the request again.
self.logger.error("unset api key, retrying")
if self.api_version == "1.0":
if self.api_version == "2.0":
del self.params["apiKey"]
self.api_key = ""
else:
Expand Down Expand Up @@ -276,28 +244,20 @@ async def load_nvd_request(self, start_index):
fetched_data = await response.json()
if start_index == 0:
# Update total results in case there is discrepancy between NVD dashboard and API
reject_count = (
self.get_reject_count(fetched_data)
if self.api_version == "1.0"
else self.get_reject_count_api2(fetched_data)
)
reject_count = self.get_reject_count_api2(fetched_data)
self.total_results = (
fetched_data["totalResults"] - reject_count
)
if self.api_version == "1.0":
self.all_cve_entries.extend(
fetched_data["result"]["CVE_Items"]
)
else:
if self.api_version == "2.0":
if len(fetched_data["vulnerabilities"]) > 0:
self.all_cve_entries.extend(
fetched_data["vulnerabilities"]
)

elif response.status == 503:
if self.api_version == "1.0":
raise NVDServiceError(self.params["modStartDate"])
else:
if self.api_version == "2.0":
raise NVDServiceError(self.params["lastModStartDate"])

else:
self.logger.info(f"Response code: {response.status}")
self.logger.info(f"Response content: {response.content}")
Expand Down

0 comments on commit debdf8e

Please sign in to comment.