From 5a38c97a69da83383c4629ab31b8e1a23f65a597 Mon Sep 17 00:00:00 2001 From: KuechA <31155350+KuechA@users.noreply.github.com> Date: Sat, 29 Jul 2023 17:21:46 +0200 Subject: [PATCH] Use assign expressions in LLVM frontend (#1265) --- .../cpg/frontends/llvm/StatementHandler.kt | 76 ++++++++++--------- .../llvm/LLVMIRLanguageFrontendTest.kt | 74 +++++++++++------- 2 files changed, 86 insertions(+), 64 deletions(-) diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt index c5fe499ad1..bb741ce560 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt @@ -506,15 +506,15 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : * of a de-referenced pointer in C like `*a = 1`. */ private fun handleStore(instr: LLVMValueRef): Statement { - val binOp = newBinaryOperator("=", frontend.codeOf(instr)) - val dereference = newUnaryOperator("*", postfix = false, prefix = true, "") dereference.input = frontend.getOperandValueAtIndex(instr, 1) - binOp.lhs = dereference - binOp.rhs = frontend.getOperandValueAtIndex(instr, 0) - - return binOp + return newAssignExpression( + "=", + listOf(dereference), + listOf(frontend.getOperandValueAtIndex(instr, 0)), + frontend.codeOf(instr) + ) } /** @@ -707,10 +707,8 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : } val compoundStatement = newCompoundStatement(frontend.codeOf(instr)) - - val assignment = newBinaryOperator("=", frontend.codeOf(instr)) - assignment.lhs = base - assignment.rhs = valueToSet + val assignment = + newAssignExpression("=", listOf(base), listOf(valueToSet), frontend.codeOf(instr)) compoundStatement.addStatement(copy) compoundStatement.addStatement(assignment) @@ -848,9 +846,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : val ptrDerefAssign = newUnaryOperator("*", false, true, instrStr) ptrDerefAssign.input = frontend.getOperandValueAtIndex(instr, 0) - val assignment = newBinaryOperator("=", instrStr) - assignment.lhs = ptrDerefAssign - assignment.rhs = value + val assignment = newAssignExpression("=", listOf(ptrDerefAssign), listOf(value), instrStr) val ifStatement = newIfStatement(instrStr) ifStatement.condition = cmpExpr @@ -872,7 +868,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : val ptr = frontend.getOperandValueAtIndex(instr, 0) val value = frontend.getOperandValueAtIndex(instr, 1) val ty = value.type - val exchOp = newBinaryOperator("=", instrStr) + val exchOp = newAssignExpression("=", code = instrStr) exchOp.name = Name("atomicrmw") val ptrDeref = newUnaryOperator("*", false, true, instrStr) @@ -880,31 +876,31 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : val ptrDerefExch = newUnaryOperator("*", false, true, instrStr) ptrDerefExch.input = frontend.getOperandValueAtIndex(instr, 0) - exchOp.lhs = ptrDerefExch + exchOp.lhs = listOf(ptrDerefExch) when (operation) { LLVMAtomicRMWBinOpXchg -> { - exchOp.rhs = value + exchOp.rhs = listOf(value) } LLVMAtomicRMWBinOpFAdd, LLVMAtomicRMWBinOpAdd -> { val binaryOperator = newBinaryOperator("+", instrStr) binaryOperator.lhs = ptrDeref binaryOperator.rhs = value - exchOp.rhs = binaryOperator + exchOp.rhs = listOf(binaryOperator) } LLVMAtomicRMWBinOpFSub, LLVMAtomicRMWBinOpSub -> { val binaryOperator = newBinaryOperator("-", instrStr) binaryOperator.lhs = ptrDeref binaryOperator.rhs = value - exchOp.rhs = binaryOperator + exchOp.rhs = listOf(binaryOperator) } LLVMAtomicRMWBinOpAnd -> { val binaryOperator = newBinaryOperator("&", instrStr) binaryOperator.lhs = ptrDeref binaryOperator.rhs = value - exchOp.rhs = binaryOperator + exchOp.rhs = listOf(binaryOperator) } LLVMAtomicRMWBinOpNand -> { val binaryOperator = newBinaryOperator("|", instrStr) @@ -912,19 +908,19 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : binaryOperator.rhs = value val unaryOperator = newUnaryOperator("~", false, true, instrStr) unaryOperator.input = binaryOperator - exchOp.rhs = unaryOperator + exchOp.rhs = listOf(unaryOperator) } LLVMAtomicRMWBinOpOr -> { val binaryOperator = newBinaryOperator("|", instrStr) binaryOperator.lhs = ptrDeref binaryOperator.rhs = value - exchOp.rhs = binaryOperator + exchOp.rhs = listOf(binaryOperator) } LLVMAtomicRMWBinOpXor -> { val binaryOperator = newBinaryOperator("^", instrStr) binaryOperator.lhs = ptrDeref binaryOperator.rhs = value - exchOp.rhs = binaryOperator + exchOp.rhs = listOf(binaryOperator) } LLVMAtomicRMWBinOpMax, LLVMAtomicRMWBinOpMin -> { @@ -947,7 +943,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : value, ty, ) - exchOp.rhs = conditional + exchOp.rhs = listOf(conditional) } LLVMAtomicRMWBinOpUMax, LLVMAtomicRMWBinOpUMin -> { @@ -977,7 +973,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : value, ty, ) - exchOp.rhs = conditional + exchOp.rhs = listOf(conditional) } else -> { throw TranslationException("LLVMAtomicRMWBinOp $operation not supported") @@ -1260,10 +1256,14 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : ) arrayExpr.subscriptExpression = frontend.getOperandValueAtIndex(instr, 2) - val binaryExpr = newBinaryOperator("=", instrStr) - binaryExpr.lhs = arrayExpr - binaryExpr.rhs = frontend.getOperandValueAtIndex(instr, 1) - compoundStatement.addStatement(binaryExpr) + val assignExpr = + newAssignExpression( + "=", + listOf(arrayExpr), + listOf(frontend.getOperandValueAtIndex(instr, 1)), + instrStr + ) + compoundStatement.addStatement(assignExpr) return compoundStatement } @@ -1440,13 +1440,19 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : for ((l, r) in labelMap) { // Now, we iterate over all the basic blocks and add an assign statement. - val assignment = newBinaryOperator("=", code) - assignment.rhs = r - assignment.lhs = newDeclaredReferenceExpression(varName, type, code) - (assignment.lhs as DeclaredReferenceExpression).type = type - (assignment.lhs as DeclaredReferenceExpression).unregisterTypeListener(assignment) - assignment.unregisterTypeListener(assignment.lhs as DeclaredReferenceExpression) - (assignment.lhs as DeclaredReferenceExpression).refersTo = declaration + val assignment = + newAssignExpression( + "=", + listOf(newDeclaredReferenceExpression(varName, type, code)), + listOf(r), + code + ) + (assignment.lhs.first() as DeclaredReferenceExpression).type = type + (assignment.lhs.first() as DeclaredReferenceExpression).unregisterTypeListener( + assignment + ) + assignment.unregisterTypeListener(assignment.lhs.first() as DeclaredReferenceExpression) + (assignment.lhs.first() as DeclaredReferenceExpression).refersTo = declaration flatAST.add(assignment) val basicBlock = l.subStatement as? CompoundStatement diff --git a/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/LLVMIRLanguageFrontendTest.kt b/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/LLVMIRLanguageFrontendTest.kt index 0d98a6852b..97710bb3a0 100644 --- a/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/LLVMIRLanguageFrontendTest.kt +++ b/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/LLVMIRLanguageFrontendTest.kt @@ -388,12 +388,14 @@ class LLVMIRLanguageFrontendTest { assertLocalName("ptr", (decl.initializer as UnaryOperator).input) // Check that the replacement equals *ptr = *ptr + 1 - val replacement = (atomicrmwStatement.statements[1] as BinaryOperator) + val replacement = (atomicrmwStatement.statements[1] as AssignExpression) + assertEquals(1, replacement.lhs.size) + assertEquals(1, replacement.rhs.size) assertEquals("=", replacement.operatorCode) - assertEquals("*", (replacement.lhs as UnaryOperator).operatorCode) - assertLocalName("ptr", (replacement.lhs as UnaryOperator).input) + assertEquals("*", (replacement.lhs.first() as UnaryOperator).operatorCode) + assertLocalName("ptr", (replacement.lhs.first() as UnaryOperator).input) // Check that the rhs is equal to *ptr + 1 - val add = replacement.rhs as BinaryOperator + val add = replacement.rhs.first() as BinaryOperator assertEquals("+", add.operatorCode) assertEquals("*", (add.lhs as UnaryOperator).operatorCode) assertLocalName("ptr", (add.lhs as UnaryOperator).input) @@ -445,12 +447,14 @@ class LLVMIRLanguageFrontendTest { assertLocalName("ptr", (ifExpr.lhs as UnaryOperator).input) assertEquals(5L, (ifExpr.rhs as Literal<*>).value as Long) - val thenExpr = ifStatement.thenStatement as BinaryOperator + val thenExpr = ifStatement.thenStatement as AssignExpression + assertEquals(1, thenExpr.lhs.size) + assertEquals(1, thenExpr.rhs.size) assertEquals("=", thenExpr.operatorCode) - assertEquals("*", (thenExpr.lhs as UnaryOperator).operatorCode) - assertLocalName("ptr", (thenExpr.lhs as UnaryOperator).input) - assertLocalName("old", thenExpr.rhs as DeclaredReferenceExpression) - assertLocalName("old", (thenExpr.rhs as DeclaredReferenceExpression).refersTo) + assertEquals("*", (thenExpr.lhs.first() as UnaryOperator).operatorCode) + assertLocalName("ptr", (thenExpr.lhs.first() as UnaryOperator).input) + assertLocalName("old", thenExpr.rhs.first() as DeclaredReferenceExpression) + assertLocalName("old", (thenExpr.rhs.first() as DeclaredReferenceExpression).refersTo) } @Test @@ -596,17 +600,19 @@ class LLVMIRLanguageFrontendTest { assertEquals("i32*", alloca.type.typeName) // store i32 3, i32* %ptr - val store = main.bodyOrNull() + val store = main.bodyOrNull() assertNotNull(store) assertEquals("=", store.operatorCode) - val dereferencePtr = store.lhs as? UnaryOperator + assertEquals(1, store.lhs.size) + val dereferencePtr = store.lhs.first() as? UnaryOperator assertNotNull(dereferencePtr) assertEquals("*", dereferencePtr.operatorCode) assertEquals("i32", dereferencePtr.type.typeName) assertSame(ptr, (dereferencePtr.input as? DeclaredReferenceExpression)?.refersTo) - val value = store.rhs as? Literal<*> + assertEquals(1, store.rhs.size) + val value = store.rhs.first() as? Literal<*> assertNotNull(value) assertEquals(3L, value.value) assertEquals("i32", value.type.typeName) @@ -655,12 +661,14 @@ class LLVMIRLanguageFrontendTest { assertEquals("literal_i32_i8", copyStatement.type.typeName) // Now, we set b.field_1 to 7 - val assignment = (compoundStatement.statements[1] as BinaryOperator) + val assignment = (compoundStatement.statements[1] as AssignExpression) assertEquals("=", assignment.operatorCode) - assertLocalName("b", (assignment.lhs as MemberExpression).base) - assertEquals(".", (assignment.lhs as MemberExpression).operatorCode) - assertLocalName("field_1", assignment.lhs as MemberExpression) - assertEquals(7L, (assignment.rhs as Literal<*>).value as Long) + assertEquals(1, assignment.lhs.size) + assertEquals(1, assignment.rhs.size) + assertLocalName("b", (assignment.lhs.first() as MemberExpression).base) + assertEquals(".", (assignment.lhs.first() as MemberExpression).operatorCode) + assertLocalName("field_1", assignment.lhs.first() as MemberExpression) + assertEquals(7L, (assignment.rhs.first() as Literal<*>).value as Long) } @Test @@ -758,13 +766,15 @@ class LLVMIRLanguageFrontendTest { val thenStmt = ifStatement.thenStatement as? CompoundStatement assertNotNull(thenStmt) assertEquals(3, thenStmt.statements.size) - assertNotNull(thenStmt.statements[1] as? BinaryOperator) + assertNotNull(thenStmt.statements[1] as? AssignExpression) val aDecl = (thenStmt.statements[0] as DeclarationStatement).singleDeclaration as VariableDeclaration - val thenY = thenStmt.statements[1] as BinaryOperator - assertSame(aDecl, (thenY.rhs as DeclaredReferenceExpression).refersTo) - assertSame(yDecl, (thenY.lhs as DeclaredReferenceExpression).refersTo) + val thenY = thenStmt.statements[1] as AssignExpression + assertEquals(1, thenY.lhs.size) + assertEquals(1, thenY.rhs.size) + assertSame(aDecl, (thenY.rhs.first() as DeclaredReferenceExpression).refersTo) + assertSame(yDecl, (thenY.lhs.first() as DeclaredReferenceExpression).refersTo) val elseStmt = ifStatement.elseStatement as? CompoundStatement assertNotNull(elseStmt) @@ -772,10 +782,12 @@ class LLVMIRLanguageFrontendTest { val bDecl = (elseStmt.statements[0] as DeclarationStatement).singleDeclaration as VariableDeclaration - assertNotNull(elseStmt.statements[1] as? BinaryOperator) - val elseY = elseStmt.statements[1] as BinaryOperator - assertSame(bDecl, (elseY.rhs as DeclaredReferenceExpression).refersTo) - assertSame(yDecl, (elseY.lhs as DeclaredReferenceExpression).refersTo) + assertNotNull(elseStmt.statements[1] as? AssignExpression) + val elseY = elseStmt.statements[1] as AssignExpression + assertEquals(1, elseY.lhs.size) + assertEquals(1, elseY.lhs.size) + assertSame(bDecl, (elseY.rhs.first() as DeclaredReferenceExpression).refersTo) + assertSame(yDecl, (elseY.lhs.first() as DeclaredReferenceExpression).refersTo) val continueBlock = (thenStmt.statements[2] as? GotoStatement)?.targetLabel?.subStatement @@ -844,19 +856,23 @@ class LLVMIRLanguageFrontendTest { assertEquals("y", (yModInit.initializer as? DeclaredReferenceExpression)?.name?.localName) assertSame(origY, (yModInit.initializer as? DeclaredReferenceExpression)?.refersTo) // Now, test the modification of yMod[3] = 8 - val yMod = ((mainBody.statements[3] as CompoundStatement).statements[1] as? BinaryOperator) + val yMod = + ((mainBody.statements[3] as CompoundStatement).statements[1] as? AssignExpression) assertNotNull(yMod) + assertEquals(1, yMod.lhs.size) + assertEquals(1, yMod.rhs.size) assertEquals( 3L, - ((yMod.lhs as? ArraySubscriptionExpression)?.subscriptExpression as? Literal<*>)?.value + ((yMod.lhs.first() as? ArraySubscriptionExpression)?.subscriptExpression as? Literal<*>) + ?.value ) assertSame( yModInit, - ((yMod.lhs as? ArraySubscriptionExpression)?.arrayExpression + ((yMod.lhs.first() as? ArraySubscriptionExpression)?.arrayExpression as? DeclaredReferenceExpression) ?.refersTo ) - assertEquals(8L, (yMod.rhs as? Literal<*>)?.value) + assertEquals(8L, (yMod.rhs.first() as? Literal<*>)?.value) // Test the last shufflevector instruction which does not contain constant as initializers. val shuffledInit =