From 836db2deb4996a7d5f23e5d79c4be0a03b6ce110 Mon Sep 17 00:00:00 2001 From: Eric Brown Date: Thu, 29 Feb 2024 14:18:26 -0800 Subject: [PATCH] Rework how a typed parameter is handled (#311) This change utilizes the visitor pattern to process nodes of type typed_parameter and typed_default_parameter instead of manually checking children of the function def. This works because typed parameters are only within function defs. This commit also adds test cases to ensure this is working. This change also protects against the traceback found when encountering an unexpected *arg as a function parameter. Fixes #310 Signed-off-by: Eric Brown --- precli/parsers/python.py | 56 +++++++++---------- ...text_set_ecdh_curve_typed_default_param.py | 14 +++++ .../ssl_context_set_ecdh_curve_typed_param.py | 14 +++++ .../stdlib/test_ssl_context_weak_key.py | 2 + 4 files changed, 55 insertions(+), 31 deletions(-) create mode 100644 tests/unit/rules/python/stdlib/examples/ssl_context_set_ecdh_curve_typed_default_param.py create mode 100644 tests/unit/rules/python/stdlib/examples/ssl_context_set_ecdh_curve_typed_param.py diff --git a/precli/parsers/python.py b/precli/parsers/python.py index f2709ce0..c5337a54 100644 --- a/precli/parsers/python.py +++ b/precli/parsers/python.py @@ -50,41 +50,35 @@ def visit_class_definition(self, nodes: list[Node]): def visit_function_definition(self, nodes: list[Node]): func_id = self.child_by_type(self.context["node"], "identifier") func = func_id.text.decode() - self.current_symtab = SymbolTable(func, parent=self.current_symtab) + self.visit(nodes) + self.current_symtab = self.current_symtab.parent() - func_parameters = self.child_by_type( - self.context["node"], "parameters" - ) - func_args = func_parameters.named_children - - for func_arg in func_args: - # typed_parameter or identifier - if func_arg.type == "typed_parameter": - param_id = self.child_by_type(func_arg, "identifier") - param_type = self.child_by_type(func_arg, "type") - - if param_type.named_children[0].type in ( - "attribute", - "identifier", - "string", - "integer", - "float", - "true", - "false", - "none", - ): - param_name = param_id.text.decode() - param_type = self.literal_value( - param_type.named_children[0], - default=param_type.named_children[0], - ) - self.current_symtab.put( - param_name, "identifier", param_type - ) + def visit_typed_default_parameter(self, nodes: list[Node]): + self.visit_typed_parameter(nodes) + + def visit_typed_parameter(self, nodes: list[Node]): + param_id = self.child_by_type(self.context["node"], "identifier") + param_type = self.child_by_type(self.context["node"], "type") + + if param_id is not None and param_type.named_children[0].type in ( + "attribute", + "identifier", + "string", + "integer", + "float", + "true", + "false", + "none", + ): + param_name = param_id.text.decode() + param_type = self.literal_value( + param_type.named_children[0], + default=param_type.named_children[0], + ) + self.current_symtab.put(param_name, "identifier", param_type) self.visit(nodes) - self.current_symtab = self.current_symtab.parent() def visit_named_expression(self, nodes: list[Node]): if len(nodes) > 1 and nodes[1].text.decode() == ":=": diff --git a/tests/unit/rules/python/stdlib/examples/ssl_context_set_ecdh_curve_typed_default_param.py b/tests/unit/rules/python/stdlib/examples/ssl_context_set_ecdh_curve_typed_default_param.py new file mode 100644 index 00000000..b6345b47 --- /dev/null +++ b/tests/unit/rules/python/stdlib/examples/ssl_context_set_ecdh_curve_typed_default_param.py @@ -0,0 +1,14 @@ +# level: WARNING +# start_line: 10 +# end_line: 10 +# start_column: 27 +# end_column: 39 +import ssl + + +def set_curve(context: ssl.SSLContext = None) -> None: + context.set_ecdh_curve("prime192v1") + + +context = ssl.SSLContext() +set_curve(context) diff --git a/tests/unit/rules/python/stdlib/examples/ssl_context_set_ecdh_curve_typed_param.py b/tests/unit/rules/python/stdlib/examples/ssl_context_set_ecdh_curve_typed_param.py new file mode 100644 index 00000000..5d7274f6 --- /dev/null +++ b/tests/unit/rules/python/stdlib/examples/ssl_context_set_ecdh_curve_typed_param.py @@ -0,0 +1,14 @@ +# level: WARNING +# start_line: 10 +# end_line: 10 +# start_column: 27 +# end_column: 39 +import ssl + + +def set_curve(context: ssl.SSLContext) -> None: + context.set_ecdh_curve("prime192v1") + + +context = ssl.SSLContext() +set_curve(context) diff --git a/tests/unit/rules/python/stdlib/test_ssl_context_weak_key.py b/tests/unit/rules/python/stdlib/test_ssl_context_weak_key.py index 72738f7c..9042b491 100644 --- a/tests/unit/rules/python/stdlib/test_ssl_context_weak_key.py +++ b/tests/unit/rules/python/stdlib/test_ssl_context_weak_key.py @@ -49,6 +49,8 @@ def test_rule_meta(self): "ssl_context_set_ecdh_curve_secp256r1.py", "ssl_context_set_ecdh_curve_sect163k1.py", "ssl_context_set_ecdh_curve_sect571k1.py", + "ssl_context_set_ecdh_curve_typed_default_param.py", + "ssl_context_set_ecdh_curve_typed_param.py", "ssl_context_set_ecdh_curve_unverified_context.py", ] )