Skip to content

Commit

Permalink
Use assign expressions in LLVM frontend (#1265)
Browse files Browse the repository at this point in the history
  • Loading branch information
KuechA authored Jul 29, 2023
1 parent 470290f commit 5a38c97
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -506,15 +506,15 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) :
* of a de-referenced pointer in C like `*a = 1`.
*/
private fun handleStore(instr: LLVMValueRef): Statement {
val binOp = newBinaryOperator("=", frontend.codeOf(instr))

val dereference = newUnaryOperator("*", postfix = false, prefix = true, "")
dereference.input = frontend.getOperandValueAtIndex(instr, 1)

binOp.lhs = dereference
binOp.rhs = frontend.getOperandValueAtIndex(instr, 0)

return binOp
return newAssignExpression(
"=",
listOf(dereference),
listOf(frontend.getOperandValueAtIndex(instr, 0)),
frontend.codeOf(instr)
)
}

/**
Expand Down Expand Up @@ -707,10 +707,8 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) :
}

val compoundStatement = newCompoundStatement(frontend.codeOf(instr))

val assignment = newBinaryOperator("=", frontend.codeOf(instr))
assignment.lhs = base
assignment.rhs = valueToSet
val assignment =
newAssignExpression("=", listOf(base), listOf(valueToSet), frontend.codeOf(instr))
compoundStatement.addStatement(copy)
compoundStatement.addStatement(assignment)

Expand Down Expand Up @@ -848,9 +846,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) :
val ptrDerefAssign = newUnaryOperator("*", false, true, instrStr)
ptrDerefAssign.input = frontend.getOperandValueAtIndex(instr, 0)

val assignment = newBinaryOperator("=", instrStr)
assignment.lhs = ptrDerefAssign
assignment.rhs = value
val assignment = newAssignExpression("=", listOf(ptrDerefAssign), listOf(value), instrStr)

val ifStatement = newIfStatement(instrStr)
ifStatement.condition = cmpExpr
Expand All @@ -872,59 +868,59 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) :
val ptr = frontend.getOperandValueAtIndex(instr, 0)
val value = frontend.getOperandValueAtIndex(instr, 1)
val ty = value.type
val exchOp = newBinaryOperator("=", instrStr)
val exchOp = newAssignExpression("=", code = instrStr)
exchOp.name = Name("atomicrmw")

val ptrDeref = newUnaryOperator("*", false, true, instrStr)
ptrDeref.input = ptr

val ptrDerefExch = newUnaryOperator("*", false, true, instrStr)
ptrDerefExch.input = frontend.getOperandValueAtIndex(instr, 0)
exchOp.lhs = ptrDerefExch
exchOp.lhs = listOf(ptrDerefExch)

when (operation) {
LLVMAtomicRMWBinOpXchg -> {
exchOp.rhs = value
exchOp.rhs = listOf(value)
}
LLVMAtomicRMWBinOpFAdd,
LLVMAtomicRMWBinOpAdd -> {
val binaryOperator = newBinaryOperator("+", instrStr)
binaryOperator.lhs = ptrDeref
binaryOperator.rhs = value
exchOp.rhs = binaryOperator
exchOp.rhs = listOf(binaryOperator)
}
LLVMAtomicRMWBinOpFSub,
LLVMAtomicRMWBinOpSub -> {
val binaryOperator = newBinaryOperator("-", instrStr)
binaryOperator.lhs = ptrDeref
binaryOperator.rhs = value
exchOp.rhs = binaryOperator
exchOp.rhs = listOf(binaryOperator)
}
LLVMAtomicRMWBinOpAnd -> {
val binaryOperator = newBinaryOperator("&", instrStr)
binaryOperator.lhs = ptrDeref
binaryOperator.rhs = value
exchOp.rhs = binaryOperator
exchOp.rhs = listOf(binaryOperator)
}
LLVMAtomicRMWBinOpNand -> {
val binaryOperator = newBinaryOperator("|", instrStr)
binaryOperator.lhs = ptrDeref
binaryOperator.rhs = value
val unaryOperator = newUnaryOperator("~", false, true, instrStr)
unaryOperator.input = binaryOperator
exchOp.rhs = unaryOperator
exchOp.rhs = listOf(unaryOperator)
}
LLVMAtomicRMWBinOpOr -> {
val binaryOperator = newBinaryOperator("|", instrStr)
binaryOperator.lhs = ptrDeref
binaryOperator.rhs = value
exchOp.rhs = binaryOperator
exchOp.rhs = listOf(binaryOperator)
}
LLVMAtomicRMWBinOpXor -> {
val binaryOperator = newBinaryOperator("^", instrStr)
binaryOperator.lhs = ptrDeref
binaryOperator.rhs = value
exchOp.rhs = binaryOperator
exchOp.rhs = listOf(binaryOperator)
}
LLVMAtomicRMWBinOpMax,
LLVMAtomicRMWBinOpMin -> {
Expand All @@ -947,7 +943,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) :
value,
ty,
)
exchOp.rhs = conditional
exchOp.rhs = listOf(conditional)
}
LLVMAtomicRMWBinOpUMax,
LLVMAtomicRMWBinOpUMin -> {
Expand Down Expand Up @@ -977,7 +973,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) :
value,
ty,
)
exchOp.rhs = conditional
exchOp.rhs = listOf(conditional)
}
else -> {
throw TranslationException("LLVMAtomicRMWBinOp $operation not supported")
Expand Down Expand Up @@ -1260,10 +1256,14 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) :
)
arrayExpr.subscriptExpression = frontend.getOperandValueAtIndex(instr, 2)

val binaryExpr = newBinaryOperator("=", instrStr)
binaryExpr.lhs = arrayExpr
binaryExpr.rhs = frontend.getOperandValueAtIndex(instr, 1)
compoundStatement.addStatement(binaryExpr)
val assignExpr =
newAssignExpression(
"=",
listOf(arrayExpr),
listOf(frontend.getOperandValueAtIndex(instr, 1)),
instrStr
)
compoundStatement.addStatement(assignExpr)

return compoundStatement
}
Expand Down Expand Up @@ -1440,13 +1440,19 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) :

for ((l, r) in labelMap) {
// Now, we iterate over all the basic blocks and add an assign statement.
val assignment = newBinaryOperator("=", code)
assignment.rhs = r
assignment.lhs = newDeclaredReferenceExpression(varName, type, code)
(assignment.lhs as DeclaredReferenceExpression).type = type
(assignment.lhs as DeclaredReferenceExpression).unregisterTypeListener(assignment)
assignment.unregisterTypeListener(assignment.lhs as DeclaredReferenceExpression)
(assignment.lhs as DeclaredReferenceExpression).refersTo = declaration
val assignment =
newAssignExpression(
"=",
listOf(newDeclaredReferenceExpression(varName, type, code)),
listOf(r),
code
)
(assignment.lhs.first() as DeclaredReferenceExpression).type = type
(assignment.lhs.first() as DeclaredReferenceExpression).unregisterTypeListener(
assignment
)
assignment.unregisterTypeListener(assignment.lhs.first() as DeclaredReferenceExpression)
(assignment.lhs.first() as DeclaredReferenceExpression).refersTo = declaration
flatAST.add(assignment)

val basicBlock = l.subStatement as? CompoundStatement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,12 +388,14 @@ class LLVMIRLanguageFrontendTest {
assertLocalName("ptr", (decl.initializer as UnaryOperator).input)

// Check that the replacement equals *ptr = *ptr + 1
val replacement = (atomicrmwStatement.statements[1] as BinaryOperator)
val replacement = (atomicrmwStatement.statements[1] as AssignExpression)
assertEquals(1, replacement.lhs.size)
assertEquals(1, replacement.rhs.size)
assertEquals("=", replacement.operatorCode)
assertEquals("*", (replacement.lhs as UnaryOperator).operatorCode)
assertLocalName("ptr", (replacement.lhs as UnaryOperator).input)
assertEquals("*", (replacement.lhs.first() as UnaryOperator).operatorCode)
assertLocalName("ptr", (replacement.lhs.first() as UnaryOperator).input)
// Check that the rhs is equal to *ptr + 1
val add = replacement.rhs as BinaryOperator
val add = replacement.rhs.first() as BinaryOperator
assertEquals("+", add.operatorCode)
assertEquals("*", (add.lhs as UnaryOperator).operatorCode)
assertLocalName("ptr", (add.lhs as UnaryOperator).input)
Expand Down Expand Up @@ -445,12 +447,14 @@ class LLVMIRLanguageFrontendTest {
assertLocalName("ptr", (ifExpr.lhs as UnaryOperator).input)
assertEquals(5L, (ifExpr.rhs as Literal<*>).value as Long)

val thenExpr = ifStatement.thenStatement as BinaryOperator
val thenExpr = ifStatement.thenStatement as AssignExpression
assertEquals(1, thenExpr.lhs.size)
assertEquals(1, thenExpr.rhs.size)
assertEquals("=", thenExpr.operatorCode)
assertEquals("*", (thenExpr.lhs as UnaryOperator).operatorCode)
assertLocalName("ptr", (thenExpr.lhs as UnaryOperator).input)
assertLocalName("old", thenExpr.rhs as DeclaredReferenceExpression)
assertLocalName("old", (thenExpr.rhs as DeclaredReferenceExpression).refersTo)
assertEquals("*", (thenExpr.lhs.first() as UnaryOperator).operatorCode)
assertLocalName("ptr", (thenExpr.lhs.first() as UnaryOperator).input)
assertLocalName("old", thenExpr.rhs.first() as DeclaredReferenceExpression)
assertLocalName("old", (thenExpr.rhs.first() as DeclaredReferenceExpression).refersTo)
}

@Test
Expand Down Expand Up @@ -596,17 +600,19 @@ class LLVMIRLanguageFrontendTest {
assertEquals("i32*", alloca.type.typeName)

// store i32 3, i32* %ptr
val store = main.bodyOrNull<BinaryOperator>()
val store = main.bodyOrNull<AssignExpression>()
assertNotNull(store)
assertEquals("=", store.operatorCode)

val dereferencePtr = store.lhs as? UnaryOperator
assertEquals(1, store.lhs.size)
val dereferencePtr = store.lhs.first() as? UnaryOperator
assertNotNull(dereferencePtr)
assertEquals("*", dereferencePtr.operatorCode)
assertEquals("i32", dereferencePtr.type.typeName)
assertSame(ptr, (dereferencePtr.input as? DeclaredReferenceExpression)?.refersTo)

val value = store.rhs as? Literal<*>
assertEquals(1, store.rhs.size)
val value = store.rhs.first() as? Literal<*>
assertNotNull(value)
assertEquals(3L, value.value)
assertEquals("i32", value.type.typeName)
Expand Down Expand Up @@ -655,12 +661,14 @@ class LLVMIRLanguageFrontendTest {
assertEquals("literal_i32_i8", copyStatement.type.typeName)

// Now, we set b.field_1 to 7
val assignment = (compoundStatement.statements[1] as BinaryOperator)
val assignment = (compoundStatement.statements[1] as AssignExpression)
assertEquals("=", assignment.operatorCode)
assertLocalName("b", (assignment.lhs as MemberExpression).base)
assertEquals(".", (assignment.lhs as MemberExpression).operatorCode)
assertLocalName("field_1", assignment.lhs as MemberExpression)
assertEquals(7L, (assignment.rhs as Literal<*>).value as Long)
assertEquals(1, assignment.lhs.size)
assertEquals(1, assignment.rhs.size)
assertLocalName("b", (assignment.lhs.first() as MemberExpression).base)
assertEquals(".", (assignment.lhs.first() as MemberExpression).operatorCode)
assertLocalName("field_1", assignment.lhs.first() as MemberExpression)
assertEquals(7L, (assignment.rhs.first() as Literal<*>).value as Long)
}

@Test
Expand Down Expand Up @@ -758,24 +766,28 @@ class LLVMIRLanguageFrontendTest {
val thenStmt = ifStatement.thenStatement as? CompoundStatement
assertNotNull(thenStmt)
assertEquals(3, thenStmt.statements.size)
assertNotNull(thenStmt.statements[1] as? BinaryOperator)
assertNotNull(thenStmt.statements[1] as? AssignExpression)
val aDecl =
(thenStmt.statements[0] as DeclarationStatement).singleDeclaration
as VariableDeclaration
val thenY = thenStmt.statements[1] as BinaryOperator
assertSame(aDecl, (thenY.rhs as DeclaredReferenceExpression).refersTo)
assertSame(yDecl, (thenY.lhs as DeclaredReferenceExpression).refersTo)
val thenY = thenStmt.statements[1] as AssignExpression
assertEquals(1, thenY.lhs.size)
assertEquals(1, thenY.rhs.size)
assertSame(aDecl, (thenY.rhs.first() as DeclaredReferenceExpression).refersTo)
assertSame(yDecl, (thenY.lhs.first() as DeclaredReferenceExpression).refersTo)

val elseStmt = ifStatement.elseStatement as? CompoundStatement
assertNotNull(elseStmt)
assertEquals(3, elseStmt.statements.size)
val bDecl =
(elseStmt.statements[0] as DeclarationStatement).singleDeclaration
as VariableDeclaration
assertNotNull(elseStmt.statements[1] as? BinaryOperator)
val elseY = elseStmt.statements[1] as BinaryOperator
assertSame(bDecl, (elseY.rhs as DeclaredReferenceExpression).refersTo)
assertSame(yDecl, (elseY.lhs as DeclaredReferenceExpression).refersTo)
assertNotNull(elseStmt.statements[1] as? AssignExpression)
val elseY = elseStmt.statements[1] as AssignExpression
assertEquals(1, elseY.lhs.size)
assertEquals(1, elseY.lhs.size)
assertSame(bDecl, (elseY.rhs.first() as DeclaredReferenceExpression).refersTo)
assertSame(yDecl, (elseY.lhs.first() as DeclaredReferenceExpression).refersTo)

val continueBlock =
(thenStmt.statements[2] as? GotoStatement)?.targetLabel?.subStatement
Expand Down Expand Up @@ -844,19 +856,23 @@ class LLVMIRLanguageFrontendTest {
assertEquals("y", (yModInit.initializer as? DeclaredReferenceExpression)?.name?.localName)
assertSame(origY, (yModInit.initializer as? DeclaredReferenceExpression)?.refersTo)
// Now, test the modification of yMod[3] = 8
val yMod = ((mainBody.statements[3] as CompoundStatement).statements[1] as? BinaryOperator)
val yMod =
((mainBody.statements[3] as CompoundStatement).statements[1] as? AssignExpression)
assertNotNull(yMod)
assertEquals(1, yMod.lhs.size)
assertEquals(1, yMod.rhs.size)
assertEquals(
3L,
((yMod.lhs as? ArraySubscriptionExpression)?.subscriptExpression as? Literal<*>)?.value
((yMod.lhs.first() as? ArraySubscriptionExpression)?.subscriptExpression as? Literal<*>)
?.value
)
assertSame(
yModInit,
((yMod.lhs as? ArraySubscriptionExpression)?.arrayExpression
((yMod.lhs.first() as? ArraySubscriptionExpression)?.arrayExpression
as? DeclaredReferenceExpression)
?.refersTo
)
assertEquals(8L, (yMod.rhs as? Literal<*>)?.value)
assertEquals(8L, (yMod.rhs.first() as? Literal<*>)?.value)

// Test the last shufflevector instruction which does not contain constant as initializers.
val shuffledInit =
Expand Down

0 comments on commit 5a38c97

Please sign in to comment.