Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start with python match statement #1801

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.neo4j.ogm.annotation.Relationship

/**
* Represents a Java or C++ switch statement of the `switch (selector) {...}` that can include case
* and default statements. Break statements break out of the switch and labeled breaks in JAva are
* and default statements. Break statements break out of the switch and labeled breaks in Java are
* handled properly.
*/
class SwitchStatement : Statement(), BranchingNode {
Expand All @@ -51,7 +51,7 @@ class SwitchStatement : Statement(), BranchingNode {

@Relationship(value = "SELECTOR_DECLARATION")
var selectorDeclarationEdge = astOptionalEdgeOf<Declaration>()
/** C++ allows to use a declaration instead of a expression as selector */
/** C++ allows to use a declaration instead of an expression as selector */
var selectorDeclaration by unwrapping(SwitchStatement::selectorDeclarationEdge)

@Relationship(value = "STATEMENT") var statementEdge = astOptionalEdgeOf<Statement>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,29 @@
}
}

fun joinListWithBinOp(
KuechA marked this conversation as resolved.
Show resolved Hide resolved
operatorCode: String,
nodes: List<Expression>,
rawNode: Python.AST.AST? = null,

Check warning on line 165 in cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandler.kt

View check run for this annotation

Codecov / codecov/patch

cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandler.kt#L165

Added line #L165 was not covered by tests
isImplicit: Boolean = true
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
): BinaryOperator {
val lastTwo = newBinaryOperator(operatorCode, rawNode = rawNode)
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
lastTwo.rhs = nodes.last()
lastTwo.lhs = nodes[nodes.size - 2]
return nodes.subList(0, nodes.size - 2).foldRight(lastTwo) { newVal, start ->
val nextValue = newBinaryOperator(operatorCode)
if (isImplicit && rawNode != null)
nextValue.implicit(
code = frontend.codeOf(rawNode),
location = frontend.locationOf(rawNode)
)
else if (isImplicit) nextValue.implicit()
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
nextValue.rhs = start
nextValue.lhs = newVal
nextValue
}
}

private fun handleStarred(node: Python.AST.Starred): Expression {
val unaryOp = newUnaryOperator("*", postfix = false, prefix = false, rawNode = node)
unaryOp.input = handle(node.value)
Expand Down Expand Up @@ -203,18 +226,7 @@
rawNode = node
)
} else {
// Start with the last two operands, then keep prepending the previous ones until the
// list is finished.
val lastTwo = newBinaryOperator(op, rawNode = node)
lastTwo.rhs = handle(node.values.last())
lastTwo.lhs = handle(node.values[node.values.size - 2])
return node.values.subList(0, node.values.size - 2).foldRight(lastTwo) { newVal, start
->
val nextValue = newBinaryOperator(op, rawNode = node)
nextValue.rhs = start
nextValue.lhs = handle(newVal)
nextValue
}
joinListWithBinOp(op, node.values.map(::handle), node, true)
KuechA marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ interface Python {
* ```
*/
class MatchSingleton(pyObject: PyObject) : BasePattern(pyObject) {
val value: Any by lazy { "value" of pyObject }
val value: Any? by lazy { "value" of pyObject }
KuechA marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
is Python.AST.Global -> handleGlobal(node)
is Python.AST.Nonlocal -> handleNonLocal(node)
is Python.AST.Raise -> handleRaise(node)
is Python.AST.Match,
is Python.AST.Match -> handleMatch(node)
is Python.AST.TryStar ->
newProblemExpression(
"The statement of class ${node.javaClass} is not supported yet",
Expand All @@ -86,6 +86,106 @@
}
}

/**
* Translates a pattern which can be used by a `match_case`. There are various options available
* and all of them are translated to traditional comparisons and logical expressions which could
* also be seen in the condition of an if-statement.
*/
fun handlePattern(node: Python.AST.BasePattern, selector: String): Expression {
KuechA marked this conversation as resolved.
Show resolved Hide resolved
return when (node) {
is Python.AST.MatchValue ->
newBinaryOperator("==", node).implicit().apply {
this.lhs = newReference(selector)
this.rhs = frontend.expressionHandler.handle(node.value)
}
is Python.AST.MatchSingleton ->
newBinaryOperator("===", node).implicit().apply {
this.lhs = newReference(selector)
this.rhs =
when (val value = node.value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like this should be an extra function (at least if we use this logic somewhere else, too). Have you seen easyConstant in the ExpressionHandler?.
We only expect True, False or None here according to the doc.

is Python.AST.BaseExpr -> frontend.expressionHandler.handle(value)
null ->
newProblemExpression(
"Can't handle value 'None'/'null' in value of Python.AST.MatchSingleton yet"
)
else ->
newProblemExpression(
"Can't handle ${value::class} in value of Python.AST.MatchSingleton yet"

Check warning on line 113 in cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt

View check run for this annotation

Codecov / codecov/patch

cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt#L112-L113

Added lines #L112 - L113 were not covered by tests
)
}
}
is Python.AST.MatchOr ->
frontend.expressionHandler.joinListWithBinOp(
"or",
node.patterns.map { handlePattern(it, selector) },
node
)
is Python.AST.MatchSequence,
is Python.AST.MatchMapping,
is Python.AST.MatchClass,
is Python.AST.MatchStar,
is Python.AST.MatchAs ->
newProblemExpression("Cannot handle of type ${node::class} yet")
else -> newProblemExpression("Cannot handle of type ${node::class} yet")

Check warning on line 129 in cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt

View check run for this annotation

Codecov / codecov/patch

cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt#L129

Added line #L129 was not covered by tests
}
}

/**
* Translates a [`match_case`](https://docs.python.org/3/library/ast.html#ast.match_case) to a
* [Block] which holds the [CaseStatement] and then all other statements of the
* [Python.AST.match_case.body].
*
* The [CaseStatement] is generated by the [Python.AST.match_case.pattern] and, if available,
* [Python.AST.match_case.guard]. If there's a `guard` present, we model the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"If there's a guard present, we model the... " this is a bit confusing. There is also a CaseStmt if there is no guard, isn't there?

* [CaseStatement.caseExpression] as an `AND` BinaryOperator, where the `lhs` is the normal
* pattern and the `rhs` is the guard. This is in line with
* [PEP 634](https://peps.python.org/pep-0634/).
*/
fun handleCase(node: Python.AST.match_case, selector: String): List<Statement> {
val statements = mutableListOf<Statement>()
// First, we add the caseStatement
statements +=
newCaseStatement(node).apply {
this.caseExpression =
node.guard?.let {
newBinaryOperator("and")
.implicit(
code = frontend.codeOf(node),
location = frontend.locationOf(node)
)
.apply {
this.lhs = handlePattern(node.pattern, selector)
this.rhs = frontend.expressionHandler.handle(it)
}
} ?: handlePattern(node.pattern, selector)
}
// Now, we add the remaining body.
statements += node.body.map(::handle)
// Currently, the EOG pass requires a break statement to work as expected. For this reason,
// we insert an implicit break statement at the end of the block.
statements +=
newBreakStatement()
.implicit(code = frontend.codeOf(node), location = frontend.locationOf(node))
return statements
}

/**
* Translates a Python [`Match`](https://docs.python.org/3/library/ast.html#ast.Match) into a
* [SwitchStatement].
*/
fun handleMatch(node: Python.AST.Match): Statement {
return newSwitchStatement(node).apply {
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
val selector = frontend.expressionHandler.handle(node.subject)
this.selector = selector

this.statement =
node.cases.fold(newBlock().implicit()) { block, case ->
block.statements += handleCase(case, selector.name.localName)
block
}
}
}

/**
* Translates a Python [`Raise`](https://docs.python.org/3/library/ast.html#ast.Raise) into a
* [ThrowStatement].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,78 @@ class StatementHandlerTest : BaseTest() {
assertNotNull(result)
}

@Test
fun testMatch() {
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
analyzeFile("match.py")

val func = result.functions["matcher"]
assertNotNull(func)

val switchStatement = func.switches.singleOrNull()
assertNotNull(switchStatement)

assertLocalName("x", switchStatement.selector)
assertIs<Reference>(switchStatement.selector)
val paramX = func.parameters.singleOrNull()
assertNotNull(paramX)
assertRefersTo(switchStatement.selector, paramX)

val statementBlock = switchStatement.statement as? Block
assertNotNull(statementBlock)
KuechA marked this conversation as resolved.
Show resolved Hide resolved
val caseSingleton = statementBlock[0]
assertIs<CaseStatement>(caseSingleton)
val singletonCheck = caseSingleton.caseExpression
assertIs<BinaryOperator>(singletonCheck)
assertNotNull(singletonCheck)
assertEquals("===", singletonCheck.operatorCode)
assertRefersTo(singletonCheck.lhs, paramX)
assertIs<ProblemExpression>(singletonCheck.rhs)
assertIs<BreakStatement>(statementBlock[2])

val caseValue = statementBlock[3]
assertIs<CaseStatement>(caseValue)
val valueCheck = caseValue.caseExpression
assertIs<BinaryOperator>(valueCheck)
assertNotNull(valueCheck)
assertEquals("==", valueCheck.operatorCode)
assertRefersTo(valueCheck.lhs, paramX)
assertLiteralValue("value", valueCheck.rhs)
assertIs<BreakStatement>(statementBlock[5])

val caseAnd = statementBlock[6]
assertIs<CaseStatement>(caseAnd)
val andExpr = caseAnd.caseExpression
assertIs<BinaryOperator>(andExpr)
assertEquals("and", andExpr.operatorCode)
val andRhs = andExpr.rhs
assertIs<BinaryOperator>(andRhs)
assertEquals(">", andRhs.operatorCode)
assertRefersTo(andRhs.lhs, paramX)
assertLiteralValue(0L, andRhs.rhs)
assertIs<BreakStatement>(statementBlock[8])

assertIs<CaseStatement>(statementBlock[9])
assertIs<BreakStatement>(statementBlock[11])
assertIs<CaseStatement>(statementBlock[12])
assertIs<BreakStatement>(statementBlock[14])
assertIs<CaseStatement>(statementBlock[15])
assertIs<BreakStatement>(statementBlock[17])
assertIs<CaseStatement>(statementBlock[18])
assertIs<BreakStatement>(statementBlock[20])
assertIs<CaseStatement>(statementBlock[21])
assertIs<BreakStatement>(statementBlock[23])
assertIs<CaseStatement>(statementBlock[24])
assertIs<BreakStatement>(statementBlock[26])

val caseOr = statementBlock[27]
assertIs<CaseStatement>(caseOr)
val orExpr = caseOr.caseExpression
assertIs<BinaryOperator>(orExpr)
assertNotNull(orExpr)
assertEquals("or", orExpr.operatorCode)
assertIs<BreakStatement>(statementBlock[29])
}

@Test
fun testTry() {
val tu =
Expand Down
22 changes: 22 additions & 0 deletions cpg-language-python/src/test/resources/python/match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
def matcher(x):
match x:
case None:
print("singleton" + x)
case "value":
print("value" + x)
case [x] if x>0:
print(x)
case [1, 2]:
print("sequence" + x)
case [1, 2, *rest]:
print("star" + x)
case [*_]:
print("star2" + x)
case {1: _, 2: _}:
print("mapping" + x)
case Point2D(0, 0):
print("class" + x)
case [x] as y:
print("as" + y)
case [x] | [y]:
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
print("or" + x)
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
Loading