diff --git a/src/macaron/slsa_analyzer/checks/provenance_l3_check.py b/src/macaron/slsa_analyzer/checks/provenance_l3_check.py index 3064f822c..9b778acc3 100644 --- a/src/macaron/slsa_analyzer/checks/provenance_l3_check.py +++ b/src/macaron/slsa_analyzer/checks/provenance_l3_check.py @@ -239,21 +239,35 @@ def _extract_archive(self, file_path: str, temp_path: str) -> bool: bool Returns True if successful. """ + + def _validate_path_traversal(path: str) -> bool: + """Check for path traversal attacks.""" + if path.startswith("/") or ".." in path: + logger.debug("Found suspicious path in the archive file: %s.", path) + return False + try: + # Check if there are any symbolic links. + if os.path.realpath(path): + return True + except OSError as error: + logger.debug("Failed to extract artifact from archive file: %s", error) + return False + return False + try: if zipfile.is_zipfile(file_path): with zipfile.ZipFile(file_path, "r") as zip_file: - zip_file.extractall(temp_path) + members = (path for path in zip_file.namelist() if _validate_path_traversal(path)) + zip_file.extractall(temp_path, members=members) # nosec B202:tarfile_unsafe_members return True elif tarfile.is_tarfile(file_path): with tarfile.open(file_path, mode="r:gz") as tar_file: - tar_file.extractall(temp_path) + members_tarinfo = ( + tarinfo for tarinfo in tar_file.getmembers() if _validate_path_traversal(tarinfo.name) + ) + tar_file.extractall(temp_path, members=members_tarinfo) # nosec B202:tarfile_unsafe_members return True - except ( - tarfile.TarError, - zipfile.BadZipFile, - zipfile.LargeZipFile, - OSError, - ) as error: + except (tarfile.TarError, zipfile.BadZipFile, zipfile.LargeZipFile, OSError, ValueError) as error: logger.info(error) return False diff --git a/src/macaron/util.py b/src/macaron/util.py index 91758c720..2db5a3d56 100644 --- a/src/macaron/util.py +++ b/src/macaron/util.py @@ -36,7 +36,9 @@ def send_get_http(url: str, headers: dict) -> dict: The response's json data or an empty dict if there is an error. """ logger.debug("GET - %s", url) - response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)) + response = requests.get( + url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10) + ) # nosec B113:request_without_timeout while response.status_code != 200: logger.error( "Receiving error code %s from server. Message: %s.", @@ -47,7 +49,9 @@ def send_get_http(url: str, headers: dict) -> dict: check_rate_limit(response) else: return {} - response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)) + response = requests.get( + url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10) + ) # nosec B113:request_without_timeout return dict(response.json()) @@ -70,7 +74,9 @@ def send_get_http_raw(url: str, headers: dict) -> Response | None: The response object or None if there is an error. """ logger.debug("GET - %s", url) - response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)) + response = requests.get( + url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10) + ) # nosec B113:request_without_timeout while response.status_code != 200: logger.error( "Receiving error code %s from server. Message: %s.", @@ -81,7 +87,9 @@ def send_get_http_raw(url: str, headers: dict) -> Response | None: check_rate_limit(response) else: return None - response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)) + response = requests.get( + url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10) + ) # nosec B113:request_without_timeout return response @@ -155,7 +163,9 @@ def download_github_build_log(url: str, headers: dict) -> str: The content of the downloaded build log or empty if error. """ logger.debug("Downloading content at link %s", url) - response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)) + response = requests.get( + url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10) + ) # nosec B113:request_without_timeout return response.content.decode("utf-8")