diff --git a/cve_bin_tool/config_generator.py b/cve_bin_tool/config_generator.py index dea9e60127..c8416fde61 100644 --- a/cve_bin_tool/config_generator.py +++ b/cve_bin_tool/config_generator.py @@ -24,7 +24,7 @@ def config_generator(args, types): {first_char}cve_data_download{last_char} #set your nvd api key nvd_api_key {sign} {coma}{args["nvd_api_key"]}{coma} - # choose method for getting CVE lists from NVD (default: api) other option available api2, json + # choose method for getting CVE lists from NVD (default: api2) other option available json nvd {sign} {coma}{args["nvd"]}{coma} # update schedule for data sources and exploits database (default: daily) update {sign} {coma}{args["update"]}{coma} diff --git a/cve_bin_tool/data_sources/nvd_source.py b/cve_bin_tool/data_sources/nvd_source.py index e1490d724b..fa7d5302ea 100644 --- a/cve_bin_tool/data_sources/nvd_source.py +++ b/cve_bin_tool/data_sources/nvd_source.py @@ -89,12 +89,12 @@ def __init__( self.nvd_api_key = nvd_api_key async def get_cve_data(self): + """Retrieves the CVE data from the data source.""" await self.fetch_cves() - if self.nvd_type == "api": - return self.format_data(self.all_cve_entries), self.source_name - elif self.nvd_type == "api2": + if self.nvd_type == "api2": return self.format_data_api2(self.all_cve_entries), self.source_name + else: severity_data = [] affected_data = [] @@ -310,6 +310,7 @@ def parse_node_api2( return affects_list async def fetch_cves(self): + """Fetches CVEs from the NVD data source.""" if not self.session: connector = aiohttp.TCPConnector(limit_per_host=19) self.session = RateLimiter( @@ -320,7 +321,7 @@ async def fetch_cves(self): tasks = [] LOGGER.info("Getting NVD CVE data...") - if self.nvd_type in ["api", "api2"]: + if self.nvd_type == "api2": self.all_cve_entries = await asyncio.create_task( self.nist_fetch_using_api(), ) @@ -363,9 +364,6 @@ async def nist_fetch_using_api(self) -> list: if self.nvd_type == "api2": LOGGER.info("[Using NVD API 2.0]") api_version = "2.0" - else: - LOGGER.info("[Using NVD API]") - api_version = "1.0" # Can only do incremental update if database exists if not db.dbpath.exists(): diff --git a/cve_bin_tool/nvd_api.py b/cve_bin_tool/nvd_api.py index 5777d8d1d7..33da2c53f6 100644 --- a/cve_bin_tool/nvd_api.py +++ b/cve_bin_tool/nvd_api.py @@ -9,7 +9,6 @@ from __future__ import annotations import asyncio -import json import math import time from datetime import datetime, timedelta, timezone @@ -48,7 +47,7 @@ def __init__( error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False, api_key: str = "", - api_version: str = "1.0", + api_version: str = "2.0", max_hosts: int = MAX_HOSTS, ): self.logger = logger or LOGGER.getChild(self.__class__.__name__) @@ -69,25 +68,16 @@ def __init__( self.feed = f"{feed}{self.api_version}" self.api_key = api_key if self.api_key != "": - if self.api_version == "1.0": - # API key is passed as URL parameter - self.params["apiKey"] = self.api_key - self.header = None - else: + if self.api_version == "2.0": # API key is passed as part of header self.header = HTTP_HEADERS self.header["apiKey"] = self.api_key + else: # Because of rate limiting self.max_hosts = 1 self.header = HTTP_HEADERS - @staticmethod - def convert_date_to_nvd_date(date: datetime) -> str: - """Returns a datetime string of NVD recognized date format""" - utc_date = date.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S:%f")[:-3] - return f"{utc_date} UTC-00:00" - @staticmethod async def nvd_count_metadata(session): """Returns CVE Status count from NVD""" @@ -106,18 +96,6 @@ async def nvd_count_metadata(session): cve_count[key["name"]] = int(key["count"]) return cve_count - @staticmethod - def get_reject_count(fetched_data: dict) -> int: - """Returns total rejected CVE count""" - all_cve_list = fetched_data["result"]["CVE_Items"] - reject_count = 0 - for cve_item in all_cve_list: - if cve_item["cve"]["description"]["description_data"][0][ - "value" - ].startswith("** REJECT **"): - reject_count += 1 - return reject_count - @staticmethod def convert_date_to_nvd_date_api2(date: datetime) -> str: """Returns a datetime string of NVD recognized date format""" @@ -179,18 +157,7 @@ async def get_nvd_params( else: if time_of_last_update: # Fetch all the updated CVE entries from the modified date. Subtracting 2-minute offset for updating cve entries - if self.api_version == "1.0": - self.params["modStartDate"] = self.convert_date_to_nvd_date( - time_of_last_update - timedelta(minutes=2) - ) - self.params["modEndDate"] = self.convert_date_to_nvd_date( - datetime.now() - ) - self.params["includeMatchStringChange"] = json.dumps(True) - self.logger.info( - f'Fetching updated CVE entries after {self.params["modStartDate"]}' - ) - else: + if self.api_version == "2.0": self.params[ "lastModStartDate" ] = self.convert_date_to_nvd_date_api2( @@ -202,6 +169,7 @@ async def get_nvd_params( self.logger.info( f'Fetching updated CVE entries after {self.params["lastModStartDate"]}' ) + # Check modified strings inside CVEs as well with Progress() as progress: task = progress.add_task( @@ -246,7 +214,7 @@ async def validate_nvd_api(self): # If the API key provided is invalid, delete from params # list and try the request again. self.logger.error("unset api key, retrying") - if self.api_version == "1.0": + if self.api_version == "2.0": del self.params["apiKey"] self.api_key = "" else: @@ -276,28 +244,20 @@ async def load_nvd_request(self, start_index): fetched_data = await response.json() if start_index == 0: # Update total results in case there is discrepancy between NVD dashboard and API - reject_count = ( - self.get_reject_count(fetched_data) - if self.api_version == "1.0" - else self.get_reject_count_api2(fetched_data) - ) + reject_count = self.get_reject_count_api2(fetched_data) self.total_results = ( fetched_data["totalResults"] - reject_count ) - if self.api_version == "1.0": - self.all_cve_entries.extend( - fetched_data["result"]["CVE_Items"] - ) - else: + if self.api_version == "2.0": if len(fetched_data["vulnerabilities"]) > 0: self.all_cve_entries.extend( fetched_data["vulnerabilities"] ) + elif response.status == 503: - if self.api_version == "1.0": - raise NVDServiceError(self.params["modStartDate"]) - else: + if self.api_version == "2.0": raise NVDServiceError(self.params["lastModStartDate"]) + else: self.logger.info(f"Response code: {response.status}") self.logger.info(f"Response content: {response.content}")