Skip to content

Commit

Permalink
pythongh-122313: Clean up deep recursion guarding code in the compiler
Browse files Browse the repository at this point in the history
Add ENTER_RECURSIVE and LEAVE_RECURSIVE macros in ast.c, ast_opt.c and
symtable.c. Remove VISIT_QUIT macro in symtable.c.

The current recursion depth counter only needs to be updated during
normal execution -- all functions should just return an error code
if an error occurs.
  • Loading branch information
serhiy-storchaka committed Aug 3, 2024
1 parent 4b63cd1 commit 442518d
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 162 deletions.
46 changes: 22 additions & 24 deletions Python/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@ struct validator {
int recursion_limit; /* recursion limit */
};

#define ENTER_RECURSIVE(ST) \
do { \
if (++(ST)->recursion_depth > (ST)->recursion_limit) { \
PyErr_SetString(PyExc_RecursionError, \
"maximum recursion depth exceeded during compilation"); \
return 0; \
} \
} while(0)

#define LEAVE_RECURSIVE(ST) \
do { \
--(ST)->recursion_depth; \
} while(0)

static int validate_stmts(struct validator *, asdl_stmt_seq *);
static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int);
static int validate_patterns(struct validator *, asdl_pattern_seq *, int);
Expand Down Expand Up @@ -166,11 +180,7 @@ validate_constant(struct validator *state, PyObject *value)
return 1;

if (PyTuple_CheckExact(value) || PyFrozenSet_CheckExact(value)) {
if (++state->recursion_depth > state->recursion_limit) {
PyErr_SetString(PyExc_RecursionError,
"maximum recursion depth exceeded during compilation");
return 0;
}
ENTER_RECURSIVE(state);

PyObject *it = PyObject_GetIter(value);
if (it == NULL)
Expand All @@ -195,7 +205,7 @@ validate_constant(struct validator *state, PyObject *value)
}

Py_DECREF(it);
--state->recursion_depth;
LEAVE_RECURSIVE(state);
return 1;
}

Expand All @@ -213,11 +223,7 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
assert(!PyErr_Occurred());
VALIDATE_POSITIONS(exp);
int ret = -1;
if (++state->recursion_depth > state->recursion_limit) {
PyErr_SetString(PyExc_RecursionError,
"maximum recursion depth exceeded during compilation");
return 0;
}
ENTER_RECURSIVE(state);
int check_ctx = 1;
expr_context_ty actual_ctx;

Expand Down Expand Up @@ -398,7 +404,7 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
PyErr_SetString(PyExc_SystemError, "unexpected expression");
ret = 0;
}
state->recursion_depth--;
LEAVE_RECURSIVE(state);
return ret;
}

Expand Down Expand Up @@ -544,11 +550,7 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok)
assert(!PyErr_Occurred());
VALIDATE_POSITIONS(p);
int ret = -1;
if (++state->recursion_depth > state->recursion_limit) {
PyErr_SetString(PyExc_RecursionError,
"maximum recursion depth exceeded during compilation");
return 0;
}
ENTER_RECURSIVE(state);
switch (p->kind) {
case MatchValue_kind:
ret = validate_pattern_match_value(state, p->v.MatchValue.value);
Expand Down Expand Up @@ -690,7 +692,7 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok)
PyErr_SetString(PyExc_SystemError, "unexpected pattern");
ret = 0;
}
state->recursion_depth--;
LEAVE_RECURSIVE(state);
return ret;
}

Expand Down Expand Up @@ -725,11 +727,7 @@ validate_stmt(struct validator *state, stmt_ty stmt)
assert(!PyErr_Occurred());
VALIDATE_POSITIONS(stmt);
int ret = -1;
if (++state->recursion_depth > state->recursion_limit) {
PyErr_SetString(PyExc_RecursionError,
"maximum recursion depth exceeded during compilation");
return 0;
}
ENTER_RECURSIVE(state);
switch (stmt->kind) {
case FunctionDef_kind:
ret = validate_body(state, stmt->v.FunctionDef.body, "FunctionDef") &&
Expand Down Expand Up @@ -946,7 +944,7 @@ validate_stmt(struct validator *state, stmt_ty stmt)
PyErr_SetString(PyExc_SystemError, "unexpected statement");
ret = 0;
}
state->recursion_depth--;
LEAVE_RECURSIVE(state);
return ret;
}

Expand Down
39 changes: 20 additions & 19 deletions Python/ast_opt.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@ typedef struct {
int recursion_limit; /* recursion limit */
} _PyASTOptimizeState;

#define ENTER_RECURSIVE(ST) \
do { \
if (++(ST)->recursion_depth > (ST)->recursion_limit) { \
PyErr_SetString(PyExc_RecursionError, \
"maximum recursion depth exceeded during compilation"); \
return 0; \
} \
} while(0)

#define LEAVE_RECURSIVE(ST) \
do { \
--(ST)->recursion_depth; \
} while(0)

static int
make_const(expr_ty node, PyObject *val, PyArena *arena)
Expand Down Expand Up @@ -708,11 +721,7 @@ astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
static int
astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
{
if (++state->recursion_depth > state->recursion_limit) {
PyErr_SetString(PyExc_RecursionError,
"maximum recursion depth exceeded during compilation");
return 0;
}
ENTER_RECURSIVE(state);
switch (node_->kind) {
case BoolOp_kind:
CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
Expand Down Expand Up @@ -811,7 +820,7 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
case Name_kind:
if (node_->v.Name.ctx == Load &&
_PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
state->recursion_depth--;
LEAVE_RECURSIVE(state);
return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
}
break;
Expand All @@ -824,7 +833,7 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
// No default case, so the compiler will emit a warning if new expression
// kinds are added without being handled here
}
state->recursion_depth--;
LEAVE_RECURSIVE(state);;
return 1;
}

Expand Down Expand Up @@ -871,11 +880,7 @@ astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
static int
astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
{
if (++state->recursion_depth > state->recursion_limit) {
PyErr_SetString(PyExc_RecursionError,
"maximum recursion depth exceeded during compilation");
return 0;
}
ENTER_RECURSIVE(state);
switch (node_->kind) {
case FunctionDef_kind:
CALL_SEQ(astfold_type_param, type_param, node_->v.FunctionDef.type_params);
Expand Down Expand Up @@ -999,7 +1004,7 @@ astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
// No default case, so the compiler will emit a warning if new statement
// kinds are added without being handled here
}
state->recursion_depth--;
LEAVE_RECURSIVE(state);
return 1;
}

Expand Down Expand Up @@ -1031,11 +1036,7 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
// Currently, this is really only used to form complex/negative numeric
// constants in MatchValue and MatchMapping nodes
// We still recurse into all subexpressions and subpatterns anyway
if (++state->recursion_depth > state->recursion_limit) {
PyErr_SetString(PyExc_RecursionError,
"maximum recursion depth exceeded during compilation");
return 0;
}
ENTER_RECURSIVE(state);
switch (node_->kind) {
case MatchValue_kind:
CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
Expand Down Expand Up @@ -1067,7 +1068,7 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
// No default case, so the compiler will emit a warning if new pattern
// kinds are added without being handled here
}
state->recursion_depth--;
LEAVE_RECURSIVE(state);
return 1;
}

Expand Down
Loading

0 comments on commit 442518d

Please sign in to comment.