diff --git a/cve_bin_tool/cli.py b/cve_bin_tool/cli.py index ddedd7714b..5f9099aafc 100644 --- a/cve_bin_tool/cli.py +++ b/cve_bin_tool/cli.py @@ -734,15 +734,12 @@ def main(argv=None): nvd_api_key=args["nvd_api_key"], error_mode=error_mode, ) - default_sources = [source_nvd] - default_sources.extend(enabled_sources) - else: - default_sources = enabled_sources + enabled_sources = [source_nvd] + enabled_sources # Database update related settings # Connect to the database cvedb_orig = CVEDB( - sources=default_sources, + sources=enabled_sources, version_check=not version_check, error_mode=error_mode, ) @@ -1024,6 +1021,7 @@ def main(argv=None): exclude_folders=args["exclude"], error_mode=error_mode, validate=not args["disable_validation_check"], + sources=enabled_sources, ) version_scanner.remove_skiplist(skips) LOGGER.info(f"Number of checkers: {version_scanner.number_of_checkers()}") diff --git a/cve_bin_tool/cvedb.py b/cve_bin_tool/cvedb.py index 1451eaa996..c43b39c1b3 100644 --- a/cve_bin_tool/cvedb.py +++ b/cve_bin_tool/cvedb.py @@ -677,14 +677,15 @@ def get_vendor_product_pairs(self, package_names) -> list[dict[str, str]]: """ cursor = self.db_open_and_get_cursor() vendor_package_pairs = [] - query = """ - SELECT DISTINCT vendor FROM cve_range - WHERE product=? - """ + query = ( + "SELECT DISTINCT vendor FROM cve_range WHERE product=? AND data_source IN (%s)" # nosec + % ",".join("?" for i in self.sources) + ) + data_sources = list(map(lambda x: x.source_name, self.sources)) # For python package checkers we don't need the progress bar running if type(package_names) is not list: - cursor.execute(query, [package_names]) + cursor.execute(query, [package_names] + data_sources) vendors = list(map(lambda x: x[0], cursor.fetchall())) for vendor in vendors: @@ -703,7 +704,7 @@ def get_vendor_product_pairs(self, package_names) -> list[dict[str, str]]: for package_name in track( package_names, description="Processing the given list...." ): - cursor.execute(query, [package_name["name"].lower()]) + cursor.execute(query, [package_name["name"].lower()] + data_sources) vendors = list(map(lambda x: x[0], cursor.fetchall())) for vendor in vendors: if vendor != "": diff --git a/cve_bin_tool/version_scanner.py b/cve_bin_tool/version_scanner.py index 3e617ff07d..39faa6a1ad 100644 --- a/cve_bin_tool/version_scanner.py +++ b/cve_bin_tool/version_scanner.py @@ -54,6 +54,7 @@ def __init__( error_mode: ErrorMode = ErrorMode.TruncTrace, score: int = 0, validate: bool = True, + sources=None, ): self.logger = logger or LOGGER.getChild(self.__class__.__name__) # Update egg if installed in development mode @@ -76,7 +77,7 @@ def __init__( self.should_extract = should_extract self.file_stack: list[str] = [] self.error_mode = error_mode - self.cve_db = CVEDB() + self.cve_db = CVEDB(sources=sources) self.validate = validate # self.logger.info("Checkers loaded: %s" % (", ".join(self.checkers.keys()))) self.language_checkers = self.available_language_checkers()