Skip to content

Commit

Permalink
Rework how a typed parameter is handled (#311)
Browse files Browse the repository at this point in the history
This change utilizes the visitor pattern to process nodes of type
typed_parameter and typed_default_parameter instead of manually checking
children of the function def. This works because typed parameters are
only within function defs.

This commit also adds test cases to ensure this is working.

This change also protects against the traceback found when encountering
an unexpected *arg as a function parameter.

Fixes #310

Signed-off-by: Eric Brown <[email protected]>
  • Loading branch information
ericwb authored Feb 29, 2024
1 parent 12e9f05 commit 836db2d
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 31 deletions.
56 changes: 25 additions & 31 deletions precli/parsers/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,41 +50,35 @@ def visit_class_definition(self, nodes: list[Node]):
def visit_function_definition(self, nodes: list[Node]):
func_id = self.child_by_type(self.context["node"], "identifier")
func = func_id.text.decode()

self.current_symtab = SymbolTable(func, parent=self.current_symtab)
self.visit(nodes)
self.current_symtab = self.current_symtab.parent()

func_parameters = self.child_by_type(
self.context["node"], "parameters"
)
func_args = func_parameters.named_children

for func_arg in func_args:
# typed_parameter or identifier
if func_arg.type == "typed_parameter":
param_id = self.child_by_type(func_arg, "identifier")
param_type = self.child_by_type(func_arg, "type")

if param_type.named_children[0].type in (
"attribute",
"identifier",
"string",
"integer",
"float",
"true",
"false",
"none",
):
param_name = param_id.text.decode()
param_type = self.literal_value(
param_type.named_children[0],
default=param_type.named_children[0],
)
self.current_symtab.put(
param_name, "identifier", param_type
)
def visit_typed_default_parameter(self, nodes: list[Node]):
self.visit_typed_parameter(nodes)

def visit_typed_parameter(self, nodes: list[Node]):
param_id = self.child_by_type(self.context["node"], "identifier")
param_type = self.child_by_type(self.context["node"], "type")

if param_id is not None and param_type.named_children[0].type in (
"attribute",
"identifier",
"string",
"integer",
"float",
"true",
"false",
"none",
):
param_name = param_id.text.decode()
param_type = self.literal_value(
param_type.named_children[0],
default=param_type.named_children[0],
)
self.current_symtab.put(param_name, "identifier", param_type)

self.visit(nodes)
self.current_symtab = self.current_symtab.parent()

def visit_named_expression(self, nodes: list[Node]):
if len(nodes) > 1 and nodes[1].text.decode() == ":=":
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# level: WARNING
# start_line: 10
# end_line: 10
# start_column: 27
# end_column: 39
import ssl


def set_curve(context: ssl.SSLContext = None) -> None:
context.set_ecdh_curve("prime192v1")


context = ssl.SSLContext()
set_curve(context)
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# level: WARNING
# start_line: 10
# end_line: 10
# start_column: 27
# end_column: 39
import ssl


def set_curve(context: ssl.SSLContext) -> None:
context.set_ecdh_curve("prime192v1")


context = ssl.SSLContext()
set_curve(context)
2 changes: 2 additions & 0 deletions tests/unit/rules/python/stdlib/test_ssl_context_weak_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def test_rule_meta(self):
"ssl_context_set_ecdh_curve_secp256r1.py",
"ssl_context_set_ecdh_curve_sect163k1.py",
"ssl_context_set_ecdh_curve_sect571k1.py",
"ssl_context_set_ecdh_curve_typed_default_param.py",
"ssl_context_set_ecdh_curve_typed_param.py",
"ssl_context_set_ecdh_curve_unverified_context.py",
]
)
Expand Down

0 comments on commit 836db2d

Please sign in to comment.