diff --git a/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/Python.kt b/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/Python.kt index 91899a9776..b0fbf27dad 100644 --- a/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/Python.kt +++ b/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/Python.kt @@ -433,7 +433,9 @@ interface Python { */ class Try(pyObject: PyObject) : BaseStmt(pyObject) { val body: kotlin.collections.List by lazy { "body" of pyObject } - val handlers: kotlin.collections.List by lazy { "handlers" of pyObject } + val handlers: kotlin.collections.List by lazy { + "handlers" of pyObject + } val orelse: kotlin.collections.List by lazy { "orelse" of pyObject } val stmt: kotlin.collections.List by lazy { "StmtBase" of pyObject } } @@ -446,7 +448,9 @@ interface Python { */ class TryStar(pyObject: PyObject) : BaseStmt(pyObject) { val body: kotlin.collections.List by lazy { "body" of pyObject } - val handlers: kotlin.collections.List by lazy { "handlers" of pyObject } + val handlers: kotlin.collections.List by lazy { + "handlers" of pyObject + } val orelse: kotlin.collections.List by lazy { "orelse" of pyObject } val finalbody: kotlin.collections.List by lazy { "finalbody" of pyObject } } @@ -1317,12 +1321,16 @@ interface Python { * ast.excepthandler = class excepthandler(AST) * | excepthandler = ExceptHandler(expr? type, identifier? name, stmt* body) * ``` - * - * TODO: excepthandler <-> ExceptHandler */ - class excepthandler(pyObject: PyObject) : AST(pyObject), WithLocation { - val type: BaseExpr by lazy { "type" of pyObject } - val name: String by lazy { "name" of pyObject } + sealed class BaseExcepthandler(python: PyObject) : AST(python), WithLocation + + /** + * ast.ExceptHandler = class ExceptHandler(excepthandler) | ExceptHandler(expr? type, + * identifier? name, stmt* body) + */ + class ExceptHandler(pyObject: PyObject) : BaseExcepthandler(pyObject) { + val type: BaseExpr? by lazy { "type" of pyObject } + val name: String? by lazy { "name" of pyObject } val body: kotlin.collections.List by lazy { "body" of pyObject } } diff --git a/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/PythonLanguageFrontend.kt b/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/PythonLanguageFrontend.kt index 029981cbe1..5e46113d1d 100644 --- a/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/PythonLanguageFrontend.kt +++ b/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/PythonLanguageFrontend.kt @@ -144,24 +144,7 @@ class PythonLanguageFrontend(language: Language, ctx: Tr autoType() } is Python.AST.Name -> { - // We have some kind of name here; let's quickly check, if this is a primitive type - val id = type.id - if (id in language.primitiveTypeNames) { - return primitiveType(id) - } - - // Otherwise, this could already be a fully qualified type - val name = - if (language.namespaceDelimiter in id) { - // TODO: This might create problem with nested classes - parseName(id) - } else { - // otherwise, we can just simply take the unqualified name and the type - // resolver will take care of the rest - id - } - - objectType(name) + this.typeOf(type.id) } else -> { // The AST supplied us with some kind of type information, but we could not parse @@ -171,6 +154,21 @@ class PythonLanguageFrontend(language: Language, ctx: Tr } } + /** Resolves a [Type] based on its string identifier. */ + fun typeOf(typeId: String): Type { + // Check if the typeId contains a namespace delimiter for qualified types + val name = + if (language.namespaceDelimiter in typeId) { + // TODO: This might create problem with nested classes + parseName(typeId) + } else { + // Unqualified name, resolved by the type resolver + typeId + } + + return objectType(name) + } + /** * This functions extracts the source code from the input file given a location. This is a bit * tricky in Python, as indents are part of the syntax. We also don't want to include leading @@ -320,7 +318,7 @@ fun fromPython(pyObject: Any?): Python.BaseObject { return when (objectname) { "ast.Module" -> Python.AST.Module(pyObject) - // statements + // `ast.stmt` "ast.FunctionDef" -> Python.AST.FunctionDef(pyObject) "ast.AsyncFunctionDef" -> Python.AST.AsyncFunctionDef(pyObject) "ast.ClassDef" -> Python.AST.ClassDef(pyObject) @@ -349,7 +347,7 @@ fun fromPython(pyObject: Any?): Python.BaseObject { "ast.Break" -> Python.AST.Break(pyObject) "ast.Continue" -> Python.AST.Continue(pyObject) - // `"ast.expr` + // `ast.expr` "ast.BoolOp" -> Python.AST.BoolOp(pyObject) "ast.NamedExpr" -> Python.AST.NamedExpr(pyObject) "ast.BinOp" -> Python.AST.BinOp(pyObject) @@ -378,11 +376,11 @@ fun fromPython(pyObject: Any?): Python.BaseObject { "ast.Tuple" -> Python.AST.Tuple(pyObject) "ast.Slice" -> Python.AST.Slice(pyObject) - // `"ast.boolop` + // `ast.boolop` "ast.And" -> Python.AST.And(pyObject) "ast.Or" -> Python.AST.Or(pyObject) - // `"ast.cmpop` + // `ast.cmpop` "ast.Eq" -> Python.AST.Eq(pyObject) "ast.NotEq" -> Python.AST.NotEq(pyObject) "ast.Lt" -> Python.AST.Lt(pyObject) @@ -394,12 +392,12 @@ fun fromPython(pyObject: Any?): Python.BaseObject { "ast.In" -> Python.AST.In(pyObject) "ast.NotIn" -> Python.AST.NotIn(pyObject) - // `"ast.expr_context` + // `ast.expr_context` "ast.Load" -> Python.AST.Load(pyObject) "ast.Store" -> Python.AST.Store(pyObject) "ast.Del" -> Python.AST.Del(pyObject) - // `"ast.operator` + // `ast.operator` "ast.Add" -> Python.AST.Add(pyObject) "ast.Sub" -> Python.AST.Sub(pyObject) "ast.Mult" -> Python.AST.Mult(pyObject) @@ -414,7 +412,7 @@ fun fromPython(pyObject: Any?): Python.BaseObject { "ast.BitAnd" -> Python.AST.BitAnd(pyObject) "ast.FloorDiv" -> Python.AST.FloorDiv(pyObject) - // `"ast.pattern` + // `ast.pattern` "ast.MatchValue" -> Python.AST.MatchValue(pyObject) "ast.MatchSingleton" -> Python.AST.MatchSingleton(pyObject) "ast.MatchSequence" -> Python.AST.MatchSequence(pyObject) @@ -424,18 +422,20 @@ fun fromPython(pyObject: Any?): Python.BaseObject { "ast.MatchAs" -> Python.AST.MatchAs(pyObject) "ast.MatchOr" -> Python.AST.MatchOr(pyObject) - // `"ast.unaryop` + // `ast.unaryop` "ast.Invert" -> Python.AST.Invert(pyObject) "ast.Not" -> Python.AST.Not(pyObject) "ast.UAdd" -> Python.AST.UAdd(pyObject) "ast.USub" -> Python.AST.USub(pyObject) + // `ast.excepthandler` + "ast.ExceptHandler" -> Python.AST.ExceptHandler(pyObject) + // misc "ast.alias" -> Python.AST.alias(pyObject) "ast.arg" -> Python.AST.arg(pyObject) "ast.arguments" -> Python.AST.arguments(pyObject) "ast.comprehension" -> Python.AST.comprehension(pyObject) - "ast.excepthandler" -> Python.AST.excepthandler(pyObject) "ast.keyword" -> Python.AST.keyword(pyObject) "ast.match_case" -> Python.AST.match_case(pyObject) "ast.type_ignore" -> Python.AST.type_ignore(pyObject) diff --git a/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt b/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt index da4af6a219..85b0acd39f 100644 --- a/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt +++ b/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt @@ -237,18 +237,15 @@ class StatementHandler(frontend: PythonLanguageFrontend) : return frontend.expressionHandler.handle(node.value) } - private fun handleAnnAssign(node: Python.AST.AnnAssign): Statement { - // TODO: annotations + /** + * Translates a Python [`AnnAssign`](https://docs.python.org/3/library/ast.html#ast.AnnAssign) + * into an [AssignExpression]. + */ + private fun handleAnnAssign(node: Python.AST.AnnAssign): AssignExpression { val lhs = frontend.expressionHandler.handle(node.target) - return if (node.value != null) { - newAssignExpression( - lhs = listOf(lhs), - rhs = listOf(frontend.expressionHandler.handle(node.value!!)), // TODO !! - rawNode = node - ) - } else { - lhs - } + lhs.type = frontend.typeOf(node.annotation) + val rhs = node.value?.let { listOf(frontend.expressionHandler.handle(it)) } ?: emptyList() + return newAssignExpression(lhs = listOf(lhs), rhs = rhs, rawNode = node) } private fun handleIf(node: Python.AST.If): Statement { @@ -275,8 +272,16 @@ class StatementHandler(frontend: PythonLanguageFrontend) : return ret } - private fun handleAssign(node: Python.AST.Assign): Statement { + /** + * Translates a Python [`Assign`](https://docs.python.org/3/library/ast.html#ast.Assign) into an + * [AssignExpression]. + */ + private fun handleAssign(node: Python.AST.Assign): AssignExpression { val lhs = node.targets.map { frontend.expressionHandler.handle(it) } + node.type_comment?.let { typeComment -> + val tpe = frontend.typeOf(typeComment) + lhs.forEach { it.type = tpe } + } val rhs = frontend.expressionHandler.handle(node.value) if (rhs is List<*>) newAssignExpression( diff --git a/cpg-language-python/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/python/statementHandler/StatementHandlerTest.kt b/cpg-language-python/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/python/statementHandler/StatementHandlerTest.kt index 68a45b3eef..c6f95f3bad 100644 --- a/cpg-language-python/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/python/statementHandler/StatementHandlerTest.kt +++ b/cpg-language-python/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/python/statementHandler/StatementHandlerTest.kt @@ -26,22 +26,18 @@ package de.fraunhofer.aisec.cpg.frontends.python.statementHandler import de.fraunhofer.aisec.cpg.TranslationResult -import de.fraunhofer.aisec.cpg.frontends.python.PythonLanguage +import de.fraunhofer.aisec.cpg.frontends.python.* import de.fraunhofer.aisec.cpg.graph.* -import de.fraunhofer.aisec.cpg.graph.statements.AssertStatement -import de.fraunhofer.aisec.cpg.graph.statements.expressions.Literal -import de.fraunhofer.aisec.cpg.test.analyze -import de.fraunhofer.aisec.cpg.test.analyzeAndGetFirstTU -import de.fraunhofer.aisec.cpg.test.assertResolvedType +import de.fraunhofer.aisec.cpg.graph.statements.* +import de.fraunhofer.aisec.cpg.graph.statements.expressions.* +import de.fraunhofer.aisec.cpg.test.* import java.nio.file.Path -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertNotNull +import kotlin.test.* import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.TestInstance @TestInstance(TestInstance.Lifecycle.PER_CLASS) -class StatementHandlerTest { +class StatementHandlerTest : BaseTest() { private lateinit var topLevel: Path private lateinit var result: TranslationResult @@ -123,4 +119,19 @@ class StatementHandlerTest { assertNotNull(message, "Assert statement should have a message") assertEquals("Test message", message.value, "The assert message is incorrect") } + + @Test + fun testTypeHints() { + analyzeFile("type_hints.py") + + // type comments + val a = result.refs["a"] + assertNotNull(a) + assertEquals(with(result) { assertResolvedType("int") }, a.type) + + // type annotation + val b = result.refs["b"] + assertNotNull(b) + assertEquals(with(result) { assertResolvedType("str") }, b.type) + } } diff --git a/cpg-language-python/src/test/resources/python/type_hints.py b/cpg-language-python/src/test/resources/python/type_hints.py new file mode 100644 index 0000000000..49032d1caa --- /dev/null +++ b/cpg-language-python/src/test/resources/python/type_hints.py @@ -0,0 +1,3 @@ +a = 1 #type: int + +b: str \ No newline at end of file