diff --git a/pytm/pytm.py b/pytm/pytm.py index 0de9a10..afe8de3 100644 --- a/pytm/pytm.py +++ b/pytm/pytm.py @@ -34,6 +34,12 @@ """ +class UIError(Exception): + def __init__(self, e, context): + self.error = e + self.context = context + + logger = logging.getLogger(__name__) @@ -777,8 +783,11 @@ def _init_threats(self): self._add_threats() def _add_threats(self): - with open(self.threatsFile, "r", encoding="utf8") as threat_file: - threats_json = json.load(threat_file) + try: + with open(self.threatsFile, "r", encoding="utf8") as threat_file: + threats_json = json.load(threat_file) + except (FileNotFoundError, PermissionError, IsADirectoryError) as e: + raise UIError(e, f"Failed while trying to open the the threat file ({self.threatsFile}).") for i in threats_json: TM._threats.append(Threat(**i)) @@ -1000,8 +1009,11 @@ def seq(self): ) def report(self, template_path): - with open(template_path) as file: - template = file.read() + try: + with open(template_path) as file: + template = file.read() + except (FileNotFoundError, PermissionError, IsADirectoryError) as e: + raise UIError(e, f"Failed while trying to open the report template file ({template_path}).") threats = encode_threat_data(TM._threats) findings = encode_threat_data(self.findings) @@ -1027,6 +1039,15 @@ def report(self, template_path): return self._sf.format(template, **data) def process(self): + try: + self._process() + except UIError as e: + print("Failed to complete requested command!") + print(f" {e.context}") + print(f" {e.error}") + sys.exit(127) + + def _process(self): self.check() result = get_args() logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") @@ -1055,8 +1076,11 @@ def process(self): self.sqlDump(result.sqldump) if result.json: - with open(result.json, "w", encoding="utf8") as f: - json.dump(self, f, default=to_serializable) + try: + with open(result.json, "w", encoding="utf8") as f: + json.dump(self, f, default=to_serializable) + except (FileExistsError, PermissionError, IsADirectoryError) as e: + raise UIError(e, f"Failed while trying to write to the result file ({result.json})") if result.report is not None: print(self.report(result.report)) diff --git a/tests/test_private_func.py b/tests/test_private_func.py index c21389a..21c6e2c 100644 --- a/tests/test_private_func.py +++ b/tests/test_private_func.py @@ -12,6 +12,7 @@ Process, Server, Threat, + UIError, ) @@ -46,10 +47,10 @@ def test_kwargs(self): def test_load_threats(self): tm = TM("TM") self.assertNotEqual(len(TM._threats), 0) - with self.assertRaises(FileNotFoundError): + with self.assertRaises(UIError): tm.threatsFile = "threats.json" - with self.assertRaises(FileNotFoundError): + with self.assertRaises(UIError): TM("TM", threatsFile="threats.json") def test_responses(self):