diff --git a/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/LLVMIRLanguageFrontendTest.kt b/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/LLVMIRLanguageFrontendTest.kt index 0f6e0567d7..dbeeb0ad41 100644 --- a/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/LLVMIRLanguageFrontendTest.kt +++ b/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/LLVMIRLanguageFrontendTest.kt @@ -71,15 +71,21 @@ class LLVMIRLanguageFrontendTest { assertNotNull(main) assertLocalName("i32", main.type) - val xVector = - (main.bodyOrNull(0)?.statements?.get(0) as? DeclarationStatement) - ?.singleDeclaration as? VariableDeclaration - val xInit = xVector?.initializer as? InitializerListExpression - assertNotNull(xInit) - assertLocalName("poison", xInit.initializers[0] as? Reference) - assertEquals(0L, (xInit.initializers[1] as? Literal<*>)?.value) - assertEquals(0L, (xInit.initializers[2] as? Literal<*>)?.value) - assertEquals(0L, (xInit.initializers[3] as? Literal<*>)?.value) + // We want to see that the declaration is the very first statement of the method body (it's + // wrapped inside another block). + val mainBody = main.bodyOrNull(0) + assertIs(mainBody) + val declarationStmt = mainBody.statements.firstOrNull() + assertIs(declarationStmt) + val xVector = declarationStmt.singleDeclaration + assertIs(xVector) + val xInit = xVector.initializer + assertIs(xInit) + assertIs(xInit.initializers[0]) + assertLocalName("poison", xInit.initializers[0]) + assertLiteralValue(0L, xInit.initializers[1]) + assertLiteralValue(0L, xInit.initializers[2]) + assertLiteralValue(0L, xInit.initializers[3]) } @Test @@ -104,25 +110,25 @@ class LLVMIRLanguageFrontendTest { assertNotNull(rand) assertNull(rand.body) - val decl = tu.variables["x"] - assertNotNull(decl) + val xDeclaration = tu.variables["x"] + assertNotNull(xDeclaration) - val call = decl.initializer as? CallExpression - assertNotNull(call) + val call = xDeclaration.initializer + assertIs(call) assertLocalName("rand", call) - assertTrue(call.invokes.contains(rand)) + assertContains(call.invokes, rand) assertEquals(0, call.arguments.size) val xorStatement = main.bodyOrNull(3) assertNotNull(xorStatement) - val xorDecl = xorStatement.singleDeclaration as? VariableDeclaration - assertNotNull(xorDecl) - assertLocalName("a", xorDecl) - assertEquals("i32", xorDecl.type.typeName) + val xorDeclaration = xorStatement.singleDeclaration + assertIs(xorDeclaration) + assertLocalName("a", xorDeclaration) + assertEquals("i32", xorDeclaration.type.typeName) - val xor = xorDecl.initializer as? BinaryOperator - assertNotNull(xor) + val xor = xorDeclaration.initializer + assertIs(xor) assertEquals("^", xor.operatorCode) } @@ -163,58 +169,57 @@ class LLVMIRLanguageFrontendTest { val s = foo.parameters.firstOrNull { it.name.localName == "s" } assertNotNull(s) - val arrayidx = - foo.bodyOrNull(0)?.singleDeclaration as? VariableDeclaration + val arrayidx = foo.variables["arrayidx"] assertNotNull(arrayidx) // arrayidx will be assigned to a chain of the following expressions: // &s[1].field2.field1[5][13] // we will check them in the reverse order (after the unary operator) - val unary = arrayidx.initializer as? UnaryOperator - assertNotNull(unary) + val unary = arrayidx.initializer + assertIs(unary) assertEquals("&", unary.operatorCode) - var arrayExpr = unary.input as? SubscriptExpression - assertNotNull(arrayExpr) + var arrayExpr = unary.input + assertIs(arrayExpr) assertLocalName("13", arrayExpr) - assertEquals( + assertLiteralValue( 13L, - (arrayExpr.subscriptExpression as? Literal<*>)?.value + arrayExpr.subscriptExpression ) // should this be integer instead of long? - arrayExpr = arrayExpr.arrayExpression as? SubscriptExpression - assertNotNull(arrayExpr) + arrayExpr = arrayExpr.arrayExpression + assertIs(arrayExpr) assertLocalName("5", arrayExpr) - assertEquals( + assertLiteralValue( 5L, - (arrayExpr.subscriptExpression as? Literal<*>)?.value + arrayExpr.subscriptExpression ) // should this be integer instead of long? - var memberExpression = arrayExpr.arrayExpression as? MemberExpression - assertNotNull(memberExpression) + var memberExpression = arrayExpr.arrayExpression + assertIs(memberExpression) assertLocalName("field_1", memberExpression) - memberExpression = memberExpression.base as? MemberExpression - assertNotNull(memberExpression) + memberExpression = memberExpression.base + assertIs(memberExpression) assertLocalName("field_2", memberExpression) - arrayExpr = memberExpression.base as? SubscriptExpression - assertNotNull(arrayExpr) + arrayExpr = memberExpression.base + assertIs(arrayExpr) assertLocalName("1", arrayExpr) - assertEquals( + assertLiteralValue( 1L, - (arrayExpr.subscriptExpression as? Literal<*>)?.value + arrayExpr.subscriptExpression ) // should this be integer instead of long? - val ref = arrayExpr.arrayExpression as? Reference - assertNotNull(ref) + val ref = arrayExpr.arrayExpression + assertIs(ref) assertLocalName("s", ref) - assertSame(s, ref.refersTo) + assertRefersTo(ref, s) } @Test - fun testSwitchCase() { // TODO: Update the test + fun testSwitchCase() { val topLevel = Path.of("src", "test", "resources", "llvm") val tu = analyzeAndGetFirstTU( @@ -231,17 +236,17 @@ class LLVMIRLanguageFrontendTest { val onzeroLabel = main.labels.getOrNull(0) assertNotNull(onzeroLabel) assertLocalName("onzero", onzeroLabel) - assertTrue(onzeroLabel.subStatement is Block) + assertIs(onzeroLabel.subStatement) val ononeLabel = main.labels.getOrNull(1) assertNotNull(ononeLabel) assertLocalName("onone", ononeLabel) - assertTrue(ononeLabel.subStatement is Block) + assertIs(ononeLabel.subStatement) val defaultLabel = main.labels.getOrNull(2) assertNotNull(defaultLabel) assertLocalName("otherwise", defaultLabel) - assertTrue(defaultLabel.subStatement is Block) + assertIs(defaultLabel.subStatement) // Check that the type of %a is i32 val a = main.variables["a"] @@ -254,21 +259,24 @@ class LLVMIRLanguageFrontendTest { assertNotNull(switchStatement) // Check that we have switch(a) - assertSame(a, (switchStatement.selector as Reference).refersTo) + assertRefersTo(switchStatement.selector, a) - val cases = switchStatement.statement as Block + val cases = switchStatement.statement + assertIs(cases) // Check that the first case is case 0 -> goto onzero and that the BB is inlined - val case1 = cases.statements[0] as CaseStatement - assertEquals(0L, (case1.caseExpression as Literal<*>).value as Long) + val case1 = cases.statements[0] + assertIs(case1) + assertLiteralValue(0L, case1.caseExpression) assertSame(onzeroLabel.subStatement, cases.statements[1]) // Check that the second case is case 1 -> goto onone and that the BB is inlined - val case2 = cases.statements[2] as CaseStatement - assertEquals(1L, (case2.caseExpression as Literal<*>).value as Long) + val case2 = cases.statements[2] + assertIs(case2) + assertLiteralValue(1L, case2.caseExpression) assertSame(ononeLabel.subStatement, cases.statements[3]) // Check that the default location is inlined val defaultStatement = cases.statements[4] as? DefaultStatement - assertNotNull(defaultStatement) + assertIs(defaultStatement) assertSame(defaultLabel.subStatement, cases.statements[5]) } @@ -288,60 +296,75 @@ class LLVMIRLanguageFrontendTest { // Test that the types and values of the comparison expression are correct val icmpStatement = main.bodyOrNull(1) assertNotNull(icmpStatement) - val variableDecl = icmpStatement.declarations[0] as VariableDeclaration - val comparison = variableDecl.initializer as BinaryOperator + val variableDecl = icmpStatement.declarations[0] + assertIs(variableDecl) + val comparison = variableDecl.initializer + assertIs(comparison) assertEquals("==", comparison.operatorCode) - val rhs = (comparison.rhs as Literal<*>) - val lhs = (comparison.lhs as Reference).refersTo as VariableDeclaration - assertEquals(10L, (rhs.value as Long)) + val rhs = comparison.rhs + assertIs>(rhs) + assertLiteralValue(10L, rhs) assertEquals(tu.primitiveType("i32"), rhs.type) - assertLocalName("x", comparison.lhs as Reference) - assertLocalName("x", lhs) - assertEquals(tu.primitiveType("i32"), lhs.type) + val lhsRef = comparison.lhs + assertIs(lhsRef) + assertLocalName("x", lhsRef) + val lhsDeclaration = lhsRef.refersTo + assertIs(lhsDeclaration) + assertLocalName("x", lhsDeclaration) + assertSame(tu.primitiveType("i32"), lhsDeclaration.type) // Check that the jump targets are set correctly val ifStatement = main.ifs.firstOrNull() assertNotNull(ifStatement) - assertEquals("IfUnequal", (ifStatement.elseStatement!! as GotoStatement).labelName) - val ifBranch = (ifStatement.thenStatement as Block) + val elseStatement = ifStatement.elseStatement + assertIs(elseStatement) + assertEquals("IfUnequal", elseStatement.labelName) + val thenBranch = ifStatement.thenStatement + assertIs(thenBranch) // Check that the condition is set correctly val ifCondition = ifStatement.condition - assertSame(variableDecl, (ifCondition as Reference).refersTo) + assertRefersTo(ifCondition, variableDecl) - val elseBranch = - (ifStatement.elseStatement!! as GotoStatement).targetLabel?.subStatement as Block + val elseBranch = elseStatement.targetLabel?.subStatement + assertIs(elseBranch) assertEquals(2, elseBranch.statements.size) assertEquals(" %y = mul i32 %x, 32768", elseBranch.statements[0].code) assertEquals(" ret i32 %y", elseBranch.statements[1].code) // Check that it's the correct then-branch - assertEquals(2, ifBranch.statements.size) - assertEquals(" %condUnsigned = icmp ugt i32 %x, -3", ifBranch.statements[0].code) - - val ifBranchVariableDecl = - (ifBranch.statements[0] as DeclarationStatement).declarations[0] as VariableDeclaration - val ifBranchComp = ifBranchVariableDecl.initializer as BinaryOperator + assertEquals(2, thenBranch.statements.size) + assertEquals(" %condUnsigned = icmp ugt i32 %x, -3", thenBranch.statements[0].code) + + val ifBranchDeclarationStatement = thenBranch.statements[0] + assertIs(ifBranchDeclarationStatement) + val ifBranchVariableDeclaration = ifBranchDeclarationStatement.declarations[0] + assertIs(ifBranchVariableDeclaration) + val ifBranchComp = ifBranchVariableDeclaration.initializer + assertIs(ifBranchComp) assertEquals(">", ifBranchComp.operatorCode) - assertTrue(ifBranchComp.rhs is CastExpression) - assertTrue(ifBranchComp.lhs is CastExpression) + assertIs(ifBranchComp.rhs) + assertIs(ifBranchComp.lhs) - val ifBranchCompRhs = ifBranchComp.rhs as CastExpression + val ifBranchCompRhs = ifBranchComp.rhs + assertIs(ifBranchCompRhs) assertEquals(tu.objectType("ui32"), ifBranchCompRhs.castType) assertEquals(tu.objectType("ui32"), ifBranchCompRhs.type) - val ifBranchCompLhs = ifBranchComp.lhs as CastExpression + val ifBranchCompLhs = ifBranchComp.lhs + assertIs(ifBranchCompLhs) assertEquals(tu.objectType("ui32"), ifBranchCompLhs.castType) assertEquals(tu.objectType("ui32"), ifBranchCompLhs.type) - val declRefExpr = ifBranchCompLhs.expression as Reference - assertEquals(-3, ((ifBranchCompRhs.expression as Literal<*>).value as Long)) + val declRefExpr = ifBranchCompLhs.expression + assertIs(declRefExpr) assertLocalName("x", declRefExpr) - // TODO: declRefExpr.refersTo is null. Is that expected/intended? + assertLiteralValue(-3L, ifBranchCompRhs.expression) + assertNotNull(declRefExpr.refersTo) - val ifBranchSecondStatement = ifBranch.statements[1] as? IfStatement - assertNotNull(ifBranchSecondStatement) - val ifRet = ifBranchSecondStatement.thenStatement as? Block - assertNotNull(ifRet) + val ifBranchSecondStatement = thenBranch.statements[1] + assertIs(ifBranchSecondStatement) + val ifRet = ifBranchSecondStatement.thenStatement + assertIs(ifRet) assertEquals(1, ifRet.statements.size) assertEquals(" ret i32 1", ifRet.statements[0].code) } @@ -365,25 +388,34 @@ class LLVMIRLanguageFrontendTest { assertNotNull(atomicrmwStatement) // Check that the value is assigned to - val decl = (atomicrmwStatement.statements[0].declarations[0] as VariableDeclaration) - assertLocalName("old", decl) - assertLocalName("i32", decl.type) - assertEquals("*", (decl.initializer as UnaryOperator).operatorCode) - assertLocalName("ptr", (decl.initializer as UnaryOperator).input) + val declaration = atomicrmwStatement.statements[0].declarations[0] + assertIs(declaration) + assertLocalName("old", declaration) + assertLocalName("i32", declaration.type) + val initializer = declaration.initializer + assertIs(initializer) + assertEquals("*", initializer.operatorCode) + assertLocalName("ptr", initializer.input) // Check that the replacement equals *ptr = *ptr + 1 - val replacement = (atomicrmwStatement.statements[1] as AssignExpression) + val replacement = atomicrmwStatement.statements[1] + assertIs(replacement) assertEquals(1, replacement.lhs.size) assertEquals(1, replacement.rhs.size) assertEquals("=", replacement.operatorCode) - assertEquals("*", (replacement.lhs.first() as UnaryOperator).operatorCode) - assertLocalName("ptr", (replacement.lhs.first() as UnaryOperator).input) + val replacementLhs = replacement.lhs.first() + assertIs(replacementLhs) + assertEquals("*", replacementLhs.operatorCode) + assertLocalName("ptr", replacementLhs.input) // Check that the rhs is equal to *ptr + 1 - val add = replacement.rhs.first() as BinaryOperator + val add = replacement.rhs.first() + assertIs(add) assertEquals("+", add.operatorCode) - assertEquals("*", (add.lhs as UnaryOperator).operatorCode) - assertLocalName("ptr", (add.lhs as UnaryOperator).input) - assertEquals(1L, (add.rhs as Literal<*>).value as Long) + val addLhs = add.lhs + assertIs(addLhs) + assertEquals("*", addLhs.operatorCode) + assertLocalName("ptr", addLhs.input) + assertLiteralValue(1L, add.rhs) } @Test @@ -407,38 +439,53 @@ class LLVMIRLanguageFrontendTest { // Check that the first statement is "literal_i32_i1 val_success = literal_i32_i1(*ptr, *ptr // == 5)" - val decl = (cmpxchgStatement.statements[0].declarations[0] as VariableDeclaration) - assertLocalName("val_success", decl) - assertLocalName("literal_i32_i1", decl.type) + val declaration = cmpxchgStatement.statements[0].declarations[0] + assertIs(declaration) + assertLocalName("val_success", declaration) + assertLocalName("literal_i32_i1", declaration.type) // Check that the first value is *ptr - val value1 = (decl.initializer as ConstructExpression).arguments[0] as UnaryOperator + val declarationInitializer = declaration.initializer + assertIs(declarationInitializer) + val value1 = declarationInitializer.arguments[0] + assertIs(value1) assertEquals("*", value1.operatorCode) assertLocalName("ptr", value1.input) // Check that the first value is *ptr == 5 - val value2 = (decl.initializer as ConstructExpression).arguments[1] as BinaryOperator + val value2 = declarationInitializer.arguments[1] + assertIs(value2) assertEquals("==", value2.operatorCode) - assertEquals("*", (value2.lhs as UnaryOperator).operatorCode) - assertLocalName("ptr", (value2.lhs as UnaryOperator).input) - assertEquals(5L, (value2.rhs as Literal<*>).value as Long) - - val ifStatement = cmpxchgStatement.statements[1] as IfStatement + val value2Lhs = value2.lhs + assertIs(value2Lhs) + assertEquals("*", value2Lhs.operatorCode) + assertLocalName("ptr", value2Lhs.input) + assertLiteralValue(5L, value2.rhs) + + val ifStatement = cmpxchgStatement.statements[1] + assertIs(ifStatement) // The condition is the same as the second value above - val ifExpr = ifStatement.condition as BinaryOperator + val ifExpr = ifStatement.condition + assertIs(ifExpr) assertEquals("==", ifExpr.operatorCode) - assertEquals("*", (ifExpr.lhs as UnaryOperator).operatorCode) - assertLocalName("ptr", (ifExpr.lhs as UnaryOperator).input) - assertEquals(5L, (ifExpr.rhs as Literal<*>).value as Long) - - val thenExpr = ifStatement.thenStatement as AssignExpression + val ifExprLhs = ifExpr.lhs + assertIs(ifExprLhs) + assertEquals("*", ifExprLhs.operatorCode) + assertLocalName("ptr", ifExprLhs.input) + assertLiteralValue(5L, ifExpr.rhs) + + val thenExpr = ifStatement.thenStatement + assertIs(thenExpr) assertEquals(1, thenExpr.lhs.size) assertEquals(1, thenExpr.rhs.size) assertEquals("=", thenExpr.operatorCode) - assertEquals("*", (thenExpr.lhs.first() as UnaryOperator).operatorCode) - assertLocalName("ptr", (thenExpr.lhs.first() as UnaryOperator).input) - assertLocalName("old", thenExpr.rhs.first() as Reference) - assertLocalName("old", (thenExpr.rhs.first() as Reference).refersTo) + val thenExprLhs = thenExpr.lhs.first() + assertIs(thenExprLhs) + assertEquals("*", thenExprLhs.operatorCode) + assertLocalName("ptr", thenExprLhs.input) + assertIs(thenExpr.rhs.first()) + assertLocalName("old", thenExpr.rhs.first()) + assertRefersTo(thenExpr.rhs.first(), tu.variables["old"]) } @Test @@ -456,13 +503,15 @@ class LLVMIRLanguageFrontendTest { val foo = tu.functions["foo"] assertNotNull(foo) - val decl = foo.variables["value_loaded"] - assertNotNull(decl) - assertLocalName("i1", decl.type) + val declaration = foo.variables["value_loaded"] + assertNotNull(declaration) + assertLocalName("i1", declaration.type) - assertLocalName("val_success", (decl.initializer as MemberExpression).base) - assertEquals(".", (decl.initializer as MemberExpression).operatorCode) - assertLocalName("field_1", decl.initializer as MemberExpression) + val initializer = declaration.initializer + assertIs(initializer) + assertLocalName("val_success", initializer.base) + assertEquals(".", initializer.operatorCode) + assertLocalName("field_1", initializer) } @Test @@ -481,28 +530,26 @@ class LLVMIRLanguageFrontendTest { val foo = tu.functions["foo"] assertNotNull(foo) - assertEquals("literal_i32_i8", foo.type.typeName) + val fooType = foo.type + assertIs(fooType) + assertEquals("literal_i32_i8", fooType.typeName) - val record = (foo.type as? ObjectType)?.recordDeclaration + val record = fooType.recordDeclaration assertNotNull(record) assertEquals(2, record.fields.size) - val returnStatement = foo.bodyOrNull(0) + val returnStatement = foo.returns.singleOrNull() assertNotNull(returnStatement) - val construct = returnStatement.returnValue as? ConstructExpression - assertNotNull(construct) + val construct = returnStatement.returnValue + assertIs(construct) assertEquals(2, construct.arguments.size) - var arg = construct.arguments.getOrNull(0) as? Literal<*> - assertNotNull(arg) - assertEquals("i32", arg.type.typeName) - assertEquals(4L, arg.value) + assertEquals("i32", construct.arguments[0].type.typeName) + assertLiteralValue(4L, construct.arguments[0]) - arg = construct.arguments.getOrNull(1) as? Literal<*> - assertNotNull(arg) - assertEquals("i8", arg.type.typeName) - assertEquals(2L, arg.value) + assertEquals("i8", construct.arguments[1].type.typeName) + assertLiteralValue(2L, construct.arguments[1]) } @Test @@ -534,26 +581,30 @@ class LLVMIRLanguageFrontendTest { assertNotNull(loadXStatement) assertLocalName("locX", loadXStatement.singleDeclaration) - val initXOp = - (loadXStatement.singleDeclaration as VariableDeclaration).initializer as UnaryOperator + val initXOpDeclaration = loadXStatement.singleDeclaration + assertIs(initXOpDeclaration) + val initXOp = initXOpDeclaration.initializer + assertIs(initXOp) assertEquals("*", initXOp.operatorCode) - var ref = initXOp.input as? Reference - assertNotNull(ref) + var ref = initXOp.input + assertIs(ref) assertLocalName("x", ref) - assertSame(globalX, ref.refersTo) + assertRefersTo(ref, globalX) val loadAStatement = main.bodyOrNull(2) assertNotNull(loadAStatement) + val loadADeclaration = loadAStatement.singleDeclaration + assertIs(loadADeclaration) assertLocalName("locA", loadAStatement.singleDeclaration) - val initAOp = - (loadAStatement.singleDeclaration as VariableDeclaration).initializer as UnaryOperator + val initAOp = loadADeclaration.initializer + assertIs(initAOp) assertEquals("*", initAOp.operatorCode) - ref = initAOp.input as? Reference - assertNotNull(ref) + ref = initAOp.input + assertIs(ref) assertLocalName("a", ref) - assertSame(globalA, ref.refersTo) + assertRefersTo(ref, globalA) } @Test @@ -570,11 +621,11 @@ class LLVMIRLanguageFrontendTest { assertNotNull(main) // %ptr = alloca i32 - val ptr = main.bodyOrNull()?.singleDeclaration as? VariableDeclaration - assertNotNull(ptr) + val ptr = main.bodyOrNull()?.singleDeclaration + assertIs(ptr) - val alloca = ptr.initializer as? NewArrayExpression - assertNotNull(alloca) + val alloca = ptr.initializer + assertIs(alloca) assertEquals("i32*", alloca.type.typeName) // store i32 3, i32* %ptr @@ -583,16 +634,16 @@ class LLVMIRLanguageFrontendTest { assertEquals("=", store.operatorCode) assertEquals(1, store.lhs.size) - val dereferencePtr = store.lhs.first() as? UnaryOperator - assertNotNull(dereferencePtr) + val dereferencePtr = store.lhs.firstOrNull() + assertIs(dereferencePtr) assertEquals("*", dereferencePtr.operatorCode) assertEquals("i32", dereferencePtr.type.typeName) - assertSame(ptr, (dereferencePtr.input as? Reference)?.refersTo) + assertRefersTo(dereferencePtr.input, ptr) assertEquals(1, store.rhs.size) - val value = store.rhs.first() as? Literal<*> - assertNotNull(value) - assertEquals(3L, value.value) + val value = store.rhs.firstOrNull() + assertIs>(value) + assertLiteralValue(3L, value) assertEquals("i32", value.type.typeName) } @@ -614,17 +665,22 @@ class LLVMIRLanguageFrontendTest { assertNotNull(foo) assertEquals("literal_i32_i8", foo.type.typeName) - val record = (foo.type as? ObjectType)?.recordDeclaration + val fooType = foo.type + assertIs(fooType) + val record = fooType.recordDeclaration assertNotNull(record) assertEquals(2, record.fields.size) - val declStatement = foo.bodyOrNull() - assertNotNull(declStatement) + val declarationStatement = foo.bodyOrNull() + assertNotNull(declarationStatement) - val varDecl = declStatement.singleDeclaration as VariableDeclaration - assertLocalName("a", varDecl) - assertEquals("literal_i32_i8", varDecl.type.typeName) - val args = (varDecl.initializer as ConstructExpression).arguments + val varDeclaration = declarationStatement.singleDeclaration + assertIs(varDeclaration) + assertLocalName("a", varDeclaration) + assertEquals("literal_i32_i8", varDeclaration.type.typeName) + val initializer = varDeclaration.initializer + assertIs(initializer) + val args = initializer.arguments assertEquals(2, args.size) assertLiteralValue(100L, args[0]) assertLiteralValue(null, args[1]) @@ -645,10 +701,12 @@ class LLVMIRLanguageFrontendTest { assertEquals("=", assign.operatorCode) assertEquals(1, assign.lhs.size) assertEquals(1, assign.rhs.size) - assertLocalName("b", (assign.lhs.first() as MemberExpression).base) - assertEquals(".", (assign.lhs.first() as MemberExpression).operatorCode) - assertLocalName("field_1", assign.lhs.first() as MemberExpression) - assertEquals(7L, (assign.rhs.first() as Literal<*>).value as Long) + val assignLhs = assign.lhs.first() + assertIs(assignLhs) + assertLocalName("b", assignLhs.base) + assertEquals(".", assignLhs.operatorCode) + assertLocalName("field_1", assignLhs) + assertLiteralValue(7L, assign.rhs.first()) } @Test @@ -668,41 +726,44 @@ class LLVMIRLanguageFrontendTest { val main = tu.functions["main"] assertNotNull(main) - val mainBody = main.body as Block - val tryStatement = mainBody.statements[0] as? TryStatement - assertNotNull(tryStatement) + val mainBody = main.body + assertIs(mainBody) + val tryStatement = mainBody.statements[0] + assertIs(tryStatement) // Check the assignment of the function call - val resDecl = - (tryStatement.tryBlock?.statements?.get(0) as? DeclarationStatement)?.singleDeclaration - as? VariableDeclaration - assertNotNull(resDecl) - assertLocalName("res", resDecl) - val call = resDecl.initializer as? CallExpression - assertNotNull(call) + val resDeclarationStatement = tryStatement.tryBlock?.statements?.get(0) + assertIs(resDeclarationStatement) + val resDeclaration = resDeclarationStatement.singleDeclaration + assertIs(resDeclaration) + assertLocalName("res", resDeclaration) + val call = resDeclaration.initializer + assertIs(call) assertLocalName("throwingFoo", call) - assertTrue(call.invokes.contains(throwingFoo)) + assertContains(call.invokes, throwingFoo) assertEquals(0, call.arguments.size) // Check that the second part of the try-block is inlined by the pass - val aDecl = - (tryStatement.tryBlock?.statements?.get(1) as? DeclarationStatement)?.singleDeclaration - as? VariableDeclaration - assertNotNull(aDecl) - assertLocalName("a", aDecl) - val resStatement = tryStatement.tryBlock?.statements?.get(2) as? ReturnStatement - assertNotNull(resStatement) + val aDeclarationStatement = tryStatement.tryBlock?.statements?.get(1) + assertIs(aDeclarationStatement) + val aDeclaration = aDeclarationStatement.singleDeclaration + assertIs(aDeclaration) + assertLocalName("a", aDeclaration) + val resStatement = tryStatement.tryBlock?.statements?.get(2) + assertIs(resStatement) // Check that the catch block is inlined by the pass assertEquals(1, tryStatement.catchClauses.size) assertEquals(5, tryStatement.catchClauses[0].body?.statements?.size) assertLocalName("_ZTIi | ...", tryStatement.catchClauses[0]) - val ifStatement = tryStatement.catchClauses[0].body?.statements?.get(4) as? IfStatement - assertNotNull(ifStatement) - assertTrue(ifStatement.thenStatement is Block) - assertEquals(4, (ifStatement.thenStatement as Block).statements.size) - assertTrue(ifStatement.elseStatement is Block) - assertEquals(1, (ifStatement.elseStatement as Block).statements.size) + val ifStatement = tryStatement.catchClauses[0].body?.statements?.get(4) + assertIs(ifStatement) + val thenStatement = ifStatement.thenStatement + assertIs(thenStatement) + assertEquals(4, thenStatement.statements.size) + val elseStatement = ifStatement.elseStatement + assertIs(elseStatement) + assertEquals(1, elseStatement.statements.size) } @Test @@ -726,48 +787,52 @@ class LLVMIRLanguageFrontendTest { val main = tu.functions["main"] assertNotNull(main) - val mainBody = main.body as Block - val yDecl = - (mainBody.statements[0] as DeclarationStatement).singleDeclaration - as VariableDeclaration - assertNotNull(yDecl) + val mainBody = main.body + assertIs(mainBody) + val yDeclarationStatement = mainBody.statements[0] + assertIs(yDeclarationStatement) + val yDecl = yDeclarationStatement.singleDeclaration + assertIs(yDecl) - val ifStatement = mainBody.statements[3] as? IfStatement - assertNotNull(ifStatement) + val ifStatement = mainBody.statements[3] + assertIs(ifStatement) - val thenStmt = ifStatement.thenStatement as? Block - assertNotNull(thenStmt) + val thenStmt = ifStatement.thenStatement + assertIs(thenStmt) assertEquals(3, thenStmt.statements.size) - assertNotNull(thenStmt.statements[1] as? AssignExpression) - val aDecl = - (thenStmt.statements[0] as DeclarationStatement).singleDeclaration - as VariableDeclaration - val thenY = thenStmt.statements[1] as AssignExpression + val aDeclarationStatement = thenStmt.statements[0] + assertIs(aDeclarationStatement) + val aDecl = aDeclarationStatement.singleDeclaration + assertIs(aDecl) + val thenY = thenStmt.statements[1] + assertIs(thenY) assertEquals(1, thenY.lhs.size) assertEquals(1, thenY.rhs.size) - assertSame(aDecl, (thenY.rhs.first() as Reference).refersTo) - assertSame(yDecl, (thenY.lhs.first() as Reference).refersTo) + assertRefersTo(thenY.rhs.first(), aDecl) + assertRefersTo(thenY.lhs.first(), yDecl) - val elseStmt = ifStatement.elseStatement as? Block - assertNotNull(elseStmt) + val elseStmt = ifStatement.elseStatement + assertIs(elseStmt) assertEquals(3, elseStmt.statements.size) - val bDecl = - (elseStmt.statements[0] as DeclarationStatement).singleDeclaration - as VariableDeclaration - assertNotNull(elseStmt.statements[1] as? AssignExpression) - val elseY = elseStmt.statements[1] as AssignExpression + val bDeclarationStatement = elseStmt.statements[0] + assertIs(bDeclarationStatement) + val bDecl = bDeclarationStatement.singleDeclaration + assertIs(bDecl) + val elseY = elseStmt.statements[1] + assertIs(elseY) assertEquals(1, elseY.lhs.size) assertEquals(1, elseY.lhs.size) - assertSame(bDecl, (elseY.rhs.first() as Reference).refersTo) - assertSame(yDecl, (elseY.lhs.first() as Reference).refersTo) - - val continueBlock = - (thenStmt.statements[2] as? GotoStatement)?.targetLabel?.subStatement as? Block - assertNotNull(continueBlock) - assertEquals( - yDecl, - ((continueBlock.statements[1] as ReturnStatement).returnValue as Reference).refersTo - ) + assertRefersTo(elseY.rhs.first(), bDecl) + assertRefersTo(elseY.lhs.first(), yDecl) + + val gotoStatement = thenStmt.statements[2] + assertIs(gotoStatement) + val continueBlock = gotoStatement.targetLabel?.subStatement + assertIs(continueBlock) + val returnStatement = continueBlock.statements[1] + assertIs(returnStatement) + assertIs(returnStatement.returnValue) + assertRefersTo(returnStatement.returnValue, yDecl) } @Test @@ -781,99 +846,85 @@ class LLVMIRLanguageFrontendTest { assertNotNull(main) // Test that x is initialized correctly - val mainBody = main.body as Block - val origX = - ((mainBody.statements[0] as? DeclarationStatement)?.singleDeclaration - as? VariableDeclaration) - val xInit = origX?.initializer as? InitializerListExpression - assertNotNull(xInit) - assertEquals(10L, (xInit.initializers[0] as? Literal<*>)?.value) - assertEquals(9L, (xInit.initializers[1] as? Literal<*>)?.value) - assertEquals(6L, (xInit.initializers[2] as? Literal<*>)?.value) - assertEquals(-100L, (xInit.initializers[3] as? Literal<*>)?.value) + val mainBody = main.body + assertIs(mainBody) + val xDeclarationStatement = mainBody.statements[0] + assertIs(xDeclarationStatement) + val origX = xDeclarationStatement.singleDeclaration + assertIs(origX) + val xInit = origX.initializer + assertIs(xInit) + assertLiteralValue(10L, xInit.initializers[0]) + assertLiteralValue(9L, xInit.initializers[1]) + assertLiteralValue(6L, xInit.initializers[2]) + assertLiteralValue(-100L, xInit.initializers[3]) // Test that y is initialized correctly - val origY = - ((mainBody.statements[1] as? DeclarationStatement)?.singleDeclaration - as? VariableDeclaration) - val yInit = origY?.initializer as? InitializerListExpression - assertNotNull(yInit) - assertEquals(15L, (yInit.initializers[0] as? Literal<*>)?.value) - assertEquals(34L, (yInit.initializers[1] as? Literal<*>)?.value) - assertEquals(99L, (yInit.initializers[2] as? Literal<*>)?.value) - assertEquals(1000L, (yInit.initializers[3] as? Literal<*>)?.value) + + val yDeclarationStatement = mainBody.statements[1] + assertIs(yDeclarationStatement) + val origY = yDeclarationStatement.singleDeclaration + assertIs(origY) + val yInit = origY.initializer + assertIs(yInit) + assertLiteralValue(15L, yInit.initializers[0]) + assertLiteralValue(34L, yInit.initializers[1]) + assertLiteralValue(99L, yInit.initializers[2]) + assertLiteralValue(1000L, yInit.initializers[3]) // Test that extractelement works - val zInit = - ((mainBody.statements[2] as? DeclarationStatement)?.singleDeclaration - as? VariableDeclaration) - ?.initializer as? SubscriptExpression - assertNotNull(zInit) - assertEquals(0L, (zInit.subscriptExpression as? Literal<*>)?.value) - assertEquals("x", (zInit.arrayExpression as? Reference)?.name?.localName) - assertSame(origX, (zInit.arrayExpression as? Reference)?.refersTo) + val zDeclarationStatement = mainBody.statements[2] + assertIs(zDeclarationStatement) + val origZ = zDeclarationStatement.singleDeclaration + assertIs(origZ) + val zInit = origZ.initializer + assertIs(zInit) + assertLiteralValue(0L, zInit.subscriptExpression) + assertLocalName("x", zInit.arrayExpression) + assertRefersTo(zInit.arrayExpression, origX) // Test the assignment of y to yMod - val yModInit = - ((mainBody.statements[3] as Block).statements[0] as? DeclarationStatement) - ?.singleDeclaration as? VariableDeclaration - assertNotNull(yModInit) - assertEquals("y", (yModInit.initializer as? Reference)?.name?.localName) - assertSame(origY, (yModInit.initializer as? Reference)?.refersTo) + val yModDeclarationStatementBlock = mainBody.statements[3] + assertIs(yModDeclarationStatementBlock) + val yModDeclarationStatement = yModDeclarationStatementBlock.statements[0] + assertIs(yModDeclarationStatement) + val modY = yModDeclarationStatement.singleDeclaration + assertIs(modY) + val yModInit = modY.initializer + assertIs(yModInit) + assertLocalName("y", yModInit) + assertRefersTo(yModInit, origY) + // Now, test the modification of yMod[3] = 8 - val yMod = ((mainBody.statements[3] as Block).statements[1] as? AssignExpression) - assertNotNull(yMod) + val yMod = yModDeclarationStatementBlock.statements[1] + assertIs(yMod) assertEquals(1, yMod.lhs.size) assertEquals(1, yMod.rhs.size) - assertEquals( - 3L, - ((yMod.lhs.first() as? SubscriptExpression)?.subscriptExpression as? Literal<*>)?.value - ) - assertSame( - yModInit, - ((yMod.lhs.first() as? SubscriptExpression)?.arrayExpression as? Reference)?.refersTo - ) - assertEquals(8L, (yMod.rhs.first() as? Literal<*>)?.value) + val yModLhs = yMod.lhs.first() + assertIs(yModLhs) + assertLiteralValue(3L, yModLhs.subscriptExpression) + assertRefersTo(yModLhs.arrayExpression, modY) + assertLiteralValue(8L, yMod.rhs.first()) // Test the last shufflevector instruction which does not contain constant as initializers. - val shuffledInit = - ((mainBody.statements[4] as? DeclarationStatement)?.singleDeclaration - as? VariableDeclaration) - ?.initializer as? InitializerListExpression - assertNotNull(shuffledInit) - assertSame( - origX, - ((shuffledInit.initializers[0] as? SubscriptExpression)?.arrayExpression as? Reference) - ?.refersTo - ) - assertSame( - yModInit, - ((shuffledInit.initializers[1] as? SubscriptExpression)?.arrayExpression as? Reference) - ?.refersTo - ) - assertSame( - yModInit, - ((shuffledInit.initializers[2] as? SubscriptExpression)?.arrayExpression as? Reference) - ?.refersTo - ) - assertSame( - 1, - ((shuffledInit.initializers[0] as? SubscriptExpression)?.subscriptExpression - as? Literal<*>) - ?.value - ) - assertSame( - 2, - ((shuffledInit.initializers[1] as? SubscriptExpression)?.subscriptExpression - as? Literal<*>) - ?.value - ) - assertSame( - 3, - ((shuffledInit.initializers[2] as? SubscriptExpression)?.subscriptExpression - as? Literal<*>) - ?.value - ) + val shuffledInitDeclarationStatement = mainBody.statements[4] + assertIs(shuffledInitDeclarationStatement) + val shuffledInitDeclaration = shuffledInitDeclarationStatement.singleDeclaration + assertIs(shuffledInitDeclaration) + val shuffledInit = shuffledInitDeclaration.initializer + assertIs(shuffledInit) + val shuffledInit0 = shuffledInit.initializers[0] + assertIs(shuffledInit0) + val shuffledInit1 = shuffledInit.initializers[1] + assertIs(shuffledInit1) + val shuffledInit2 = shuffledInit.initializers[2] + assertIs(shuffledInit2) + assertRefersTo(shuffledInit0.arrayExpression, origX) + assertRefersTo(shuffledInit1.arrayExpression, modY) + assertRefersTo(shuffledInit2.arrayExpression, modY) + assertLiteralValue(1, shuffledInit0.subscriptExpression) + assertLiteralValue(2, shuffledInit1.subscriptExpression) + assertLiteralValue(3, shuffledInit2.subscriptExpression) } @Test @@ -887,19 +938,20 @@ class LLVMIRLanguageFrontendTest { assertNotNull(main) // Test that x is initialized correctly - val mainBody = main.body as Block + val mainBody = main.body + assertIs(mainBody) - val fenceCall = mainBody.statements[0] as? CallExpression - assertNotNull(fenceCall) + val fenceCall = mainBody.statements[0] + assertIs(fenceCall) assertEquals(1, fenceCall.arguments.size) - assertEquals(2, (fenceCall.arguments[0] as Literal<*>).value) + assertLiteralValue(2, fenceCall.arguments[0]) - val fenceCallScope = mainBody.statements[2] as? CallExpression - assertNotNull(fenceCallScope) + val fenceCallScope = mainBody.statements[2] + assertIs(fenceCallScope) assertEquals(2, fenceCallScope.arguments.size) // TODO: This doesn't match but it doesn't seem to be our mistake // assertEquals(5, (fenceCallScope.arguments[0] as Literal<*>).value) - assertEquals("scope", (fenceCallScope.arguments[1] as Literal<*>).value) + assertLiteralValue("scope", fenceCallScope.arguments[1]) } @Test @@ -917,26 +969,25 @@ class LLVMIRLanguageFrontendTest { val funcF = tu.functions["f"] assertNotNull(funcF) - val tryStatement = - (funcF.bodyOrNull(0)?.subStatement as? Block) - ?.statements - ?.firstOrNull { s -> s is TryStatement } - assertIs(tryStatement) - assertEquals(2, tryStatement.tryBlock?.statements?.size) - assertFullName( - "_CxxThrowException", - tryStatement.tryBlock?.statements?.get(0) as? CallExpression - ) - assertLocalName( - "end", - (tryStatement.tryBlock?.statements?.get(1) as? GotoStatement)?.targetLabel - ) + val tryStatement = funcF.bodyOrNull(0)?.subStatement?.trys?.firstOrNull() + assertNotNull(tryStatement) + val tryBlock = tryStatement.tryBlock + assertNotNull(tryBlock) + assertEquals(2, tryBlock.statements.size) + assertIs(tryBlock.statements[0]) + assertFullName("_CxxThrowException", tryBlock.statements[0]) + val gotoStatement = tryBlock.statements[1] + assertIs(gotoStatement) + assertLocalName("end", gotoStatement.targetLabel) assertEquals(1, tryStatement.catchClauses.size) - val catchSwitchExpr = tryStatement.catchClauses[0].body?.statements?.get(0) + val catchBody = tryStatement.catchClauses[0].body + assertNotNull(catchBody) + val catchSwitchExpr = catchBody.statements[0] assertIs(catchSwitchExpr) - val catchswitchCall = - (catchSwitchExpr.singleDeclaration as? VariableDeclaration)?.initializer + val catchSwitchDeclaration = catchSwitchExpr.singleDeclaration + assertIs(catchSwitchDeclaration) + val catchswitchCall = catchSwitchDeclaration.initializer assertIs(catchswitchCall) assertFullName("llvm.catchswitch", catchswitchCall) val ifExceptionMatches = tryStatement.catchClauses[0].body?.statements?.get(1) @@ -944,47 +995,47 @@ class LLVMIRLanguageFrontendTest { val matchesExceptionCall = ifExceptionMatches.condition assertIs(matchesExceptionCall) assertFullName("llvm.matchesCatchpad", matchesExceptionCall) - assertEquals( - catchSwitchExpr.singleDeclaration, - (matchesExceptionCall.arguments[0] as Reference).refersTo - ) - assertEquals(null, (matchesExceptionCall.arguments[1] as Literal<*>).value) - assertEquals(64L, (matchesExceptionCall.arguments[2] as Literal<*>).value as Long) - assertEquals(null, (matchesExceptionCall.arguments[3] as Literal<*>).value) + assertRefersTo(matchesExceptionCall.arguments[0], catchSwitchDeclaration) + assertLiteralValue(null, matchesExceptionCall.arguments[1]) + assertLiteralValue(64L, matchesExceptionCall.arguments[2]) + assertLiteralValue(null, matchesExceptionCall.arguments[3]) val catchBlock = ifExceptionMatches.thenStatement assertIs(catchBlock) - assertFullName( - "llvm.catchpad", - ((catchBlock.statements[0] as? DeclarationStatement)?.singleDeclaration - as? VariableDeclaration) - ?.initializer as? CallExpression - ) + val catchpadDeclarationStatement = catchBlock.statements[0] + assertIs(catchpadDeclarationStatement) + val catchpadDeclaration = catchpadDeclarationStatement.singleDeclaration + assertIs(catchpadDeclaration) + assertIs(catchpadDeclaration.initializer) + assertFullName("llvm.catchpad", catchpadDeclaration.initializer) val innerTry = catchBlock.statements[1] assertIs(innerTry) - assertFullName( - "_CxxThrowException", - innerTry.tryBlock?.statements?.get(0) as? CallExpression - ) - assertLocalName( - "try.cont", - (innerTry.tryBlock?.statements?.get(1) as? GotoStatement)?.targetLabel - ) - - val innerCatchClause = - (innerTry.catchClauses[0].body?.statements?.get(1) as? IfStatement)?.thenStatement + val innerTryBlock = innerTry.tryBlock + assertNotNull(innerTryBlock) + assertIs(innerTryBlock.statements[0]) + assertFullName("_CxxThrowException", innerTryBlock.statements[0]) + val innerTryGoto = innerTryBlock.statements[1] + assertIs(innerTryGoto) + assertLocalName("try.cont", innerTryGoto.targetLabel) + + val innerCatchBody = innerTry.catchClauses[0].body + assertNotNull(innerCatchBody) + val innerCatchIf = innerCatchBody.statements[1] + assertIs(innerCatchIf) + val innerCatchClause = innerCatchIf.thenStatement assertIs(innerCatchClause) - assertFullName( - "llvm.catchpad", - ((innerCatchClause.statements[0] as? DeclarationStatement)?.singleDeclaration - as? VariableDeclaration) - ?.initializer - ) - assertLocalName("try.cont", (innerCatchClause.statements[1] as? GotoStatement)?.targetLabel) - - val innerCatchThrows = - (innerTry.catchClauses[0].body?.statements?.get(1) as? IfStatement)?.elseStatement + val innerCatchpadDeclarationStatement = innerCatchClause.statements[0] + assertIs(innerCatchpadDeclarationStatement) + val innerCatchDeclaration = innerCatchpadDeclarationStatement.singleDeclaration + assertIs(innerCatchDeclaration) + assertFullName("llvm.catchpad", innerCatchDeclaration.initializer) + + val innerCatchGoto = innerCatchClause.statements[1] + assertIs(innerCatchGoto) + assertLocalName("try.cont", innerCatchGoto.targetLabel) + + val innerCatchThrows = innerCatchIf.elseStatement assertIs(innerCatchThrows) assertNotNull(innerCatchThrows.exception) assertRefersTo(innerCatchThrows.exception, innerTry.catchClauses[0].parameter)