Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support python type hints #1701

Merged
merged 11 commits into from
Sep 25, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -144,24 +144,7 @@ class PythonLanguageFrontend(language: Language<PythonLanguageFrontend>, ctx: Tr
autoType()
}
is Python.AST.Name -> {
// We have some kind of name here; let's quickly check, if this is a primitive type
val id = type.id
if (id in language.primitiveTypeNames) {
return primitiveType(id)
}

// Otherwise, this could already be a fully qualified type
val name =
if (language.namespaceDelimiter in id) {
// TODO: This might create problem with nested classes
parseName(id)
} else {
// otherwise, we can just simply take the unqualified name and the type
// resolver will take care of the rest
id
}

objectType(name)
this.typeOf(type.id)
}
else -> {
// The AST supplied us with some kind of type information, but we could not parse
Expand All @@ -171,6 +154,21 @@ class PythonLanguageFrontend(language: Language<PythonLanguageFrontend>, ctx: Tr
}
}

/** Resolves a [Type] based on its string identifier. */
fun typeOf(typeId: String): Type {
// Check if the typeId contains a namespace delimiter for qualified types
val name =
if (language.namespaceDelimiter in typeId) {
// TODO: This might create problem with nested classes
parseName(typeId)
} else {
// Unqualified name, resolved by the type resolver
typeId
}

return objectType(name)
}

/**
* This functions extracts the source code from the input file given a location. This is a bit
* tricky in Python, as indents are part of the syntax. We also don't want to include leading
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import de.fraunhofer.aisec.cpg.graph.declarations.*
import de.fraunhofer.aisec.cpg.graph.statements.AssertStatement
import de.fraunhofer.aisec.cpg.graph.statements.DeclarationStatement
import de.fraunhofer.aisec.cpg.graph.statements.Statement
import de.fraunhofer.aisec.cpg.graph.statements.expressions.AssignExpression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.Block
import de.fraunhofer.aisec.cpg.graph.statements.expressions.Expression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.MemberExpression
Expand Down Expand Up @@ -84,7 +85,7 @@ class StatementHandler(frontend: PythonLanguageFrontend) :
}

/**
* Translates a Python (https://docs.python.org/3/library/ast.html#ast.Assert] into a
* Translates a Python [`Assert`](https://docs.python.org/3/library/ast.html#ast.Assert) into a
* [AssertStatement].
*/
private fun handleAssert(node: Python.AST.Assert): AssertStatement {
Expand Down Expand Up @@ -196,18 +197,15 @@ class StatementHandler(frontend: PythonLanguageFrontend) :
return frontend.expressionHandler.handle(node.value)
}

private fun handleAnnAssign(node: Python.AST.AnnAssign): Statement {
// TODO: annotations
/**
* Translates a Python [`AnnAssign`](https://docs.python.org/3/library/ast.html#ast.AnnAssign)
* into a [Statement].
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
*/
private fun handleAnnAssign(node: Python.AST.AnnAssign): AssignExpression {
oxisto marked this conversation as resolved.
Show resolved Hide resolved
val lhs = frontend.expressionHandler.handle(node.target)
return if (node.value != null) {
newAssignExpression(
lhs = listOf(lhs),
rhs = listOf(frontend.expressionHandler.handle(node.value!!)), // TODO !!
rawNode = node
)
} else {
lhs
}
lhs.type = frontend.typeOf(node.annotation)
oxisto marked this conversation as resolved.
Show resolved Hide resolved
val rhs = node.value?.let { listOf(frontend.expressionHandler.handle(it)) } ?: emptyList()
return newAssignExpression(lhs = listOf(lhs), rhs = rhs, rawNode = node)
}

private fun handleIf(node: Python.AST.If): Statement {
Expand All @@ -234,8 +232,16 @@ class StatementHandler(frontend: PythonLanguageFrontend) :
return ret
}

private fun handleAssign(node: Python.AST.Assign): Statement {
/**
* Translates a Python [`Assign`](https://docs.python.org/3/library/ast.html#ast.Assign) into a
* [Statement].
*/
private fun handleAssign(node: Python.AST.Assign): AssignExpression {
val lhs = node.targets.map { frontend.expressionHandler.handle(it) }
node.type_comment?.let { typeComment ->
val tpe = frontend.typeOf(typeComment)
lhs.forEach { it.type = tpe }
}
val rhs = frontend.expressionHandler.handle(node.value)
if (rhs is List<*>)
newAssignExpression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class PythonAddDeclarationsPass(ctx: TranslationContext) : ComponentPass(ctx) {
when (node) {
is AssignExpression -> handleAssignExpression(node)
is ForEachStatement -> handleForEach(node)
is Reference -> handleReference(node)
oxisto marked this conversation as resolved.
Show resolved Hide resolved
else -> {
// Nothing to do for all other types of nodes
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,18 @@
package de.fraunhofer.aisec.cpg.frontends.python.statementHandler

import de.fraunhofer.aisec.cpg.TranslationResult
import de.fraunhofer.aisec.cpg.frontends.python.PythonLanguage
import de.fraunhofer.aisec.cpg.frontends.python.*
import de.fraunhofer.aisec.cpg.graph.*
import de.fraunhofer.aisec.cpg.graph.statements.AssertStatement
import de.fraunhofer.aisec.cpg.graph.statements.expressions.Literal
import de.fraunhofer.aisec.cpg.test.analyze
import de.fraunhofer.aisec.cpg.test.analyzeAndGetFirstTU
import de.fraunhofer.aisec.cpg.test.assertResolvedType
import de.fraunhofer.aisec.cpg.graph.statements.*
import de.fraunhofer.aisec.cpg.graph.statements.expressions.*
import de.fraunhofer.aisec.cpg.test.*
import java.nio.file.Path
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.*
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.TestInstance

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class StatementHandlerTest {
class StatementHandlerTest : BaseTest() {

private lateinit var topLevel: Path
private lateinit var result: TranslationResult
Expand Down Expand Up @@ -123,4 +119,19 @@ class StatementHandlerTest {
assertNotNull(message, "Assert statement should have a message")
assertEquals("Test message", message.value, "The assert message is incorrect")
}

@Test
fun testTypeHints() {
analyzeFile("type_hints.py")
lshala marked this conversation as resolved.
Show resolved Hide resolved

// type comments
val a = result.refs["a"]
assertNotNull(a)
assertEquals(result.assertResolvedType("int"), a.type)
lshala marked this conversation as resolved.
Show resolved Hide resolved

// type annotation
val b = result.refs["b"]
assertNotNull(b)
assertEquals(result.assertResolvedType("str"), b.type)
}
}
3 changes: 3 additions & 0 deletions cpg-language-python/src/test/resources/python/type_hints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
a = 1 #type: int

b: str
Loading