diff --git a/src/main/resources/runtime.c0 b/src/main/resources/runtime.c0 index bcd6738b..c40ea5c3 100644 --- a/src/main/resources/runtime.c0 +++ b/src/main/resources/runtime.c0 @@ -1,188 +1,367 @@ #use #use -int GROW_CAPACITY (int oldCapacity){ - return ((oldCapacity) < 128 ? 128 : (oldCapacity) + 128); -} - -int hash(int index, int arrayLength) +int runtime_hash(int index, int arrayLength) //@requires index >= 0; //@ensures \result >= 0 && \result < arrayLength; { - index = ((index >> 16) ^ index) * 0x45d9f3b; - index = ((index >> 16) ^ index) * 0x45d9f3b; - index = (index >> 16) ^ index; - return index % arrayLength; + index = ((index >> 16) ^ index) * 0x45d9f3b; + index = ((index >> 16) ^ index) * 0x45d9f3b; + index = (index >> 16) ^ index; + return index % arrayLength; } -OwnedFields* initOwnedFields( int * instanceCounter){ +OwnedFields* runtime_init(){ OwnedFields* fields = alloc(OwnedFields); - fields->instanceCounter = instanceCounter; - int oldCapacity = 0; - fields->capacity = GROW_CAPACITY(oldCapacity); + fields->capacity = 128; fields->contents = alloc_array(FieldArray*, fields->capacity); - - for(int i = 0; i < fields->capacity; i += 1){ - fields->contents[i] = NULL; - } return fields; } -void grow(OwnedFields* fields) +int runtime_compact(FieldArray*[] contents, int capacity, int newId) { + // Find first NULL + int start = 0; + while (contents[start] != NULL) + start++; -{ - int oldCapacity = fields->capacity; - fields->capacity = GROW_CAPACITY(oldCapacity); - FieldArray*[] newContents = alloc_array(FieldArray*, fields->capacity); - for(int i = 0; icontents[i] != NULL && !fields->contents[i]->deleted){ - int _id = fields->contents[i]->_id; - int newIndex = hash(_id, fields->capacity); - while(newContents[newIndex] != NULL){ - newIndex = (newIndex + 1) % fields->capacity; + int removed = 0; + int insert = start; + int i = start + 1; + while (i != start) { + FieldArray* entry = contents[i]; + if (entry == NULL) { + // skip + } else if (entry->accessible == 0 && entry->id != newId) { + removed++; + } else { + if (insert != i) { + int psl = entry->psl; + // Calculate the distance between `insert` and `i` + // Allow `i` to wrap around before `insert` does + int distance = insert <= i ? i - insert : capacity + i - insert; + if (psl < distance) { + // Only move up to `psl` spots backwards, otherwise it would be prior + // to the hash location + int newInsert = (insert + distance - psl) % capacity; + while (insert != newInsert) { + contents[insert] = NULL; + insert = (insert + 1) % capacity; + } + + distance = psl; } - newContents[newIndex] = fields->contents[i]; + + if (distance != 0) { + contents[insert] = entry; + entry->psl -= distance; + } + + } + + insert = (insert + 1) % capacity; } + + i = (i + 1) % capacity; } - fields->contents = newContents; + + // `i == start` + // NULL all locations from `insert` up to `start` (which is already NULL) + while (insert != i) { + contents[insert] = NULL; + insert = (insert + 1) % capacity; + } + + return removed; } -FieldArray* find(OwnedFields* fields, int _id){ - if(_id >= 0){ - int index = hash(_id, fields->capacity); - while(fields->contents[index] != NULL){ - if(!fields->contents[index]->deleted && fields->contents[index]->_id == _id){ - return fields->contents[index]; - }else{ - index = (index + 1) % fields->capacity; +// Called whenever a new element is added. Increments the `length` and resizes +// the backing array whenever necessary. +// `newIndex` specifies an index that will not be dropped, even if no fields in +// it are accessible (i.e., if it has just been added to the array). +void runtime_grow(OwnedFields* fields, int newId) +{ + int newLength = fields->length + 1; + fields->length = newLength; + + // Keep the hash table's "load factor" under 80% + // I.e., length is always < 80% of capacity + int capacity = fields->capacity; + if (newLength * 5 >= capacity * 4) { + int GROW_CONST = 128; + FieldArray*[] contents = fields->contents; + + // Check if we can gain enough space just by dropping unused entries + int unused = runtime_compact(contents, capacity, newId); + if (unused != 0) { + fields->length -= unused; + if (unused >= GROW_CONST) + return; + } + + // Otherwise, use a larger array and rehash + int newCapacity = capacity + GROW_CONST; + fields->capacity = newCapacity; + + FieldArray*[] newContents = alloc_array(FieldArray*, newCapacity); + fields->contents = newContents; + + // Add all items into the new array + for (int i = 0; i < capacity; i++) { + FieldArray* entry = contents[i]; + if (entry != NULL) { + int k = runtime_hash(entry->id, newCapacity); + int psl = 0; + while (newContents[k] != NULL) { + FieldArray* existing = newContents[k]; + if (psl > existing->psl) { + entry->psl = psl; + psl = existing->psl; + + newContents[k] = entry; + entry = existing; + } + + k = (k + 1) % newCapacity; + psl++; } + newContents[k] = entry; + entry->psl = psl; } } + } +} + +FieldArray* runtime_find(OwnedFields* fields, int id) { + if (id >= 0) { + FieldArray*[] contents = fields->contents; + int capacity = fields->capacity; + int index = runtime_hash(id, capacity); + int psl = 0; + while (true) { + FieldArray* entry = contents[index]; + if (entry == NULL || entry->psl < psl) + return NULL; + else if (entry->id == id) + return entry; + else + index = (index + 1) % capacity; + psl++; + } + } + return NULL; } -FieldArray * newFieldArray(OwnedFields * fields, int _id, int numFields, bool accAll) - //@requires fields != NULL && _id >= 0 && numFields > 0; - //@ensures fields != NULL && fields->capacity > fields->length; -{ - if(fields->length > (fields->capacity * 4 / 5)) grow(fields); +// Add a new field array with the specified reference ID +// NOTE: This assumes that no field with the same reference ID have been added +FieldArray* runtime_newFieldArray(OwnedFields* fields, int id, int numFields) { + FieldArray*[] contents = fields->contents; + int capacity = fields->capacity; + int index = runtime_hash(id, capacity); + int psl = 0; + FieldArray* current = NULL; + FieldArray* newEntry = NULL; - int fieldIndex = hash(_id, fields->capacity); - while(fields->contents[fieldIndex] != NULL && !fields->contents[fieldIndex]->deleted) fieldIndex = (fieldIndex + 1) % fields->capacity; + while (contents[index] != NULL) { + FieldArray* entry = contents[index]; + if (entry->accessible == 0) { + if (current == NULL) { + entry->id = id; + if (entry->length < numFields) { + entry->contents = alloc_array(bool, numFields); + entry->length = numFields; + // Because we do not shrink `length` when we re-use the `contents` + // array, we can end up with entries that have a larger `length` than + // the number of fields contained at the heap location. But, assuming + // sound type-checking, this is not problematic. For performance we + // prioritize memory reuse above memory usage (on the assumption that + // GC/allocation is slow). + } - FieldArray * array = alloc(FieldArray); - fields->contents[fieldIndex] = array; - fields->length += 1; + current = entry; + newEntry = entry; + } - array->contents = alloc_array(bool, numFields); - array->length = numFields; - array->_id = _id; - array->deleted = false; + entry = contents[(index + 1) % capacity]; + while (entry != NULL && psl < entry->psl) { + entry->psl--; + contents[index] = entry; + index = (index + 1) % capacity; + entry = contents[(index + 1) % capacity]; + psl++; + } - for(int i = 0; ilength; i += 1){ - array->contents[i] = accAll; + current->psl = psl; + contents[index] = current; + return newEntry; + } + + // Robin-hood hashing: For any i, the values that hash to bucket i precede + // the values that hash to bucket i+1 + // https://www.cs.cornell.edu/courses/JavaAndDS/files/hashing_RobinHood.pdf + if (psl > entry->psl) { + // If the current PSL is longer than this PSL, we replace this one with + // the new item, and then "re-insert" the item we just replaced + + if (current == NULL) { + newEntry = alloc(FieldArray); + newEntry->id = id; + newEntry->length = numFields; + newEntry->contents = alloc_array(bool, numFields); + current = newEntry; + } + + // Swap psl and contents[index]->psl + current->psl = psl; + psl = entry->psl; + + // Swap current and contents[index] + contents[index] = current; + current = entry; + + // Continue inserting `entry` + } + + index = (index + 1) % capacity; + psl++; } - if(accAll) { - array->numAccessible = array->length; - } else { - array->numAccessible = 0; + + if (current == NULL) { + newEntry = alloc(FieldArray); + newEntry->id = id; + newEntry->length = numFields; + newEntry->contents = alloc_array(bool, numFields); + current = newEntry; } - return fields->contents[fieldIndex]; + + contents[index] = current; + current->psl = psl; + runtime_grow(fields, id); + return newEntry; } -int addStructAcc(OwnedFields * fields, int numFields){ - newFieldArray(fields, *fields->instanceCounter, numFields, true); - *(fields->instanceCounter) += 1; - return *(fields->instanceCounter) - 1; +// Adds a new FieldArray and marks all fields as accessible +// NOTE: This assumes that no field with the same reference ID have been added +void runtime_addAll(OwnedFields* fields, int id, int numFields) { + if (fields == NULL) return; + + // Reuse the entry if one already exists with that ID + // Cannot let a duplicate (unused) entry exist with the same ID + FieldArray* entry = runtime_find(fields, id); + if (entry == NULL) + entry = runtime_newFieldArray(fields, id, numFields); + bool[] entryContents = entry->contents; + for (int i = 0; i < numFields; i++){ + entryContents[i] = true; + } + entry->accessible = numFields; } -void addAcc(OwnedFields * fields, int _id, int numFields, int fieldIndex){ - FieldArray * array = find(fields, _id); - if(array != NULL){ - if(!array->contents[fieldIndex]){ - array->numAccessible += 1; - array->contents[fieldIndex] = true; - } - }else{ - array = newFieldArray(fields, _id, numFields, false); - array->contents[fieldIndex] = true; - array->numAccessible += 1; +// Adds the specified permission (id, fieldIndex) to the set of permissions. +// Returns false if permission already exists, otherwise returns true. +bool runtime_tryAdd(OwnedFields* fields, int id, int numFields, int fieldIndex) { + int capacity = fields->capacity; + FieldArray*[] contents = fields->contents; + + FieldArray* entry = runtime_find(fields, id); + if (entry == NULL) { + entry = runtime_newFieldArray(fields, id, numFields); + } else if (entry->contents[fieldIndex]) { + return false; } + + entry->contents[fieldIndex] = true; + entry->accessible++; + + return true; } -void assertAcc(OwnedFields* fields, int _id, int fieldIndex, string errorMessage){ - FieldArray* toCheck = find(fields, _id); - if(toCheck == NULL || !toCheck->contents[fieldIndex]){ - error(errorMessage); - } +// Adds the specified permission (id, fieldIndex) to the set of permissions. +// Throws `errorMessage` if the permission already exists. +void runtime_add(OwnedFields* fields, int id, int numFields, int fieldIndex, string errorMessage) { + if (!runtime_tryAdd(fields, id, numFields, fieldIndex)) + error(errorMessage); } -void addAccEnsureSeparate(OwnedFields* fields, int _id, int fieldIndex, int numFields, string errorMessage){ - FieldArray* toCheck = find(fields, _id); - if (toCheck == NULL) { - toCheck = newFieldArray(fields, _id, numFields, false); - } else if (toCheck->contents[fieldIndex]) { - error(errorMessage); - } - toCheck->contents[fieldIndex] = true; - toCheck->numAccessible += 1; +// Assert that the permission (id, fieldIndex) is already contained, otherwise +// throws `errorMessage` +void runtime_assert(OwnedFields* fields, int id, int fieldIndex, string errorMessage) { + FieldArray* entry = runtime_find(fields, id); + if (entry == NULL || !entry->contents[fieldIndex]){ + println(errorMessage); + assert(false); + } } -void loseAcc(OwnedFields* fields, int _id, int fieldIndex){ - FieldArray * toLose = find(fields, _id); - if(toLose != NULL){ - if(fieldIndex >= toLose->length){ - error("[INTERNAL] Field index exceeds maximum for the given struct.\n"); - }else if(toLose->contents[fieldIndex]){ - toLose->contents[fieldIndex] = false; - toLose->numAccessible -= 1; - } - if(toLose->numAccessible == 0) { - toLose->deleted = true; - fields->length -= 1; - } - } +bool runtime_tryRemove(OwnedFields* fields, int id, int fieldIndex) { + FieldArray* entry = runtime_find(fields, id); + if (entry == NULL || !entry->contents[fieldIndex]) { + return false; + } + + entry->contents[fieldIndex] = false; + entry->accessible--; + return true; } -void join(OwnedFields* target, OwnedFields* source){ - if(source != NULL && target != NULL){ - for(int i = 0; icapacity; i += 1){ - FieldArray* currFields = source->contents[i]; - if(currFields != NULL && currFields->numAccessible > 0 ){ - for(int j = 0; j< currFields->length; j += 1){ - addAcc(target, currFields->_id, currFields->length, j); - } - } +// Removes the permission (id, fieldIndex) from the set of permissions. Throws +// `errorMessage` if the permission does not exist. +void runtime_remove(OwnedFields* fields, int id, int fieldIndex, string errorMessage) { + if (!runtime_tryRemove(fields, id, fieldIndex)) + error(errorMessage); +} + +void runtime_join(OwnedFields* target, OwnedFields* source) { + if (target == NULL || source == NULL) + return; + + int sourceCapacity = source->capacity; + FieldArray*[] sourceContents = source->contents; + for (int i = 0; i < sourceCapacity; i++) { + FieldArray* sourceEntry = sourceContents[i]; + if (sourceEntry != NULL && sourceEntry->accessible != 0) { + // Add all fields from `sourceEntry` to an existing or new entry in target + int id = sourceEntry->id; + FieldArray* targetEntry = runtime_find(target, id); + if (targetEntry == NULL) + targetEntry = runtime_newFieldArray(target, id, sourceEntry->length); + + // Assume that they both have the same length + int length = sourceEntry->length; + bool[] sc = sourceEntry->contents; + bool[] tc = targetEntry->contents; + for (int j = 0; j < length; j++) { + if (sc[j]) { + assert(!tc[j]); + tc[j] = true; } + } + + targetEntry->accessible += sourceEntry->accessible; } + } } -void printof(OwnedFields* fields) { - if (fields != NULL) { - print("Capacity: "); - printint(fields->capacity); - print(" "); - print("Length: "); - printint(fields->length); - print(" "); - print("OwnedFields: [ "); - for (int i = 0; i < fields->capacity; i += 1) { - FieldArray* currFields = fields->contents[i]; - if (currFields != NULL) { - print("("); - printint(currFields->_id); - print(","); - for (int j = 0; j < currFields->length; j +=1) { - print(" "); - printbool(currFields->contents[j]); - } - print(", i: "); - printint(i); - print(") "); - } +void runtime_print(OwnedFields* fields) { + if (fields != NULL) { + print("{ "); + + bool first = true; + for (int i = 0; i < fields->capacity; i++) { + FieldArray* entry = fields->contents[i]; + if (entry != NULL) { + if (first) first = false; + else print(", "); + + printf("%d@%d#%d+%d:", entry->id, i, runtime_hash(entry->id, fields->capacity), entry->psl); + for (int j = 0; j < entry->length; j++) { + print(entry->contents[j] ? "X" : "_"); } - print("]\n"); - } else { - println("OwnedFields is empty/NULL"); + } } + + printf(" } (%d of %d)\n", fields->length, fields->capacity); + } else { + println("OwnedFields is empty/NULL"); + } } \ No newline at end of file diff --git a/src/main/resources/runtime.h0 b/src/main/resources/runtime.h0 index 04d98fc5..00d2ab58 100644 --- a/src/main/resources/runtime.h0 +++ b/src/main/resources/runtime.h0 @@ -2,28 +2,23 @@ struct OwnedFields { struct FieldArray*[] contents; int capacity; int length; - int* instanceCounter; - }; -typedef struct OwnedFields OwnedFields; - struct FieldArray { - bool[] contents; + int psl; + int id; int length; - int _id; - int numAccessible; - bool deleted; + int accessible; + bool[] contents; }; +typedef struct OwnedFields OwnedFields; typedef struct FieldArray FieldArray; -OwnedFields* initOwnedFields(int * instanceCounter); -int addStructAcc(OwnedFields * fields, int numFields); -void addAcc(OwnedFields * fields, int _id, int numFields, int fieldIndex); -void loseAcc(OwnedFields * fields, int _id, int fieldIndex); -void join(OwnedFields* target, OwnedFields* source); -void assertAcc(OwnedFields* fields, int _id, int fieldIndex, string errorMessage); -void addAccEnsureSeparate(OwnedFields* fields, int _id, int fieldIndex, int numFields, string errorMessage); -FieldArray* find(OwnedFields* fields, int _id); -void printof(OwnedFields* fields); \ No newline at end of file +OwnedFields* runtime_init(); +void runtime_addAll(OwnedFields* fields, int id, int numFields); +void runtime_add(OwnedFields* fields, int id, int numFields, int fieldIndex, string errorMessage); +void runtime_remove(OwnedFields* fields, int id, int fieldIndex, string errorMessage); +void runtime_join(OwnedFields* target, OwnedFields* source); +void runtime_assert(OwnedFields* fields, int id, int fieldIndex, string errorMessage); +void runtime_print(OwnedFields* fields); \ No newline at end of file diff --git a/src/main/scala/gvc/analyzer/Resolver.scala b/src/main/scala/gvc/analyzer/Resolver.scala index 7785dfac..f5d651e2 100644 --- a/src/main/scala/gvc/analyzer/Resolver.scala +++ b/src/main/scala/gvc/analyzer/Resolver.scala @@ -300,7 +300,9 @@ object Resolver { ) ) ), - specifications = f.specifications + // Specifications have already been resolved prior to resolving this + // statement + specifications = List.empty ) } diff --git a/src/main/scala/gvc/benchmarking/BaselineChecker.scala b/src/main/scala/gvc/benchmarking/BaselineChecker.scala deleted file mode 100644 index 183c2f13..00000000 --- a/src/main/scala/gvc/benchmarking/BaselineChecker.scala +++ /dev/null @@ -1,492 +0,0 @@ -package gvc.benchmarking - -import gvc.transformer.IR -import gvc.weaver.Collector.getCallstyle -import gvc.weaver._ - -object BaselineChecker { - - def check(program: IR.Program, onlyFraming: Boolean = false): Unit = { - val structIds = - program.structs.map(s => (s.name, s.addField("_id", IR.IntType))).toMap - val runtime = CheckRuntime.addToIR(program) - val checks = new CheckImplementation(program, runtime, structIds) - program.methods.foreach(checkMethod(_, checks, onlyFraming)) - } - - def checkFraming(program: IR.Program): Unit = - check(program, onlyFraming = true) - - private def checkMethod( - method: IR.Method, - checks: CheckImplementation, - onlyFraming: Boolean - ): Unit = { - val globalPerms = method.name match { - case "main" => - method.addVar( - checks.runtime.ownedFieldsRef, - CheckRuntime.Names.primaryOwnedFields - ) - case _ => - method.addParameter( - checks.runtime.ownedFieldsRef, - CheckRuntime.Names.primaryOwnedFields - ) - } - - val tempPerms = method.addVar( - checks.runtime.ownedFieldsRef, - CheckRuntime.Names.temporaryOwnedFields - ) - val callstyle = getCallstyle(method) - - callstyle match { - - case Collector.PreciseCallStyle | Collector.PrecisePreCallStyle => - val contextPerms = method.addVar(checks.runtime.ownedFieldsRef, - CheckRuntime.Names.contextOwnedFields) - - checkBlock(method.body, - checks, - contextPerms, - tempPerms, - globalPerms, - onlyFraming) - - // CheckAddRemove mode - // Check in global perms, add to context perms, remove from global perms. - val mode = if (onlyFraming) AddRemoveMode else CheckAddRemoveMode - method.precondition.toSeq - .flatMap( - checks - .translate(mode, _, contextPerms, Some(globalPerms), ValueContext) - ) - .toList ++=: method.body - - if (Collector.hasImplicitReturn(method)) { - - if (callstyle == Collector.PrecisePreCallStyle) { - if (!onlyFraming) { - method.body ++= method.postcondition.toSeq.flatMap( - validateSpec(_, contextPerms, tempPerms, checks) - ) - } - method.body ++= Seq( - new IR.Invoke( - checks.runtime.join, - List(globalPerms, contextPerms), - None - )) - - } else { - val mode = if (onlyFraming) AddMode else CheckAddMode - method.body ++= method.postcondition.toSeq - .flatMap( - checks.translate(mode, - _, - globalPerms, - Some(contextPerms), - ValueContext) - ) - .toList - } - } - Seq( - new IR.Invoke( - checks.runtime.initOwnedFields, - List( - new IR.FieldMember( - globalPerms, - checks.runtime.ownedFieldInstanceCounter - ) - ), - Some(contextPerms) - ) - ) ++=: method.body - case Collector.ImpreciseCallStyle | Collector.MainCallStyle => - checkBlock(method.body, - checks, - globalPerms, - tempPerms, - globalPerms, - onlyFraming) - - if (!onlyFraming) { - method.precondition.toSeq.flatMap( - validateSpec(_, globalPerms, tempPerms, checks) - ) ++=: method.body - } - - if (!onlyFraming && Collector.hasImplicitReturn(method)) { - method.body ++= method.postcondition.toSeq.flatMap( - validateSpec(_, globalPerms, tempPerms, checks) - ) - } - - if (callstyle == Collector.MainCallStyle) { - val instanceCounter = method.addVar( - new IR.PointerType(IR.IntType), - CheckRuntime.Names.instanceCounter - ) - - Seq( - new IR.AllocValue(IR.IntType, instanceCounter), - new IR.Invoke( - checks.runtime.initOwnedFields, - List(instanceCounter), - Some(globalPerms) - ) - ) ++=: method.body - } - } - } - - private def equivalentFields(x: IR.Member, y: IR.Member): Boolean = { - (x, y) match { - case (xf: IR.FieldMember, yf: IR.FieldMember) => - xf.field == yf.field && ((xf.root, yf.root) match { - case (xr: IR.Var, yr: IR.Var) => xr == yr - case (xr: IR.FieldMember, yr: IR.FieldMember) => - equivalentFields(xr, yr) - case _ => false - }) - case _ => false - } - } - - private def validateAccess( - expr: IR.Expression, - perms: IR.Var, - checks: CheckImplementation, - context: SpecificationContext = ValueContext, - inSpec: Boolean = false, - fieldAccs: List[IR.Member] = Nil - ): (Seq[IR.Op], List[IR.Member]) = expr match { - case acc: IR.Accessibility => - // Check framing - val (ops, fields) = validateAccess( - acc.member.root, - perms, - checks, - context, - inSpec, - fieldAccs - ) - (ops, acc.member :: fields) - - case cond: IR.Conditional => { - val (initial, fields) = - validateAccess(cond.condition, perms, checks, context, false, fieldAccs) - val (ifTrue, _) = - validateAccess(cond.ifTrue, perms, checks, context, inSpec, fieldAccs) - val (ifFalse, _) = - validateAccess(cond.ifFalse, perms, checks, context, inSpec, fieldAccs) - - if (ifTrue.isEmpty && ifFalse.isEmpty) { - (initial, fields) - } else if (ifTrue.isEmpty) { - val iff = new IR.If(new IR.Unary(IR.UnaryOp.Not, cond.condition)) - iff.ifTrue ++= ifFalse - (initial :+ iff, fields) - } else if (ifFalse.isEmpty) { - val iff = new IR.If(cond.condition) - iff.ifTrue ++= ifTrue - (initial :+ iff, fields) - } else { - val iff = new IR.If(cond.condition) - iff.ifTrue ++= ifTrue - iff.ifFalse ++= ifFalse - (initial :+ iff, fields) - } - } - - case b: IR.Binary => { - val subSpec = inSpec && b.operator == IR.BinaryOp.And - val (left, leftFields) = - validateAccess(b.left, perms, checks, context, subSpec, fieldAccs) - val (right, rightFields) = - validateAccess(b.right, perms, checks, context, subSpec, leftFields) - - if (right.isEmpty) { - (left, leftFields) - } else { - b.operator match { - // If we are in the top-level of a specification, the conditions must all - // be satisfied anyway, and we cannot switch based on the condition value - // (e.g. we cannot check if an acc() is true). - - // But, if we are not in a spec, the short-circuiting behavior of AND - // must be followed - case IR.BinaryOp.And if !inSpec => - val iff = new IR.If(b.left) - iff.ifTrue ++= right - (left :+ iff, leftFields) - - case IR.BinaryOp.Or => - val iff = new IR.If(new IR.Unary(IR.UnaryOp.Not, b.left)) - iff.ifTrue ++= right - (left :+ iff, leftFields) - - case _ => - (left ++ right, rightFields) - } - } - } - - case u: IR.Unary => - validateAccess(u.operand, perms, checks, context, false, fieldAccs) - case imp: IR.Imprecise => - imp.precise match { - case None => (Seq.empty, fieldAccs) - case Some(precise) => - validateAccess(precise, perms, checks, context, inSpec, fieldAccs) - } - case _: IR.Literal | _: IR.Result | _: IR.Var => - (Seq.empty, fieldAccs) - - case field: IR.FieldMember => - val (rootOps, fields) = - validateAccess(field.root, perms, checks, context, inSpec, fieldAccs) - if (fields.exists(equivalentFields(_, field))) { - (rootOps, fields) - } else { - val acc = - checks.translateFieldPermission(VerifyMode, - field, - perms, - None, - context) - (rootOps ++ acc, field :: fields) - } - - case pred: IR.PredicateInstance => - var fields = fieldAccs - val arguments = pred.arguments.flatMap(arg => { - val (argOps, argFields) = - validateAccess(arg, perms, checks, context, false, fields) - fields = argFields - argOps - }) - (arguments, fields) - - case _: IR.ArrayMember | _: IR.DereferenceMember => - throw new WeaverException("Invalid member") - } - - private def validateSpec( - expr: IR.Expression, - primaryPerms: IR.Var, - tempPerms: IR.Var, - checks: CheckImplementation, - context: SpecificationContext = ValueContext - ): Seq[IR.Op] = { - val (access, _) = - validateAccess(expr, primaryPerms, checks, context, true, Nil) - val verify = checks.translate(VerifyMode, expr, primaryPerms, None, context) - - if (verify.isEmpty) { - // If there are no checks in the specification, there will be no separation checks - access - } else { - val separation = - checks.translate(SeparationMode, expr, tempPerms, None, context) - if (separation.isEmpty) { - access ++ verify - } else { - Seq.concat( - access, - verify, - Seq( - new IR.Invoke( - checks.runtime.initOwnedFields, - List( - new IR.FieldMember( - primaryPerms, - checks.runtime.ownedFieldInstanceCounter - ) - ), - Some(tempPerms) - ) - ), - separation - ) - } - } - - } - - private def checkBlock( - block: IR.Block, - checks: CheckImplementation, - perms: IR.Var, - tempPerms: IR.Var, - globalPerms: IR.Var, - onlyFraming: Boolean - ): Unit = { - for (op <- block) op match { - case _: IR.AllocValue | _: IR.AllocArray => - throw new WeaverException("Unsupported alloc") - - case alloc: IR.AllocStruct => - checks.trackAllocation(alloc, perms) - - case assert: IR.Assert => - assert.kind match { - case IR.AssertKind.Imperative => - val (access, _) = validateAccess(assert.value, perms, checks) - assert.insertBefore(access) - case IR.AssertKind.Specification if !onlyFraming => - assert.insertAfter( - validateSpec(assert.value, perms, tempPerms, checks) - ) - case _ => - } - - case assign: IR.Assign => { - val (access, _) = validateAccess(assign.value, perms, checks) - assign.insertBefore(access) - } - - case assign: IR.AssignMember => - assign.member match { - case field: IR.FieldMember => - val (valueAccess, valueFields) = - validateAccess(assign.value, perms, checks) - val (rootAccess, rootFields) = validateAccess( - assign.member.root, - perms, - checks, - fieldAccs = valueFields - ) - assign.insertBefore( - valueAccess ++ - rootAccess ++ - checks.translateFieldPermission( - VerifyMode, - field, - perms, - None, - ValueContext - ) - ) - case _: IR.DereferenceMember | _: IR.ArrayMember => - throw new WeaverException("Invalid member") - } - - case err: IR.Error => { - val (access, _) = validateAccess(err.value, perms, checks) - err.insertBefore(access) - } - - case iff: IR.If => - val (condAccess, _) = validateAccess(iff.condition, perms, checks) - iff.insertBefore(condAccess) - - checkBlock(iff.ifTrue, - checks, - perms, - tempPerms, - globalPerms, - onlyFraming) - - checkBlock(iff.ifFalse, - checks, - perms, - tempPerms, - globalPerms, - onlyFraming) - - case ret: IR.Return => - val context = ret.value match { - case None => ValueContext - case Some(value) => new ReturnContext(value) - } - - val valueAccess = - ret.value.toSeq.flatMap(validateAccess(_, perms, checks) match { - case (ops, _) => ops - }) - - val validationOps = if (!onlyFraming) { - valueAccess ++ - block.method.postcondition.toSeq.flatMap( - validateSpec( - _, - perms, - tempPerms, - checks, - context = context - ) - ) - } else { - Seq.empty - } - - ret.insertBefore(validationOps ++ (getCallstyle(block.method) match { - case Collector.PreciseCallStyle => - block.method.postcondition.toSeq - .flatMap( - checks.translate(AddMode, _, globalPerms, None, context) - ) - .toList - case Collector.PrecisePreCallStyle => - Seq( - new IR.Invoke( - checks.runtime.join, - List(globalPerms, perms), - None - )) - case _ => Seq.empty - })) - case loop: IR.While => - val preAccessibility = validateAccess( - loop.condition, - perms, - checks - )._1 ++ (if (!onlyFraming) - validateSpec(loop.invariant, perms, tempPerms, checks) - else Seq.empty) - loop.insertBefore( - preAccessibility - ) - - checkBlock(loop.body, - checks, - perms, - tempPerms, - globalPerms, - onlyFraming) - - val postAccessibility = validateAccess( - loop.condition, - perms, - checks - )._1 ++ (if (!onlyFraming) - validateSpec(loop.invariant, perms, tempPerms, checks) - else Seq.empty) - loop.body ++= postAccessibility - - case invoke: IR.Invoke => - // Pre-conditions are handled inside callee - var fields: List[IR.Member] = Nil - val argAccess = invoke.arguments.flatMap(arg => { - val (argOps, argFields) = - validateAccess(arg, perms, checks, fieldAccs = fields) - fields = argFields - argOps - }) - val targetAccess = invoke.target.toSeq.flatMap(t => - validateAccess(t, perms, checks, fieldAccs = fields)._1) - invoke.insertBefore(argAccess ++ targetAccess) - invoke.callee match { - case method: IR.Method => - invoke.arguments = invoke.arguments :+ perms - case method: IR.DependencyMethod => - } - case fold: IR.Fold => - case unfold: IR.Unfold => - case _ => - } - } -} diff --git a/src/main/scala/gvc/benchmarking/BaselineChecks.scala b/src/main/scala/gvc/benchmarking/BaselineChecks.scala new file mode 100644 index 00000000..e823f95a --- /dev/null +++ b/src/main/scala/gvc/benchmarking/BaselineChecks.scala @@ -0,0 +1,425 @@ +package gvc.benchmarking + +import gvc.transformer.IR +import gvc.weaver._ +import gvc.weaver.CheckRuntime.Names +import gvc.transformer.IRPrinter + +object BaselineChecks { + def insert( + program: IR.Program, + checkFraming: Boolean = true, + checkSpecs: Boolean = true + ): Unit = { + new BaselineChecks(program, checkFraming, checkSpecs).insert() + } +} + +class BaselineChecks( + program: IR.Program, + checkFraming: Boolean = true, + checkSpecs: Boolean = true, +) { + val structIds = + program.structs.map(s => (s.name, s.addField("_id", IR.IntType))).toMap + val runtime = + CheckRuntime.addToIR(program) + val impl = new CheckImplementation(program, runtime, structIds) + val precision = new EquirecursivePrecision(program) + + def insert(): Unit = { + // Only inject instance counter into existing methods, not methods that are + // added to implement predicates + val programMethods = program.methods.toList + + program.methods.foreach(insertChecks) + + InstanceCounter.inject(programMethods, structIds) + } + + def insertChecks( + method: IR.Method + ): Unit = { + val (callerPerms, perms) = { + if (method.name == "main") { + (None, method.addVar(impl.permsType, Names.primaryOwnedFields)) + } else { + val permsArg = + method.addParameter(impl.permsType, Names.primaryOwnedFields) + if (precision.isPrecise(method.precondition)) + (Some(permsArg), method.addVar(impl.permsType, Names.temporaryOwnedFields)) + else + (None, permsArg) + } + } + + insertChecks(method.body, perms, callerPerms) + + val initOps = callerPerms match { + case Some(callerPerms) => + // Precise pre-condition + impl.init(perms) ++ + (method.precondition match { + case None => Seq.empty + case Some(pre) => + // No need to use another temp set for separation + checkSpecFraming(pre, callerPerms, ValueContext) ++ + impl.translate( + pre, + ValueContext, + (if (checkSpecs) List(AssertMode(callerPerms)) else Nil) ::: + List(RemoveMode(callerPerms), AddMode(perms, guarded=true)) + ) + }) + case None => + // Imprecise pre-condition + (if (method.name == "main") impl.init(perms) else Seq.empty) ++ + assertSpec(method.precondition, perms, method, ValueContext) + } + + initOps ++=: method.body + + method.body ++= (method.body.lastOption match { + case Some(_: IR.Return) => + Seq.empty // Handled when checking the `return` statement + case _ => + postcondition(ValueContext, method, perms, callerPerms) + }) + } + + def insertChecks( + block: IR.Block, + perms: IR.Var, + callerPerms: Option[IR.Var] + ): Unit = { + var element = block.headOption + while (element.isDefined) { + val e = element.get + // Get the next before the current element is removed, instructions are + // added after it, etc. + element = e.getNext + insertChecks(e, perms, callerPerms) + } + } + + def insertChecks( + op: IR.Op, + perms: IR.Var, + callerPerms: Option[IR.Var] + ): Unit = op match { + case alloc: IR.AllocStruct => { + alloc.insertBefore(checkFraming(alloc.target, perms)) + alloc.insertAfter(impl.trackAllocation(alloc, perms)) + } + + case assert: IR.Assert if assert.kind == IR.AssertKind.Specification => { + // We are assuming that expressions are self-framing + if (checkSpecs) { + val (baseMode, init) = + if (SeparationChecks.canOverlap(assert.value)) { + val tempPerms = assert.method.addVar(impl.permsType, Names.temporaryOwnedFields) + (AddMode(tempPerms, guarded=true) :: Nil, impl.init(tempPerms)) + } else { + (Nil, Seq.empty) + } + val checks = impl.translate( + assert.value, ValueContext, AssertMode(perms) :: baseMode + ) + + assert.insertAfter(init ++ checks) + } + assert.remove() + } + + case assert: IR.Assert => { + // Imperative assert + assert.insertBefore(checkFraming(assert.value, perms)) + } + + case assign: IR.Assign => { + assign.insertBefore( + checkFraming(assign.value, perms) ++ + checkFraming(assign.target, perms) + ) + } + + case assign: IR.AssignMember => { + assign.insertBefore( + checkFraming(assign.member, perms) ++ + checkFraming(assign.value, perms) + ) + } + + case error: IR.Error => { + error.insertBefore(checkFraming(error.value, perms)) + } + + case fold: IR.Fold => { + fold.remove() + } + + case iff: IR.If => { + iff.insertBefore(checkFraming(iff.condition, perms)) + insertChecks(iff.ifTrue, perms, callerPerms) + insertChecks(iff.ifFalse, perms, callerPerms) + } + + case invoke: IR.Invoke => { + invoke.insertBefore( + invoke.arguments.flatMap(checkFraming(_, perms)) ++ + invoke.target.map(checkFraming(_, perms)).getOrElse(Seq.empty) + ) + invoke.callee match { + case _: IR.Method => + invoke.arguments :+= perms + case _ => () + } + // Pre- and post-conditions handled inside the callee + } + + case ret: IR.Return if ret.value.isDefined => { + val value = ret.value.get + ret.insertBefore( + checkFraming(value, perms) ++ + postcondition(new ReturnContext(value), ret.method, perms, callerPerms) + ) + } + + case ret: IR.Return => + throw new WeaverException("Invalid return") + + case unfold: IR.Unfold => + unfold.remove() + + case loop: IR.While => { + if (precision.isPrecise(loop.invariant)) { + // Precise invariant + + val method = loop.method + val loopPerms = method.addVar(impl.permsType, Names.primaryOwnedFields) + insertChecks(loop.body, loopPerms, callerPerms) + + // Create a new set of permissions and transfer invariant perms into it, + // before the loop. This eliminates the need for another temp set for + // separation checking. + loop.insertBefore( + checkFraming(loop.condition, perms, ValueContext) ++ + impl.init(loopPerms) ++ + impl.translate( + loop.invariant, + ValueContext, + (if (checkSpecs) AssertMode(perms) :: Nil else Nil) ::: + RemoveMode(perms) :: AddMode(loopPerms) :: Nil + ) + ) + + // At the end of the loop, first check framing of loop condition + loop.body ++= checkFraming(loop.condition, loopPerms, ValueContext) + // Initialize a new temp set + val newPerms = method.addVar(impl.permsType, Names.temporaryOwnedFields) + loop.body ++= impl.init(newPerms) + // Add invariant perms from the existing set to the new one + // No need to remove since the existing set will be discarded + loop.body ++= checkSpecFraming(loop.invariant, loopPerms, ValueContext) + loop.body ++= impl.translate( + loop.invariant, + ValueContext, + (if (checkSpecs) AssertMode(loopPerms) :: Nil else Nil) ::: + AddMode(newPerms) :: Nil + ) + // Replace the existing set with the new one + loop.body += new IR.Assign(loopPerms, newPerms) + + // After the loop ends, add the invariant perms back to the main set + loop.insertAfter( + impl.translate(loop.invariant, ValueContext, List(AddMode(perms))) + ) + } else { + // Imprecise invariant -- inherit the current permission set + insertChecks(loop.body, perms, callerPerms) + + // Check loop invariant and framing of loop condition before the loop + // and at the end of the loop body + loop.insertBefore( + checkFraming(loop.condition, perms, ValueContext) ++ + assertSpec(loop.invariant, perms, loop.method, ValueContext) + ) + + loop.body ++= checkFraming(loop.condition, perms, ValueContext) + loop.body ++= assertSpec( + loop.invariant, perms, loop.method, ValueContext) + } + } + } + + def checkFraming( + expr: IR.Expression, + perms: IR.Expression, + context: SpecificationContext = ValueContext + ): Seq[IR.Op] = { + if (checkFraming) { + val check: IR.Expression => Seq[IR.Op] = checkFraming(_, perms, context) + expr match { + case bin: IR.Binary => bin.operator match { + case IR.BinaryOp.And => + // Allow short-circuiting framing -- `false && x->y` does _not_ + // check `acc(x->y)`). + check(bin.left) ++ makeIf(bin.left, check(bin.right)) + case IR.BinaryOp.Or => + check(bin.left) ++ + makeIf(new IR.Unary(IR.UnaryOp.Not, bin.left), check(bin.right)) + case _ => + check(bin.left) ++ check(bin.right) + } + case cond: IR.Conditional => + check(cond.condition) ++ + makeIf(cond.condition, check(cond.ifTrue), check(cond.ifFalse)) + case unary: IR.Unary => + check(unary.operand) + case _: IR.Var | _: IR.Literal | _: IR.Result => + Seq.empty + case field: IR.FieldMember => + impl.translateFieldPermission(field, List(AssertMode(perms)), context) + case expr => + throw new WeaverException( + "Unexpected expression '" + IRPrinter.print(expr) + "'") + } + } else { + Seq.empty + } + } + + def checkSpecFraming( + spec: IR.Expression, + perms: IR.Expression, + context: SpecificationContext + ): Seq[IR.Op] = spec match { + case imp: IR.Imprecise if imp.precise.isDefined => + checkSpecFramingInternal(imp.precise.get, perms, context) + case _ => Seq.empty + } + + def checkSpecFramingInternal( + spec: IR.Expression, + perms: IR.Expression, + context: SpecificationContext + ): Seq[IR.Op] = spec match { + case imp: IR.Imprecise => + throw new WeaverException("Unexpected imprecise modifier") + case acc: IR.Accessibility => + checkFraming(acc.member.root, perms, context) + case bin: IR.Binary if bin.operator == IR.BinaryOp.And => + checkSpecFramingInternal(bin.left, perms, context) ++ + checkSpecFramingInternal(bin.right, perms, context) + case cond: IR.Conditional => + checkFraming(cond.condition, perms, context) ++ + makeIf(cond.condition, + checkSpecFramingInternal(cond.ifTrue, perms, context), + checkSpecFramingInternal(cond.ifFalse, perms, context)) + case pred: IR.PredicateInstance => + pred.arguments.flatMap(checkFraming(_, perms, context)) + case expr => + checkFraming(expr, perms, context) + } + + def postcondition( + context: SpecificationContext, + method: IR.Method, + perms: IR.Expression, + callerPerms: Option[IR.Expression], + ): Seq[IR.Op] = callerPerms match { + case Some(callerPerms) => { + // `calleePerms` is defined whenever the pre-condition is precise + // Need to add the post-condition permissions back to `outerPerms` + if (precision.isPrecise(method.postcondition)) { + method.postcondition match { + // No temporary set of perms is needed for separation since we are + // adding them to the callee set + case Some(post) if checkSpecs => + checkSpecFraming(post, perms, context) ++ + impl.translate( + post, + context, + List(AssertMode(perms), AddMode(callerPerms, guarded=true)) + ) + case Some(post) => + checkSpecFraming(post, perms, context) ++ + impl.translate( + post, + context, + List(AddMode(callerPerms, guarded=false)) + ) + case None => Seq.empty + } + } else { + // Imprecise post-condition, so pass everything back + assertSpec(method.postcondition, perms, method, context) ++ + impl.join(callerPerms, perms) + } + } + case None => { + // `calleePerms` is not defined, so the pre-condition is imprecise. + // All permissions will be passed back since we are using the caller's + // permissions already. + assertSpec(method.postcondition, perms, method, context) + } + } + + def assertSpec( + spec: IR.Expression, + perms: IR.Expression, + method: IR.Method, + context: SpecificationContext = ValueContext, + ): Seq[IR.Op] = { + if (checkSpecs) { + // Check if we need to add checks (and a corresponding temporary set of + // permissions) for separation + val framing = checkSpecFraming(spec, perms, context) + val (mode, init) = + if (SeparationChecks.canOverlap(spec)) { + val tempPerms = method.addVar(impl.permsType, Names.temporaryOwnedFields) + (AddMode(tempPerms, guarded=true) :: Nil, impl.init(tempPerms)) + } else { + (Nil, Seq.empty) + } + val assert = impl.translate(spec, context, AssertMode(perms) :: mode) + framing ++ init ++ assert + } else { + Seq.empty + } + } + + def assertSpec( + spec: Option[IR.Expression], + perms: IR.Expression, + method: IR.Method, + context: SpecificationContext + ): Seq[IR.Op] = spec match { + case None => Seq.empty + case Some(spec) => assertSpec(spec, perms, method, context) + } + + def makeIf( + cond: IR.Expression, + ifTrue: Seq[IR.Op], + ifFalse: Seq[IR.Op] = Seq.empty + ): Seq[IR.Op] = (ifTrue.isEmpty, ifFalse.isEmpty) match { + case (true, true) => Seq.empty + case (false, true) => { + val iff = new IR.If(cond) + iff.ifTrue ++= ifTrue + Seq(iff) + } + case (true, false) => { + val iff = new IR.If(new IR.Unary(IR.UnaryOp.Not, cond)) + iff.ifTrue ++= ifFalse + Seq(iff) + } + case _ => { + val iff = new IR.If(cond) + iff.ifTrue ++= ifTrue + iff.ifFalse ++= ifFalse + Seq(iff) + } + } +} \ No newline at end of file diff --git a/src/main/scala/gvc/benchmarking/BenchmarkExecutor.scala b/src/main/scala/gvc/benchmarking/BenchmarkExecutor.scala index 9267014a..d03dbebe 100644 --- a/src/main/scala/gvc/benchmarking/BenchmarkExecutor.scala +++ b/src/main/scala/gvc/benchmarking/BenchmarkExecutor.scala @@ -249,7 +249,7 @@ object BenchmarkExecutor { ir: IR.Program, onlyFraming: Boolean ): Option[Path] = { - BaselineChecker.check(ir, onlyFraming) + BaselineChecks.insert(ir, true, !onlyFraming) val sourceText = IRPrinter.print(ir, includeSpecs = false) diff --git a/src/main/scala/gvc/main.scala b/src/main/scala/gvc/main.scala index ab586a1a..c86e3b4c 100644 --- a/src/main/scala/gvc/main.scala +++ b/src/main/scala/gvc/main.scala @@ -6,7 +6,7 @@ import gvc.analyzer._ import gvc.benchmarking.BenchmarkExecutor.injectAndWrite import gvc.transformer._ import gvc.benchmarking.{ - BaselineChecker, + BaselineChecks, BenchmarkExecutor, BenchmarkExporter, BenchmarkExternalConfig, @@ -92,7 +92,7 @@ object Main extends App { val inputSource = readFile(config.sourceFile.get) val onlyFraming = config.mode == Config.FramingVerification val ir = generateIR(inputSource, linkedLibraries) - BaselineChecker.check(ir, onlyFraming) + BaselineChecks.insert(ir, true, !onlyFraming) val outputC0Source = Paths.get(fileNames.c0FileName) val outputBinary = Paths.get(fileNames.binaryName) injectAndWrite( diff --git a/src/main/scala/gvc/transformer/IR.scala b/src/main/scala/gvc/transformer/IR.scala index e756e150..727b620d 100644 --- a/src/main/scala/gvc/transformer/IR.scala +++ b/src/main/scala/gvc/transformer/IR.scala @@ -64,6 +64,13 @@ object IR { def struct(name: String): StructDefinition = _structs.getOrElseUpdate(name, new Struct(name)) + def structDef(name: String): Struct = + _structs(name) match { + case s: Struct => s + case _: DependencyStruct => + throw new IRException("Cannot get definition for struct declared in library") + } + // Adds a new struct, renaming it if necessary to avoid collisions def newStruct(name: String): Struct = { val actualName = Helpers.findAvailableName(_structs, name) @@ -385,7 +392,7 @@ object IR { def method: Method = _method } - class ChildBlock(op: Op) extends Block { + class ChildBlock(val op: Op) extends Block { def method = op.block.method } @@ -394,6 +401,8 @@ object IR { def contains(exp: Expression): Boolean = exp == this + + override def toString() = IRPrinter.print(this) } class Parameter(varType: Type, name: String) @@ -444,6 +453,8 @@ object IR { ) extends SpecificationExpression { override def contains(exp: Expression) = super.contains(exp) || arguments.exists(_.contains(exp)) + override def toString() = + predicate.name + "(" + arguments.map(IRPrinter.print).mkString(", ") + ")" } // Represents a \result expression in a specification @@ -507,18 +518,18 @@ object IR { sealed trait BinaryOp object BinaryOp { - object Add extends BinaryOp - object Subtract extends BinaryOp - object Divide extends BinaryOp - object Multiply extends BinaryOp - object And extends BinaryOp - object Or extends BinaryOp - object Equal extends BinaryOp - object NotEqual extends BinaryOp - object Less extends BinaryOp - object LessOrEqual extends BinaryOp - object Greater extends BinaryOp - object GreaterOrEqual extends BinaryOp + object Add extends BinaryOp { override def toString() = "+" } + object Subtract extends BinaryOp { override def toString() = "-" } + object Divide extends BinaryOp { override def toString() = "/" } + object Multiply extends BinaryOp { override def toString() = "*" } + object And extends BinaryOp { override def toString() = "&&" } + object Or extends BinaryOp { override def toString() = "||" } + object Equal extends BinaryOp { override def toString() = "==" } + object NotEqual extends BinaryOp { override def toString() = "!=" } + object Less extends BinaryOp { override def toString() = "<" } + object LessOrEqual extends BinaryOp { override def toString() = "<=" } + object Greater extends BinaryOp { override def toString() = ">" } + object GreaterOrEqual extends BinaryOp { override def toString() = ">=" } } class Unary( @@ -535,8 +546,8 @@ object IR { sealed trait UnaryOp object UnaryOp { - object Not extends UnaryOp - object Negate extends UnaryOp + object Not extends UnaryOp { override def toString() = "!" } + object Negate extends UnaryOp { override def toString() = "-" } } sealed trait Type { @@ -618,6 +629,8 @@ object IR { // Creates a copy of the current Op // The new copy will not be attached to any Block def copy: IR.Op + + def summary: String } class Invoke( @@ -626,6 +639,10 @@ object IR { var target: Option[Expression] ) extends Op { def copy = new Invoke(callee, arguments, target) + def summary = ( + target.map(e => IRPrinter.print(e) + " = ").getOrElse("") + + callee.name + "(" + arguments.map(IRPrinter.print) + ")" + ) } class AllocValue( @@ -633,6 +650,7 @@ object IR { var target: Var ) extends Op { def copy = new AllocValue(valueType, target) + def summary = target.name + " = alloc(" + valueType.name + ")" } class AllocStruct( @@ -640,6 +658,8 @@ object IR { var target: Expression ) extends Op { def copy = new AllocStruct(struct, target) + def summary = + IRPrinter.print(target) + " = alloc(struct " + struct.name + ")" } // TODO: Length should be an expression @@ -649,6 +669,11 @@ object IR { var target: Var ) extends Op { def copy = new AllocArray(valueType, length, target) + def summary = ( + IRPrinter.print(target) + + "= alloc_array(" + valueType.name + + ", " + IRPrinter.print(length) + ")" + ) } class Assign( @@ -656,6 +681,7 @@ object IR { var value: Expression ) extends Op { def copy = new Assign(target, value) + def summary = target.name + " = " + IRPrinter.print(value) } class AssignMember( @@ -663,6 +689,7 @@ object IR { var value: Expression ) extends Op { def copy = new AssignMember(member, value) + def summary = IRPrinter.print(member) + " = " + IRPrinter.print(value) } class Assert( @@ -670,6 +697,10 @@ object IR { var kind: AssertKind ) extends Op { def copy = new Assert(value, kind) + def summary = (kind match { + case IR.AssertKind.Imperative => "assert " + case IR.AssertKind.Specification => "//@assert " + }) + IRPrinter.print(value) } sealed trait AssertKind @@ -682,22 +713,29 @@ object IR { var instance: PredicateInstance ) extends Op { def copy = new Fold(instance) + def summary = "//@fold " + instance.toString() } class Unfold( var instance: PredicateInstance ) extends Op { def copy = new Unfold(instance) + def summary = "//@unfold " + instance.toString() } class Error( var value: Expression ) extends Op { def copy = new Error(value) + def summary = "error(" + IRPrinter.print(value) + ")" } class Return(var value: Option[Expression]) extends Op { def copy = new Return(value) + def summary = value match { + case None => "return" + case Some(e) => "return " + IRPrinter.print(e) + } } class If( @@ -718,6 +756,8 @@ object IR { falseBranch.foreach(newIf.ifFalse += _.copy) newIf } + + def summary = "if (" + IRPrinter.print(condition) + ") ..." } class While( @@ -739,6 +779,8 @@ object IR { newBody.foreach(newWhile.body += _.copy) newWhile } + + def summary = "while (" + IRPrinter.print(condition) + ") ..." } class Dependency( diff --git a/src/main/scala/gvc/transformer/IRPrinter.scala b/src/main/scala/gvc/transformer/IRPrinter.scala index ac636778..6b485d13 100644 --- a/src/main/scala/gvc/transformer/IRPrinter.scala +++ b/src/main/scala/gvc/transformer/IRPrinter.scala @@ -13,18 +13,136 @@ object IRPrinter { val Top = 9 } - def print(program: IR.Program, includeSpecs: Boolean): String = { - val p = new Printer() + private def printExpr( + p: Printer, + expr: IR.Expression, + precedence: Int = Precedence.Top + ): Unit = expr match { + case v: IR.Var => p.print(v.name) + case m: IR.FieldMember => { + printExpr(p, m.root) + p.print("->") + p.print(m.field.name) + } + case deref: IR.DereferenceMember => + wrapExpr(p, precedence, Precedence.Unary) { + p.print("*") + printExpr(p, deref.root, Precedence.Unary) + } + case acc: IR.Accessibility => { + p.print("acc(") + printExpr(p, acc.member) + p.print(")") + } + case pred: IR.PredicateInstance => { + p.print(pred.predicate.name) + p.print("(") + printList(p, pred.arguments) { arg => printExpr(p, arg) } + p.print(")") + } + case arr: IR.ArrayMember => { + printExpr(p, arr.root) + p.print("[") + printExpr(p, arr.index) + p.print("]") + } + case res: IR.Result => p.print("\\result") + case imp: IR.Imprecise => + imp.precise match { + case None => p.print("?") + case Some(precise) => + wrapExpr(p, precedence, Precedence.And) { + p.print("? && ") + printExpr(p, precise, Precedence.And) + } + } + case int: IR.IntLit => p.print(int.value.toString()) + case str: IR.StringLit => + p.print("\"") + p.print(str.value) + p.print("\"") + case char: IR.CharLit => { + p.print("'") + p.print(char.value match { + case '\\' => "\\\\" + case '\n' => "\\n" + case '\r' => "\\r" + case '\t' => "\\t" + case '\u0000' => "\\0" + case other => other.toString() + }) + p.print("'") + } + case bool: IR.BoolLit => p.print(if (bool.value) "true" else "false") + case _: IR.NullLit => p.print("NULL") + + case cond: IR.Conditional => + wrapExpr(p, precedence, Precedence.Conditional) { + printExpr(p, cond.condition, Precedence.Conditional) + p.print(" ? ") + printExpr(p, cond.ifTrue, Precedence.Conditional) + p.print(" : ") + printExpr(p, cond.ifFalse, Precedence.Conditional) + } - def printList[T](values: Seq[T])(action: T => Unit): Unit = { - var first = true - for (value <- values) { - if (first) first = false - else p.print(", ") - action(value) + case binary: IR.Binary => { + val (sep, opPrecedence) = binary.operator match { + case IR.BinaryOp.Add => (" + ", Precedence.Add) + case IR.BinaryOp.Subtract => (" - ", Precedence.Add) + case IR.BinaryOp.Divide => (" / ", Precedence.Multiply) + case IR.BinaryOp.Multiply => (" * ", Precedence.Multiply) + case IR.BinaryOp.And => (" && ", Precedence.And) + case IR.BinaryOp.Or => (" || ", Precedence.Or) + case IR.BinaryOp.Equal => (" == ", Precedence.Equality) + case IR.BinaryOp.NotEqual => (" != ", Precedence.Equality) + case IR.BinaryOp.Less => (" < ", Precedence.Inequality) + case IR.BinaryOp.LessOrEqual => (" <= ", Precedence.Inequality) + case IR.BinaryOp.Greater => (" > ", Precedence.Inequality) + case IR.BinaryOp.GreaterOrEqual => (" >= ", Precedence.Inequality) + } + + wrapExpr(p, precedence, opPrecedence) { + printExpr(p, binary.left, opPrecedence) + p.print(sep) + printExpr(p, binary.right, opPrecedence) } } + case unary: IR.Unary => + wrapExpr(p, precedence, Precedence.Unary) { + p.print(unary.operator match { + case IR.UnaryOp.Not => "!" + case IR.UnaryOp.Negate => "-" + }) + printExpr(p, unary.operand, Precedence.Unary) + } + } + + def printList[T](p: Printer, values: Seq[T])(action: T => Unit): Unit = { + var first = true + for (value <- values) { + if (first) first = false + else p.print(", ") + action(value) + } + } + + def wrapExpr(p: Printer, currentPrecedence: Int, exprPrecedence: Int)( + action: => Unit + ): Unit = { + if (currentPrecedence < exprPrecedence) { + p.print("(") + action + p.print(")") + } else { + action + } + } + + + def print(program: IR.Program, includeSpecs: Boolean): String = { + val p = new Printer() + def printDependency(dependency: IR.Dependency): Unit = { p.print("#use ") if (dependency.isLibrary) { @@ -59,7 +177,7 @@ object IRPrinter { p.print("//@predicate ") p.print(predicate.name) p.print("(") - printList(predicate.parameters) { param => + printList(p, predicate.parameters) { param => printType(param.varType) p.print(" ") p.print(param.name) @@ -70,7 +188,7 @@ object IRPrinter { def printPredicate(predicate: IR.Predicate): Unit = { printPredicateHeader(predicate) p.print(" = ") - printExpr(predicate.expression) + printExpr(p, predicate.expression) p.println(";") } @@ -85,7 +203,7 @@ object IRPrinter { p.print("(") var first = true - printList(method.parameters) { param => + printList(p, method.parameters) { param => printType(param.varType) p.print(" ") p.print(param.name) @@ -102,7 +220,7 @@ object IRPrinter { method.precondition.foreach { pre => p.withIndent { p.print("//@requires ") - printExpr(pre) + printExpr(p, pre) p.println(";") } } @@ -110,7 +228,7 @@ object IRPrinter { method.postcondition.foreach { post => p.withIndent { p.print("//@ensures ") - printExpr(post) + printExpr(p, post) p.println(";") } } @@ -131,7 +249,7 @@ object IRPrinter { case _: IR.ArrayType | _: IR.ReferenceArrayType => () case varType => { p.print(" = ") - printExpr(varType.default) + printExpr(p, varType.default) } } p.println(";") @@ -146,70 +264,70 @@ object IRPrinter { def printOp(op: IR.Op): Unit = op match { case invoke: IR.Invoke => { invoke.target.foreach { target => - printExpr(target) + printExpr(p, target) p.print(" = ") } p.print(invoke.callee.name) p.print("(") - printList(invoke.arguments) { arg => - printExpr(arg) + printList(p, invoke.arguments) { arg => + printExpr(p, arg) } p.println(");") } case alloc: IR.AllocValue => { - printExpr(alloc.target) + printExpr(p, alloc.target) p.print(" = alloc(") printType(alloc.valueType) p.println(");") } case alloc: IR.AllocArray => { - printExpr(alloc.target) + printExpr(p, alloc.target) p.print(" = alloc_array(") printType(alloc.valueType) p.print(", ") - printExpr(alloc.length) + printExpr(p, alloc.length) p.println(");") } case alloc: IR.AllocStruct => { - printExpr(alloc.target) + printExpr(p, alloc.target) p.print(" = alloc(struct ") p.print(alloc.struct.name) p.println(");") } case assign: IR.Assign => { - printExpr(assign.target) + printExpr(p, assign.target) p.print(" = ") - printExpr(assign.value) + printExpr(p, assign.value) p.println(";") } case assign: IR.AssignMember => { assign.member match { case member: IR.FieldMember => { - printExpr(member.root) + printExpr(p, member.root) p.print("->") p.print(member.field.name) } case member: IR.DereferenceMember => { p.print("*") - printExpr(member.root, Precedence.Unary) + printExpr(p, member.root, Precedence.Unary) } case member: IR.ArrayMember => { - printExpr(member.root) + printExpr(p, member.root) p.print("[") - printExpr(member.index) + printExpr(p, member.index) p.print("]") } } p.print(" = ") - printExpr(assign.value) + printExpr(p, assign.value) p.println(";") } @@ -218,12 +336,12 @@ object IRPrinter { case IR.AssertKind.Specification => if (includeSpecs) { p.print("//@assert ") - printExpr(assert.value) + printExpr(p, assert.value) p.println(";") } case IR.AssertKind.Imperative => { p.print("assert(") - printExpr(assert.value) + printExpr(p, assert.value) p.println(");") } } @@ -231,20 +349,20 @@ object IRPrinter { case fold: IR.Fold => if (includeSpecs) { p.print("//@fold ") - printExpr(fold.instance) + printExpr(p, fold.instance) p.println(";") } case unfold: IR.Unfold => if (includeSpecs) { p.print("//@unfold ") - printExpr(unfold.instance) + printExpr(p, unfold.instance) p.println(";") } case error: IR.Error => { p.print("error(") - printExpr(error.value) + printExpr(p, error.value) p.println(");") } @@ -252,14 +370,14 @@ object IRPrinter { p.print("return") ret.value.foreach { value => p.print(" ") - printExpr(value) + printExpr(p, value) } p.println(";") } case iff: IR.If => { p.print("if (") - printExpr(iff.condition) + printExpr(p, iff.condition) p.println(")") printBlock(iff.ifTrue) @@ -271,12 +389,12 @@ object IRPrinter { case w: IR.While => { p.print("while (") - printExpr(w.condition) + printExpr(p, w.condition) p.println(")") if (includeSpecs) { p.withIndent { p.print("//@loop_invariant ") - printExpr(w.invariant) + printExpr(p, w.invariant) p.println(";") } } @@ -284,121 +402,6 @@ object IRPrinter { } } - def wrapExpr(currentPrecedence: Int, exprPrecedence: Int)( - action: => Unit - ): Unit = { - if (currentPrecedence < exprPrecedence) { - p.print("(") - action - p.print(")") - } else { - action - } - } - - def printExpr( - expr: IR.Expression, - precedence: Int = Precedence.Top - ): Unit = expr match { - case v: IR.Var => p.print(v.name) - case m: IR.FieldMember => { - printExpr(m.root) - p.print("->") - p.print(m.field.name) - } - case deref: IR.DereferenceMember => - wrapExpr(precedence, Precedence.Unary) { - p.print("*") - printExpr(deref.root, Precedence.Unary) - } - case acc: IR.Accessibility => { - p.print("acc(") - printExpr(acc.member) - p.print(")") - } - case pred: IR.PredicateInstance => { - p.print(pred.predicate.name) - p.print("(") - printList(pred.arguments) { arg => printExpr(arg) } - p.print(")") - } - case arr: IR.ArrayMember => { - printExpr(arr.root) - p.print("[") - printExpr(arr.index) - p.print("]") - } - case res: IR.Result => p.print("\\result") - case imp: IR.Imprecise => - imp.precise match { - case None => p.print("?") - case Some(precise) => - wrapExpr(precedence, Precedence.And) { - p.print("? && ") - printExpr(precise, Precedence.And) - } - } - case int: IR.IntLit => p.print(int.value.toString()) - case str: IR.StringLit => - p.print("\"") - p.print(str.value) - p.print("\"") - case char: IR.CharLit => { - p.print("'") - p.print(char.value match { - case '\\' => "\\\\" - case '\n' => "\\n" - case '\r' => "\\r" - case '\t' => "\\t" - case '\u0000' => "\\0" - case other => other.toString() - }) - p.print("'") - } - case bool: IR.BoolLit => p.print(if (bool.value) "true" else "false") - case _: IR.NullLit => p.print("NULL") - - case cond: IR.Conditional => - wrapExpr(precedence, Precedence.Conditional) { - printExpr(cond.condition, Precedence.Conditional) - p.print(" ? ") - printExpr(cond.ifTrue, Precedence.Conditional) - p.print(" : ") - printExpr(cond.ifFalse, Precedence.Conditional) - } - - case binary: IR.Binary => { - val (sep, opPrecedence) = binary.operator match { - case IR.BinaryOp.Add => (" + ", Precedence.Add) - case IR.BinaryOp.Subtract => (" - ", Precedence.Add) - case IR.BinaryOp.Divide => (" / ", Precedence.Multiply) - case IR.BinaryOp.Multiply => (" * ", Precedence.Multiply) - case IR.BinaryOp.And => (" && ", Precedence.And) - case IR.BinaryOp.Or => (" || ", Precedence.Or) - case IR.BinaryOp.Equal => (" == ", Precedence.Equality) - case IR.BinaryOp.NotEqual => (" != ", Precedence.Equality) - case IR.BinaryOp.Less => (" < ", Precedence.Inequality) - case IR.BinaryOp.LessOrEqual => (" <= ", Precedence.Inequality) - case IR.BinaryOp.Greater => (" > ", Precedence.Inequality) - case IR.BinaryOp.GreaterOrEqual => (" >= ", Precedence.Inequality) - } - - wrapExpr(precedence, opPrecedence) { - printExpr(binary.left, opPrecedence) - p.print(sep) - printExpr(binary.right, opPrecedence) - } - } - - case unary: IR.Unary => - wrapExpr(precedence, Precedence.Unary) { - p.print(unary.operator match { - case IR.UnaryOp.Not => "!" - case IR.UnaryOp.Negate => "-" - }) - printExpr(unary.operand, Precedence.Unary) - } - } var empty = true def printSeparator() = { @@ -461,6 +464,12 @@ object IRPrinter { p.toString() } + def print(expr: IR.Expression) = { + val p = new Printer() + printExpr(p, expr) + p.toString() + } + private class Printer { var indentLevel = 0 var startedLine = false diff --git a/src/main/scala/gvc/transformer/IRSilver.scala b/src/main/scala/gvc/transformer/IRSilver.scala index 20e2f436..a5f0fb3b 100644 --- a/src/main/scala/gvc/transformer/IRSilver.scala +++ b/src/main/scala/gvc/transformer/IRSilver.scala @@ -9,7 +9,6 @@ object IRSilver { object Names { val ReturnVar = "$result" - val TempResultPrefix = "$result_" val ReservedResult = "result" val RenamedResult = "_result$" } diff --git a/src/main/scala/gvc/weaver/CheckImplementation.scala b/src/main/scala/gvc/weaver/CheckImplementation.scala index e9aad5a3..4d3537a6 100644 --- a/src/main/scala/gvc/weaver/CheckImplementation.scala +++ b/src/main/scala/gvc/weaver/CheckImplementation.scala @@ -2,94 +2,132 @@ package gvc.weaver import scala.collection.mutable import gvc.transformer.IR +import gvc.transformer.IRPrinter sealed trait CheckMode { def prefix: String + def perms: IR.Expression + def withPerms(perms: IR.Expression): CheckMode + def guarded: Boolean + + def visitPerm( + runtime: CheckRuntime, + instanceId: IR.Expression, + fieldIndex: IR.Expression, + numFields: IR.Expression, + expression: String + ): Seq[IR.Op] + + def visitBool( + runtime: CheckRuntime, + value: IR.Expression + ): Seq[IR.Op] } -case object CheckAddRemoveMode extends CheckMode { - def prefix = "check_add_remove_" -} +case class AssertMode(val perms: IR.Expression) extends CheckMode { + def prefix = "assert_" + def withPerms(perms: IR.Expression): CheckMode = AssertMode(perms) + def guarded = true + + def visitPerm( + runtime: CheckRuntime, + instanceId: IR.Expression, + fieldIndex: IR.Expression, + numFields: IR.Expression, + expression: String + ): Seq[IR.Op] = { + val error = new IR.StringLit(s"No permission to access '$expression'") + val args = List(perms, instanceId, fieldIndex, error) + val invoke = new IR.Invoke(runtime.assert, args, None) + Seq(invoke) + } -case object CheckAddMode extends CheckMode { - def prefix = "check_add_" + def visitBool(runtime: CheckRuntime, value: IR.Expression): Seq[IR.Op] = { + val assert = new IR.Assert(value, IR.AssertKind.Imperative) + Seq(assert) + } } - -case object AddMode extends CheckMode { +case class AddMode(val perms: IR.Expression, val guarded: Boolean = true) extends CheckMode { def prefix = "add_" -} + def withPerms(perms: IR.Expression): CheckMode = AddMode(perms) + + def visitPerm(runtime: CheckRuntime, instanceId: IR.Expression, fieldIndex: IR.Expression, numFields: IR.Expression, expression: String): Seq[IR.Op] = { + // TODO (?): We need to also add permissions required for framing + val error = new IR.StringLit(s"Invalid aliasing - '$expression' overlaps with existing permission") + val args = List(perms, instanceId, numFields, fieldIndex, error) + val invoke = new IR.Invoke(runtime.add, args, None) + Seq(invoke) + } -case object AddRemoveMode extends CheckMode { - def prefix = "add_remove_" + def visitBool(runtime: CheckRuntime, value: IR.Expression): Seq[IR.Op] = + Seq.empty } - -case object RemoveMode extends CheckMode { +case class RemoveMode(val perms: IR.Expression) extends CheckMode { def prefix = "remove_" -} - -case object SeparationMode extends CheckMode { - def prefix = "sep_" -} + def withPerms(perms: IR.Expression): CheckMode = RemoveMode(perms) + def guarded = false + + def visitPerm( + runtime: CheckRuntime, + instanceId: IR.Expression, + fieldIndex: IR.Expression, + numFields: IR.Expression, + expression: String + ): Seq[IR.Op] = { + val error = new IR.StringLit(s"No permission to access '$expression'") + val args = List(perms, instanceId, fieldIndex, error) + val invoke = new IR.Invoke(runtime.remove, args, None) + Seq(invoke) + } -case object VerifyMode extends CheckMode { - def prefix = "" + def visitBool(runtime: CheckRuntime, value: IR.Expression): Seq[IR.Op] = + Seq.empty } -sealed trait CheckType - -case object Separation extends CheckType - -case object Verification extends CheckType - class CheckImplementation( - program: IR.Program, - val runtime: CheckRuntime, - structIds: Map[String, IR.StructField] - ) { + program: IR.Program, + val runtime: CheckRuntime, + structIds: Map[String, IR.StructField] +) { private val predicateImplementations = - mutable.Map[(CheckMode, String), Option[IR.MethodDefinition]]() + mutable.Map[(String, String), Option[IR.MethodDefinition]]() private def resolvePredicateDefinition( - mode: CheckMode, - pred: IR.Predicate - ): Option[IR.MethodDefinition] = { + modes: List[CheckMode], + pred: IR.Predicate + ): Option[IR.MethodDefinition] = { + val prefix = modes.map(_.prefix).mkString predicateImplementations.getOrElse( - (mode, pred.name), - implementPredicate(mode, pred) + (prefix, pred.name), + implementPredicate(modes, pred) ) } private def implementPredicate( - mode: CheckMode, - pred: IR.Predicate - ): Option[IR.MethodDefinition] = { - + modes: List[CheckMode], + pred: IR.Predicate + ): Option[IR.MethodDefinition] = { // TODO: allow name collisions when adding methods - val defn = program.addMethod(mode.prefix + pred.name, None) - predicateImplementations += ((mode, pred.name) -> Some(defn)) + val prefix = modes.map(_.prefix.mkString).mkString + val methodName = prefix + pred.name + + val defn = program.addMethod(methodName, None) + predicateImplementations += ((prefix, pred.name) -> Some(defn)) - val newParams = pred.parameters + val predParams = pred.parameters .map((p: IR.Var) => p -> defn.addParameter(p.varType, p.name)) .toMap - val permsPrimary = defn.addParameter( - runtime.ownedFieldsRef, - CheckRuntime.Names.primaryOwnedFields - ) - - val permsSecondary = - if (mode == AddRemoveMode || mode == CheckAddMode || mode == CheckAddRemoveMode) - Some( - defn.addParameter( - runtime.ownedFieldsRef, - CheckRuntime.Names.temporaryOwnedFields - )) - else None + val childModes = modes.map(mode => { + val paramName = mode.prefix + "perms" + val param = defn.addParameter(runtime.ownedFieldsRef, paramName) + mode.withPerms(param) + }) - val context = new PredicateContext(pred, newParams) + val context = new PredicateContext(pred, predParams) val ops = - translate(mode, pred.expression, permsPrimary, permsSecondary, context) + translate(pred.expression, context, childModes) if (ops.nonEmpty) { defn.body ++= ops @@ -97,7 +135,7 @@ class CheckImplementation( } else { // Otherwise, this predicate implementation is a no-op, and it can be ignored // TODO: Remove the no-op method definition - predicateImplementations.update((mode, pred.name), None) + predicateImplementations.update((prefix, pred.name), None) None } } @@ -106,49 +144,32 @@ class CheckImplementation( structIds(struct.name) def translate( - mode: CheckMode, - expr: IR.Expression, - permsPrimary: IR.Var, - permsSecondary: Option[IR.Var], - context: SpecificationContext, - ): Seq[IR.Op] = expr match { + expr: IR.Expression, + context: SpecificationContext, + modes: List[CheckMode] + ): Seq[IR.Op] = expr match { case acc: IR.Accessibility => acc.member match { case member: IR.FieldMember => - translateFieldPermission(mode, - member, - permsPrimary, - permsSecondary, - context) - case _ => - throw new WeaverException("Invalid conjunct in specification.") + translateFieldPermission(member, modes, context) + case member => + throw new WeaverException( + "Invalid member '" + IRPrinter.print(member) + "' in specification.") } - case pred: IR.PredicateInstance => - translatePredicateInstance(mode, - pred, - permsPrimary, - permsSecondary, - context) + case instance: IR.PredicateInstance => + translatePredicateInstance(instance, modes, context) case imp: IR.Imprecise => imp.precise match { case None => Seq.empty case Some(precise) => - translate(mode, precise, permsPrimary, permsSecondary, context) + translate(precise, context, modes) } case conditional: IR.Conditional => - val trueOps = translate(mode, - conditional.ifTrue, - permsPrimary, - permsSecondary, - context) - val falseOps = translate(mode, - conditional.ifFalse, - permsPrimary, - permsSecondary, - context) - val condition = context.convertExpression(conditional.condition) + val trueOps = translate(conditional.ifTrue, context, modes) + val falseOps = translate(conditional.ifFalse, context, modes) + val condition = context.convert(conditional.condition) (trueOps.isEmpty, falseOps.isEmpty) match { case (false, false) => val ifStmt = new IR.If(condition) @@ -169,254 +190,106 @@ class CheckImplementation( case (true, true) => Seq.empty - } - case binary: IR.Binary if binary.operator == IR.BinaryOp.And => - translate(mode, binary.left, permsPrimary, permsSecondary, context) ++ translate( - mode, - binary.right, - permsPrimary, - permsSecondary, - context + Seq.concat( + translate(binary.left, context, modes), + translate(binary.right, context, modes) ) - - case expr => - mode match { - case SeparationMode | AddMode | RemoveMode | AddRemoveMode => - Seq.empty - case VerifyMode | CheckAddMode | CheckAddRemoveMode => - val toAssert = context.convertExpression(expr) - Seq(new IR.Assert(toAssert, IR.AssertKind.Imperative)) - - } + case expr => { + val converted = context.convert(expr) + modes.flatMap(_.visitBool(runtime, converted)) + } } def translateFieldPermission( - mode: CheckMode, - member: IR.FieldMember, - permsPrimary: IR.Var, - permsSecondary: Option[IR.Var], - context: SpecificationContext - ): Seq[IR.Op] = { - val convertedMember = context.convertFieldMember(member) - val struct = program.structs.find(s => s.name == convertedMember.field.struct.name) match { - case Some(s) => s - case None => throw new WeaverException("struct not found for field member") + member: IR.FieldMember, + modes: List[CheckMode], + context: SpecificationContext + ): Seq[IR.Op] = { + val converted = context.convert(member) + val root = converted.root + val struct = converted.field.struct match { + case s: IR.Struct => + s + case s => + throw new WeaverException("Cannot access fields of struct " + s.name) } - val idFieldExists = struct.fields.exists(fld => { - fld.name == "_id" - }) - if (!idFieldExists) { - throw new WeaverException("Couldn't locate _id field") - } - val instanceId = - if (convertedMember.root.valueType.isDefined) { - mode match { - case SeparationMode | VerifyMode | CheckAddRemoveMode | - CheckAddMode => - new IR.Conditional( - new IR.Binary( - IR.BinaryOp.NotEqual, - convertedMember.root, - new IR.NullLit() - ), - new IR.FieldMember(convertedMember.root, structIdField(struct)), - new IR.IntLit(-1) - ) - case AddMode | RemoveMode | AddRemoveMode => - // If it's in add/remove, it doesn't need the null check - new IR.FieldMember(convertedMember.root, structIdField(struct)) - } - } else { - // If valueType is not defined, there is a NULL dereference in - // the expression, so we cannot compile it - mode match { - case SeparationMode | VerifyMode => + + // Get the expression to access the instance ID + val guarded = modes.exists(_.guarded) + val instanceId: IR.Expression = + root.valueType match { + case None => + // If valueType is not defined, there is a NULL dereference in + // the expression, so we cannot compile it + if (guarded) new IR.IntLit(-1) - case AddMode | RemoveMode | AddRemoveMode | CheckAddRemoveMode | - CheckAddMode => + else throw new WeaverException("Invalid NULL dereference") + case Some(_) => { + val id = new IR.FieldMember(root, structIdField(struct)) + if (guarded) { + // Convert to `root == null ? -1 : root->_id` + val nullCheck = new IR.Binary(IR.BinaryOp.Equal, root, new IR.NullLit()) + new IR.Conditional(nullCheck, new IR.IntLit(-1), id) + } else { + id + } } } - - val fname = convertedMember.field.name - val fvtype = convertedMember.field.valueType - val fieldIndex = new IR.IntLit(struct.fields.indexWhere(f => f.name == fname && f.valueType == fvtype)) - val numFields = new IR.IntLit(struct.fields.length-1) - //TODO: add support for IRPrinter.printExpr here - val fullName = s"struct ${struct.name}.${convertedMember.field.name}" - - mode match { - case SeparationMode => - val error = - new IR.StringLit(s"Overlapping field permissions for $fullName") - Seq( - new IR.Invoke( - runtime.addAccEnsureSeparate, - List(permsPrimary, instanceId, fieldIndex, numFields, error), - None - ) - ) - case VerifyMode => - val error = - new IR.StringLit(s"Field access runtime check failed for $fullName") - Seq( - new IR.Invoke( - runtime.assertAcc, - List(permsPrimary, instanceId, fieldIndex, error), - None - ) - ) - case RemoveMode => - Seq( - new IR.Invoke( - runtime.loseAcc, - List(permsPrimary, instanceId, fieldIndex), - None - ) - ) - - // TODO: We need to also add permissions required for framing - case AddMode => - Seq( - new IR.Invoke( - runtime.addAcc, - List(permsPrimary, instanceId, numFields, fieldIndex), - None - ) - ) - case CheckAddRemoveMode => - val error = - new IR.StringLit(s"Field access runtime check failed for $fullName") - permsSecondary match { - case Some(secondary) => - Seq( - new IR.Invoke( - runtime.assertAcc, - List(secondary, instanceId, fieldIndex, error), - None - ), - new IR.Invoke( - runtime.addAcc, - List(permsPrimary, instanceId, numFields, fieldIndex), - None - ), - new IR.Invoke( - runtime.loseAcc, - List(secondary, instanceId, fieldIndex), - None - ) - ) - case None => - throw new WeaverException( - "Missing temporary OwnedFields struct reference for CheckAddRemove mode.") - } - case CheckAddMode => - val error = - new IR.StringLit(s"Field access runtime check failed for $fullName") - permsSecondary match { - case Some(secondary) => - Seq( - new IR.Invoke( - runtime.assertAcc, - List(secondary, instanceId, fieldIndex, error), - None - ), - new IR.Invoke( - runtime.addAcc, - List(permsPrimary, instanceId, numFields, fieldIndex), - None - ) - ) - case None => - throw new WeaverException( - "Missing temporary OwnedFields struct reference for CheckAdd mode.") - } - case AddRemoveMode => - permsSecondary match { - case Some(secondary) => - Seq( - new IR.Invoke( - runtime.addAcc, - List(permsPrimary, instanceId, numFields, fieldIndex), - None - ), - new IR.Invoke( - runtime.loseAcc, - List(secondary, instanceId, fieldIndex), - None - ) - ) - case None => - throw new WeaverException( - "Missing temporary OwnedFields struct reference for AddRemove mode.") - } - - } + + val fieldName = converted.field.name + val fieldType = converted.field.valueType + // TODO: Can we just use regular `indexOf`? + val fieldIndex = new IR.IntLit(struct.fields.indexWhere( + f => f.name == fieldName && f.valueType == fieldType)) + val numFields = new IR.IntLit(struct.fields.length - 1) + val expression = IRPrinter.print(converted) + + modes.flatMap(_.visitPerm(runtime, instanceId, fieldIndex, numFields, expression)) } def translatePredicateInstance( - mode: CheckMode, - pred: IR.PredicateInstance, - permsPrimary: IR.Var, - permsSecondary: Option[IR.Var], - context: SpecificationContext - ): Seq[IR.Op] = { - val arguments = pred.arguments.map(context.convertExpression) - - val toAppend = mode match { - - case AddRemoveMode | CheckAddRemoveMode | CheckAddMode => - permsSecondary match { - case Some(value) => List(permsPrimary, value) - case None => - throw new WeaverException( - "Missing secondary OwnedFields reference for optimized permission tracking mode.") - } - case _ => List(permsPrimary) - } - resolvePredicateDefinition(mode, pred.predicate) - .map(new IR.Invoke(_, arguments ++ toAppend, None)) + instance: IR.PredicateInstance, + modes: List[CheckMode], + context: SpecificationContext + ): Seq[IR.Op] = { + // Pass the predicate arguments followed by the permission sets + val arguments = + instance.arguments.map(context.convert) ::: modes.map(_.perms) + + resolvePredicateDefinition(modes, instance.predicate) + .map(new IR.Invoke(_, arguments, None)) .toSeq } - def trackAllocation(alloc: IR.AllocStruct, perms: IR.Var): Unit = { - val structType = program.structs.find(s => s.name == alloc.struct.name) match { - case Some(s) => s - case None => throw new WeaverException("struct def not found for struct alloc") + def trackAllocation(alloc: IR.AllocStruct, perms: IR.Expression): Seq[IR.Op] = { + val struct = alloc.struct match { + case s: IR.Struct => s + case _: IR.DependencyStruct => + throw new WeaverException("Cannot allocate library struct") } val idField = new IR.FieldMember( alloc.target, - structIdField(alloc.struct) + structIdField(struct) ) - alloc.insertAfter( - new IR.Invoke( - runtime.addStructAcc, - List(perms, new IR.IntLit(structType.fields.length-1)), - Some(idField) - ) - ) + new IR.Invoke( + runtime.addStruct, + List(perms, idField, new IR.IntLit(struct.fields.length-1)), + None + ) :: Nil } - def idAllocation(alloc: IR.AllocStruct, - instanceCounter: IR.Expression): Unit = { - val idField = new IR.FieldMember( - alloc.target, - structIdField(alloc.struct) - ) + def permsType: IR.Type = runtime.ownedFieldsRef + + def init(target: IR.Expression): Seq[IR.Op] = + Seq(new IR.Invoke(runtime.init, Nil, Some(target))) - alloc.insertAfter( - Seq( - new IR.AssignMember(idField, new IR.DereferenceMember(instanceCounter)), - new IR.AssignMember( - new IR.DereferenceMember(instanceCounter), - new IR.Binary(IR.BinaryOp.Add, - new IR.DereferenceMember(instanceCounter), - new IR.IntLit(1))) - )) + def join(target: IR.Expression, source: IR.Expression): Seq[IR.Op] = { + Seq(new IR.Invoke(runtime.join, List(target, source), None)) } } diff --git a/src/main/scala/gvc/weaver/CheckRuntime.scala b/src/main/scala/gvc/weaver/CheckRuntime.scala index 945e0b93..17065016 100644 --- a/src/main/scala/gvc/weaver/CheckRuntime.scala +++ b/src/main/scala/gvc/weaver/CheckRuntime.scala @@ -35,19 +35,14 @@ object CheckRuntime { val primaryOwnedFields = "_ownedFields" val temporaryOwnedFields = "_tempFields" val contextOwnedFields = "_contextFields" - val initOwnedFields = "initOwnedFields" - val addStructAcc = "addStructAcc" - val addAcc = "addAcc" - val loseAcc = "loseAcc" - val join = "join" - val assertAcc = "assertAcc" - val addAccEnsureSeparate = "addAccEnsureSeparate" - val find = "find" + val initOwnedFields = "runtime_init" + val addStruct = "runtime_addAll" + val remove = "runtime_remove" + val join = "runtime_join" + val assert = "runtime_assert" + val add = "runtime_add" val instanceCounter = "_instanceCounter" val id = "_id" - val removePrefix = "remove_" - val addPrefix = "add_" - val checkPrefix = "check_" } } @@ -56,17 +51,13 @@ class CheckRuntime private (program: IR.Program) { val ownedFields: IR.StructDefinition = program.struct(Names.ownedFieldsStruct) val ownedFieldsRef = new IR.ReferenceType(ownedFields) - val ownedFieldInstanceCounter: IR.StructField = - ownedFields.fields.find(_.name == "instanceCounter").get - val initOwnedFields: IR.MethodDefinition = + val init: IR.MethodDefinition = program.method(Names.initOwnedFields) - val addStructAcc: IR.MethodDefinition = - program.method(Names.addStructAcc) - val addAcc: IR.MethodDefinition = program.method(Names.addAcc) - val addAccEnsureSeparate: IR.MethodDefinition = - program.method(Names.addAccEnsureSeparate) - val loseAcc: IR.MethodDefinition = program.method(Names.loseAcc) + val addStruct: IR.MethodDefinition = + program.method(Names.addStruct) + val add: IR.MethodDefinition = + program.method(Names.add) + val remove: IR.MethodDefinition = program.method(Names.remove) val join: IR.MethodDefinition = program.method(Names.join) - val assertAcc: IR.MethodDefinition = program.method(Names.assertAcc) - val find: IR.MethodDefinition = program.method(Names.find) + val assert: IR.MethodDefinition = program.method(Names.assert) } diff --git a/src/main/scala/gvc/weaver/CheckScope.scala b/src/main/scala/gvc/weaver/CheckScope.scala new file mode 100644 index 00000000..85704c69 --- /dev/null +++ b/src/main/scala/gvc/weaver/CheckScope.scala @@ -0,0 +1,102 @@ +package gvc.weaver + +import scala.collection.mutable +import gvc.transformer.IR + +sealed trait CheckScope { + def block: IR.Block + def children: Seq[WhileScope] + def checks: Seq[RuntimeCheck] +} + +sealed trait MethodScope extends CheckScope { + def method: IR.Method + def block: IR.Block = method.body + def conditions: Iterable[TrackedCondition] +} + +sealed trait WhileScope extends CheckScope { + def op: IR.While + def block: IR.Block = op.body +} + +class ProgramScope( + val program: IR.Program, + val methods: Map[String, MethodScope] +) + +object CheckScope { + private sealed abstract class CheckScopeImplementation extends CheckScope { + val children = mutable.ListBuffer[WhileScope]() + val checks = mutable.ArrayBuffer[RuntimeCheck]() + } + + private sealed class MethodScopeImplementation( + val method: IR.Method, + val conditions: Iterable[TrackedCondition] + ) extends CheckScopeImplementation with MethodScope + private sealed class WhileScopeImplementation(val op: IR.While) + extends CheckScopeImplementation with WhileScope + + def scope(collected: Collector.CollectedProgram): ProgramScope = + new ProgramScope( + collected.program, + collected.methods.map({ case(k, cm) => + (k, scope(cm.checks, cm.conditions, cm.method)) }) + ) + + def scope( + checks: Seq[RuntimeCheck], + conditions: Iterable[TrackedCondition], + method: IR.Method + ): MethodScope = { + val outer = new MethodScopeImplementation(method, conditions) + val inner = mutable.HashMap[IR.While, WhileScopeImplementation]() + + // Create and index all the child scopes + def initBlock(block: IR.Block, scope: CheckScopeImplementation): Unit = + block.foreach(init(_, scope)) + def init(op: IR.Op, scope: CheckScopeImplementation): Unit = + op match { + case w: IR.While => { + val child = new WhileScopeImplementation(w) + scope.children += child + inner += w -> child + + initBlock(w.body, child) + } + case i: IR.If => initBlock(i.ifTrue, scope); initBlock(i.ifFalse, scope) + case _ => () + } + + initBlock(method.body, outer) + + def getScope(op: IR.Op): CheckScopeImplementation = { + if (op.block == method.body) { + outer + } else { + op.block match { + case c: IR.ChildBlock => c.op match { + case cond: IR.If => getScope(cond) + case loop: IR.While => + inner.getOrElse(loop, throw new WeaverException("Missing inner scope")) + case _ => throw new WeaverException("Invalid IR structure") + } + + case _ => throw new WeaverException("Invalid IR structure") + } + } + } + + for (c <- checks) { + val scope = c.location match { + case at: AtOp => getScope(at.op) + case MethodPre | MethodPost => outer + } + + scope.checks += c + } + + outer + } +} \ No newline at end of file diff --git a/src/main/scala/gvc/weaver/Checker.scala b/src/main/scala/gvc/weaver/Checker.scala index fa01901b..c4fc3007 100644 --- a/src/main/scala/gvc/weaver/Checker.scala +++ b/src/main/scala/gvc/weaver/Checker.scala @@ -1,14 +1,19 @@ package gvc.weaver import gvc.transformer.IR -import Collector._ import scala.collection.mutable import scala.annotation.tailrec +import CheckRuntime.Names object Checker { + sealed trait CallStyle + case object PermissionsOptional extends CallStyle + case object PermissionsRequired extends CallStyle + case object PermissionsElided extends CallStyle + type StructIDTracker = Map[String, IR.StructField] - def insert(program: Collector.CollectedProgram): Unit = { + def insert(program: ProgramDependencies): Unit = { val runtime = CheckRuntime.addToIR(program.program) // Add the _id field to each struct @@ -24,374 +29,406 @@ object Checker { program.methods.values.foreach { method => insert(program, method, runtime, implementation) } + + // Use the methods from the ProgramDependencies object, to avoid injecting + // instance counter into predicate implementations + InstanceCounter.inject( + program.methods.values.map(_.method).toSeq, + structIdFields) } - private def insert( - programData: CollectedProgram, - methodData: CollectedMethod, - runtime: CheckRuntime, - implementation: CheckImplementation - ): Unit = { - val program = programData.program - val method = methodData.method + // Assumes that there is a single return statement at the end of the method. + // This assumption is guaranteed during transformation by the + // ReturnSimplification pass. + private def insertBeforeReturn( + block: IR.Block, + ops: Option[IR.Expression] => Seq[IR.Op] + ) : Unit = block.lastOption match { + case Some(ret: IR.Return) => ret.insertBefore(ops(ret.value)) + case _ => block ++= ops(None) + } - val callsImprecise: Boolean = methodData.calls.exists(c => - programData.methods.get(c.ir.callee.name) match { - case Some(value) => value.callStyle != PreciseCallStyle - case None => false - }) - - // `ops` is a function that generates the operations, given the current return value at that - // position. DO NOT construct the ops before passing them to this method since multiple copies - // may be required. - def insertAt(at: Location, ops: Option[IR.Expression] => Seq[IR.Op]): Unit = - at match { - case LoopStart(op: IR.While) => ops(None) ++=: op.body - case LoopEnd(op: IR.While) => op.body ++= ops(None) - case Pre(op) => op.insertBefore(ops(None)) - case Post(op) => op.insertAfter(ops(None)) - case MethodPre => ops(None) ++=: method.body - case MethodPost => - methodData.returns.foreach(e => e.insertBefore(ops(e.value))) - if (methodData.hasImplicitReturn) { - method.body ++= ops(None) - } - case _ => throw new WeaverException(s"Invalid location '$at'") + private def insertAt( + at: Location, + method: IR.Method, + ops: Option[IR.Expression] => Seq[IR.Op] + ): Unit = at match { + case LoopStart(op: IR.While) => ops(None) ++=: op.body + case LoopEnd(op: IR.While) => op.body ++= ops(None) + case Pre(op) => op.insertBefore(ops(None)) + case Post(op) => op.insertAfter(ops(None)) + case MethodPre => ops(None) ++=: method.body + case MethodPost => insertBeforeReturn(method.body, ops) + case _ => throw new WeaverException(s"Invalid location '$at'") + } + + trait PermissionScope { + def requirePermissions: IR.Expression + def optionalPermissions(generate: IR.Expression => Seq[IR.Op]): Seq[IR.Op] + def optionalPermissions: IR.Expression + def trackingPermissions: Boolean + } + + class RequiredPermissions(permissions: IR.Expression) extends PermissionScope { + def requirePermissions = permissions + def optionalPermissions(generate: IR.Expression => Seq[IR.Op]): Seq[IR.Op] = + generate(permissions) + def optionalPermissions: IR.Expression = permissions + def trackingPermissions = true + } + + class OptionalPermissions(permissions: IR.Expression) extends PermissionScope { + def requirePermissions: IR.Expression = + throw new WeaverException("Required permissions inside optional permission scope") + def optionalPermissions(generate: IR.Expression => Seq[IR.Op]): Seq[IR.Op] = { + val cond = new IR.If( + new IR.Binary(IR.BinaryOp.NotEqual, permissions, new IR.NullLit())) + cond.ifTrue ++= generate(permissions) + List(cond) + } + def optionalPermissions: IR.Expression = permissions + + def trackingPermissions: Boolean = false + } + + object NoPermissions extends PermissionScope { + def requirePermissions: IR.Expression = + throw new WeaverException("No permission object available") + def optionalPermissions(generate: IR.Expression => Seq[IR.Op]): Seq[IR.Op] = + Seq.empty + def optionalPermissions: IR.Expression = new IR.NullLit() + def trackingPermissions: Boolean = false + } + + private def addPermissionsVar(method: IR.Method, impl: CheckImplementation): IR.Var = { + method.addVar(impl.runtime.ownedFieldsRef, Names.primaryOwnedFields) + } + + private def addPermissionsParam(method: IR.Method, impl: CheckImplementation): IR.Var = + method.addParameter(impl.runtime.ownedFieldsRef, Names.primaryOwnedFields) + + private def initPermissions(perms: IR.Var, impl: CheckImplementation): IR.Op = { + new IR.Invoke(impl.runtime.init, Nil, Some(perms)) + } + + private def addMethodPerms(method: IR.Method, impl: CheckImplementation): IR.Var = + method.name match { + case "main" => { + val perms = addPermissionsVar(method, impl) + initPermissions(perms, impl) +=: method.body + perms } - var nextConditionalId = 1 - val conditionVars = methodData.conditions.map { c => - val flag = method.addVar(IR.BoolType, s"_cond_$nextConditionalId") - nextConditionalId += 1 - c -> flag - }.toMap + case _ => addPermissionsParam(method, impl) + } + + private def getCallStyle(dep: MethodDependencies): CallStyle = { + if (dep.returnsPerms) { + if (dep.requiresPerms) PermissionsRequired + else if (dep.modifiesPerms && dep.method.name != "main") PermissionsOptional + else PermissionsElided + } else { + PermissionsElided + } + } - def foldConditionList(conds: List[Condition], - op: IR.BinaryOp): IR.Expression = { - conds - .foldLeft[Option[IR.Expression]](None) { - case (Some(expr), cond) => - Some(new IR.Binary(op, expr, getCondition(cond))) - case (None, cond) => Some(getCondition(cond)) + private def getPermissions( + dep: ScopeDependencies, + impl: CheckImplementation, + parent: Option[PermissionScope] = None + ): PermissionScope = { + dep match { + case m: MethodDependencies if (m.returnsPerms && m.requiresPerms) => + new RequiredPermissions(addMethodPerms(m.method, impl)) + case m: MethodDependencies if (m.returnsPerms && m.modifiesPerms) => + new OptionalPermissions(addMethodPerms(m.method, impl)) + case w: WhileDependencies if w.returnsPerms => + parent.getOrElse(throw new WeaverException("Parent permissions required")) + case dep if dep.requiresPerms => { + // Requires perms but does not return (or inherit) them + // Need to reconstruct from the spec + val variable = addPermissionsVar(dep.block.method, impl) + val spec = dep match { + case m: MethodDependencies => m.method.precondition + case w: WhileDependencies => Some(w.op.invariant) } - .getOrElse(throw new WeaverException("Invalid empty condition list")) + spec.foreach(p => + impl.translate(p, ValueContext, List(AddMode(variable, guarded=false))) ++=: dep.block) + new IR.Invoke(impl.runtime.init, Nil, Some(variable)) +=: dep.block + new RequiredPermissions(variable) + } + case _ => NoPermissions } + } - def getCondition(cond: Condition): IR.Expression = cond match { - case ImmediateCondition(expr) => expr.toIR(program, method, None) - case cond: TrackedCondition => conditionVars(cond) - case NotCondition(value) => - new IR.Unary(IR.UnaryOp.Not, getCondition(value)) - case AndCondition(values) => foldConditionList(values, IR.BinaryOp.And) - case OrCondition(values) => foldConditionList(values, IR.BinaryOp.Or) - } + private def insert( + programData: ProgramDependencies, + methodData: MethodDependencies, + runtime: CheckRuntime, + impl: CheckImplementation + ): Unit = { + val method = methodData.method + + // Create the permissions scope + // Adds a parameter to receive OwnedFields, if necessary + val permissions = getPermissions(methodData, impl, None) - val initializeOps = mutable.ListBuffer[IR.Op]() + val conditions = methodData.conditions.map(c => + c -> method.addVar(IR.BoolType, "_cond")).toMap - def methodContainsImprecision(methodData: CollectedMethod): Boolean = { - val contractImprecise = methodData.callStyle match { - case ImpreciseCallStyle | PrecisePreCallStyle => true - case _ => false + val context = CheckContext( + program = programData.program, + method = method, + conditions = conditions, + permissions = permissions, + implementation = impl, + runtime = runtime) + + insert(programData, methodData, context) + + // Add all conditions that need tracked + // Group all conditions for a single location and insert in sequence + // to preserve the correct ordering of conditions. + methodData.conditions + .groupBy(_.location) + .foreach { + case (loc, conds) => + insertAt(loc, method, retVal => { + val instrs = mutable.ListBuffer[IR.Op]() + conds.foreach( + c => + instrs += new IR.Assign(conditions(c), + c.value.toIR(programData.program, method, retVal))) + instrs + }) } + } - contractImprecise || - methodData.bodyContainsImprecision || - methodData.calls.exists( - c => - c.ir.callee.isInstanceOf[IR.Method] && - (programData.methods(c.ir.callee.name).callStyle match { - case ImpreciseCallStyle | PrecisePreCallStyle => true - case _ => false - }) - ) + // Creates a temporary set of permissions, + private def useTempPermissions(call: IR.Invoke, perms: PermissionScope, context: CheckContext) = { + // Need to create temporary set and merge after + val impl = context.implementation + val runtime = context.runtime + val tempPerms = context.method.addVar( + runtime.ownedFieldsRef, Names.temporaryOwnedFields) + + call.insertBefore( + new IR.Invoke(runtime.init, Nil, Some(tempPerms))) + + val pre = call.callee match { + case m: IR.Method => m.precondition + case _: IR.DependencyMethod => + throw new WeaverException("Attempting to add permissions to library method") } - var (primaryOwnedFields, instanceCounter) = methodData.callStyle match { - case MainCallStyle => { - val instanceCounter = - method.addVar( - new IR.PointerType(IR.IntType), - CheckRuntime.Names.instanceCounter - ) - initializeOps += new IR.AllocValue(IR.IntType, instanceCounter) - (None, instanceCounter) - } + call.arguments :+= tempPerms + + val specContext = new CallSiteContext(call) + perms match { + case NoPermissions => { + pre.foreach(pre => + call.insertBefore(impl.translate( + pre, + specContext, + List(AddMode(tempPerms, guarded=false)) + )) + ) - case PreciseCallStyle => { - if (methodContainsImprecision(methodData)) { - val ownedFields: IR.Var = method.addParameter( - runtime.ownedFieldsRef, - CheckRuntime.Names.primaryOwnedFields) - val instanceCounter = - new IR.FieldMember(ownedFields, runtime.ownedFieldInstanceCounter) - (Some(ownedFields), instanceCounter) - } else { - val instanceCounter = - method.addParameter( - new IR.PointerType(IR.IntType), - CheckRuntime.Names.instanceCounter - ) - (None, instanceCounter) - } + // No permission tracking, so no need to join } - case ImpreciseCallStyle | PrecisePreCallStyle => { - val ownedFields: IR.Var = - method.addParameter( - runtime.ownedFieldsRef, - CheckRuntime.Names.primaryOwnedFields + case perms: OptionalPermissions => { + pre.foreach(pre => { + val permsVal = perms.optionalPermissions + val cond = new IR.If( + new IR.Binary(IR.BinaryOp.Equal, permsVal, new IR.NullLit())) + // Use AddMode if perms have not been passed + cond.ifTrue ++= impl.translate( + pre, + specContext, + List(AddMode(tempPerms)) ) - val instanceCounter = - new IR.FieldMember(ownedFields, runtime.ownedFieldInstanceCounter) - (Some(ownedFields), instanceCounter) + // Use AddRemoveMode if perms have been passed + cond.ifFalse ++= impl.translate( + pre, + specContext, + List(AddMode(tempPerms), RemoveMode(permsVal)) + ) + call.insertBefore(cond) + }) + + // Join if perms have been passed + call.insertAfter(perms.optionalPermissions( + permsVal => impl.join(permsVal, tempPerms))) } - } - def getPrimaryOwnedFields(): IR.Var = primaryOwnedFields.getOrElse { - val ownedFields = method.addVar( - runtime.ownedFieldsRef, - CheckRuntime.Names.primaryOwnedFields - ) - primaryOwnedFields = Some(ownedFields) + case perms: RequiredPermissions => { + // Permissions are always required, so always use AddRemoveMode and + // always join after + pre.foreach(pre => + call.insertBefore(impl.translate( + pre, + specContext, + List(AddMode(tempPerms), RemoveMode(perms.requirePermissions)) + )) + ) - initializeOps += new IR.Invoke( - runtime.initOwnedFields, - List(instanceCounter), - primaryOwnedFields - ) - ownedFields + call.insertAfter(impl.join(perms.requirePermissions, tempPerms)) + } } + } + + // Adds permissions to a method + private def addPermissions(call: IR.Invoke, context: CheckContext, program: ProgramDependencies) = { + val dep = program.methods(call.callee.name) + getCallStyle(dep) match { + case PermissionsRequired if dep.inheritsPerms => + call.arguments :+= context.permissions.requirePermissions + case PermissionsRequired => + // Permissions are returned, but not inherited (i.e., precise pre, + // imprecise post) + useTempPermissions(call, context.permissions, context) + case PermissionsOptional => + // Since permissions are optional, the presence of permissions is never + // checked, so it doesn't matter that we send more permissions than are + // required. (Thus we don't need to special-case precise pre, imprecise + // post.) + call.arguments :+= context.permissions.optionalPermissions + case PermissionsElided if dep.returnsPerms => + // Returns permissions dynamically, but they are elided since it does + // not modify (or check) permissions + () + case PermissionsElided => { + // Elided because the method is static + + // Remove permissions in the pre-condition before calling + dep.method.precondition.foreach(pre => { + call.insertBefore(context.permissions.optionalPermissions( + perms => context.implementation.translate( + pre, + new CallSiteContext(call), + List(RemoveMode(perms)) + ) + )) + }) + + // Add permissions in the post-condition after the call + dep.method.postcondition.foreach(post => { + call.insertAfter(context.permissions.optionalPermissions( + perms => context.implementation.translate( + post, + new CallSiteContext(call), + List(AddMode(perms)) + ) + )) + }) + } + } + } + + private def insert( + programData: ProgramDependencies, + scope: ScopeDependencies, + context: CheckContext + ): Unit = { + val program = programData.program // Insert the runtime checks // Group them by location and condition, so that multiple checks can be contained in a single // if block. - val context = CheckContext(program, method, implementation, runtime) - for ((loc, checkData) <- groupChecks(methodData.checks)) { - insertAt( - loc, - retVal => { - val ops = mutable.ListBuffer[IR.Op]() - - // Create a temporary owned fields instance when it is required - var temporaryOwnedFields: Option[IR.Var] = None - - def getTemporaryOwnedFields(): IR.Var = - temporaryOwnedFields.getOrElse { - val tempVar = context.method.addVar( - context.runtime.ownedFieldsRef, - CheckRuntime.Names.temporaryOwnedFields - ) - temporaryOwnedFields = Some(tempVar) - tempVar - } - - for ((cond, checks) <- checkData) { - val condition = cond.map(getCondition(_)) - ops ++= implementChecks( - condition, - checks.map(_.check), - retVal, - getPrimaryOwnedFields, - getTemporaryOwnedFields, - instanceCounter, - context + for ((loc, checkData) <- groupChecks(scope.checks)) { + insertAt(loc, context.method, retVal => { + val ops = mutable.ListBuffer[IR.Op]() + + // Create a temporary owned fields instance when it is required + var temporaryOwnedFields: Option[IR.Var] = None + + def getTemporaryOwnedFields(): IR.Var = + temporaryOwnedFields.getOrElse { + val tempVar = context.method.addVar( + context.runtime.ownedFieldsRef, + CheckRuntime.Names.temporaryOwnedFields ) + temporaryOwnedFields = Some(tempVar) + tempVar } - // Prepend op to initialize owned fields if it is required - temporaryOwnedFields.foreach { tempOwned => - new IR.Invoke( - context.runtime.initOwnedFields, - List(instanceCounter), - Some(tempOwned) - ) +=: ops - } + for ((cond, checks) <- checkData) { + val condition = cond.map(context.getCondition(_)) + ops ++= implementChecks( + condition, + checks.map(_.check), + retVal, + getTemporaryOwnedFields, + context + ) + } - ops + // Prepend op to initialize owned fields if it is required + temporaryOwnedFields.foreach { tempOwned => + new IR.Invoke( + context.runtime.init, + Nil, + Some(tempOwned) + ) +=: ops } - ) - } - val needsToTrackPrecisePerms = methodContainsImprecision(methodData) - - if (needsToTrackPrecisePerms && methodData.callStyle == PreciseCallStyle) { - primaryOwnedFields match { - case Some(_) => - initializeOps ++= methodData.method.precondition.toSeq.flatMap( - implementation.translate( - AddMode, - _, - getPrimaryOwnedFields, - None, - ValueContext - ) - ) - case None => - } + ops + }) } - // Update the call sites to add any required parameters - for (call <- methodData.calls) { - call.ir.callee match { + + // Update the call sites to add permission tracking/passing. + // It is important that this gets done after checks are inserted so that + // the permission handling code binds closer to the call sites than checks. + // For example, checks required for a callee's pre-condition must be done + // before the permissions in the callee's pre-condition are removed. + for (call <- scope.calls) { + call.callee match { + // No parameters can be added to a main method or library methods case _: IR.DependencyMethod => () - case callee: IR.Method => - val calleeData = programData.methods(callee.name) - calleeData.callStyle match { - // No parameters can be added to a main method - case MainCallStyle => () - - // Imprecise methods always get the primary owned fields instance directly - case ImpreciseCallStyle => - call.ir.arguments :+= getPrimaryOwnedFields() - - case PreciseCallStyle => { - val context = new CallSiteContext(call.ir, method) - - if (methodContainsImprecision(calleeData)) { - val tempSet = method.addVar( - runtime.ownedFieldsRef, - CheckRuntime.Names.temporaryOwnedFields - ) - call.ir.arguments :+= tempSet - - val initTemp = new IR.Invoke( - runtime.initOwnedFields, - List(instanceCounter), - Some(tempSet) - ) - - call.ir.insertBefore(initTemp) - if (needsToTrackPrecisePerms) { - call.ir.insertBefore( - callee.precondition.toSeq - .flatMap( - implementation - .translate(AddRemoveMode, - _, - tempSet, - Some(getPrimaryOwnedFields()), - context) - ) - .toList - ) - } - } else { - call.ir.arguments :+= instanceCounter - if (needsToTrackPrecisePerms) { - val removePermsPrior = callee.precondition.toSeq - .flatMap( - implementation - .translate(RemoveMode, - _, - getPrimaryOwnedFields(), - None, - context) - ) - .toList - call.ir.insertBefore(removePermsPrior) - } - } - if (needsToTrackPrecisePerms) { - val addPermsAfter = callee.postcondition.toSeq - .flatMap( - implementation - .translate(AddMode, - _, - getPrimaryOwnedFields(), - None, - context) - ) - .toList - call.ir.insertAfter(addPermsAfter) - } - } - // For precise-pre/imprecise-post, create a temporary set of permissions, add the - // permissions from the precondition, call the method, and add the temporary set to the - // primary set - case PrecisePreCallStyle => { - val tempSet = method.addVar( - runtime.ownedFieldsRef, - CheckRuntime.Names.temporaryOwnedFields - ) - - val createTemp = new IR.Invoke( - runtime.initOwnedFields, - List(instanceCounter), - Some(tempSet) - ) - - val context = new CallSiteContext(call.ir, method) - - val resolvePermissions = callee.precondition.toSeq - .flatMap( - implementation.translate(AddRemoveMode, - _, - tempSet, - Some(getPrimaryOwnedFields), - context) - ) - .toList - call.ir.insertBefore( - createTemp :: resolvePermissions - ) - call.ir.arguments :+= tempSet - call.ir.insertAfter( - new IR.Invoke( - runtime.join, - List(getPrimaryOwnedFields, tempSet), - None - ) - ) - } - } + case m if m.name == "main" => () + case _: IR.Method => addPermissions(call, context, programData) } } // If a primary owned fields instance is required for this method, add all allocations into it - addAllocationTracking( - primaryOwnedFields, - instanceCounter, - methodData.allocations, - implementation, - runtime - ) + for (alloc <- scope.allocations) { + addAllocationTracking(alloc, context) + } - // Add all conditions that need tracked - // Group all conditions for a single location and insert in sequence - // to preserve the correct ordering of conditions. - methodData.conditions - .groupBy(_.location) - .foreach { - case (loc, conds) => - insertAt(loc, retVal => { - conds.map( - c => - new IR.Assign(conditionVars(c), - c.value.toIR(program, method, retVal))) - }) + for (child <- scope.children) { + val perms = getPermissions( + child, + context.implementation, + Some(context.permissions)) + insert(programData, child, context.copy(permissions = perms)) + + if (!child.returnsPerms && child.modifiesPerms) { + // While loop that maintains its perms internally but may modify the + // outer permissions. Insert code before and after to remove and + // then add (respectively) the permissions to/from the outer scope, + // if the outer scope is tracking permissions. + val op = child.op + op.insertBefore(context.permissions.optionalPermissions(perms => { + context.implementation.translate( + op.invariant, ValueContext, RemoveMode(perms) :: Nil) + })) + op.insertAfter(context.permissions.optionalPermissions(perms => { + context.implementation.translate( + op.invariant, ValueContext, AddMode(perms) :: Nil) + })) } - - // Finally, add all the initialization ops to the beginning - initializeOps ++=: method.body + } } def addAllocationTracking( - primaryOwnedFields: Option[IR.Var], - instanceCounter: IR.Expression, - allocations: List[IR.Op], - implementation: CheckImplementation, - runtime: CheckRuntime + alloc: IR.AllocStruct, + context: CheckContext ): Unit = { - for (alloc <- allocations) { - alloc match { - case alloc: IR.AllocStruct => - primaryOwnedFields match { - case Some(primary) => implementation.trackAllocation(alloc, primary) - case None => implementation.idAllocation(alloc, instanceCounter) - } - case _ => - throw new WeaverException( - "Tracking is only currently supported for struct allocations." - ) + context.permissions match { + case NoPermissions => () + case _ => { + alloc.insertAfter(context.implementation.trackAllocation(alloc, context.permissions.optionalPermissions)) } } } @@ -403,17 +440,15 @@ object Checker { context: CheckContext ): Seq[IR.Op] = { val field = check.field.toIR(context.program, context.method, returnValue) - val (mode, perms) = check match { + val mode = check match { case _: FieldSeparationCheck => - (SeparationMode, fields.temporaryOwnedFields()) + AddMode(fields.temporaryOwnedFields(), guarded=true) case _: FieldAccessibilityCheck => - (VerifyMode, fields.primaryOwnedFields()) + AssertMode(fields.primaryOwnedFields()) } - context.implementation.translateFieldPermission(mode, - field, - perms, - None, - ValueContext) + + val impl = context.implementation + impl.translateFieldPermission(field, List(mode), ValueContext) } def implementPredicateCheck( @@ -426,35 +461,55 @@ object Checker { context.program.predicate(check.predicateName), check.arguments.map(_.toIR(context.program, context.method, returnValue)) ) - val (mode, perms) = check match { + val mode = check match { case _: PredicateSeparationCheck => - (SeparationMode, fields.temporaryOwnedFields()) + AddMode(fields.temporaryOwnedFields(), guarded=true) case _: PredicateAccessibilityCheck => - (VerifyMode, fields.primaryOwnedFields()) + AssertMode(fields.primaryOwnedFields()) } - context.implementation.translatePredicateInstance(mode, - instance, - perms, - None, - ValueContext) + + val impl = context.implementation + impl.translatePredicateInstance(instance, List(mode), ValueContext) } case class FieldCollection( - primaryOwnedFields: () => IR.Var, - temporaryOwnedFields: () => IR.Var + primaryOwnedFields: () => IR.Expression, + temporaryOwnedFields: () => IR.Expression ) case class CheckContext( program: IR.Program, method: IR.Method, + conditions: Map[TrackedCondition, IR.Var], + permissions: PermissionScope, implementation: CheckImplementation, runtime: CheckRuntime - ) + ) { + private def foldConditionList(conds: List[Condition], + op: IR.BinaryOp): IR.Expression = { + conds + .foldLeft[Option[IR.Expression]](None) { + case (Some(expr), cond) => + Some(new IR.Binary(op, expr, getCondition(cond))) + case (None, cond) => Some(getCondition(cond)) + } + .getOrElse(throw new WeaverException("Invalid empty condition list")) + } + + def getCondition(cond: Condition): IR.Expression = cond match { + case ImmediateCondition(expr) => expr.toIR(program, method, None) + case cond: TrackedCondition => conditions(cond) + case NotCondition(value) => + new IR.Unary(IR.UnaryOp.Not, getCondition(value)) + case AndCondition(values) => foldConditionList(values, IR.BinaryOp.And) + case OrCondition(values) => foldConditionList(values, IR.BinaryOp.Or) + } + } def implementCheck( check: Check, returnValue: Option[IR.Expression], - fields: FieldCollection, + getTemporaryOwnedFields: () => IR.Expression, context: CheckContext ): Seq[IR.Op] = { check match { @@ -462,14 +517,18 @@ object Checker { implementAccCheck( acc, returnValue, - fields, + FieldCollection( + () => context.permissions.requirePermissions, + getTemporaryOwnedFields), context ) case pc: PredicatePermissionCheck => implementPredicateCheck( pc, returnValue, - fields, + FieldCollection( + () => context.permissions.requirePermissions, + getTemporaryOwnedFields), context ) case expr: CheckExpression => @@ -486,9 +545,7 @@ object Checker { cond: Option[IR.Expression], checks: List[Check], returnValue: Option[IR.Expression], - getPrimaryOwnedFields: () => IR.Var, getTemporaryOwnedFields: () => IR.Var, - instanceCounter: IR.Expression, context: CheckContext ): Seq[IR.Op] = { // Collect all the ops for the check @@ -497,7 +554,7 @@ object Checker { implementCheck( _, returnValue, - FieldCollection(getPrimaryOwnedFields, getTemporaryOwnedFields), + getTemporaryOwnedFields, context ) ) @@ -513,9 +570,10 @@ object Checker { } } - def groupChecks(items: List[RuntimeCheck]) + def groupChecks(items: Seq[RuntimeCheck]) : List[(Location, List[(Option[Condition], List[RuntimeCheck])])] = { items + .toList .groupBy(_.location) .toList .map { diff --git a/src/main/scala/gvc/weaver/Checks.scala b/src/main/scala/gvc/weaver/Checks.scala index 1a37c362..d685126c 100644 --- a/src/main/scala/gvc/weaver/Checks.scala +++ b/src/main/scala/gvc/weaver/Checks.scala @@ -5,14 +5,10 @@ import gvc.transformer.{IR, IRSilver} sealed trait Check object Check { - def fromViper( - check: vpr.Exp, - program: IR.Program, - method: IR.Method - ): Check = { + def fromViper(check: vpr.Exp): Check = { check match { case fieldAccess: vpr.FieldAccessPredicate => - CheckExpression.fromViper(fieldAccess.loc, method) match { + CheckExpression.fromViper(fieldAccess.loc) match { case field: CheckExpression.Field => FieldAccessibilityCheck(field) case _ => throw new WeaverException( @@ -24,15 +20,15 @@ object Check { PredicateAccessibilityCheck( predicate.predicateName, predicate.args - .map(CheckExpression.fromViper(_, method)) + .map(CheckExpression.fromViper) .toList ) case predicateAccess: vpr.PredicateAccessPredicate => - Check.fromViper(predicateAccess.loc, program, method) + Check.fromViper(predicateAccess.loc) case _ => - CheckExpression.fromViper(check, method) + CheckExpression.fromViper(check) } } } @@ -53,19 +49,39 @@ sealed trait PredicatePermissionCheck extends PermissionCheck { case class FieldSeparationCheck(field: CheckExpression.Field) extends FieldPermissionCheck with SeparationCheck +{ + override def toString(): String = s"sep($field)" +} + case class FieldAccessibilityCheck(field: CheckExpression.Field) extends FieldPermissionCheck with AccessibilityCheck +{ + override def toString() = s"acc($field)" +} + case class PredicateSeparationCheck( predicateName: String, arguments: List[CheckExpression] ) extends PredicatePermissionCheck with SeparationCheck +{ + override def toString() = { + val args = arguments.map(_.toString()).mkString(", ") + s"sep($predicateName($args))" + } +} case class PredicateAccessibilityCheck( predicateName: String, arguments: List[CheckExpression] ) extends PredicatePermissionCheck with AccessibilityCheck +{ + override def toString() = { + val args = arguments.map(_.toString()).mkString(", ") + s"$predicateName($args)" + } +} sealed trait CheckExpression extends Check { def toIR( @@ -104,6 +120,8 @@ object CheckExpression { new IR.Binary(op, left.toIR(p, m, r), right.toIR(p, m, r)) def guard = and(left.guard, right.guard) + + override def toString() = s"($left) $op ($right)" } case class And(left: Expr, right: Expr) extends Binary { @@ -154,6 +172,8 @@ object CheckExpression { ): IR.Unary = new IR.Unary(op, operand.toIR(p, m, r)) def guard = operand.guard + + override def toString() = s"$op($operand)" } case class Not(operand: Expr) extends Unary { def op = IR.UnaryOp.Not @@ -167,6 +187,7 @@ object CheckExpression { m.variable(name) } def guard = None + override def toString() = name } case class Field(root: Expr, structName: String, fieldName: String) @@ -183,12 +204,16 @@ object CheckExpression { new IR.FieldMember(root.toIR(p, m, r), getIRField(p)) def guard = Some(and(root.guard, Not(Eq(root, NullLit)))) + + override def toString() = s"$root.$fieldName" } case class Deref(operand: Expr) extends Expr { def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) = new IR.DereferenceMember(operand.toIR(p, m, r)) def guard = Some(and(operand.guard, Not(Eq(operand, NullLit)))) + + override def toString() = s"*($operand)" } sealed trait Literal extends Expr { @@ -198,23 +223,28 @@ object CheckExpression { case class IntLit(value: Int) extends Literal { def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) = new IR.IntLit(value) + override def toString() = value.toString() } case class CharLit(value: Char) extends Literal { def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) = new IR.CharLit(value) + override def toString() = s"'$value'" } case class StrLit(value: String) extends Literal { def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) = new IR.StringLit(value) + override def toString() = "\"" + value + "\"" } case object NullLit extends Literal { def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) = new IR.NullLit() + override def toString() = "NULL" } sealed trait BoolLit extends Literal { def value: Boolean def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) = new IR.BoolLit(value) + override def toString() = value.toString() } object BoolLit { def apply(value: Boolean): BoolLit = if (value) TrueLit else FalseLit @@ -249,6 +279,8 @@ object CheckExpression { // Either the true path is taken, or the false guard is satisifed case (None, Some(fg)) => Some(Or(cond, fg)) }) + + override def toString() = s"($cond) ? ($ifTrue) : ($ifFalse)" } case object Result extends Expr { @@ -261,6 +293,7 @@ object CheckExpression { throw new WeaverException("Invalid \\result expression") ) def guard = None + override def toString() = "\\result" } def irValue(value: IR.Expression): Expr = { @@ -307,35 +340,33 @@ object CheckExpression { } def fromViper( - value: vpr.Exp, - method: IR.Method + value: vpr.Exp ): Expr = { - def expr(e: vpr.Exp) = fromViper(e, method) value match { - case eq: vpr.EqCmp => Eq(expr(eq.left), expr(eq.right)) - case ne: vpr.NeCmp => Not(Eq(expr(ne.left), expr(ne.right))) - case lt: vpr.LtCmp => Lt(expr(lt.left), expr(lt.right)) - case lte: vpr.LeCmp => LtEq(expr(lte.left), expr(lte.right)) - case gt: vpr.GtCmp => Gt(expr(gt.left), expr(gt.right)) - case gte: vpr.GeCmp => GtEq(expr(gte.left), expr(gte.right)) - - case and: vpr.And => And(expr(and.left), expr(and.right)) - case or: vpr.Or => Or(expr(or.left), expr(or.right)) - - case add: vpr.Add => Add(expr(add.left), expr(add.right)) - case sub: vpr.Sub => Sub(expr(sub.left), expr(sub.right)) - case mul: vpr.Mul => Mul(expr(mul.left), expr(mul.right)) - case div: vpr.Div => Div(expr(div.left), expr(div.right)) - - case minus: vpr.Minus => Neg(expr(minus.exp)) + case eq: vpr.EqCmp => Eq(fromViper(eq.left), fromViper(eq.right)) + case ne: vpr.NeCmp => Not(Eq(fromViper(ne.left), fromViper(ne.right))) + case lt: vpr.LtCmp => Lt(fromViper(lt.left), fromViper(lt.right)) + case lte: vpr.LeCmp => LtEq(fromViper(lte.left), fromViper(lte.right)) + case gt: vpr.GtCmp => Gt(fromViper(gt.left), fromViper(gt.right)) + case gte: vpr.GeCmp => GtEq(fromViper(gte.left), fromViper(gte.right)) + + case and: vpr.And => And(fromViper(and.left), fromViper(and.right)) + case or: vpr.Or => Or(fromViper(or.left), fromViper(or.right)) + + case add: vpr.Add => Add(fromViper(add.left), fromViper(add.right)) + case sub: vpr.Sub => Sub(fromViper(sub.left), fromViper(sub.right)) + case mul: vpr.Mul => Mul(fromViper(mul.left), fromViper(mul.right)) + case div: vpr.Div => Div(fromViper(div.left), fromViper(div.right)) + + case minus: vpr.Minus => Neg(fromViper(minus.exp)) case not: vpr.Not => - expr(not.exp) match { + fromViper(not.exp) match { case Not(f) => f case x => Not(x) } case access: vpr.FieldAccess => { - val root = expr(access.rcv) + val root = fromViper(access.rcv) access.field.name match { case field => { val segments = field.split('.') @@ -378,3 +409,51 @@ object CheckExpression { } } } + +sealed trait Location + +sealed trait AtOp extends Location { val op: IR.Op } +case class Pre(override val op: IR.Op) extends AtOp { + override def toString() = "PRE:" + op.summary +} +case class Post(override val op: IR.Op) extends AtOp { + override def toString() = "POST:" + op.summary +} +case class LoopStart(override val op: IR.Op) extends AtOp { + override def toString() = "START:" + op.summary +} +case class LoopEnd(override val op: IR.Op) extends AtOp { + override def toString() = "END:" + op.summary +} +case object MethodPre extends Location { + override def toString() = "requires" +} +case object MethodPost extends Location { + override def toString() = "ensures" +} + +sealed trait Condition +case class NotCondition(value: Condition) extends Condition +case class AndCondition(values: List[Condition]) extends Condition +case class OrCondition(values: List[Condition]) extends Condition +case class ImmediateCondition(value: CheckExpression) extends Condition +case class TrackedCondition( + location: Location, + value: CheckExpression +) extends Condition + +case class RuntimeCheck( + location: Location, + check: Check, + when: Option[Condition] +) + +object RuntimeCheck { + def dump(checks: Seq[RuntimeCheck]) = { + System.out.println( + checks + .map(c => c.location.toString() + "\n" + c.check.toString()) + .mkString("\n\n") + ) + } +} diff --git a/src/main/scala/gvc/weaver/Collector.scala b/src/main/scala/gvc/weaver/Collector.scala index c3043ec1..195b66ec 100644 --- a/src/main/scala/gvc/weaver/Collector.scala +++ b/src/main/scala/gvc/weaver/Collector.scala @@ -1,76 +1,25 @@ package gvc.weaver -import gvc.transformer.IR.Predicate - import scala.collection.mutable import gvc.transformer.IR -import viper.silver.ast.MethodCall import viper.silver.{ast => vpr} -import viper.silicon.state.CheckPosition -import viper.silicon.state.LoopPosition -import viper.silicon.state.BranchCond object Collector { - sealed trait Location - sealed trait AtOp extends Location { val op: IR.Op } - case class Pre(override val op: IR.Op) extends AtOp - case class Post(override val op: IR.Op) extends AtOp - case class LoopStart(override val op: IR.Op) extends AtOp - case class LoopEnd(override val op: IR.Op) extends AtOp - case object MethodPre extends Location - case object MethodPost extends Location - - sealed trait Condition - case class NotCondition(value: Condition) extends Condition - case class AndCondition(values: List[Condition]) extends Condition - case class OrCondition(values: List[Condition]) extends Condition - case class ImmediateCondition(value: CheckExpression) extends Condition - case class TrackedCondition( - location: Location, - value: CheckExpression - ) extends Condition - - case class CheckInfo( - check: Check, - when: Option[Condition] - ) - case class RuntimeCheck( - location: Location, - check: Check, - when: Option[Condition] - ) - - sealed trait CallStyle - case object PreciseCallStyle extends CallStyle - case object PrecisePreCallStyle extends CallStyle - case object ImpreciseCallStyle extends CallStyle - case object MainCallStyle extends CallStyle - - class CollectedMethod( - val method: IR.Method, - val conditions: List[TrackedCondition], - val checks: List[RuntimeCheck], - val returns: List[IR.Return], - val hasImplicitReturn: Boolean, - val calls: List[CollectedInvocation], - val allocations: List[IR.Op], - val callStyle: CallStyle, - val bodyContainsImprecision: Boolean, - val checkedSpecificationLocations: Set[Location] - ) + class CollectedChecks(val method: IR.Method) { + val conditions = mutable.LinkedHashSet[TrackedCondition]() + val checks = mutable.ListBuffer[RuntimeCheck]() + } class CollectedProgram( val program: IR.Program, - val methods: Map[String, CollectedMethod] + val methods: Map[String, CollectedChecks] ) - case class CollectedInvocation(ir: IR.Invoke, vpr: MethodCall) - def collect( irProgram: IR.Program, vprProgram: vpr.Program ): CollectedProgram = { - val checks = collectChecks(vprProgram) + val checks = ViperChecks.collect(vprProgram) val methods = irProgram.methods .map( @@ -93,160 +42,6 @@ object Collector { ) } - private class ConditionTerm(val id: Int) { - val conditions = mutable.Set[Logic.Conjunction]() - } - - private sealed trait ViperLocation - private object ViperLocation { - case object Value extends ViperLocation - case object PreInvoke extends ViperLocation - case object PostInvoke extends ViperLocation - case object PreLoop extends ViperLocation - case object PostLoop extends ViperLocation - case object Fold extends ViperLocation - case object Unfold extends ViperLocation - case object InvariantLoopStart extends ViperLocation - case object InvariantLoopEnd extends ViperLocation - - def loop(loopPosition: LoopPosition): ViperLocation = loopPosition match { - case LoopPosition.After => ViperLocation.PostLoop - case LoopPosition.Before => ViperLocation.PreLoop - case LoopPosition.Beginning => ViperLocation.InvariantLoopStart - case LoopPosition.End => ViperLocation.InvariantLoopEnd - } - - def forIR(irLocation: Location, vprLocation: ViperLocation): Location = - irLocation match { - case at: AtOp => - vprLocation match { - case ViperLocation.PreInvoke | ViperLocation.PreLoop | - ViperLocation.Fold | ViperLocation.Unfold | - ViperLocation.Value => - Pre(at.op) - case ViperLocation.PostInvoke | ViperLocation.PostLoop => - Post(at.op) - case ViperLocation.InvariantLoopStart => LoopStart(at.op) - case ViperLocation.InvariantLoopEnd => LoopEnd(at.op) - } - case _ => { - if (vprLocation != ViperLocation.Value) - throw new WeaverException("Invalid location") - irLocation - } - } - } - - private case class ViperBranch( - at: vpr.Node, - location: ViperLocation, - condition: vpr.Exp - ) - - private object ViperBranch { - def apply( - branch: BranchCond, - program: vpr.Program - ) = branch match { - case BranchCond( - condition, - position, - Some(CheckPosition.GenericNode(invoke: vpr.MethodCall)) - ) => { - // This must be a method pre-condition or post-condition - val callee = program.findMethod(invoke.methodName) - - val location: ViperLocation = - if (isContained(position, callee.posts)) ViperLocation.PostInvoke - else if (isContained(position, callee.pres)) ViperLocation.PreInvoke - else ViperLocation.Value - new ViperBranch(invoke, location, condition) - } - - case BranchCond( - condition, - position, - Some(CheckPosition.GenericNode(unfold: vpr.Unfold)) - ) => - new ViperBranch(unfold, ViperLocation.Fold, condition) - case BranchCond( - condition, - position, - Some(CheckPosition.GenericNode(unfold: vpr.Fold)) - ) => - new ViperBranch(unfold, ViperLocation.Unfold, condition) - - case BranchCond( - condition, - _, - Some(CheckPosition.Loop(inv, position)) - ) => { - // This must be an invariant - if (inv.isEmpty || !inv.tail.isEmpty) - throw new WeaverException("Invalid loop invariant") - - new ViperBranch(inv.head, ViperLocation.loop(position), condition) - } - - case BranchCond(condition, position, None) => { - new ViperBranch(position, ViperLocation.Value, condition) - } - - case _ => throw new WeaverException("Invalid branch condition") - } - } - - private case class ViperCheck( - check: vpr.Exp, - conditions: List[ViperBranch], - location: ViperLocation, - context: vpr.Exp - ) - - private type ViperCheckMap = - mutable.HashMap[Int, mutable.ListBuffer[ViperCheck]] - - // Convert the verifier's check map into a ViperCheckMap - private def collectChecks(vprProgram: vpr.Program): ViperCheckMap = { - val vprChecks = viper.silicon.state.runtimeChecks.getChecks - val collected = new ViperCheckMap() - - for ((pos, checks) <- vprChecks) { - val (node, location) = pos match { - case CheckPosition.GenericNode(node) => (node, ViperLocation.Value) - case CheckPosition.Loop(invariants, position) => { - if (invariants.tail.nonEmpty) - throw new WeaverException("Invalid loop invariant") - (invariants.head, ViperLocation.loop(position)) - } - } - - val list = - collected.getOrElseUpdate(node.uniqueIdentifier, mutable.ListBuffer()) - for (c <- checks) { - val conditions = c.branchInfo.map(ViperBranch(_, vprProgram)).toList - list += ViperCheck(c.checks, conditions, location, c.context) - } - } - - collected - } - - private def isContained(node: vpr.Node, container: vpr.Node): Boolean = { - container.visit { - case n => { - if (n.uniqueIdentifier == node.uniqueIdentifier) { - return true - } - } - } - - false - } - - private def isContained(node: vpr.Node, containers: Seq[vpr.Node]): Boolean = - containers.exists(isContained(node, _)) - private def unwrap(expr: CheckExpression, value: Boolean = true): (CheckExpression, Boolean) = { expr match { @@ -260,36 +55,15 @@ object Collector { vprProgram: vpr.Program, irMethod: IR.Method, vprMethod: vpr.Method, - vprChecks: ViperCheckMap - ): CollectedMethod = { + vprChecks: ViperChecks.CheckMap + ): CollectedChecks = { // A mapping of Viper node IDs to the corresponding IR op. // This is used for locating the correct insertion of conditionals. val locations = mutable.Map[Int, Location]() - // A list of `return` statements in the IR method, used for inserting any runtime checks that - // the postcondition may require. - val exits = mutable.ListBuffer[IR.Return]() - // A list of invocations and allocations, used for inserting permission tracking - val invokes = mutable.ListBuffer[CollectedInvocation]() - val allocations = mutable.ListBuffer[IR.Op]() - - // The collection of conditions that are used in runtime checks - val trackedConditions = mutable.LinkedHashSet[TrackedCondition]() - - // The collection of runtime checks that are required, mapping a runtime check to the list of - // conjuncts where one conjunct must be true in order for the runtime check to occur. - // Note: Uses a List as a Map so that the order is preserved in the way that the verifier - // determines (this is important for acc checks of a nested field, for example). - val checks = - mutable.Map[Location, - mutable.ListBuffer[ - (Check, mutable.ListBuffer[Option[Condition]]) - ]]() - - // A set of all locations that need the full specification walked to verify separation. Used - // to implement the semantics of the separating conjunction. Pre-calculates a set so that the - // same location is not checked twice. - val needsFullPermissionChecking = mutable.Set[Location]() + val collected = new CollectedChecks(irMethod) + + val separationLocations = mutable.HashSet[Location]() // Indexing adds the node to the mapping of Viper locations to IR locations def index(node: vpr.Node, location: Location): Unit = @@ -308,16 +82,30 @@ object Collector { methodCall: Option[vpr.Method], loopInvs: List[vpr.Exp] ): Unit = { - for (vprCheck <- vprChecks.get(node.uniqueIdentifier).toSeq.flatten) { - val (checkLocation, inSpec) = loc match { + vprChecks.get(node.uniqueIdentifier) match { + case None => () + case Some(checks) => + for (c <- checks) + collected.checks += convertCheck(c, node, loc, methodCall, loopInvs) + } + } + + def convertCheck( + vprCheck: ViperCheck, + node: vpr.Node, + loc: Location, + methodCall: Option[vpr.Method], + loopInvs: List[vpr.Exp] + ): RuntimeCheck = { + val (checkLocation, inSpec) = loc match { case at: AtOp => vprCheck.location match { case ViperLocation.Value => methodCall match { case Some(method) => - if (isContained(vprCheck.context, method.posts)) + if (ViperChecks.isContained(vprCheck.context, method.posts)) (Post(at.op), true) - else if (isContained(vprCheck.context, method.pres)) + else if (ViperChecks.isContained(vprCheck.context, method.pres)) (Pre(at.op), true) else (Pre(at.op), false) @@ -340,35 +128,22 @@ object Collector { } } - val condition = - branchCondition(checkLocation, vprCheck.conditions, loopInvs) - - // TODO: Split apart ANDed checks? - val check = Check.fromViper(vprCheck.check, irProgram, irMethod) - - val locationChecks = - checks.getOrElseUpdate(checkLocation, mutable.ListBuffer()) - val conditions = locationChecks.find { - case (c, _) => - c == check - } match { - case Some((_, conditions)) => conditions - case None => - val conditions = mutable.ListBuffer[Option[Condition]]() - locationChecks += (check -> conditions) - conditions + val check = Check.fromViper(vprCheck.check) + if (inSpec) { + check match { + case _: AccessibilityCheck => separationLocations += loc + case _ => () + } } - conditions += condition - - if (check.isInstanceOf[ - AccessibilityCheck - ] && inSpec) { - needsFullPermissionChecking += checkLocation - } - } + RuntimeCheck( + loc, + check, + branchCondition(checkLocation, vprCheck.conditions, loopInvs) + ) } + // Recursively collects all runtime checks def checkAll( node: vpr.Node, @@ -420,27 +195,22 @@ object Collector { irBlock: IR.Block, vprBlock: vpr.Seqn, loopInvs: List[vpr.Exp] - ): Boolean = { - var containsImprecision = false + ): Unit = { var vprOps = vprBlock.ss.toList for (irOp <- irBlock) { vprOps = (irOp, vprOps) match { case (irIf: IR.If, (vprIf: vpr.If) :: vprRest) => { visit(irIf, vprIf, loopInvs) - containsImprecision = visitBlock(irIf.ifTrue, vprIf.thn, loopInvs) || containsImprecision - containsImprecision = visitBlock(irIf.ifFalse, vprIf.els, loopInvs) || containsImprecision + visitBlock(irIf.ifTrue, vprIf.thn, loopInvs) + visitBlock(irIf.ifFalse, vprIf.els, loopInvs) vprRest } case (irWhile: IR.While, (vprWhile: vpr.While) :: vprRest) => { visit(irWhile, vprWhile, loopInvs) // Supports only a single invariant - containsImprecision = containsImprecision || isImprecise( - Some(irWhile.invariant)) val newInvs = vprWhile.invs.headOption.map(_ :: loopInvs).getOrElse(loopInvs) - containsImprecision = visitBlock(irWhile.body, - vprWhile.body, - newInvs) || containsImprecision + visitBlock(irWhile.body, vprWhile.body, newInvs) // Check invariants after loop body since they may depend on conditions from body vprWhile.invs.foreach { i => @@ -450,17 +220,14 @@ object Collector { vprRest } case (irInvoke: IR.Invoke, (vprInvoke: vpr.MethodCall) :: vprRest) => { - invokes += CollectedInvocation(irInvoke, vprInvoke) visit(irInvoke, vprInvoke, loopInvs) vprRest } case (irAlloc: IR.AllocValue, (vprAlloc: vpr.NewStmt) :: vprRest) => { - allocations += irAlloc visit(irAlloc, vprAlloc, loopInvs) vprRest } case (irAlloc: IR.AllocStruct, (vprAlloc: vpr.NewStmt) :: vprRest) => { - allocations += irAlloc visit(irAlloc, vprAlloc, loopInvs) vprRest } @@ -487,16 +254,10 @@ object Collector { vprRest } case (irFold: IR.Fold, (vprFold: vpr.Fold) :: vprRest) => { - containsImprecision = containsImprecision || isImprecise( - Some(irFold.instance.predicate.expression), - mutable.Set(irFold.instance.predicate)) visit(irFold, vprFold, loopInvs) vprRest } case (irUnfold: IR.Unfold, (vprUnfold: vpr.Unfold) :: vprRest) => { - containsImprecision = containsImprecision || isImprecise( - Some(irUnfold.instance.predicate.expression), - mutable.Set(irUnfold.instance.predicate)) visit(irUnfold, vprUnfold, loopInvs) vprRest } @@ -505,13 +266,11 @@ object Collector { vprRest } case (irReturn: IR.Return, vprRest) if irReturn.value.isEmpty => { - exits += irReturn vprRest } case (irReturn: IR.Return, (vprReturn: vpr.LocalVarAssign) :: vprRest) if irReturn.value.isDefined => { visit(irReturn, vprReturn, loopInvs) - exits += irReturn vprRest } @@ -530,7 +289,6 @@ object Collector { s"Unexpected Silver statement: ${vprOps.head}" ) } - containsImprecision } def normalizeLocation(loc: Location): Location = loc match { @@ -569,7 +327,7 @@ object Collector { // Special case for when the verifier uses positions tagged as the beginning of the loop // outside of the loop body. In this case, just use the after loop position. case ViperLocation.InvariantLoopStart - if !isContained(b.at, loopInvs) => + if !ViperChecks.isContained(b.at, loopInvs) => ViperLocation.PostLoop case p => p } @@ -577,15 +335,15 @@ object Collector { val conditionLocation = normalizeLocation(ViperLocation.forIR(irLoc, position)) val (expr, flag) = - unwrap(CheckExpression.fromViper(b.condition, irMethod)) + unwrap(CheckExpression.fromViper(b.condition)) val unwrappedCondition: Condition = if (conditionLocation == normalizeLocation(location)) { ImmediateCondition(expr) } else { - val tracked = TrackedCondition(conditionLocation, expr.guarded) - trackedConditions += tracked - tracked + val cond = TrackedCondition(conditionLocation, expr.guarded) + collected.conditions += cond + cond } val cond = @@ -604,345 +362,16 @@ object Collector { vprMethod.pres.foreach(checkAll(_, MethodPre, None, Nil)) // Loop through each operation and collect checks - val bodyContainsImprecision = - visitBlock(irMethod.body, vprMethod.body.get, Nil) + visitBlock(irMethod.body, vprMethod.body.get, Nil) // Index post-conditions and add required runtime checks vprMethod.posts.foreach(indexAll(_, MethodPost)) vprMethod.posts.foreach(checkAll(_, MethodPost, None, Nil)) - // Check if execution can fall-through to the end of the method - // It is valid to only check the last operation since we don't allow early returns - val implicitReturn: Boolean = hasImplicitReturn(irMethod) - - // Get all checks (grouped by their location) and simplify their conditions - val collectedChecks = mutable.ListBuffer[RuntimeCheck]() - for ((loc, locChecks) <- checks) - for ((check, conditions) <- locChecks) { - val condition = - if (conditions.isEmpty || conditions.contains(None)) { - None - } else if (conditions.size == 1) { - conditions.head - } else { - Some(OrCondition(conditions.map(_.get).toList)) - } - - collectedChecks += RuntimeCheck(loc, check, condition) - } - - // Traverse the specifications for statements that require full permission checks - for (location <- needsFullPermissionChecking) { - val (spec, arguments) = location match { - case at: AtOp => - at.op match { - case op: IR.Invoke => - op.callee match { - case callee: IR.Method if callee.precondition.isDefined => - ( - callee.precondition.get, - Some( - op.callee.parameters - .zip(op.arguments.map(resolveValue(_))) - .toMap - ) - ) - case _ => - throw new WeaverException( - s"Could not locate specification at invoke: $location") - } - // TODO: Do we need unfold? - case op: IR.Fold => - ( - op.instance.predicate.expression, - Some( - op.instance.predicate.parameters - .zip(op.instance.arguments.map(resolveValue(_))) - .toMap - ) - ) - case op: IR.While => (op.invariant, None) - case op: IR.Assert => (op.value, None) - case _ => - throw new WeaverException( - "Could not locate specification for permission checking: " + location - .toString() - ) - } - case MethodPost if irMethod.postcondition.isDefined => - (irMethod.postcondition.get, None) - case _ => - throw new WeaverException( - "Could not locate specification for permission checking: " + location - .toString() - ) - } - - val separationChecks = - traversePermissions(spec, arguments, None, Separation).map(info => - RuntimeCheck(location, info.check, info.when)) - - // Since the checks are for separation, only include them if there is more than one - // otherwise, there can be no overlap - val needsSeparationCheck = - separationChecks.length > 1 || - separationChecks.length == 1 && !separationChecks.head.check - .isInstanceOf[FieldSeparationCheck] - if (needsSeparationCheck) { - collectedChecks ++= separationChecks - } - } - - // Wrap up all the results - new CollectedMethod( - method = irMethod, - conditions = trackedConditions.toList, - checks = collectedChecks.toList, - returns = exits.toList, - hasImplicitReturn = implicitReturn, - calls = invokes.toList, - allocations = allocations.toList, - callStyle = getCallstyle(irMethod), - bodyContainsImprecision = bodyContainsImprecision, - checkedSpecificationLocations = needsFullPermissionChecking.toSet - ) - } - // TODO: Factor this out - def traversePermissions( - spec: IR.Expression, - arguments: Option[Map[IR.Parameter, CheckExpression]], - condition: Option[CheckExpression], - checkType: CheckType - ): Seq[CheckInfo] = spec match { - // Imprecise expressions just needs the precise part checked. - // TODO: This should also enable framing checks. - case imp: IR.Imprecise => { - imp.precise.toSeq.flatMap( - traversePermissions(_, arguments, condition, checkType) - ) - } - - // And expressions just traverses both parts - case and: IR.Binary if and.operator == IR.BinaryOp.And => { - val left = traversePermissions(and.left, arguments, condition, checkType) - val right = - traversePermissions(and.right, arguments, condition, checkType) - left ++ right - } + // Add separation checks + SeparationChecks.inject(separationLocations, irMethod, collected.checks) - // A condition expression traverses each side with its respective condition, joined with the - // existing condition if provided to support nested conditionals. - case cond: IR.Conditional => { - val baseCond = resolveValue(cond.condition, arguments) - val negCond = CheckExpression.Not(baseCond) - val (trueCond, falseCond) = condition match { - case None => (baseCond, negCond) - case Some(otherCond) => - ( - CheckExpression.And(otherCond, baseCond), - CheckExpression.And(otherCond, negCond) - ) - } - - val truePerms = - traversePermissions(cond.ifTrue, arguments, Some(trueCond), checkType) - val falsePerms = traversePermissions( - cond.ifFalse, - arguments, - Some(falseCond), - checkType - ) - truePerms ++ falsePerms - } - - // A single accessibility check - case acc: IR.Accessibility => { - val field = resolveValue(acc.member, arguments) match { - case f: CheckExpression.Field => f - case invalid => - throw new WeaverException(s"Invalid acc() argument: '$invalid'") - } - - checkType match { - case Separation => - Seq( - CheckInfo( - FieldSeparationCheck(field), - condition.map(ImmediateCondition) - ) - ) - case Verification => - Seq( - CheckInfo( - FieldAccessibilityCheck(field), - condition.map(ImmediateCondition) - ) - ) - } - - } - case pred: IR.PredicateInstance => { - checkType match { - case Separation => - Seq( - CheckInfo( - PredicateSeparationCheck( - pred.predicate.name, - pred.arguments.map(resolveValue(_, arguments)) - ), - condition.map(ImmediateCondition) - ) - ) - case Verification => - Seq( - CheckInfo( - PredicateAccessibilityCheck( - pred.predicate.name, - pred.arguments.map(resolveValue(_, arguments)) - ), - condition.map(ImmediateCondition) - ) - ) - } - - } - case _ => { - // Otherwise there can be no permission specifiers in this term or its children - Seq.empty - } - } - - def hasImplicitReturn(method: IR.Method): Boolean = - method.body.lastOption match { - case None => true - case Some(tailOp) => hasImplicitReturn(tailOp) - } - - // Checks if execution can fall-through a given Op - def hasImplicitReturn(tailOp: IR.Op): Boolean = tailOp match { - case r: IR.Return => false - case _: IR.While => true - case iff: IR.If => - (iff.ifTrue.lastOption, iff.ifFalse.lastOption) match { - case (Some(t), Some(f)) => hasImplicitReturn(t) || hasImplicitReturn(f) - case _ => true - } - case _ => true - } - - def isImprecise( - cond: Option[IR.Expression], - visited: mutable.Set[Predicate] = mutable.Set.empty[Predicate]): Boolean = - cond match { - case Some(expr: IR.Expression) => - expr match { - case instance: IR.PredicateInstance => - if (visited.contains(instance.predicate)) { - false - } else { - visited += instance.predicate - isImprecise(Some(instance.predicate.expression), visited) - } - case _: IR.Imprecise => true - case conditional: IR.Conditional => - isImprecise(Some(conditional.condition), visited) || isImprecise( - Some(conditional.ifTrue), - visited) || isImprecise(Some(conditional.ifFalse), visited) - case binary: IR.Binary => - isImprecise(Some(binary.left), visited) || isImprecise( - Some(binary.right), - visited) - case unary: IR.Unary => isImprecise(Some(unary.operand), visited) - case _ => false - } - case None => false - } - - def getCallstyle(irMethod: IR.Method) = - if (irMethod.name == "main") - MainCallStyle - else if (isImprecise(irMethod.precondition)) - ImpreciseCallStyle - else if (isImprecise(irMethod.postcondition)) - PrecisePreCallStyle - else PreciseCallStyle - - // Changes an expression from an IR expression into a CheckExpression. If an argument lookup - // mapping is given, it will use this mapping to resolve variables. Otherwise, it will assume - // any variables are accessible in the current scope. - def resolveValue( - input: IR.Expression, - arguments: Option[Map[IR.Parameter, CheckExpression]] = None - ): CheckExpression = { - def resolve(input: IR.Expression) = resolveValue(input, arguments) - - input match { - // These types can only be used at the "root" of a specification, not in an arbitrary - // expression - case _: IR.ArrayMember | _: IR.Imprecise | _: IR.PredicateInstance | - _: IR.Accessibility => - throw new WeaverException("Invalid specification value") - - case n: IR.Var => - arguments match { - case None => CheckExpression.Var(n.name) - case Some(arguments) => - n match { - case p: IR.Parameter => - arguments.getOrElse( - p, - throw new WeaverException(s"Unknown parameter '${p.name}'") - ) - case v => - throw new WeaverException(s"Unknown variable '${v.name}'") - } - } - - case n: IR.FieldMember => - CheckExpression.Field( - resolve(n.root), - n.field.struct.name, - n.field.name - ) - case n: IR.DereferenceMember => CheckExpression.Deref(resolve(n.root)) - case n: IR.Result => CheckExpression.Result - case n: IR.IntLit => CheckExpression.IntLit(n.value) - case n: IR.CharLit => CheckExpression.CharLit(n.value) - case n: IR.BoolLit => CheckExpression.BoolLit(n.value) - case n: IR.StringLit => CheckExpression.StrLit(n.value) - case n: IR.NullLit => CheckExpression.NullLit - case n: IR.Conditional => - CheckExpression.Cond( - resolve(n.condition), - resolve(n.ifTrue), - resolve(n.ifFalse) - ) - case n: IR.Binary => { - val l = resolve(n.left) - val r = resolve(n.right) - n.operator match { - case IR.BinaryOp.Add => CheckExpression.Add(l, r) - case IR.BinaryOp.Subtract => CheckExpression.Sub(l, r) - case IR.BinaryOp.Divide => CheckExpression.Div(l, r) - case IR.BinaryOp.Multiply => CheckExpression.Mul(l, r) - case IR.BinaryOp.And => CheckExpression.And(l, r) - case IR.BinaryOp.Or => CheckExpression.Or(l, r) - case IR.BinaryOp.Equal => CheckExpression.Eq(l, r) - case IR.BinaryOp.NotEqual => - CheckExpression.Not(CheckExpression.Eq(l, r)) - case IR.BinaryOp.Less => CheckExpression.Lt(l, r) - case IR.BinaryOp.LessOrEqual => CheckExpression.LtEq(l, r) - case IR.BinaryOp.Greater => CheckExpression.Gt(l, r) - case IR.BinaryOp.GreaterOrEqual => CheckExpression.GtEq(l, r) - } - } - case n: IR.Unary => { - val o = resolve(n.operand) - n.operator match { - case IR.UnaryOp.Not => CheckExpression.Not(o) - case IR.UnaryOp.Negate => CheckExpression.Neg(o) - } - } - } + collected } + } diff --git a/src/main/scala/gvc/weaver/Dependencies.scala b/src/main/scala/gvc/weaver/Dependencies.scala new file mode 100644 index 00000000..90147c54 --- /dev/null +++ b/src/main/scala/gvc/weaver/Dependencies.scala @@ -0,0 +1,275 @@ +package gvc.weaver + +import scala.collection.mutable +import gvc.transformer.IR + +// This helper class handles calculation of the permission requirements for a +// method; namely, whether a method requires permissions to be passed to it, and +// whether a method may produce permissions. This enables optimization in cases +// when calling an imprecise method that does not have acc runtime checks. + +sealed trait ScopeDependencies { + // The syntatic block that introduces this scope + def block: IR.Block + + // The runtime checks that may occur within this scope + def checks: Seq[RuntimeCheck] + + // The allocations that may occur within this scope + def allocations: Seq[IR.AllocStruct] + + // The calls that may occur within this scope + def calls: Seq[IR.Invoke] + + // Indicates whether this scope may require a dynamic permission set + def requiresPerms: Boolean + + // Indicates whether this scope may modify a dynamic permission set + def modifiesPerms: Boolean + + // Indicates whether this scope inherits the caller's dynamic permission + // set. For example, a method with imprecise pre-condition. + def inheritsPerms: Boolean + + // Indicates whether this scope returns a dynamic permission set. For + // example, a method with imprecise post-condition. We assume that + // `inheritsPerms` implies `returnsPerms`. + def returnsPerms: Boolean + + def children: Seq[WhileDependencies] +} + +sealed trait MethodDependencies extends ScopeDependencies { + def method: IR.Method + def precisePre: Boolean + def precisePost: Boolean + def conditions: Iterable[TrackedCondition] + + def inheritsPerms: Boolean = !precisePre + def returnsPerms: Boolean = !precisePost || !precisePre +} + +sealed trait WhileDependencies extends ScopeDependencies { + def op: IR.While + def preciseInvariant: Boolean + + def inheritsPerms: Boolean = !preciseInvariant + def returnsPerms: Boolean = !preciseInvariant +} + +sealed trait ProgramDependencies { + def program: IR.Program + def methods: Map[String, MethodDependencies] +} + +object Dependencies { + private sealed abstract class DependencyScope extends ScopeDependencies { + val calls = mutable.ListBuffer[IR.Invoke]() + val allocations = mutable.ListBuffer[IR.AllocStruct]() + val permDependencies = mutable.HashSet[String]() + val children = mutable.ListBuffer[WhileDependenciesImpl]() + var requiresPerms = false + var modifiesPerms = false + } + + private class MethodDependenciesImpl( + val method: IR.Method, + val conditions: Iterable[TrackedCondition], + val checks: Seq[RuntimeCheck], + val precisePre: Boolean, + val precisePost: Boolean + ) extends DependencyScope with MethodDependencies { + + def block: IR.Block = method.body + } + + private class WhileDependenciesImpl( + val op: IR.While, + val checks: Seq[RuntimeCheck], + val preciseInvariant: Boolean) + extends DependencyScope with WhileDependencies { + + def block: IR.Block = op.body + } + + private class ProgramDependenciesImpl( + val program: IR.Program, + val methods: Map[String, MethodDependenciesImpl] + ) extends ProgramDependencies + + def calculate(scope: ProgramScope): ProgramDependencies = { + val program = scope.program + val precision = new EquirecursivePrecision(program) + + val graph = scope.methods.map({ case (k, v) => (k, initDependencies(v, precision)) }) + val deps = new ProgramDependenciesImpl(program, graph) + + val collect = mutable.HashSet[String]() + graph.values.foreach(c => { + setRequiresPerms(c, deps, collect) + setModifiesPerms(c, deps, collect) + }) + + deps + } + + private def initDependencies( + scope: MethodScope, + precision: EquirecursivePrecision + ): MethodDependenciesImpl = { + val method = scope.method + val dep = new MethodDependenciesImpl( + method, + scope.conditions, + scope.checks, + precision.isPrecise(method.precondition), + precision.isPrecise(method.postcondition) + ) + + traverseBlock(method.body, dep) + dep.children ++= scope.children.map(initDependencies(_, precision)) + + dep.requiresPerms = requiresPerms(scope.checks) + dep.modifiesPerms = + if (dep.returnsPerms) !dep.allocations.isEmpty + else refsPerms(method.precondition) || refsPerms(method.postcondition) + + dep + } + + private def initDependencies( + scope: WhileScope, + precision: EquirecursivePrecision + ): WhileDependenciesImpl = { + val op = scope.op + val dep = new WhileDependenciesImpl( + op, + scope.checks, + precision.isPrecise(op.invariant) + ) + + traverseBlock(op.body, dep) + dep.children ++= scope.children.map(initDependencies(_, precision)) + + dep.requiresPerms = requiresPerms(scope.checks) + dep.modifiesPerms = + if (dep.preciseInvariant) refsPerms(dep.op.invariant) + else !dep.allocations.isEmpty + + dep + } + + private def traverseBlock(block: IR.Block, dep: DependencyScope) { + // Note that we do not traverse into `while` statements, because we assume + // that each `while` will have its own Scope instance + block.foreach(_ match { + case _: IR.AllocArray | _: IR.AllocValue => + throw new WeaverException("Unsupported allocation") + case alloc: IR.AllocStruct => + dep.allocations += alloc + case call: IR.Invoke => + dep.calls += call + case cond: IR.If => { + traverseBlock(cond.ifTrue, dep) + traverseBlock(cond.ifFalse, dep) + } + case _ => + () + }) + } + + private def setRequiresPerms( + scope: DependencyScope, + program: ProgramDependenciesImpl, + collect: mutable.HashSet[String] + ): Unit = { + scope.children.foreach(setRequiresPerms(_, program, collect)) + + if (!scope.requiresPerms) { + scope.requiresPerms = deepRequiresPerms(scope, program, collect) + collect.clear() + } + } + + private def setModifiesPerms( + scope: DependencyScope, + program: ProgramDependenciesImpl, + collect: mutable.HashSet[String] + ): Unit = { + scope.children.foreach(setModifiesPerms(_, program, collect)) + + if (!scope.modifiesPerms) { + scope.modifiesPerms = deepModifiesPerms(scope, program, collect) + collect.clear() + } + } + + // Given a set of methods already explored, checks whether there are any child + // scopes or method calls that require a dynamic set of permissions + private def deepRequiresPerms( + scope: DependencyScope, + program: ProgramDependenciesImpl, + collect: mutable.HashSet[String] + ): Boolean = { + // Check the current scope + scope.requiresPerms || + // Recursively check child scopes that inherit permissions + scope.children.exists(c => + c.inheritsPerms && + deepRequiresPerms(c, program, collect)) || + // Recursively check methods when they have not been visited + scope.calls.exists(c => program.methods.get(c.callee.name) match { + case None => false // Ignore external methods + case Some(m) => + m.inheritsPerms && + collect.add(m.method.name) && + deepRequiresPerms(m, program, collect) + }) + } + + // Given a set of methods already explored, checks whether there are any child + // scopes or method calls that can modify a dynamic set of permissions + private def deepModifiesPerms( + scope: DependencyScope, + program: ProgramDependenciesImpl, + collect: mutable.HashSet[String]): Boolean = { + // Check the current scope + // Assume that if the method is precise (`returnsPerms` is false), then + // `modifiesPerms` has been correctly set, since it can be determined by + // analyzing the specification without recursing + scope.modifiesPerms || (scope.returnsPerms && ( + // Recursively check child scopes that return permissions + scope.children.exists(c => + c.returnsPerms && + deepModifiesPerms(c, program, collect)) || + // Recursively check methods when they have not been visited + scope.calls.exists(c => program.methods.get(c.callee.name) match { + case None => false // Ignore external methods + case Some(m) => + collect.add(m.method.name) && deepModifiesPerms(m, program, collect) + })) + ) + } + + private def requiresPerms(checks: Seq[RuntimeCheck]): Boolean = + checks.exists(_.check match { + case _: AccessibilityCheck => true + case _ => false + }) + + private def refsPerms(spec: IR.Expression): Boolean = { + spec match { + case _: IR.Accessibility | _: IR.PredicateInstance | _: IR.Imprecise => true + case x: IR.Conditional => refsPerms(x.ifTrue) || refsPerms(x.ifFalse) + case x: IR.Binary if x.operator == IR.BinaryOp.And => + refsPerms(x.left) || refsPerms(x.right) + case _ => false + } + } + + private def refsPerms(spec: Option[IR.Expression]): Boolean = + spec match { + case None => false + case Some(e) => refsPerms(e) + } +} diff --git a/src/main/scala/gvc/weaver/EquirecursivePrecision.scala b/src/main/scala/gvc/weaver/EquirecursivePrecision.scala new file mode 100644 index 00000000..3ddd44a1 --- /dev/null +++ b/src/main/scala/gvc/weaver/EquirecursivePrecision.scala @@ -0,0 +1,34 @@ +package gvc.weaver +import gvc.transformer.IR +import scala.collection.mutable + +// Helper class for determining equi-recursive precision +class EquirecursivePrecision(program: IR.Program) { + private val predicatePrecision = mutable.HashMap[String, Boolean]() + + def isPrecise(pred: IR.Predicate): Boolean = { + predicatePrecision.get(pred.name) match { + case Some(v) => v + case None => { + predicatePrecision.update(pred.name, true) + val result = isPrecise(pred.expression) + predicatePrecision.update(pred.name, result) + result + } + } + } + + def isPrecise(e: IR.Expression): Boolean = e match { + case _: IR.Imprecise => false + case b: IR.Binary if b.operator == IR.BinaryOp.And => + isPrecise(b.left) && isPrecise(b.right) + case c: IR.Conditional => isPrecise(c.ifTrue) && isPrecise(c.ifFalse) + case p: IR.PredicateInstance => isPrecise(p.predicate) + case _ => true + } + + def isPrecise(e: Option[IR.Expression]): Boolean = e match { + case None => true + case Some(e) => isPrecise(e) + } +} \ No newline at end of file diff --git a/src/main/scala/gvc/weaver/InstanceCounter.scala b/src/main/scala/gvc/weaver/InstanceCounter.scala new file mode 100644 index 00000000..cfd8f2ff --- /dev/null +++ b/src/main/scala/gvc/weaver/InstanceCounter.scala @@ -0,0 +1,71 @@ +package gvc.weaver + +import gvc.transformer.IR +import scala.collection.mutable + +object InstanceCounter { + private val counterRef = new IR.PointerType(IR.IntType) + private val counterName = "_instanceCounter" + + def inject(methods: Seq[IR.Method], idFields: Map[String, IR.StructField]): Unit = { + val names = mutable.HashSet[String]() + for (m <- methods) { + val name = m.name + if (name != "main") + names += name + } + + methods.foreach(inject(_, names, idFields)) + } + + private def inject( + method: IR.Method, + methods: mutable.HashSet[String], + idFields: Map[String, IR.StructField] + ): Unit = { + val counter = method.name match { + case "main" => { + val counter = method.addVar(counterRef, counterName) + new IR.AllocValue(IR.IntType, counter) +=: method.body + counter + } + case _ => method.addParameter(counterRef, counterName) + } + + inject(method.body, methods, counter, idFields) + } + + private def inject( + block: IR.Block, + methods: mutable.HashSet[String], + counter: IR.Var, + idFields: Map[String, IR.StructField] + ): Unit = { + block.foreach(_ match { + case call: IR.Invoke => call.callee match { + case m: IR.Method if methods.contains(m.name) => + call.arguments :+= counter + case _ => () + } + + case cond: IR.If => { + inject(cond.ifTrue, methods, counter, idFields) + inject(cond.ifFalse, methods, counter, idFields) + } + case loop: IR.While => { + inject(loop.body, methods, counter, idFields) + } + case alloc: IR.AllocStruct => { + idFields.get(alloc.struct.name).foreach(field => { + val deref = new IR.DereferenceMember(counter) + val idField = new IR.FieldMember(alloc.target, field) + alloc.insertAfter(List( + new IR.AssignMember(idField, deref), + new IR.AssignMember(deref, new IR.Binary(IR.BinaryOp.Add, deref, new IR.IntLit(1))) + )) + }) + } + case _ => () + }) + } +} \ No newline at end of file diff --git a/src/main/scala/gvc/weaver/SeparationChecks.scala b/src/main/scala/gvc/weaver/SeparationChecks.scala new file mode 100644 index 00000000..aba4b6c6 --- /dev/null +++ b/src/main/scala/gvc/weaver/SeparationChecks.scala @@ -0,0 +1,136 @@ +package gvc.weaver + +import scala.collection.mutable +import gvc.transformer.IR + +object SeparationChecks { + def canOverlap(spec: IR.Expression): Boolean = countAccs(spec) > 1 + + // Count the number of acc'd heap locations. Optimized so that if at least 2 + // are found, it may stop counting. + private def countAccs(spec: IR.Expression): Int = spec match { + case _: IR.Accessibility => 1 + + case b: IR.Binary => { + val left = countAccs(b.left) + if (left > 1) left else left + countAccs(b.right) + } + + case c: IR.Conditional => { + val left = countAccs(c.ifTrue) + if (left > 1) left else Math.max(left, countAccs(c.ifFalse)) + } + + case i: IR.Imprecise => i.precise match { + case None => 0 + case Some(precise) => countAccs(precise) + } + + // Could optimize so that it explores predicates (would have to implement + // handling for mutually-recursive predicates). + case p: IR.PredicateInstance => 2 + + case _ => 0 + } + + def inject(locations: mutable.HashSet[Location], method: IR.Method, checks: mutable.ListBuffer[RuntimeCheck]): Unit = { + locations.foreach(loc => { + val (spec, context) = loc match { + case at: AtOp => at.op match { + case call: IR.Invoke => { + call.callee match { + case m: IR.Method if m.precondition.isDefined => + (m.precondition.get, new CallSiteContext(call)) + case _ => + throw new WeaverException("Invalid method definition") + } + } + + case f: IR.Fold => { + val p = f.instance.predicate + val params: Seq[IR.Var] = p.parameters + val args = params.zip(f.instance.arguments).toMap + (p.expression, new PredicateContext(p, args)) + } + + case w: IR.While => + (w.invariant, ValueContext) + + case a: IR.Assert => + (a.value, ValueContext) + + case op => + throw new WeaverException(s"Cannot check separation at $op") + } + + case MethodPost if method.postcondition.isDefined => { + (method.postcondition.get, IdentityContext) + } + } + + if (canOverlap(spec)) + inject(spec, loc, None, context, checks) + }) + } + + private def inject( + spec: IR.Expression, + loc: Location, + cond: Option[CheckExpression], + context: SpecificationContext, + checks: mutable.ListBuffer[RuntimeCheck] + ): Unit = { + spec match { + case acc: IR.Accessibility => + CheckExpression.irValue(context.convert(acc.member)) match { + case f: CheckExpression.Field => { + checks += RuntimeCheck( + loc, + FieldSeparationCheck(f), + cond.map(ImmediateCondition) + ) + } + case _ => throw new WeaverException("Invalid acc value") + } + + case b: IR.Binary if b.operator == IR.BinaryOp.And => { + inject(b.left, loc, cond, context, checks) + inject(b.right, loc, cond, context, checks) + } + + case c: IR.Conditional => { + val t = CheckExpression.irValue(context.convert(c.condition)) + val f = CheckExpression.Not(t) + val (trueCond, falseCond) = cond match { + case None => + (t, f) + case Some(cond) => + (CheckExpression.And(cond, t), CheckExpression.And(cond, f)) + } + + inject(c.ifTrue, loc, Some(trueCond), context, checks) + inject(c.ifFalse, loc, Some(falseCond), context, checks) + } + + case i: IR.Imprecise => + i.precise match { + case None => () + case Some(spec) => inject(spec, loc, cond, context, checks) + } + + case p: IR.PredicateInstance => { + checks += RuntimeCheck( + loc, + PredicateSeparationCheck( + p.predicate.name, + p.arguments.map(arg => + CheckExpression.irValue(context.convert(arg))) + ), + cond.map(ImmediateCondition) + ) + } + + case _ => () + } + } +} diff --git a/src/main/scala/gvc/weaver/SpecificationContext.scala b/src/main/scala/gvc/weaver/SpecificationContext.scala index b3e4a177..73960bb4 100644 --- a/src/main/scala/gvc/weaver/SpecificationContext.scala +++ b/src/main/scala/gvc/weaver/SpecificationContext.scala @@ -1,52 +1,45 @@ package gvc.weaver import gvc.transformer.IR import gvc.transformer.IR +import gvc.transformer.IR.PredicateInstance abstract class SpecificationContext { - def convertVar(source: IR.Var): IR.Expression - def convertResult: IR.Expression - - def convertFieldMember(member: IR.FieldMember): IR.FieldMember = { - new IR.FieldMember( - convertExpression(member.root), - member.field - ) - } - - def convertExpression(expr: IR.Expression): IR.Expression = { + def convert(v: IR.Var): IR.Expression + def convert(r: IR.Result): IR.Expression + + def convert(f: IR.FieldMember): IR.FieldMember = + new IR.FieldMember(convert(f.root), f.field) + def convert(d: IR.DereferenceMember): IR.DereferenceMember = + new IR.DereferenceMember(convert(d.root)) + def convert(l: IR.Literal): IR.Literal = + l + def convert(b: IR.Binary): IR.Binary = + new IR.Binary(b.operator, convert(b.left), convert(b.right)) + def convert(u: IR.Unary): IR.Unary = + new IR.Unary(u.operator, u.operand) + def convert(c: IR.Conditional): IR.Conditional = + new IR.Conditional(convert(c.condition), convert(c.ifTrue), convert(c.ifFalse)) + def convert(a: IR.Accessibility): IR.Accessibility = + throw new WeaverException("Cannot translate acc(...) to new context") + def convert(p: IR.PredicateInstance): IR.PredicateInstance = + throw new WeaverException("Cannot translate " + p.predicate.name + "(...) to new context") + def convert(a: IR.ArrayMember): IR.ArrayMember = + throw new WeaverException("Cannot translate array access to new context") + + def convert(expr: IR.Expression): IR.Expression = { expr match { - case value: IR.Var => convertVar(value) - - case fieldMember: IR.FieldMember => - convertFieldMember(fieldMember) - - case derefMember: IR.DereferenceMember => - new IR.DereferenceMember(convertExpression(derefMember.root)) - - case _: IR.Result => convertResult - - case literal: IR.Literal => literal - - case binary: IR.Binary => - new IR.Binary( - binary.operator, - convertExpression(binary.left), - convertExpression(binary.right) - ) - - case unary: IR.Unary => - new IR.Unary(unary.operator, convertExpression(unary.operand)) - - case cond: IR.Conditional => - new IR.Conditional(convertExpression(cond.condition), - convertExpression(cond.ifTrue), - convertExpression(cond.ifFalse)) - - case _: IR.Accessibility | _: IR.Imprecise | _: IR.ArrayMember | - _: IR.PredicateInstance => - throw new WeaverException( - "Invalid expression; cannot convert to new context." - ) + case v: IR.Var => convert(v) + case f: IR.FieldMember => convert(f) + case d: IR.DereferenceMember => convert(d) + case r: IR.Result => convert(r) + case l: IR.Literal => convert(l) + case b: IR.Binary => convert(b) + case u: IR.Unary => convert(u) + case c: IR.Conditional => convert(c) + case a: IR.Accessibility => convert(a) + case i: IR.Imprecise => convert(i) + case p: PredicateInstance => convert(p) + case a: IR.ArrayMember => convert(a) } } } @@ -54,18 +47,28 @@ abstract class SpecificationContext { // A context implementation that only validates that invalid expressions // like \result are not used incorrectly object ValueContext extends SpecificationContext { - def convertResult: IR.Expression = + def convert(source: IR.Result): IR.Expression = throw new WeaverException("Invalid result expression") - def convertVar(source: IR.Var): IR.Expression = source + def convert(source: IR.Var): IR.Expression = source } -class PredicateContext(pred: IR.Predicate, params: Map[IR.Var, IR.Var]) +object IdentityContext extends SpecificationContext { + def convert(source: IR.Result): IR.Expression = source + def convert(source: IR.Var): IR.Expression = source + override def convert(expr: IR.Expression): IR.Expression = expr + override def convert(f: IR.FieldMember): IR.FieldMember = f + override def convert(p: IR.PredicateInstance): IR.PredicateInstance = p + override def convert(a: IR.Accessibility): IR.Accessibility = a + override def convert(a: IR.ArrayMember): IR.ArrayMember = a +} + +class PredicateContext(pred: IR.Predicate, params: Map[IR.Var, IR.Expression]) extends SpecificationContext { - def convertResult: IR.Expression = + def convert(source: IR.Result): IR.Expression = throw new WeaverException(s"Invalid \result expression in '${pred.name}'") - def convertVar(source: IR.Var): IR.Expression = + def convert(source: IR.Var): IR.Expression = params.getOrElse( source, throw new WeaverException( @@ -74,10 +77,10 @@ class PredicateContext(pred: IR.Predicate, params: Map[IR.Var, IR.Var]) } class ReturnContext(returnValue: IR.Expression) extends SpecificationContext { - def convertVar(source: IR.Var): IR.Expression = + def convert(source: IR.Var): IR.Expression = source - def convertResult: IR.Expression = + def convert(source: IR.Result): IR.Expression = returnValue } @@ -85,41 +88,31 @@ class ReturnContext(returnValue: IR.Expression) extends SpecificationContext { // the arguments specified at a given call site // If 'NULL' is passed as a parameter, it is replaced with a temporary variable to avoid // generating runtime checks or permission tracking operations with dereferences of the form 'NULL->' -class CallSiteContext(call: IR.Invoke, caller: IR.Method) +class CallSiteContext(call: IR.Invoke) extends SpecificationContext { val variableMapping: Map[IR.Var, IR.Expression] = (call.callee.parameters zip call.arguments) - .map(pair => { - pair._2 match { - case _: IR.NullLit => - val validValueType = pair._1.valueType match { + .map(_ match { + case (param, _: IR.NullLit) => { + val validValueType = param.valueType match { case Some(value) => value case None => throw new WeaverException( - s"Couldn't resolve parameter value type for parameter ${pair._1.name} of method ${call.callee.name}") + s"Couldn't resolve parameter value type for parameter ${param.name} of method ${call.callee.name}") } - (pair._1, caller.addVar(validValueType)) - case _ => pair + (param, call.method.addVar(validValueType)) } + case pair => pair }) .toMap - def convertVar(source: IR.Var): IR.Expression = + def convert(source: IR.Var): IR.Expression = variableMapping.getOrElse( source, throw new WeaverException( s"Could not find variable '${source.name} at call site of '${call.callee.name}'" )) - def convertResult: IR.Expression = call.target.getOrElse { - call.callee.returnType match { - case Some(returnType) => - val target = caller.addVar(returnType) - call.target = Some(target) - target - case None => - throw new WeaverException( - s"Invalid \result expression for void '${call.callee.name}'") - } - } + def convert(source: IR.Result): IR.Expression = call.target.getOrElse( + throw new WeaverException("Invoke of non-void method must have a target")) } diff --git a/src/main/scala/gvc/weaver/ViperChecks.scala b/src/main/scala/gvc/weaver/ViperChecks.scala new file mode 100644 index 00000000..26e00449 --- /dev/null +++ b/src/main/scala/gvc/weaver/ViperChecks.scala @@ -0,0 +1,160 @@ +package gvc.weaver + +import scala.collection.mutable +import viper.silicon.state.{CheckPosition, LoopPosition, BranchCond} +import viper.silver.{ast => vpr} + +case class ViperCheck( + check: vpr.Exp, + conditions: List[ViperBranch], + location: ViperLocation, + context: vpr.Exp +) + +sealed trait ViperLocation +object ViperLocation { + case object Value extends ViperLocation + case object PreInvoke extends ViperLocation + case object PostInvoke extends ViperLocation + case object PreLoop extends ViperLocation + case object PostLoop extends ViperLocation + case object Fold extends ViperLocation + case object Unfold extends ViperLocation + case object InvariantLoopStart extends ViperLocation + case object InvariantLoopEnd extends ViperLocation + + def loop(loopPosition: LoopPosition): ViperLocation = loopPosition match { + case LoopPosition.After => ViperLocation.PostLoop + case LoopPosition.Before => ViperLocation.PreLoop + case LoopPosition.Beginning => ViperLocation.InvariantLoopStart + case LoopPosition.End => ViperLocation.InvariantLoopEnd + } + + def forIR(irLocation: Location, vprLocation: ViperLocation): Location = + irLocation match { + case at: AtOp => + vprLocation match { + case ViperLocation.PreInvoke | ViperLocation.PreLoop | + ViperLocation.Fold | ViperLocation.Unfold | + ViperLocation.Value => + Pre(at.op) + case ViperLocation.PostInvoke | ViperLocation.PostLoop => + Post(at.op) + case ViperLocation.InvariantLoopStart => LoopStart(at.op) + case ViperLocation.InvariantLoopEnd => LoopEnd(at.op) + } + case _ => { + if (vprLocation != ViperLocation.Value) + throw new WeaverException("Invalid location") + irLocation + } + } +} + +case class ViperBranch( + at: vpr.Node, + location: ViperLocation, + condition: vpr.Exp +) + +object ViperBranch { + def apply( + branch: BranchCond, + program: vpr.Program + ) = branch match { + case BranchCond( + condition, + position, + Some(CheckPosition.GenericNode(invoke: vpr.MethodCall)) + ) => { + // This must be a method pre-condition or post-condition + val callee = program.findMethod(invoke.methodName) + + val location: ViperLocation = + if (ViperChecks.isContained(position, callee.posts)) + ViperLocation.PostInvoke + else if (ViperChecks.isContained(position, callee.pres)) + ViperLocation.PreInvoke + else + ViperLocation.Value + new ViperBranch(invoke, location, condition) + } + + case BranchCond( + condition, + position, + Some(CheckPosition.GenericNode(unfold: vpr.Unfold)) + ) => + new ViperBranch(unfold, ViperLocation.Fold, condition) + case BranchCond( + condition, + position, + Some(CheckPosition.GenericNode(unfold: vpr.Fold)) + ) => + new ViperBranch(unfold, ViperLocation.Unfold, condition) + + case BranchCond( + condition, + _, + Some(CheckPosition.Loop(inv, position)) + ) => { + // This must be an invariant + if (inv.isEmpty || !inv.tail.isEmpty) + throw new WeaverException("Invalid loop invariant") + + new ViperBranch(inv.head, ViperLocation.loop(position), condition) + } + + case BranchCond(condition, position, None) => { + new ViperBranch(position, ViperLocation.Value, condition) + } + + case _ => throw new WeaverException("Invalid branch condition") + } +} + +object ViperChecks { + type CheckMap = + mutable.HashMap[Int, mutable.ListBuffer[ViperCheck]] + + // Convert the verifier's check map into a ViperCheckMap + def collect(vprProgram: vpr.Program): CheckMap = { + val vprChecks = viper.silicon.state.runtimeChecks.getChecks + val collected = new CheckMap() + + for ((pos, checks) <- vprChecks) { + val (node, location) = pos match { + case CheckPosition.GenericNode(node) => (node, ViperLocation.Value) + case CheckPosition.Loop(invariants, position) => { + if (invariants.tail.nonEmpty) + throw new WeaverException("Invalid loop invariant") + (invariants.head, ViperLocation.loop(position)) + } + } + + val list = + collected.getOrElseUpdate(node.uniqueIdentifier, mutable.ListBuffer()) + for (c <- checks) { + val conditions = c.branchInfo.map(ViperBranch(_, vprProgram)).toList + list += ViperCheck(c.checks, conditions, location, c.context) + } + } + + collected + } + + def isContained(node: vpr.Node, container: vpr.Node): Boolean = { + container.visit { + case n => { + if (n.uniqueIdentifier == node.uniqueIdentifier) { + return true + } + } + } + + false + } + + def isContained(node: vpr.Node, containers: Seq[vpr.Node]): Boolean = + containers.exists(isContained(node, _)) +} diff --git a/src/main/scala/gvc/weaver/Weaver.scala b/src/main/scala/gvc/weaver/Weaver.scala index 43943811..2b2218a7 100644 --- a/src/main/scala/gvc/weaver/Weaver.scala +++ b/src/main/scala/gvc/weaver/Weaver.scala @@ -6,6 +6,17 @@ class WeaverException(message: String) extends Exception(message) object Weaver { def weave(ir: IR.Program, silver: vpr.Program): Unit = { - Checker.insert(Collector.collect(ir, silver)) + val collected = Collector.collect(ir, silver) + /* + // Dump collected checks (uncomment for debugging) + for ((k, v) <- collected.methods) { + System.out.println(k + "\n--------------------\n") + RuntimeCheck.dump(v.checks) + System.out.print("\n\n") + } + */ + val scoped = CheckScope.scope(collected) + val deps = Dependencies.calculate(scoped) + Checker.insert(deps) } } diff --git a/src/test/resources/baseline/acc.baseline.c0 b/src/test/resources/baseline/acc.baseline.c0 index eec5d629..93e3fcec 100644 --- a/src/test/resources/baseline/acc.baseline.c0 +++ b/src/test/resources/baseline/acc.baseline.c0 @@ -7,35 +7,33 @@ struct Test int _id; }; -struct Test* createTest(struct OwnedFields* _ownedFields); -int getValue(struct Test* test, struct OwnedFields* _ownedFields); +struct Test* createTest(struct OwnedFields* _ownedFields, int* _instanceCounter); +int getValue(struct Test* test, struct OwnedFields* _ownedFields, int* _instanceCounter); int main(); -struct Test* createTest(struct OwnedFields* _ownedFields) +struct Test* createTest(struct OwnedFields* _ownedFields, int* _instanceCounter) { struct Test* _ = NULL; struct OwnedFields* _tempFields = NULL; - struct OwnedFields* _contextFields = NULL; - _contextFields = initOwnedFields(_ownedFields->instanceCounter); + _tempFields = runtime_init(); assert(true); _ = alloc(struct Test); - _->_id = addStructAcc(_contextFields, 1); - assertAcc(_contextFields, _ != NULL ? _->_id : -1, 0, "Field access runtime check failed for struct Test.value"); - _tempFields = initOwnedFields(_contextFields->instanceCounter); - addAccEnsureSeparate(_tempFields, _ != NULL ? _->_id : -1, 0, 1, "Overlapping field permissions for struct Test.value"); - addAcc(_ownedFields, _->_id, 1, 0); + _->_id = *_instanceCounter; + *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_tempFields, _->_id, 1); + runtime_assert(_tempFields, _ == NULL ? -1 : _->_id, 0, "No permission to access '_->value'"); + runtime_add(_ownedFields, _ == NULL ? -1 : _->_id, 1, 0, "Invalid aliasing - '_->value' overlaps with existing permission"); return _; } -int getValue(struct Test* test, struct OwnedFields* _ownedFields) +int getValue(struct Test* test, struct OwnedFields* _ownedFields, int* _instanceCounter) { struct OwnedFields* _tempFields = NULL; - struct OwnedFields* _contextFields = NULL; - _contextFields = initOwnedFields(_ownedFields->instanceCounter); - assertAcc(_ownedFields, test != NULL ? test->_id : -1, 0, "Field access runtime check failed for struct Test.value"); - addAcc(_contextFields, test != NULL ? test->_id : -1, 1, 0); - loseAcc(_ownedFields, test != NULL ? test->_id : -1, 0); - assertAcc(_contextFields, test != NULL ? test->_id : -1, 0, "Field access runtime check failed for struct Test.value"); + _tempFields = runtime_init(); + runtime_assert(_ownedFields, test == NULL ? -1 : test->_id, 0, "No permission to access 'test->value'"); + runtime_remove(_ownedFields, test == NULL ? -1 : test->_id, 0, "No permission to access 'test->value'"); + runtime_add(_tempFields, test == NULL ? -1 : test->_id, 1, 0, "Invalid aliasing - 'test->value' overlaps with existing permission"); + runtime_assert(_tempFields, test == NULL ? -1 : test->_id, 0, "No permission to access 'test->value'"); assert(true); return test->value; } @@ -45,13 +43,12 @@ int main() struct Test* _ = NULL; int _1 = 0; struct OwnedFields* _ownedFields = NULL; - struct OwnedFields* _tempFields = NULL; int* _instanceCounter = NULL; _instanceCounter = alloc(int); - _ownedFields = initOwnedFields(_instanceCounter); + _ownedFields = runtime_init(); assert(true); - _ = createTest(_ownedFields); - _1 = getValue(_, _ownedFields); + _ = createTest(_ownedFields, _instanceCounter); + _1 = getValue(_, _ownedFields, _instanceCounter); assert(true); return _1; } diff --git a/src/test/resources/baseline/framing.baseline.c0 b/src/test/resources/baseline/framing.baseline.c0 index 6f8deff0..256d7046 100644 --- a/src/test/resources/baseline/framing.baseline.c0 +++ b/src/test/resources/baseline/framing.baseline.c0 @@ -14,58 +14,53 @@ struct Outer int _id; }; -struct Outer* createOuter(struct OwnedFields* _ownedFields); -int getValue(struct Outer* outer, struct OwnedFields* _ownedFields); -int getValueSafe(struct Outer* outer, struct OwnedFields* _ownedFields); -int getValueStatic(struct Outer* outer, struct OwnedFields* _ownedFields); +struct Outer* createOuter(struct OwnedFields* _ownedFields, int* _instanceCounter); +int getValue(struct Outer* outer, struct OwnedFields* _ownedFields, int* _instanceCounter); +int getValueSafe(struct Outer* outer, struct OwnedFields* _ownedFields, int* _instanceCounter); +int getValueStatic(struct Outer* outer, struct OwnedFields* _ownedFields, int* _instanceCounter); int main(); -struct Outer* createOuter(struct OwnedFields* _ownedFields) +struct Outer* createOuter(struct OwnedFields* _ownedFields, int* _instanceCounter) { struct Inner* inner = NULL; struct Outer* outer = NULL; struct OwnedFields* _tempFields = NULL; - struct OwnedFields* _contextFields = NULL; - _contextFields = initOwnedFields(_ownedFields->instanceCounter); + _tempFields = runtime_init(); assert(true); inner = alloc(struct Inner); - inner->_id = addStructAcc(_contextFields, 1); + inner->_id = *_instanceCounter; + *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_tempFields, inner->_id, 1); outer = alloc(struct Outer); - outer->_id = addStructAcc(_contextFields, 1); - assertAcc(_contextFields, outer != NULL ? outer->_id : -1, 0, "Field access runtime check failed for struct Outer.inner"); + outer->_id = *_instanceCounter; + *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_tempFields, outer->_id, 1); + runtime_assert(_tempFields, outer == NULL ? -1 : outer->_id, 0, "No permission to access 'outer->inner'"); outer->inner = inner; - assertAcc(_contextFields, outer != NULL ? outer->_id : -1, 0, "Field access runtime check failed for struct Outer.inner"); - assertAcc(_contextFields, outer->inner != NULL ? outer->inner->_id : -1, 0, "Field access runtime check failed for struct Inner.value"); - _tempFields = initOwnedFields(_contextFields->instanceCounter); - addAccEnsureSeparate(_tempFields, outer->inner != NULL ? outer->inner->_id : -1, 0, 1, "Overlapping field permissions for struct Inner.value"); - join(_ownedFields, _contextFields); + runtime_assert(_tempFields, outer == NULL ? -1 : outer->_id, 0, "No permission to access 'outer->inner'"); + runtime_assert(_tempFields, outer->inner == NULL ? -1 : outer->inner->_id, 0, "No permission to access 'outer->inner->value'"); + runtime_join(_ownedFields, _tempFields); return outer; } -int getValue(struct Outer* outer, struct OwnedFields* _ownedFields) +int getValue(struct Outer* outer, struct OwnedFields* _ownedFields, int* _instanceCounter) { - struct OwnedFields* _tempFields = NULL; - assertAcc(_ownedFields, outer != NULL ? outer->_id : -1, 0, "Field access runtime check failed for struct Outer.inner"); - assertAcc(_ownedFields, outer->inner != NULL ? outer->inner->_id : -1, 0, "Field access runtime check failed for struct Inner.value"); - _tempFields = initOwnedFields(_ownedFields->instanceCounter); - addAccEnsureSeparate(_tempFields, outer->inner != NULL ? outer->inner->_id : -1, 0, 1, "Overlapping field permissions for struct Inner.value"); - assertAcc(_ownedFields, outer != NULL ? outer->_id : -1, 0, "Field access runtime check failed for struct Outer.inner"); - assertAcc(_ownedFields, outer->inner != NULL ? outer->inner->_id : -1, 0, "Field access runtime check failed for struct Inner.value"); + runtime_assert(_ownedFields, outer == NULL ? -1 : outer->_id, 0, "No permission to access 'outer->inner'"); + runtime_assert(_ownedFields, outer->inner == NULL ? -1 : outer->inner->_id, 0, "No permission to access 'outer->inner->value'"); + runtime_assert(_ownedFields, outer->inner == NULL ? -1 : outer->inner->_id, 0, "No permission to access 'outer->inner->value'"); return outer->inner->value; } -int getValueSafe(struct Outer* outer, struct OwnedFields* _ownedFields) +int getValueSafe(struct Outer* outer, struct OwnedFields* _ownedFields, int* _instanceCounter) { int result = 0; - struct OwnedFields* _tempFields = NULL; if (outer != NULL) { - assertAcc(_ownedFields, outer != NULL ? outer->_id : -1, 0, "Field access runtime check failed for struct Outer.inner"); + runtime_assert(_ownedFields, outer == NULL ? -1 : outer->_id, 0, "No permission to access 'outer->inner'"); } if (outer != NULL && outer->inner != NULL) { - assertAcc(_ownedFields, outer != NULL ? outer->_id : -1, 0, "Field access runtime check failed for struct Outer.inner"); - assertAcc(_ownedFields, outer->inner != NULL ? outer->inner->_id : -1, 0, "Field access runtime check failed for struct Inner.value"); + runtime_assert(_ownedFields, outer->inner == NULL ? -1 : outer->inner->_id, 0, "No permission to access 'outer->inner->value'"); result = outer->inner->value; } else @@ -75,26 +70,21 @@ int getValueSafe(struct Outer* outer, struct OwnedFields* _ownedFields) return result; } -int getValueStatic(struct Outer* outer, struct OwnedFields* _ownedFields) +int getValueStatic(struct Outer* outer, struct OwnedFields* _ownedFields, int* _instanceCounter) { struct OwnedFields* _tempFields = NULL; - struct OwnedFields* _contextFields = NULL; - _contextFields = initOwnedFields(_ownedFields->instanceCounter); - assertAcc(_ownedFields, outer != NULL ? outer->_id : -1, 0, "Field access runtime check failed for struct Outer.inner"); - addAcc(_contextFields, outer != NULL ? outer->_id : -1, 1, 0); - loseAcc(_ownedFields, outer != NULL ? outer->_id : -1, 0); - assertAcc(_ownedFields, outer->inner != NULL ? outer->inner->_id : -1, 0, "Field access runtime check failed for struct Inner.value"); - addAcc(_contextFields, outer->inner != NULL ? outer->inner->_id : -1, 1, 0); - loseAcc(_ownedFields, outer->inner != NULL ? outer->inner->_id : -1, 0); - assertAcc(_contextFields, outer != NULL ? outer->_id : -1, 0, "Field access runtime check failed for struct Outer.inner"); - assertAcc(_contextFields, outer->inner != NULL ? outer->inner->_id : -1, 0, "Field access runtime check failed for struct Inner.value"); - assertAcc(_contextFields, outer != NULL ? outer->_id : -1, 0, "Field access runtime check failed for struct Outer.inner"); - assertAcc(_contextFields, outer->inner != NULL ? outer->inner->_id : -1, 0, "Field access runtime check failed for struct Inner.value"); - _tempFields = initOwnedFields(_contextFields->instanceCounter); - addAccEnsureSeparate(_tempFields, outer != NULL ? outer->_id : -1, 0, 1, "Overlapping field permissions for struct Outer.inner"); - addAccEnsureSeparate(_tempFields, outer->inner != NULL ? outer->inner->_id : -1, 0, 1, "Overlapping field permissions for struct Inner.value"); - addAcc(_ownedFields, outer->_id, 1, 0); - addAcc(_ownedFields, outer->inner->_id, 1, 0); + _tempFields = runtime_init(); + runtime_assert(_ownedFields, outer == NULL ? -1 : outer->_id, 0, "No permission to access 'outer->inner'"); + runtime_remove(_ownedFields, outer == NULL ? -1 : outer->_id, 0, "No permission to access 'outer->inner'"); + runtime_add(_tempFields, outer == NULL ? -1 : outer->_id, 1, 0, "Invalid aliasing - 'outer->inner' overlaps with existing permission"); + runtime_assert(_ownedFields, outer->inner == NULL ? -1 : outer->inner->_id, 0, "No permission to access 'outer->inner->value'"); + runtime_remove(_ownedFields, outer->inner == NULL ? -1 : outer->inner->_id, 0, "No permission to access 'outer->inner->value'"); + runtime_add(_tempFields, outer->inner == NULL ? -1 : outer->inner->_id, 1, 0, "Invalid aliasing - 'outer->inner->value' overlaps with existing permission"); + runtime_assert(_tempFields, outer->inner == NULL ? -1 : outer->inner->_id, 0, "No permission to access 'outer->inner->value'"); + runtime_assert(_tempFields, outer == NULL ? -1 : outer->_id, 0, "No permission to access 'outer->inner'"); + runtime_add(_ownedFields, outer == NULL ? -1 : outer->_id, 1, 0, "Invalid aliasing - 'outer->inner' overlaps with existing permission"); + runtime_assert(_tempFields, outer->inner == NULL ? -1 : outer->inner->_id, 0, "No permission to access 'outer->inner->value'"); + runtime_add(_ownedFields, outer->inner == NULL ? -1 : outer->inner->_id, 1, 0, "Invalid aliasing - 'outer->inner->value' overlaps with existing permission"); return outer->inner->value; } @@ -103,13 +93,12 @@ int main() struct Outer* _ = NULL; int _1 = 0; struct OwnedFields* _ownedFields = NULL; - struct OwnedFields* _tempFields = NULL; int* _instanceCounter = NULL; _instanceCounter = alloc(int); - _ownedFields = initOwnedFields(_instanceCounter); + _ownedFields = runtime_init(); assert(true); - _ = createOuter(_ownedFields); - _1 = getValue(_, _ownedFields); + _ = createOuter(_ownedFields, _instanceCounter); + _1 = getValue(_, _ownedFields, _instanceCounter); assert(true); return _1; } diff --git a/src/test/resources/baseline/main.baseline.c0 b/src/test/resources/baseline/main.baseline.c0 index db529c86..89995862 100644 --- a/src/test/resources/baseline/main.baseline.c0 +++ b/src/test/resources/baseline/main.baseline.c0 @@ -5,10 +5,9 @@ int main() { int a = 0; struct OwnedFields* _ownedFields = NULL; - struct OwnedFields* _tempFields = NULL; int* _instanceCounter = NULL; _instanceCounter = alloc(int); - _ownedFields = initOwnedFields(_instanceCounter); + _ownedFields = runtime_init(); a = 1; return a; } diff --git a/src/test/resources/baseline/postcondition.baseline.c0 b/src/test/resources/baseline/postcondition.baseline.c0 index d7649c9f..f49633c6 100644 --- a/src/test/resources/baseline/postcondition.baseline.c0 +++ b/src/test/resources/baseline/postcondition.baseline.c0 @@ -1,26 +1,24 @@ #use int main(); -int test(int x, struct OwnedFields* _ownedFields); +int test(int x, struct OwnedFields* _ownedFields, int* _instanceCounter); int main() { int _ = 0; struct OwnedFields* _ownedFields = NULL; - struct OwnedFields* _tempFields = NULL; int* _instanceCounter = NULL; _instanceCounter = alloc(int); - _ownedFields = initOwnedFields(_instanceCounter); + _ownedFields = runtime_init(); assert(true); - _ = test(2, _ownedFields); + _ = test(2, _ownedFields, _instanceCounter); assert(true); return 0; } -int test(int x, struct OwnedFields* _ownedFields) +int test(int x, struct OwnedFields* _ownedFields, int* _instanceCounter) { struct OwnedFields* _tempFields = NULL; - struct OwnedFields* _contextFields = NULL; - _contextFields = initOwnedFields(_ownedFields->instanceCounter); + _tempFields = runtime_init(); assert(true); assert(x + 1 == x + 1); return x + 1; diff --git a/src/test/resources/baseline/precondition.baseline.c0 b/src/test/resources/baseline/precondition.baseline.c0 index ffb4802a..06f437af 100644 --- a/src/test/resources/baseline/precondition.baseline.c0 +++ b/src/test/resources/baseline/precondition.baseline.c0 @@ -1,25 +1,23 @@ #use int main(); -void test(int x, struct OwnedFields* _ownedFields); +void test(int x, struct OwnedFields* _ownedFields, int* _instanceCounter); int main() { struct OwnedFields* _ownedFields = NULL; - struct OwnedFields* _tempFields = NULL; int* _instanceCounter = NULL; _instanceCounter = alloc(int); - _ownedFields = initOwnedFields(_instanceCounter); + _ownedFields = runtime_init(); assert(true); - test(2, _ownedFields); + test(2, _ownedFields, _instanceCounter); assert(true); return 0; } -void test(int x, struct OwnedFields* _ownedFields) +void test(int x, struct OwnedFields* _ownedFields, int* _instanceCounter) { struct OwnedFields* _tempFields = NULL; - struct OwnedFields* _contextFields = NULL; - _contextFields = initOwnedFields(_ownedFields->instanceCounter); + _tempFields = runtime_init(); assert(x > 1); assert(true); } diff --git a/src/test/resources/baseline/predicate.baseline.c0 b/src/test/resources/baseline/predicate.baseline.c0 index 4684d22d..b3e86071 100644 --- a/src/test/resources/baseline/predicate.baseline.c0 +++ b/src/test/resources/baseline/predicate.baseline.c0 @@ -8,24 +8,27 @@ struct Node int _id; }; -void add_list(struct Node* node, struct OwnedFields* _ownedFields); -void check_add_remove_list(struct Node* node, struct OwnedFields* _ownedFields, struct OwnedFields* _tempFields); -struct Node* emptyList(struct OwnedFields* _ownedFields); -void list(struct Node* node, struct OwnedFields* _ownedFields); +void assert_add_list(struct Node* node, struct OwnedFields* assert_perms, struct OwnedFields* add_perms); +void assert_remove_add_list(struct Node* node, struct OwnedFields* assert_perms, struct OwnedFields* remove_perms, struct OwnedFields* add_perms); +struct Node* emptyList(struct OwnedFields* _ownedFields, int* _instanceCounter); int main(); -struct Node* prependList(int value, struct Node* node, struct OwnedFields* _ownedFields); -void sep_list(struct Node* node, struct OwnedFields* _ownedFields); +struct Node* prependList(int value, struct Node* node, struct OwnedFields* _ownedFields, int* _instanceCounter); -void add_list(struct Node* node, struct OwnedFields* _ownedFields) +void assert_add_list(struct Node* node, struct OwnedFields* assert_perms, struct OwnedFields* add_perms) { - if (!(node == NULL)) + if (node == NULL) + { + assert(true); + } + else { - addAcc(_ownedFields, node->_id, 2, 0); - add_list(node->next, _ownedFields); + runtime_assert(assert_perms, node == NULL ? -1 : node->_id, 0, "No permission to access 'node->value'"); + runtime_add(add_perms, node == NULL ? -1 : node->_id, 2, 0, "Invalid aliasing - 'node->value' overlaps with existing permission"); + assert_add_list(node->next, assert_perms, add_perms); } } -void check_add_remove_list(struct Node* node, struct OwnedFields* _ownedFields, struct OwnedFields* _tempFields) +void assert_remove_add_list(struct Node* node, struct OwnedFields* assert_perms, struct OwnedFields* remove_perms, struct OwnedFields* add_perms) { if (node == NULL) { @@ -33,78 +36,49 @@ void check_add_remove_list(struct Node* node, struct OwnedFields* _ownedFields, } else { - assertAcc(_tempFields, node != NULL ? node->_id : -1, 0, "Field access runtime check failed for struct Node.value"); - addAcc(_ownedFields, node != NULL ? node->_id : -1, 2, 0); - loseAcc(_tempFields, node != NULL ? node->_id : -1, 0); - check_add_remove_list(node->next, _ownedFields, _tempFields); + runtime_assert(assert_perms, node == NULL ? -1 : node->_id, 0, "No permission to access 'node->value'"); + runtime_remove(remove_perms, node == NULL ? -1 : node->_id, 0, "No permission to access 'node->value'"); + runtime_add(add_perms, node == NULL ? -1 : node->_id, 2, 0, "Invalid aliasing - 'node->value' overlaps with existing permission"); + assert_remove_add_list(node->next, assert_perms, remove_perms, add_perms); } } -struct Node* emptyList(struct OwnedFields* _ownedFields) +struct Node* emptyList(struct OwnedFields* _ownedFields, int* _instanceCounter) { struct Node* nullList = NULL; struct OwnedFields* _tempFields = NULL; - struct OwnedFields* _contextFields = NULL; - _contextFields = initOwnedFields(_ownedFields->instanceCounter); + _tempFields = runtime_init(); assert(true); nullList = NULL; - list(nullList, _contextFields); - _tempFields = initOwnedFields(_contextFields->instanceCounter); - sep_list(nullList, _tempFields); - add_list(nullList, _ownedFields); + assert_add_list(nullList, _tempFields, _ownedFields); return nullList; } -void list(struct Node* node, struct OwnedFields* _ownedFields) -{ - if (node == NULL) - { - assert(true); - } - else - { - assertAcc(_ownedFields, node != NULL ? node->_id : -1, 0, "Field access runtime check failed for struct Node.value"); - list(node->next, _ownedFields); - } -} - int main() { struct OwnedFields* _ownedFields = NULL; - struct OwnedFields* _tempFields = NULL; int* _instanceCounter = NULL; _instanceCounter = alloc(int); - _ownedFields = initOwnedFields(_instanceCounter); + _ownedFields = runtime_init(); assert(true); assert(true); return 0; } -struct Node* prependList(int value, struct Node* node, struct OwnedFields* _ownedFields) +struct Node* prependList(int value, struct Node* node, struct OwnedFields* _ownedFields, int* _instanceCounter) { struct Node* newNode = NULL; struct OwnedFields* _tempFields = NULL; - struct OwnedFields* _contextFields = NULL; - _contextFields = initOwnedFields(_ownedFields->instanceCounter); - check_add_remove_list(node, _contextFields, _ownedFields); + _tempFields = runtime_init(); + assert_remove_add_list(node, _ownedFields, _ownedFields, _tempFields); newNode = alloc(struct Node); - newNode->_id = addStructAcc(_contextFields, 2); - assertAcc(_contextFields, newNode != NULL ? newNode->_id : -1, 1, "Field access runtime check failed for struct Node.next"); + newNode->_id = *_instanceCounter; + *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_tempFields, newNode->_id, 2); + runtime_assert(_tempFields, newNode == NULL ? -1 : newNode->_id, 1, "No permission to access 'newNode->next'"); newNode->next = node; - assertAcc(_contextFields, newNode != NULL ? newNode->_id : -1, 0, "Field access runtime check failed for struct Node.value"); + runtime_assert(_tempFields, newNode == NULL ? -1 : newNode->_id, 0, "No permission to access 'newNode->value'"); newNode->value = value; - list(newNode, _contextFields); - _tempFields = initOwnedFields(_contextFields->instanceCounter); - sep_list(newNode, _tempFields); - add_list(newNode, _ownedFields); + assert_add_list(newNode, _tempFields, _ownedFields); return newNode; } - -void sep_list(struct Node* node, struct OwnedFields* _ownedFields) -{ - if (!(node == NULL)) - { - addAccEnsureSeparate(_ownedFields, node != NULL ? node->_id : -1, 0, 2, "Overlapping field permissions for struct Node.value"); - sep_list(node->next, _ownedFields); - } -} diff --git a/src/test/resources/c0/quickcheck.c0 b/src/test/resources/c0/quickcheck.c0 new file mode 100644 index 00000000..08504ebf --- /dev/null +++ b/src/test/resources/c0/quickcheck.c0 @@ -0,0 +1,308 @@ +#use +#use +#use + +// Use some internal methods for testing +FieldArray* runtime_find(OwnedFields* fields, int id); +bool runtime_tryAdd(OwnedFields* fields, int id, int numFields, int fieldIndex); +bool runtime_tryRemove(OwnedFields* fields, int id, int fieldIndex); +int runtime_hash(int index, int arrayLength); + +struct List { + int[] values; + int length; + int capacity; +}; + +typedef struct List List; + +List* list_new() { + List* l = alloc(struct List); + l->capacity = 2048; + l->values = alloc_array(int, l->capacity); + return l; +} + +void list_add(List* l, int value) { + int length = l->length; + if (length == l->capacity) { + int capacity = l->capacity; + int[] values = l->values; + int newCapacity = capacity * 2; + int[] newValues = alloc_array(int, newCapacity); + for (int i = 0; i < capacity; i++) { + newValues[i] = values[i]; + } + l->values = newValues; + l->capacity = newCapacity; + } + l->values[length] = value; + l->length = length + 1; +} + +int list_index(List* l, int value) { + int length = l->length; + int[] values = l->values; + for (int i = 0; i < length; i++) { + if (values[i] == value) return i; + } + return -1; +} + +bool list_contains(List* l, int value) { + return list_index(l, value) != -1; +} + +void list_remove(List* l, int value) { + int i = list_index(l, value); + if (i == -1) return; + + int length = l->length - 1; + l->length = length; + int[] values = l->values; + for (; i < length; i++) + values[i] = values[i + 1]; + + assert(!list_contains(l, value)); +} + +// Expected value: 25.5 +int field_count(int id) { + return ((id + 13) % 51) + 1; +} + +int tag(int id, int field) { + return (id << 16) | field; +} + +int tag_id(int t) { + return t >> 16; +} + +int tag_field(int t) { + return t & 0xFFFF; +} + +bool contains_id(List* l, int id) { + int length = l->length; + int[] values = l->values; + for (int i = 0; i < length; i++) { + if (values[i] >> 16 == id) + return true; + } + return false; +} + +void check_fields(OwnedFields* fields) { + // Checks properties: + // (1) Entries should be sorted by hash key (modulo wrap-around) + // (2) `length` should count the number of non-NULL entries + // (3) The number of fields marked `true` should be equal to `accessible + // (4) Hash + PSL == index + + int capacity = fields->capacity; + FieldArray*[] contents = fields->contents; + + int start = 0; + int min = int_max(); + for (int a = 0; a < capacity; a++) { + FieldArray* entry = contents[a]; + if (entry != NULL) { + int k = runtime_hash(entry->id, capacity); + if (k < min) { + min = k; + start = a; + } + } + } + + int i = start; + int count = 0; + int lastKey = -1; + for (int c = 0; c < capacity; c++) { + if (contents[i] != NULL) { + count++; + + FieldArray* entry = contents[i]; + int key = runtime_hash(entry->id, capacity); + if (key < lastKey) { // (1) + runtime_print(fields); + printf("Error: Hash key %d is smaller than previous %d (at %d)\n", key, lastKey, i); + assert(false); + } + if ((key + entry->psl) % capacity != i) { // (4) + runtime_print(fields); + printf("Error: Index %d != key %d + PSL %d\n", i, key, entry->psl); + assert(false); + } + + int acc = 0; + for (int j = 0; j < entry->length; j++) { + if (entry->contents[j]) acc++; + } + if (acc != entry->accessible) { // (3) + runtime_print(fields); + printf("Invalid accessible count @ %d: expected %d, found %d\n", i, entry->accessible, acc); + assert(false); + } + + lastKey = key; + } + + i = (i + 1) % capacity; + } + + assert(count == fields->length); // (2) +} + +bool removeNext(int id, List* tags, OwnedFields* fields) { + int length = tags->length; + int[] values = tags->values; + for (int j = 0; j < length; j++) { + int t = values[j]; + if (tag_id(t) == id) { + assert(runtime_tryRemove(fields, id, tag_field(t))); + list_remove(tags, t); + return true; + } + } + return false; +} + +int main() { + int INITIAL_SIZE = 2048; + int MAX_ID = 1024; + int REPS = 5000; + bool PRINT = false; + + rand_t r = init_rand(1); + + List* tags = list_new(); + int length = 0; + + OwnedFields* fields = runtime_init(); + + int i = 0; + while (i < INITIAL_SIZE) { + check_fields(fields); + + int id = abs(rand(r)) % MAX_ID; + int numFields = field_count(id); + int field = abs(rand(r)) % numFields; + int t = tag(id, field); + if (!list_contains(tags, t)) { + if (PRINT) printf("add(%d, %d, %d)\n", id, numFields, field); + list_add(tags, t); + runtime_add(fields, id, numFields, field, "Could not add (initial)"); + } + i++; + } + + for (i = 0; i < REPS; i++) { + check_fields(fields); + + // Check that tags coincides with the entries + // This takes a bit of time + /*for (int j = 0; j < tags->length; j++) { + int tag = tags->values[j]; + runtime_assert(fields, tag_id(tag), tag_field(tag), "Missing field"); + } + + for (int j = 0; j < fields->length; j++) { + FieldArray* entry = fields->contents[j]; + if (entry != NULL) { + for (int f = 0; f < entry->length; f++) { + if (entry->contents[f] && !list_contains(tags, tag(entry->id, f))) { + runtime_print(fields); + printf("Unexpected field %d.%d @ %d", entry->id, f, j); + assert(false); + } + } + } + } + */ + + int type = abs(rand(r)) % 100; + int id = abs(rand(r)) % MAX_ID; + int numFields = field_count(id); + if (type < 15) { + // 15% of the time: Add a struct + if (!contains_id(tags, id)) { + if (PRINT) printf("addAll(%d, %d)\n", id, numFields); + runtime_addAll(fields, id, numFields); + for (int j = 0; j< numFields; j++) + list_add(tags, tag(id, j)); + + FieldArray* entry = runtime_find(fields, id); + assert(entry != NULL); + for (int k = 0; k < numFields; k++) { + if (!entry->contents[k]) { + runtime_print(fields); + assert(false); + } + } + assert(entry->length >= numFields); + } else { + // if (PRINT) printf("Skipped: addAll(%d, %d)\n", id, numFields); + i--; // Don't count this iteration + } + } else if (type < 30) { + // 15% of the time: Remove all fields for a single ID + if (contains_id(tags, id)) { + if (PRINT) printf("removeAll(%d, %d)\n", id, numFields); + while (removeNext(id, tags, fields)) + true; + } else { + // if (PRINT) printf("Skipped: removeAll(%d)\n", id); + i--; + } + } else if (type < 70) { + // 40% of the time: Add a single field + int field = abs(rand(r)) % numFields; + int t = tag(id, field); + bool contained = list_contains(tags, t); + if (PRINT) { + printf("add(%d, %d, %d)", id, numFields, field); + println(contained ? " - overlaps" : ""); + } + bool result = runtime_tryAdd(fields, id, numFields, field); + if (result != !contained) { + runtime_print(fields); + printf("Unexpected result\n"); + assert(false); + } + + if (result) + list_add(tags, t); + + FieldArray* entry = runtime_find(fields, id); + if (entry == NULL || !entry->contents[field]) { + runtime_print(fields); + printf("Added field missing\n"); + assert(false); + } + } else { + // 30% of the time: Remove a single field + int field = abs(rand(r)) % numFields; + int t = tag(id, field); + bool contained = list_contains(tags, t); + if (PRINT) { + printf("remove(%d, %d)", id, field); + println(contained ? "" : " - missing"); + } + bool result = runtime_tryRemove(fields, id, field); + assert(result == contained); + + FieldArray* entry = runtime_find(fields, id); + if (entry != NULL && entry->contents[field]) { + runtime_print(fields); + printf("Removed field still exists\n"); + assert(false); + } + + list_remove(tags, t); + } + } + + return 0; +} \ No newline at end of file diff --git a/src/test/resources/c0/test.c0 b/src/test/resources/c0/test.c0 index b4f5147c..f1b37730 100644 --- a/src/test/resources/c0/test.c0 +++ b/src/test/resources/c0/test.c0 @@ -1,264 +1,303 @@ #use -bool NassertAcc(OwnedFields* fields, int _id, int fieldIndex, string errorMessage){ - FieldArray* toCheck = find(fields, _id); - if(toCheck == NULL || !toCheck->contents[fieldIndex]){ - return false; - }else{ - return true; - } -} +// Use some internal methods for testing +FieldArray* runtime_find(OwnedFields* fields, int id); +bool runtime_tryAdd(OwnedFields* fields, int id, int numFields, int fieldIndex); -bool NaddAccEnsureSeparate(OwnedFields* fields, int _id, int fieldIndex, int numFields, string errorMessage){ - FieldArray* toCheck = find(fields, _id); - if (toCheck == NULL) { - toCheck = newFieldArray(fields, _id, numFields, false); - } else if (toCheck->contents[fieldIndex]) { - return false; - } - toCheck->contents[fieldIndex] = true; - toCheck->numAccessible += 1; - return true; +bool checkAcc(OwnedFields* fields, int id, int fieldIndex) { + FieldArray* entry = runtime_find(fields, id); + return entry != NULL && entry->contents[fieldIndex]; } -void debugOwnedFields(OwnedFields* fields, string name) { - print("---[ "); - print(name); - print(" ]---\n"); - print("* Num. of structs: "); - printint(fields->length); - print("\n"); - print("* Capacity: "); - printint(fields->capacity); - print("\n"); - print("["); - if(fields->capacity > 0){ - if(fields->contents[0] != NULL){ - printint(fields->contents[0]->_id); - }else{ - print("NULL"); - } - for(int i = 1; i < fields->capacity; i+=1){ - print(", "); - if(fields->contents[i] != NULL){ - printint(fields->contents[i]->_id); - }else{ - print("NULL"); - } - } - } - print("]\n"); -} +void test_init() { + print("test_init: "); -bool test(string header, bool condition, string message){ - if(!condition){ - error(string_join(string_join(header, " - "), message)); - return false; - }else{ - return true; - } -} - -bool testOwnedFieldsInitialization(string header){ - int * _id_counter = alloc(int); - *(_id_counter) = 0; - - OwnedFields* fields = initOwnedFields(_id_counter); - if(!test(header, fields->capacity > 0, "(initOwnedFields) OwnedFields must have a nonzero default capacity after intialization.")){ - return false; - } - if(!test(header, fields->length == 0, "(initOwnedFields) OwnedFields must have a length of zero after initialization.")){ - return false; - } + OwnedFields* fields = runtime_init(); + assert(fields->capacity > 0); + assert(fields->length == 0); + for (int i = 0; i < fields->capacity; i++) + assert(fields->contents[i] == NULL); - bool allNULL = true; - for(int i = 0; i < fields->length; i+=1){ - if(fields->contents[i] != NULL){ - allNULL = false; - } - } - if(!test(header, allNULL, "(initOwnedFields) The contents of OwnedFields must all be set to NULL after initialization.")){ - return false; - } - return true; + println("PASS"); } -bool testStructCreation(string header){ - int * _id_counter = alloc(int); - *(_id_counter) = 0; +void test_addAll() { + print("test_addAll: "); + OwnedFields* fields = runtime_init(); - OwnedFields* fields = initOwnedFields(_id_counter); + int id = 123; + int numFields = 80; + runtime_addAll(fields, id, numFields); - int _id = addStructAcc(fields, 80); + FieldArray* entry = runtime_find(fields, id); + assert(entry != NULL); + assert(entry->accessible == numFields); - for(int i = 0; i<80; i += 1){ - assertAcc(fields, _id, i, "(addStructAcc) Failed to verify access to the fields of a newly created struct."); + for (int i = 0; i < numFields; i++) { + runtime_assert(fields, id, i, + "(addAll) Failed to verify access to the fields of a newly created struct."); } - return true; + + println("PASS"); } +void test_add(){ + print("test_add: "); + + OwnedFields* fields_1 = runtime_init(); + OwnedFields* fields_2 = runtime_init(); + + int id = 0; + runtime_addAll(fields_1, id, 4); + runtime_add(fields_2, id, 4, 2, "add failed (2)"); + runtime_add(fields_2, id, 4, 3, "add failed (3)"); -bool testFieldAddition(string header){ - int * _id_counter = alloc(int); - *(_id_counter) = 0; + // Has the fields we added + runtime_assert(fields_2, id, 2, "assert failed (2)"); + runtime_assert(fields_2, id, 3, "assert failed (3)"); - OwnedFields* fields_1 = initOwnedFields(_id_counter); - OwnedFields* fields_2 = initOwnedFields(_id_counter); + // Doesn't have another field + assert(!checkAcc(fields_2, id, 1)); - int _id = addStructAcc(fields_1, 4); + println("PASS"); +} - addAcc(fields_2, _id, 4, 2); - assertAcc(fields_2, _id, 2, "(addAcc/assertAcc) Failed to add an arbitrary singular field access permission to an OwnedFields struct."); - return true; +void test_remove() { + print("test_remove: "); + + OwnedFields* fields = runtime_init(); + + int numFields = 2; + runtime_addAll(fields, 0, numFields); + runtime_addAll(fields, 1, numFields); + runtime_addAll(fields, 2, numFields); + + runtime_remove(fields, 0, 0, "remove failed (0.0)"); + runtime_remove(fields, 0, 1, "remove failed (0.1)"); + runtime_remove(fields, 1, 1, "remove failed (1.1)"); + runtime_remove(fields, 2, 0, "remove failed (2.0)"); + + assert(!checkAcc(fields, 0, 0)); + assert(!checkAcc(fields, 0, 1)); + assert(checkAcc(fields, 1, 0)); + assert(!checkAcc(fields, 1, 1)); + assert(!checkAcc(fields, 2, 0)); + assert(checkAcc(fields, 2, 1)); + + println("PASS"); } +void test_resizing() { + print("test_resizing: "); + + OwnedFields* fields = runtime_init(); -bool testFieldMerging(string header){ - int * _id_counter = alloc(int); - *(_id_counter) = 0; + int numFields = 5; + for (int id = 0; id < 1024; id++) { + runtime_add(fields, id, numFields, id % numFields, "add failed"); + } - int num_fields = 20; + for (int id = 0; id < 1024; id++) { + assert(checkAcc(fields, id, id % numFields)); + } - OwnedFields* source = initOwnedFields(_id_counter); + for (int id = 0; id < 1024; id++) { + for (int f = 0; f < numFields; f++) { + assert((f == id % numFields) || !checkAcc(fields, id, f)); + } + } - int _id_node_1 = addStructAcc(source, num_fields); - int _id_node_2 = addStructAcc(source, num_fields); + println("PASS"); +} - OwnedFields* fields_1 = initOwnedFields(_id_counter); +void test_resizingRemove() { + print("test_resizingRemove: "); - OwnedFields* fields_2 = initOwnedFields(_id_counter); + OwnedFields* fields = runtime_init(); - for(int i = 0; icapacity == 128); - addAcc(fields_2, _id_node_2, num_fields, i); - assertAcc(fields_2, _id_node_2, i, "(addAcc/assertAcc) Failed to add permission to new OwnedFields from source OwnedFields."); + // Add 1024 records, forcing a resize + int numFields = 2; + for (int i = 0; i < 1023; i++) { + runtime_addAll(fields, i, numFields); } + assert(fields->length == 1023); + assert(fields->capacity == 1280); // 80% of 1280 is 1024, and we added 1023 + + // Remove all (both) fields from the first 1000 records + for (int i = 0; i < 1000; i++) { + runtime_remove(fields, i, 0, "remove failed (0)"); + runtime_remove(fields, i, 1, "remove failed (1)"); + } - join(fields_1, fields_2); + assert(fields->length == 1023); // Unused records are included in length + assert(fields->capacity == 1280); - for(int i = 0; ilength == 1023); + assert(fields->capacity == 1280); - int num_fields = 10; - int _id_node = addStructAcc(source, num_fields); + // Verify that we still have access to IDs 1000-1022 and 2000-2999 + for (int i = 1000; i < 1023; i++) { + assert(checkAcc(fields, i, 0)); + assert(checkAcc(fields, i, 1)); + } + for (int i = 2000; i < 3000; i++) { + assert(checkAcc(fields, i, 0)); + assert(checkAcc(fields, i, 1)); + } + println("PASS"); } -*/ - -bool testNegativeAccessibility(string header) { - int * _id_counter = alloc(int); - *(_id_counter) = 0; - OwnedFields* source = initOwnedFields(_id_counter); - int num_fields = 10; - int _id_node = addStructAcc(source, num_fields); +void test_join() { + print("test_join: "); + + int num_fields = 4; + int num_items = 20; + + OwnedFields* source = runtime_init(); + OwnedFields* target = runtime_init(); + + for (int i = 0; i < num_items; i++) { + if (i % 2 == 0) { + // Even IDs: field 0 added to source, 1 added to target + runtime_add(source, i, num_fields, 0, "add failed"); + runtime_add(target, i, num_fields, 1, "add failed"); + } else { + // Odd IDs: fields 0 & 2 added to source, 1 & 3 added to target + runtime_add(source, i, num_fields, 0, "add failed"); + runtime_add(source, i, num_fields, 2, "add failed"); + runtime_add(target, i, num_fields, 1, "add failed"); + runtime_add(target, i, num_fields, 3, "add failed"); + } + } - bool found = NassertAcc(source, _id_node+1, num_fields, ""); - if(found) error("ACK discovered."); + runtime_join(target, source); + + for (int i = 0; i < num_items; i++) { + if (i % 2 == 0) { + // Even IDs: should have fields 0, 1 + assert(checkAcc(target, i, 0)); + assert(checkAcc(target, i, 1)); + assert(!checkAcc(target, i, 2)); + assert(!checkAcc(target, i, 3)); + } else { + // Odd IDs: should have fields 0, 1, 2, 3 + assert(checkAcc(target, i, 0)); + assert(checkAcc(target, i, 1)); + assert(checkAcc(target, i, 2)); + assert(checkAcc(target, i, 3)); + } + } - return true; + println("PASS"); } -bool testEnsureSeparate(string header) { - int * _id_counter = alloc(int); - *(_id_counter) = 0; - OwnedFields* source = initOwnedFields(_id_counter); +// Tests join in the case where `target` needs to be resized +void test_joinLarger() { + print("test_joinLarger: "); - int num_fields = 10; - int _id_node = addStructAcc(source, num_fields); + int numFields = 2; + OwnedFields* source = runtime_init(); + OwnedFields* target = runtime_init(); - for(int i = 0; ivalue = 1; } -void test(struct Test* t, struct OwnedFields* _ownedFields) +void test(struct Test* t, struct OwnedFields* _ownedFields, int* _instanceCounter) { - assertAcc(_ownedFields, t != NULL ? t->_id : -1, 0, "Field access runtime check failed for struct Test.value"); + runtime_assert(_ownedFields, t == NULL ? -1 : t->_id, 0, "No permission to access 't->value'"); assert(!(t == NULL)); - loseAcc(_ownedFields, t->_id, 0); - setValue(t, _ownedFields->instanceCounter); - addAcc(_ownedFields, t->_id, 1, 0); + runtime_remove(_ownedFields, t->_id, 0, "No permission to access 't->value'"); + setValue(t, _instanceCounter); + runtime_add(_ownedFields, t == NULL ? -1 : t->_id, 1, 0, "Invalid aliasing - 't->value' overlaps with existing permission"); } diff --git a/src/test/resources/verifier/addtwo.output.c0 b/src/test/resources/verifier/addtwo.output.c0 index 158a8721..813f61f0 100644 --- a/src/test/resources/verifier/addtwo.output.c0 +++ b/src/test/resources/verifier/addtwo.output.c0 @@ -1,8 +1,8 @@ #use -int add(int a, int b, struct OwnedFields* _ownedFields); +int add(int a, int b, int* _instanceCounter); int main(); -int add(int a, int b, struct OwnedFields* _ownedFields) +int add(int a, int b, int* _instanceCounter) { return a + b; } @@ -11,13 +11,8 @@ int main() { int sum = 0; int* _instanceCounter = NULL; - struct OwnedFields* _tempFields = NULL; - struct OwnedFields* _ownedFields = NULL; _instanceCounter = alloc(int); - _ownedFields = initOwnedFields(_instanceCounter); - _tempFields = initOwnedFields(_instanceCounter); - sum = add(1, 3, _tempFields); - join(_ownedFields, _tempFields); + sum = add(1, 3, _instanceCounter); assert(sum == 4); return 0; } diff --git a/src/test/resources/verifier/alloc_precise.output.c0 b/src/test/resources/verifier/alloc_precise.output.c0 index b88eaf22..84fb2d19 100644 --- a/src/test/resources/verifier/alloc_precise.output.c0 +++ b/src/test/resources/verifier/alloc_precise.output.c0 @@ -9,7 +9,7 @@ struct A struct A* create(int* _instanceCounter); int main(); -void test(struct A* x, struct A* y, struct OwnedFields* _ownedFields); +void test(struct A* x, struct A* y, struct OwnedFields* _ownedFields, int* _instanceCounter); void test2(struct A* x, struct A* y, int* _instanceCounter); struct A* create(int* _instanceCounter) @@ -25,31 +25,31 @@ int main() { struct A* _ = NULL; struct A* _1 = NULL; - int* _instanceCounter = NULL; struct OwnedFields* _ownedFields = NULL; + int* _instanceCounter = NULL; _instanceCounter = alloc(int); - _ownedFields = initOwnedFields(_instanceCounter); + _ownedFields = runtime_init(); _ = create(_instanceCounter); - addAcc(_ownedFields, _->_id, 1, 0); + runtime_add(_ownedFields, _ == NULL ? -1 : _->_id, 1, 0, "Invalid aliasing - '_->value' overlaps with existing permission"); _1 = create(_instanceCounter); - addAcc(_ownedFields, _1->_id, 1, 0); - test(_, _1, _ownedFields); + runtime_add(_ownedFields, _1 == NULL ? -1 : _1->_id, 1, 0, "Invalid aliasing - '_1->value' overlaps with existing permission"); + test(_, _1, _ownedFields, _instanceCounter); return 0; } -void test(struct A* x, struct A* y, struct OwnedFields* _ownedFields) +void test(struct A* x, struct A* y, struct OwnedFields* _ownedFields, int* _instanceCounter) { struct OwnedFields* _tempFields = NULL; - _tempFields = initOwnedFields(_ownedFields->instanceCounter); - assertAcc(_ownedFields, y != NULL ? y->_id : -1, 0, "Field access runtime check failed for struct A.value"); - assertAcc(_ownedFields, x != NULL ? x->_id : -1, 0, "Field access runtime check failed for struct A.value"); + _tempFields = runtime_init(); + runtime_assert(_ownedFields, y == NULL ? -1 : y->_id, 0, "No permission to access 'y->value'"); + runtime_assert(_ownedFields, x == NULL ? -1 : x->_id, 0, "No permission to access 'x->value'"); assert(!(x == NULL)); assert(!(y == NULL)); - addAccEnsureSeparate(_tempFields, x != NULL ? x->_id : -1, 0, 1, "Overlapping field permissions for struct A.value"); - addAccEnsureSeparate(_tempFields, y != NULL ? y->_id : -1, 0, 1, "Overlapping field permissions for struct A.value"); - loseAcc(_ownedFields, x->_id, 0); - loseAcc(_ownedFields, y->_id, 0); - test2(x, y, _ownedFields->instanceCounter); + runtime_add(_tempFields, x == NULL ? -1 : x->_id, 1, 0, "Invalid aliasing - 'x->value' overlaps with existing permission"); + runtime_add(_tempFields, y == NULL ? -1 : y->_id, 1, 0, "Invalid aliasing - 'y->value' overlaps with existing permission"); + runtime_remove(_ownedFields, x->_id, 0, "No permission to access 'x->value'"); + runtime_remove(_ownedFields, y->_id, 0, "No permission to access 'y->value'"); + test2(x, y, _instanceCounter); } void test2(struct A* x, struct A* y, int* _instanceCounter) diff --git a/src/test/resources/verifier/bare_pointers.output.c0 b/src/test/resources/verifier/bare_pointers.output.c0 index 38801456..01ec3d63 100644 --- a/src/test/resources/verifier/bare_pointers.output.c0 +++ b/src/test/resources/verifier/bare_pointers.output.c0 @@ -49,9 +49,16 @@ struct _ptr_struct_Test_ int _id; }; +void add_accInt(struct _ptr_int* ptr, struct OwnedFields* add_perms); int main(); +void remove_accInt(struct _ptr_int* ptr, struct OwnedFields* remove_perms); void test(struct _ptr_int* input, int* _instanceCounter); +void add_accInt(struct _ptr_int* ptr, struct OwnedFields* add_perms) +{ + runtime_add(add_perms, ptr == NULL ? -1 : ptr->_id, 1, 0, "Invalid aliasing - 'ptr->value' overlaps with existing permission"); +} + int main() { struct _ptr_struct_Test_* refPtr = NULL; @@ -64,50 +71,75 @@ int main() struct _ptr_int* _2 = NULL; struct _ptr_struct_Test_* _3 = NULL; struct Test* _4 = NULL; + struct OwnedFields* _ownedFields = NULL; int* _instanceCounter = NULL; _instanceCounter = alloc(int); + _ownedFields = runtime_init(); refPtr = alloc(struct _ptr_struct_Test_); refPtr->_id = *_instanceCounter; *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, refPtr->_id, 1); _ = alloc(struct Test); _->_id = *_instanceCounter; *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, _->_id, 1); refPtr->value = _; intTest = alloc(struct _ptr_int__); intTest->_id = *_instanceCounter; *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, intTest->_id, 1); _1 = alloc(struct _ptr_int_); _1->_id = *_instanceCounter; *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, _1->_id, 1); intTest->value = _1; _2 = alloc(struct _ptr_int); _2->_id = *_instanceCounter; *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, _2->_id, 1); intTest->value->value = _2; intTest->value->value->value = -1; c = alloc(struct _ptr_char); c->_id = *_instanceCounter; *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, c->_id, 1); c->value = '\0'; ptr = alloc(struct _ptr_int); ptr->_id = *_instanceCounter; *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, ptr->_id, 1); + if (_ownedFields != NULL) + { + remove_accInt(ptr, _ownedFields); + } test(ptr, _instanceCounter); + if (_ownedFields != NULL) + { + add_accInt(ptr, _ownedFields); + } wrapper = alloc(struct Wrapper); wrapper->_id = *_instanceCounter; *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, wrapper->_id, 1); _3 = alloc(struct _ptr_struct_Test_); _3->_id = *_instanceCounter; *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, _3->_id, 1); wrapper->test = _3; _4 = alloc(struct Test); _4->_id = *_instanceCounter; *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, _4->_id, 1); wrapper->test->value = _4; wrapper->test->value->value = 1; return 0; } +void remove_accInt(struct _ptr_int* ptr, struct OwnedFields* remove_perms) +{ + runtime_remove(remove_perms, ptr->_id, 0, "No permission to access 'ptr->value'"); +} + void test(struct _ptr_int* input, int* _instanceCounter) { } diff --git a/src/test/resources/verifier/check_field_value.output.c0 b/src/test/resources/verifier/check_field_value.output.c0 index 6b7eaecf..f0d3e4b9 100644 --- a/src/test/resources/verifier/check_field_value.output.c0 +++ b/src/test/resources/verifier/check_field_value.output.c0 @@ -8,7 +8,7 @@ struct Test }; int main(); -void test(struct Test* value, struct OwnedFields* _ownedFields); +void test(struct Test* value, int* _instanceCounter); int main() { @@ -17,7 +17,7 @@ int main() return 0; } -void test(struct Test* value, struct OwnedFields* _ownedFields) +void test(struct Test* value, int* _instanceCounter) { assert(value->value == 0); } diff --git a/src/test/resources/verifier/conditional.output.c0 b/src/test/resources/verifier/conditional.output.c0 index bb17bde1..d60810db 100644 --- a/src/test/resources/verifier/conditional.output.c0 +++ b/src/test/resources/verifier/conditional.output.c0 @@ -1,6 +1,6 @@ #use int main(); -void test(int x, int y, struct OwnedFields* _ownedFields); +void test(int x, int y, int* _instanceCounter); int main() { @@ -9,17 +9,17 @@ int main() return 0; } -void test(int x, int y, struct OwnedFields* _ownedFields) +void test(int x, int y, int* _instanceCounter) { - bool _cond_1 = false; - bool _cond_2 = false; - _cond_1 = x > 1; + bool _cond = false; + bool _cond1 = false; + _cond = x > 1; if (x > 1) { - _cond_2 = x > 2; + _cond1 = x > 2; if (x > 2) { - if (_cond_1 && _cond_2) + if (_cond && _cond1) { assert(y == 0); } diff --git a/src/test/resources/verifier/conditional_embedding.output.c0 b/src/test/resources/verifier/conditional_embedding.output.c0 index c534de20..ba913b35 100644 --- a/src/test/resources/verifier/conditional_embedding.output.c0 +++ b/src/test/resources/verifier/conditional_embedding.output.c0 @@ -1,6 +1,6 @@ #use int main(); -int test(int input, struct OwnedFields* _ownedFields); +int test(int input, int* _instanceCounter); int main() { @@ -9,7 +9,7 @@ int main() return 0; } -int test(int input, struct OwnedFields* _ownedFields) +int test(int input, int* _instanceCounter) { if (input > 0) { diff --git a/src/test/resources/verifier/conditional_guard.output.c0 b/src/test/resources/verifier/conditional_guard.output.c0 index 3ec50f11..99452974 100644 --- a/src/test/resources/verifier/conditional_guard.output.c0 +++ b/src/test/resources/verifier/conditional_guard.output.c0 @@ -8,36 +8,36 @@ struct Test int _id; }; -struct Test* getTest(struct OwnedFields* _ownedFields); +struct Test* getTest(struct OwnedFields* _ownedFields, int* _instanceCounter); int main(); -struct Test* getTest(struct OwnedFields* _ownedFields) +struct Test* getTest(struct OwnedFields* _ownedFields, int* _instanceCounter) { struct Test* _ = NULL; _ = alloc(struct Test); - _->_id = addStructAcc(_ownedFields, 2); + _->_id = *_instanceCounter; + *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, _->_id, 2); return _; } int main() { struct Test* test = NULL; - bool _cond_1 = false; - int* _instanceCounter = NULL; struct OwnedFields* _ownedFields = NULL; - struct OwnedFields* _tempFields = NULL; + bool _cond = false; + int* _instanceCounter = NULL; _instanceCounter = alloc(int); - _ownedFields = initOwnedFields(_instanceCounter); - _tempFields = initOwnedFields(_instanceCounter); - test = getTest(_tempFields); - join(_ownedFields, _tempFields); - assertAcc(_ownedFields, test != NULL ? test->_id : -1, 0, "Field access runtime check failed for struct Test.a"); - _cond_1 = !(test == NULL) && test->a == 0; + _ownedFields = runtime_init(); + test = getTest(_ownedFields, _instanceCounter); + runtime_assert(_ownedFields, test == NULL ? -1 : test->_id, 0, "No permission to access 'test->a'"); + runtime_assert(_ownedFields, test == NULL ? -1 : test->_id, 0, "No permission to access 'test->a'"); + _cond = !(test == NULL) && test->a == 0; if (test->a == 0) { - if (_cond_1) + if (_cond) { - assertAcc(_ownedFields, test != NULL ? test->_id : -1, 1, "Field access runtime check failed for struct Test.b"); + runtime_assert(_ownedFields, test == NULL ? -1 : test->_id, 1, "No permission to access 'test->b'"); } test->b = 1; } diff --git a/src/test/resources/verifier/conditional_ordering.output.c0 b/src/test/resources/verifier/conditional_ordering.output.c0 index 0d79b480..2ce28ddc 100644 --- a/src/test/resources/verifier/conditional_ordering.output.c0 +++ b/src/test/resources/verifier/conditional_ordering.output.c0 @@ -8,33 +8,33 @@ struct Node int _id; }; -void appendLemmaAfterLoopBody(struct Node* a, struct Node* b, struct Node* c, int aPrev, int bVal, int cVal, struct OwnedFields* _ownedFields); +void appendLemmaAfterLoopBody(struct Node* a, struct Node* b, struct Node* c, int aPrev, int bVal, int cVal, struct OwnedFields* _ownedFields, int* _instanceCounter); int main(); -void appendLemmaAfterLoopBody(struct Node* a, struct Node* b, struct Node* c, int aPrev, int bVal, int cVal, struct OwnedFields* _ownedFields) +void appendLemmaAfterLoopBody(struct Node* a, struct Node* b, struct Node* c, int aPrev, int bVal, int cVal, struct OwnedFields* _ownedFields, int* _instanceCounter) { - bool _cond_1 = false; - bool _cond_2 = false; - bool _cond_3 = false; - _cond_1 = b == c; - _cond_2 = c == b; + bool _cond = false; + bool _cond1 = false; + bool _cond2 = false; + _cond = b == c; + _cond1 = c == b; if (!(b == c)) { } else { - _cond_3 = a == c; + _cond2 = a == c; if (a == b) { } else { - if (_cond_1 && _cond_2 && _cond_2 && !_cond_3) + if (_cond && _cond1 && _cond1 && !_cond2) { - assertAcc(_ownedFields, a != NULL ? a->_id : -1, 1, "Field access runtime check failed for struct Node.next"); - assertAcc(_ownedFields, a != NULL ? a->_id : -1, 0, "Field access runtime check failed for struct Node.val"); + runtime_assert(_ownedFields, a == NULL ? -1 : a->_id, 1, "No permission to access 'a->next'"); + runtime_assert(_ownedFields, a == NULL ? -1 : a->_id, 0, "No permission to access 'a->val'"); } - appendLemmaAfterLoopBody(a->next, b, c, a->val, bVal, cVal, _ownedFields); + appendLemmaAfterLoopBody(a->next, b, c, a->val, bVal, cVal, _ownedFields, _instanceCounter); } } } diff --git a/src/test/resources/verifier/conditional_separation.output.c0 b/src/test/resources/verifier/conditional_separation.output.c0 index 7329ce63..f03e30cd 100644 --- a/src/test/resources/verifier/conditional_separation.output.c0 +++ b/src/test/resources/verifier/conditional_separation.output.c0 @@ -14,35 +14,36 @@ struct _ptr_int_ int _id; }; -struct _ptr_int* create(struct OwnedFields* _ownedFields); -struct _ptr_int_* createNested(struct OwnedFields* _ownedFields); -int get(struct OwnedFields* _ownedFields); +struct _ptr_int* create(struct OwnedFields* _ownedFields, int* _instanceCounter); +struct _ptr_int_* createNested(struct OwnedFields* _ownedFields, int* _instanceCounter); +int get(int* _instanceCounter); int main(); -void test(int x, struct _ptr_int_* y, struct _ptr_int* z, struct OwnedFields* _ownedFields); +void test(int x, struct _ptr_int_* y, struct _ptr_int* z, int* _instanceCounter); -struct _ptr_int* create(struct OwnedFields* _ownedFields) +struct _ptr_int* create(struct OwnedFields* _ownedFields, int* _instanceCounter) { struct _ptr_int* _ = NULL; _ = alloc(struct _ptr_int); - _->_id = addStructAcc(_ownedFields, 1); + _->_id = *_instanceCounter; + *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, _->_id, 1); return _; } -struct _ptr_int_* createNested(struct OwnedFields* _ownedFields) +struct _ptr_int_* createNested(struct OwnedFields* _ownedFields, int* _instanceCounter) { struct _ptr_int_* ptr = NULL; struct _ptr_int* _ = NULL; - struct OwnedFields* _tempFields = NULL; ptr = alloc(struct _ptr_int_); - ptr->_id = addStructAcc(_ownedFields, 1); - _tempFields = initOwnedFields(_ownedFields->instanceCounter); - _ = create(_tempFields); - join(_ownedFields, _tempFields); + ptr->_id = *_instanceCounter; + *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, ptr->_id, 1); + _ = create(_ownedFields, _instanceCounter); ptr->value = _; return ptr; } -int get(struct OwnedFields* _ownedFields) +int get(int* _instanceCounter) { return 1; } @@ -52,41 +53,32 @@ int main() int v = 0; struct _ptr_int_* a = NULL; struct _ptr_int* b = NULL; - int* _instanceCounter = NULL; struct OwnedFields* _ownedFields = NULL; struct OwnedFields* _tempFields = NULL; - struct OwnedFields* _tempFields1 = NULL; - struct OwnedFields* _tempFields2 = NULL; - struct OwnedFields* _tempFields3 = NULL; + int* _instanceCounter = NULL; _instanceCounter = alloc(int); - _ownedFields = initOwnedFields(_instanceCounter); - _tempFields1 = initOwnedFields(_instanceCounter); - v = get(_tempFields1); - join(_ownedFields, _tempFields1); - _tempFields2 = initOwnedFields(_instanceCounter); - a = createNested(_tempFields2); - join(_ownedFields, _tempFields2); - _tempFields3 = initOwnedFields(_instanceCounter); - b = create(_tempFields3); - join(_ownedFields, _tempFields3); - _tempFields = initOwnedFields(_instanceCounter); - assertAcc(_ownedFields, a != NULL ? a->_id : -1, 0, "Field access runtime check failed for struct _ptr_int_.value"); + _ownedFields = runtime_init(); + v = get(_instanceCounter); + a = createNested(_ownedFields, _instanceCounter); + b = create(_ownedFields, _instanceCounter); + _tempFields = runtime_init(); + runtime_assert(_ownedFields, a == NULL ? -1 : a->_id, 0, "No permission to access 'a->value'"); if (v == 1) { - assertAcc(_ownedFields, b != NULL ? b->_id : -1, 0, "Field access runtime check failed for struct _ptr_int.value"); + runtime_assert(_ownedFields, b == NULL ? -1 : b->_id, 0, "No permission to access 'b->value'"); } - assertAcc(_ownedFields, a->value != NULL ? a->value->_id : -1, 0, "Field access runtime check failed for struct _ptr_int.value"); + runtime_assert(_ownedFields, a->value == NULL ? -1 : a->value->_id, 0, "No permission to access 'a->value->value'"); assert(!(a->value == NULL)); - addAccEnsureSeparate(_tempFields, a->value != NULL ? a->value->_id : -1, 0, 1, "Overlapping field permissions for struct _ptr_int.value"); + runtime_add(_tempFields, a->value == NULL ? -1 : a->value->_id, 1, 0, "Invalid aliasing - 'a->value->value' overlaps with existing permission"); if (v == 1) { assert(!(b == NULL)); - addAccEnsureSeparate(_tempFields, b != NULL ? b->_id : -1, 0, 1, "Overlapping field permissions for struct _ptr_int.value"); + runtime_add(_tempFields, b == NULL ? -1 : b->_id, 1, 0, "Invalid aliasing - 'b->value' overlaps with existing permission"); } - test(v, a, b, _ownedFields); + test(v, a, b, _instanceCounter); return 0; } -void test(int x, struct _ptr_int_* y, struct _ptr_int* z, struct OwnedFields* _ownedFields) +void test(int x, struct _ptr_int_* y, struct _ptr_int* z, int* _instanceCounter) { } diff --git a/src/test/resources/verifier/conditional_version.output.c0 b/src/test/resources/verifier/conditional_version.output.c0 index a980a32b..9d036576 100644 --- a/src/test/resources/verifier/conditional_version.output.c0 +++ b/src/test/resources/verifier/conditional_version.output.c0 @@ -1,7 +1,7 @@ #use int main(); -void test(int x, int y, struct OwnedFields* _ownedFields); -int testCall(int a, struct OwnedFields* _ownedFields); +void test(int x, int y, int* _instanceCounter); +int testCall(int a, int* _instanceCounter); int main() { @@ -10,28 +10,29 @@ int main() return 0; } -void test(int x, int y, struct OwnedFields* _ownedFields) +void test(int x, int y, int* _instanceCounter) { int z = 0; - bool _cond_1 = false; - bool _cond_2 = false; - bool _cond_3 = false; - bool _cond_4 = false; - struct OwnedFields* _tempFields = NULL; - _cond_1 = x > 1; + bool _cond = false; + bool _cond1 = false; + bool _cond2 = false; + bool _cond3 = false; + _cond = x > 1; if (x > 1) { - _cond_2 = x > 2; + _cond1 = x > 2; if (x > 2) { - _tempFields = initOwnedFields(_ownedFields->instanceCounter); - z = testCall(y, _tempFields); - join(_ownedFields, _tempFields); - _cond_3 = y == 0; - _cond_4 = z == 0; + z = testCall(y, _instanceCounter); + _cond2 = y == 0; + _cond3 = z == 0; if (z == 0) { - if (_cond_1 && _cond_2 && !_cond_3 && _cond_4 || _cond_1 && _cond_2 && _cond_3 && _cond_4) + if (_cond && _cond1 && !_cond2 && _cond3) + { + assert(x > 3); + } + if (_cond && _cond1 && _cond2 && _cond3) { assert(x > 3); } @@ -40,7 +41,7 @@ void test(int x, int y, struct OwnedFields* _ownedFields) } } -int testCall(int a, struct OwnedFields* _ownedFields) +int testCall(int a, int* _instanceCounter) { return a; } diff --git a/src/test/resources/verifier/loop.c0 b/src/test/resources/verifier/loop.c0 new file mode 100644 index 00000000..7c5f7077 --- /dev/null +++ b/src/test/resources/verifier/loop.c0 @@ -0,0 +1,43 @@ +struct Box { + int value; +}; + +void fact(struct Box* box) + //@requires ?; + //@ensures ?; +{ + int total = 1; + + // This should require a runtime check for acc(box->value) before the loop + // begins executing, but the loop body should run without permission tracking + // or any checks. + + while (box->value > 1) + //@loop_invariant acc(box->value); + { + total *= box->value; + box->value--; + + // No-op alloc to test if permission tracking is happening + alloc(int); + } + + box->value = total; +} + +int main() + //@requires true; +{ + struct Box* box = alloc(struct Box); + box->value = 3; + fact(box); + + int result; + if (box->value == 6) + result = 0; + else + result = -1; + + return result; +} + diff --git a/src/test/resources/verifier/loop.output.c0 b/src/test/resources/verifier/loop.output.c0 new file mode 100644 index 00000000..bc02e5ee --- /dev/null +++ b/src/test/resources/verifier/loop.output.c0 @@ -0,0 +1,65 @@ +#use +struct Box; +struct _ptr_int; + +struct Box +{ + int value; + int _id; +}; + +struct _ptr_int +{ + int value; + int _id; +}; + +void fact(struct Box* box, struct OwnedFields* _ownedFields, int* _instanceCounter); +int main(); + +void fact(struct Box* box, struct OwnedFields* _ownedFields, int* _instanceCounter) +{ + int total = 0; + struct _ptr_int* _ = NULL; + total = 1; + runtime_assert(_ownedFields, box == NULL ? -1 : box->_id, 0, "No permission to access 'box->value'"); + assert(!(box == NULL)); + runtime_remove(_ownedFields, box->_id, 0, "No permission to access 'box->value'"); + while (box->value > 1) + { + total = total * box->value; + box->value = box->value - 1; + _ = alloc(struct _ptr_int); + _->_id = *_instanceCounter; + *_instanceCounter = *_instanceCounter + 1; + } + runtime_add(_ownedFields, box == NULL ? -1 : box->_id, 1, 0, "Invalid aliasing - 'box->value' overlaps with existing permission"); + box->value = total; +} + +int main() +{ + struct Box* box = NULL; + int result = 0; + struct OwnedFields* _ownedFields = NULL; + int* _instanceCounter = NULL; + _instanceCounter = alloc(int); + _ownedFields = runtime_init(); + box = alloc(struct Box); + box->_id = *_instanceCounter; + *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, box->_id, 1); + box->value = 3; + fact(box, _ownedFields, _instanceCounter); + runtime_assert(_ownedFields, box == NULL ? -1 : box->_id, 0, "No permission to access 'box->value'"); + runtime_assert(_ownedFields, box == NULL ? -1 : box->_id, 0, "No permission to access 'box->value'"); + if (box->value == 6) + { + result = 0; + } + else + { + result = -1; + } + return result; +} diff --git a/src/test/resources/verifier/loop2.c0 b/src/test/resources/verifier/loop2.c0 new file mode 100644 index 00000000..783c73f8 --- /dev/null +++ b/src/test/resources/verifier/loop2.c0 @@ -0,0 +1,50 @@ +struct Box { + int value; +}; + +void dec(int* value) + //@requires ?; + //@ensures ?; +{ + *value = *value - 1; +} + +void fact(struct Box* box) + //@requires ?; + //@ensures ?; +{ + int total = 1; + + // This should require a runtime check for acc(box->value) before the loop + // begins executing, but the loop body should run without permission tracking + // or any checks. + + while (box->value > 1) + //@loop_invariant acc(box->value); + { + total *= box->value; + box->value--; + + // No-op call to incur permission tracking and passing inside a precise loop + dec(alloc(int)); + } + + box->value = total; +} + +int main() + //@requires true; +{ + struct Box* box = alloc(struct Box); + box->value = 3; + fact(box); + + int result; + if (box->value == 6) + result = 0; + else + result = -1; + + return result; +} + diff --git a/src/test/resources/verifier/loop2.output.c0 b/src/test/resources/verifier/loop2.output.c0 new file mode 100644 index 00000000..fe092398 --- /dev/null +++ b/src/test/resources/verifier/loop2.output.c0 @@ -0,0 +1,84 @@ +#use +struct Box; +struct _ptr_int; + +struct Box +{ + int value; + int _id; +}; + +struct _ptr_int +{ + int value; + int _id; +}; + +void dec(struct _ptr_int* value, struct OwnedFields* _ownedFields, int* _instanceCounter); +void fact(struct Box* box, struct OwnedFields* _ownedFields, int* _instanceCounter); +int main(); + +void dec(struct _ptr_int* value, struct OwnedFields* _ownedFields, int* _instanceCounter) +{ + runtime_assert(_ownedFields, value == NULL ? -1 : value->_id, 0, "No permission to access 'value->value'"); + value->value = value->value - 1; +} + +void fact(struct Box* box, struct OwnedFields* _ownedFields, int* _instanceCounter) +{ + int total = 0; + struct _ptr_int* _ = NULL; + bool _cond = false; + struct OwnedFields* _ownedFields1 = NULL; + total = 1; + runtime_assert(_ownedFields, box == NULL ? -1 : box->_id, 0, "No permission to access 'box->value'"); + if (_cond) + { + runtime_assert(_ownedFields, box == NULL ? -1 : box->_id, 0, "No permission to access 'box->value'"); + runtime_assert(_ownedFields, box == NULL ? -1 : box->_id, 0, "No permission to access 'box->value'"); + } + assert(!(box == NULL)); + runtime_remove(_ownedFields, box->_id, 0, "No permission to access 'box->value'"); + _cond = !(box == NULL) && box->value > 1; + while (box->value > 1) + { + _ownedFields1 = runtime_init(); + runtime_add(_ownedFields1, box->_id, 1, 0, "Invalid aliasing - 'box->value' overlaps with existing permission"); + total = total * box->value; + box->value = box->value - 1; + _ = alloc(struct _ptr_int); + _->_id = *_instanceCounter; + *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields1, _->_id, 1); + dec(_, _ownedFields1, _instanceCounter); + } + runtime_add(_ownedFields, box == NULL ? -1 : box->_id, 1, 0, "Invalid aliasing - 'box->value' overlaps with existing permission"); + box->value = total; +} + +int main() +{ + struct Box* box = NULL; + int result = 0; + struct OwnedFields* _ownedFields = NULL; + int* _instanceCounter = NULL; + _instanceCounter = alloc(int); + _ownedFields = runtime_init(); + box = alloc(struct Box); + box->_id = *_instanceCounter; + *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, box->_id, 1); + box->value = 3; + fact(box, _ownedFields, _instanceCounter); + runtime_assert(_ownedFields, box == NULL ? -1 : box->_id, 0, "No permission to access 'box->value'"); + runtime_assert(_ownedFields, box == NULL ? -1 : box->_id, 0, "No permission to access 'box->value'"); + if (box->value == 6) + { + result = 0; + } + else + { + result = -1; + } + return result; +} diff --git a/src/test/resources/verifier/loop3.c0 b/src/test/resources/verifier/loop3.c0 new file mode 100644 index 00000000..1187a32b --- /dev/null +++ b/src/test/resources/verifier/loop3.c0 @@ -0,0 +1,24 @@ +//@predicate nested(int* value) = ? && true; +//@predicate test(int* value) = nested(value); + +int main() + //@requires true; +{ + int* value = alloc(int); + //@fold nested(value); + //@fold test(value); + + // Loop invariant is precise but equi-recursively imprecise + for (int i = 0; i < 10; i++) + //@loop_invariant test(value); + { + //@unfold test(value); + //@unfold nested(value); + // If equi-recursive imprecision is not respected, this will fail + (*value)++; + } + + assert(*value == 10); + return 0; +} + diff --git a/src/test/resources/verifier/loop3.output.c0 b/src/test/resources/verifier/loop3.output.c0 new file mode 100644 index 00000000..460965e6 --- /dev/null +++ b/src/test/resources/verifier/loop3.output.c0 @@ -0,0 +1,66 @@ +#use +struct _ptr_int; + +struct _ptr_int +{ + int value; + int _id; +}; + +void add_nested(struct _ptr_int* value, struct OwnedFields* add_perms); +void add_test(struct _ptr_int* value, struct OwnedFields* add_perms); +void assert_nested(struct _ptr_int* value, struct OwnedFields* assert_perms); +void assert_test(struct _ptr_int* value, struct OwnedFields* assert_perms); +int main(); + +void add_nested(struct _ptr_int* value, struct OwnedFields* add_perms) +{ +} + +void add_test(struct _ptr_int* value, struct OwnedFields* add_perms) +{ +} + +void assert_nested(struct _ptr_int* value, struct OwnedFields* assert_perms) +{ + assert(true); +} + +void assert_test(struct _ptr_int* value, struct OwnedFields* assert_perms) +{ + assert_nested(value, assert_perms); +} + +int main() +{ + struct _ptr_int* value = NULL; + int i = 0; + struct OwnedFields* _ownedFields = NULL; + bool _cond = false; + struct OwnedFields* _tempFields = NULL; + int* _instanceCounter = NULL; + _instanceCounter = alloc(int); + _ownedFields = runtime_init(); + value = alloc(struct _ptr_int); + value->_id = *_instanceCounter; + *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, value->_id, 1); + i = 0; + _tempFields = runtime_init(); + if (_cond) + { + assert_test(value, _ownedFields); + } + _cond = i < 10; + while (i < 10) + { + if (_cond) + { + runtime_assert(_ownedFields, value == NULL ? -1 : value->_id, 0, "No permission to access 'value->value'"); + } + value->value = value->value + 1; + i = i + 1; + } + assert(value->value == 10); + return 0; +} diff --git a/src/test/resources/verifier/loop4.c0 b/src/test/resources/verifier/loop4.c0 new file mode 100644 index 00000000..1da9c700 --- /dev/null +++ b/src/test/resources/verifier/loop4.c0 @@ -0,0 +1,44 @@ +struct Node { + int value; + struct Node* next; +}; + +/*@ predicate list(struct Node* node) = + node == NULL ? true : acc(node->value) && acc(node->next) && list(node->next); @*/ + +struct Node* cons(struct Node* tl, int value) + //@requires list(tl); + //@ensures list(\result); +{ + struct Node* hd = alloc(struct Node); + hd->value = value; + hd->next = tl; + //@fold list(hd); + return hd; +} + +int sum_list(struct Node* node) + //@requires ?; + //@ensures ?; +{ + int sum = node->value; + if (node->next != NULL) + sum += sum_list(node->next); + return sum; +} + +int main() + //@requires true; + //@ensures true; +{ + struct Node* l = NULL; + //@fold list(l); + for (int i = 0; i < 10; i++) + //@loop_invariant list(l); + l = cons(l, i); + + int total = sum_list(l); + assert(total == 1+2+3+4+5+6+7+8+9); + + return 0; +} \ No newline at end of file diff --git a/src/test/resources/verifier/missing_result.output.c0 b/src/test/resources/verifier/missing_result.output.c0 index 620b1b70..e8fc2817 100644 --- a/src/test/resources/verifier/missing_result.output.c0 +++ b/src/test/resources/verifier/missing_result.output.c0 @@ -1,7 +1,7 @@ #use int doSomething(int x, int* _instanceCounter); int main(); -int random(struct OwnedFields* _ownedFields); +int random(int* _instanceCounter); int doSomething(int x, int* _instanceCounter) { @@ -13,26 +13,25 @@ int main() int x = 0; int _ = 0; int _1 = 0; - bool _cond_1 = false; + bool _cond = false; int* _instanceCounter = NULL; - struct OwnedFields* _ownedFields = NULL; - struct OwnedFields* _tempFields = NULL; _instanceCounter = alloc(int); - _ownedFields = initOwnedFields(_instanceCounter); _ = doSomething(0, _instanceCounter); - _tempFields = initOwnedFields(_instanceCounter); - _cond_1 = _ == 0; - x = random(_tempFields); - join(_ownedFields, _tempFields); + _cond = _ == 0; + x = random(_instanceCounter); _1 = doSomething(x, _instanceCounter); - if (!_cond_1 && !(_1 == 0) || _cond_1 && !(_1 == 0)) + if (!_cond && !(_1 == 0)) + { + assert(x == 0); + } + if (_cond && !(_1 == 0)) { assert(x == 0); } return 0; } -int random(struct OwnedFields* _ownedFields) +int random(int* _instanceCounter) { return 0; } diff --git a/src/test/resources/verifier/predicates.output.c0 b/src/test/resources/verifier/predicates.output.c0 index 43a915ed..47622c4f 100644 --- a/src/test/resources/verifier/predicates.output.c0 +++ b/src/test/resources/verifier/predicates.output.c0 @@ -8,60 +8,54 @@ struct Node int _id; }; -void add_remove_wrappedAcc(struct Node* node, struct OwnedFields* _ownedFields, struct OwnedFields* _tempFields); -void add_wrappedAcc(struct Node* node, struct OwnedFields* _ownedFields); -int fullyImprecise(struct Node* a, struct OwnedFields* _ownedFields); -int fullyPrecise(struct Node* a, struct OwnedFields* _ownedFields); -int imprecisePostcondition(struct Node* a, struct OwnedFields* _ownedFields); -int imprecisePrecondition(struct Node* a, struct OwnedFields* _ownedFields); +void add_wrappedAcc(struct Node* node, struct OwnedFields* add_perms); +void assert_wrappedAcc(struct Node* node, struct OwnedFields* assert_perms); +int fullyImprecise(struct Node* a, struct OwnedFields* _ownedFields, int* _instanceCounter); +int fullyPrecise(struct Node* a, int* _instanceCounter); +int imprecisePostcondition(struct Node* a, int* _instanceCounter); +int imprecisePrecondition(struct Node* a, struct OwnedFields* _ownedFields, int* _instanceCounter); int main(); -void sep_wrappedAcc(struct Node* node, struct OwnedFields* _ownedFields); -void wrappedAcc(struct Node* node, struct OwnedFields* _ownedFields); -void add_remove_wrappedAcc(struct Node* node, struct OwnedFields* _ownedFields, struct OwnedFields* _tempFields) +void add_wrappedAcc(struct Node* node, struct OwnedFields* add_perms) { - addAcc(_ownedFields, node->_id, 2, 0); - loseAcc(_tempFields, node->_id, 0); - addAcc(_ownedFields, node->_id, 2, 1); - loseAcc(_tempFields, node->_id, 1); + runtime_add(add_perms, node == NULL ? -1 : node->_id, 2, 0, "Invalid aliasing - 'node->value' overlaps with existing permission"); + runtime_add(add_perms, node == NULL ? -1 : node->_id, 2, 1, "Invalid aliasing - 'node->next' overlaps with existing permission"); } -void add_wrappedAcc(struct Node* node, struct OwnedFields* _ownedFields) +void assert_wrappedAcc(struct Node* node, struct OwnedFields* assert_perms) { - addAcc(_ownedFields, node->_id, 2, 0); - addAcc(_ownedFields, node->_id, 2, 1); + runtime_assert(assert_perms, node == NULL ? -1 : node->_id, 0, "No permission to access 'node->value'"); + runtime_assert(assert_perms, node == NULL ? -1 : node->_id, 1, "No permission to access 'node->next'"); } -int fullyImprecise(struct Node* a, struct OwnedFields* _ownedFields) +int fullyImprecise(struct Node* a, struct OwnedFields* _ownedFields, int* _instanceCounter) { - assertAcc(_ownedFields, a != NULL ? a->_id : -1, 0, "Field access runtime check failed for struct Node.value"); + runtime_assert(_ownedFields, a == NULL ? -1 : a->_id, 0, "No permission to access 'a->value'"); return a->value; } -int fullyPrecise(struct Node* a, struct OwnedFields* _ownedFields) +int fullyPrecise(struct Node* a, int* _instanceCounter) { int _ = 0; + struct OwnedFields* _ownedFields = NULL; struct OwnedFields* _tempFields = NULL; - struct OwnedFields* _tempFields1 = NULL; + _ownedFields = runtime_init(); add_wrappedAcc(a, _ownedFields); - _tempFields1 = initOwnedFields(_ownedFields->instanceCounter); - add_remove_wrappedAcc(a, _tempFields1, _ownedFields); - _ = imprecisePostcondition(a, _tempFields1); - join(_ownedFields, _tempFields1); - _tempFields = initOwnedFields(_ownedFields->instanceCounter); - wrappedAcc(a, _ownedFields); - sep_wrappedAcc(a, _tempFields); + _ = imprecisePostcondition(a, _instanceCounter); + _tempFields = runtime_init(); + assert_wrappedAcc(a, _ownedFields); + add_wrappedAcc(a, _tempFields); return _; } -int imprecisePostcondition(struct Node* a, struct OwnedFields* _ownedFields) +int imprecisePostcondition(struct Node* a, int* _instanceCounter) { return a->value; } -int imprecisePrecondition(struct Node* a, struct OwnedFields* _ownedFields) +int imprecisePrecondition(struct Node* a, struct OwnedFields* _ownedFields, int* _instanceCounter) { - assertAcc(_ownedFields, a != NULL ? a->_id : -1, 0, "Field access runtime check failed for struct Node.value"); + runtime_assert(_ownedFields, a == NULL ? -1 : a->_id, 0, "No permission to access 'a->value'"); return a->value; } @@ -70,25 +64,11 @@ int main() struct Node* a = NULL; int _ = 0; int* _instanceCounter = NULL; - struct OwnedFields* _tempFields = NULL; _instanceCounter = alloc(int); a = alloc(struct Node); a->_id = *_instanceCounter; *_instanceCounter = *_instanceCounter + 1; a->next = NULL; - _tempFields = initOwnedFields(_instanceCounter); - _ = fullyPrecise(a, _tempFields); + _ = fullyPrecise(a, _instanceCounter); return _; } - -void sep_wrappedAcc(struct Node* node, struct OwnedFields* _ownedFields) -{ - addAccEnsureSeparate(_ownedFields, node != NULL ? node->_id : -1, 0, 2, "Overlapping field permissions for struct Node.value"); - addAccEnsureSeparate(_ownedFields, node != NULL ? node->_id : -1, 1, 2, "Overlapping field permissions for struct Node.next"); -} - -void wrappedAcc(struct Node* node, struct OwnedFields* _ownedFields) -{ - assertAcc(_ownedFields, node != NULL ? node->_id : -1, 0, "Field access runtime check failed for struct Node.value"); - assertAcc(_ownedFields, node != NULL ? node->_id : -1, 1, "Field access runtime check failed for struct Node.next"); -} diff --git a/src/test/resources/verifier/result.output.c0 b/src/test/resources/verifier/result.output.c0 index 4656bfdb..a1bc059f 100644 --- a/src/test/resources/verifier/result.output.c0 +++ b/src/test/resources/verifier/result.output.c0 @@ -1,8 +1,8 @@ #use -int get(struct OwnedFields* _ownedFields); +int get(int* _instanceCounter); int main(); -int get(struct OwnedFields* _ownedFields) +int get(int* _instanceCounter) { return 2; } @@ -11,13 +11,8 @@ int main() { int result = 0; int* _instanceCounter = NULL; - struct OwnedFields* _tempFields = NULL; - struct OwnedFields* _ownedFields = NULL; _instanceCounter = alloc(int); - _ownedFields = initOwnedFields(_instanceCounter); - _tempFields = initOwnedFields(_instanceCounter); - result = get(_tempFields); - join(_ownedFields, _tempFields); + result = get(_instanceCounter); result = result - 1; assert(result > 0); return 0; diff --git a/src/test/resources/verifier/result_acc.output.c0 b/src/test/resources/verifier/result_acc.output.c0 index a5eb4275..f6163216 100644 --- a/src/test/resources/verifier/result_acc.output.c0 +++ b/src/test/resources/verifier/result_acc.output.c0 @@ -7,26 +7,27 @@ struct Test int _id; }; -struct Test* getTest(struct OwnedFields* _ownedFields); -struct Test* getTestPrecise(struct OwnedFields* _ownedFields); +struct Test* getTest(struct OwnedFields* _ownedFields, int* _instanceCounter); +struct Test* getTestPrecise(int* _instanceCounter); int main(); -struct Test* getTest(struct OwnedFields* _ownedFields) +struct Test* getTest(struct OwnedFields* _ownedFields, int* _instanceCounter) { struct Test* _ = NULL; _ = alloc(struct Test); - _->_id = addStructAcc(_ownedFields, 1); + _->_id = *_instanceCounter; + *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, _->_id, 1); return _; } -struct Test* getTestPrecise(struct OwnedFields* _ownedFields) +struct Test* getTestPrecise(int* _instanceCounter) { struct Test* _ = NULL; - struct OwnedFields* _tempFields = NULL; - _tempFields = initOwnedFields(_ownedFields->instanceCounter); - _ = getTest(_tempFields); - join(_ownedFields, _tempFields); - assertAcc(_ownedFields, _ != NULL ? _->_id : -1, 0, "Field access runtime check failed for struct Test.value"); + struct OwnedFields* _ownedFields = NULL; + _ownedFields = runtime_init(); + _ = getTest(_ownedFields, _instanceCounter); + runtime_assert(_ownedFields, _ == NULL ? -1 : _->_id, 0, "No permission to access '_->value'"); assert(!(_ == NULL)); return _; } diff --git a/src/test/resources/verifier/separation.output.c0 b/src/test/resources/verifier/separation.output.c0 index 7f910cfd..b5f5305c 100644 --- a/src/test/resources/verifier/separation.output.c0 +++ b/src/test/resources/verifier/separation.output.c0 @@ -7,16 +7,18 @@ struct _ptr_int int _id; }; -struct _ptr_int* create(struct OwnedFields* _ownedFields); +struct _ptr_int* create(struct OwnedFields* _ownedFields, int* _instanceCounter); void ensureSeparate(struct _ptr_int* x, struct _ptr_int* y, int* _instanceCounter); int main(); -void test(struct _ptr_int* x, struct _ptr_int* y, struct OwnedFields* _ownedFields); +void test(struct _ptr_int* x, struct _ptr_int* y, struct OwnedFields* _ownedFields, int* _instanceCounter); -struct _ptr_int* create(struct OwnedFields* _ownedFields) +struct _ptr_int* create(struct OwnedFields* _ownedFields, int* _instanceCounter) { struct _ptr_int* _ = NULL; _ = alloc(struct _ptr_int); - _->_id = addStructAcc(_ownedFields, 1); + _->_id = *_instanceCounter; + *_instanceCounter = *_instanceCounter + 1; + runtime_addAll(_ownedFields, _->_id, 1); return _; } @@ -28,33 +30,27 @@ int main() { struct _ptr_int* _ = NULL; struct _ptr_int* _1 = NULL; - int* _instanceCounter = NULL; - struct OwnedFields* _tempFields = NULL; struct OwnedFields* _ownedFields = NULL; - struct OwnedFields* _tempFields1 = NULL; + int* _instanceCounter = NULL; _instanceCounter = alloc(int); - _ownedFields = initOwnedFields(_instanceCounter); - _tempFields = initOwnedFields(_instanceCounter); - _ = create(_tempFields); - join(_ownedFields, _tempFields); - _tempFields1 = initOwnedFields(_instanceCounter); - _1 = create(_tempFields1); - join(_ownedFields, _tempFields1); - test(_, _1, _ownedFields); + _ownedFields = runtime_init(); + _ = create(_ownedFields, _instanceCounter); + _1 = create(_ownedFields, _instanceCounter); + test(_, _1, _ownedFields, _instanceCounter); return 0; } -void test(struct _ptr_int* x, struct _ptr_int* y, struct OwnedFields* _ownedFields) +void test(struct _ptr_int* x, struct _ptr_int* y, struct OwnedFields* _ownedFields, int* _instanceCounter) { struct OwnedFields* _tempFields = NULL; - _tempFields = initOwnedFields(_ownedFields->instanceCounter); - assertAcc(_ownedFields, y != NULL ? y->_id : -1, 0, "Field access runtime check failed for struct _ptr_int.value"); - assertAcc(_ownedFields, x != NULL ? x->_id : -1, 0, "Field access runtime check failed for struct _ptr_int.value"); + _tempFields = runtime_init(); + runtime_assert(_ownedFields, y == NULL ? -1 : y->_id, 0, "No permission to access 'y->value'"); + runtime_assert(_ownedFields, x == NULL ? -1 : x->_id, 0, "No permission to access 'x->value'"); assert(!(x == NULL)); assert(!(y == NULL)); - addAccEnsureSeparate(_tempFields, x != NULL ? x->_id : -1, 0, 1, "Overlapping field permissions for struct _ptr_int.value"); - addAccEnsureSeparate(_tempFields, y != NULL ? y->_id : -1, 0, 1, "Overlapping field permissions for struct _ptr_int.value"); - loseAcc(_ownedFields, x->_id, 0); - loseAcc(_ownedFields, y->_id, 0); - ensureSeparate(x, y, _ownedFields->instanceCounter); + runtime_add(_tempFields, x == NULL ? -1 : x->_id, 1, 0, "Invalid aliasing - 'x->value' overlaps with existing permission"); + runtime_add(_tempFields, y == NULL ? -1 : y->_id, 1, 0, "Invalid aliasing - 'y->value' overlaps with existing permission"); + runtime_remove(_ownedFields, x->_id, 0, "No permission to access 'x->value'"); + runtime_remove(_ownedFields, y->_id, 0, "No permission to access 'y->value'"); + ensureSeparate(x, y, _instanceCounter); } diff --git a/src/test/resources/verifier/simple.output.c0 b/src/test/resources/verifier/simple.output.c0 index 52ceeccd..85515f5e 100644 --- a/src/test/resources/verifier/simple.output.c0 +++ b/src/test/resources/verifier/simple.output.c0 @@ -1,6 +1,6 @@ #use int main(); -int test(int value, struct OwnedFields* _ownedFields); +int test(int value, int* _instanceCounter); int main() { @@ -9,7 +9,7 @@ int main() return 0; } -int test(int value, struct OwnedFields* _ownedFields) +int test(int value, int* _instanceCounter) { assert(value == 0); return value; diff --git a/src/test/scala/integration/BaselineCompilerSpec.scala b/src/test/scala/integration/BaselineCompilerSpec.scala new file mode 100644 index 00000000..0f0551c1 --- /dev/null +++ b/src/test/scala/integration/BaselineCompilerSpec.scala @@ -0,0 +1,83 @@ +package gvc.specs.integration + +import org.scalatest.funsuite.AnyFunSuite +import java.lang.ProcessBuilder.Redirect + +import gvc.specs._ +import org.scalatest._ +import java.nio.file._ +import gvc.specs.BaseFileSpec +import scala.io.Source +import gvc.parser.Parser +import fastparse.Parsed.{Failure, Success} +import gvc.analyzer.{ErrorSink, Validator} +import gvc.transformer.{IRTransformer, IRPrinter} +import gvc.benchmarking.BaselineChecks + +class BaselineCompilerSpec extends AnyFunSuite with BaseFileSpec with ParallelTestExecution { + var output: Path = null + val includeDirs = gvc.Main.Defaults.includeDirectories + + override protected def withFixture(test: NoArgTest): Outcome = { + try { + output = Files.createTempDirectory("gvc0") + super.withFixture(test) + } finally { + TestUtils.deleteDirectory(output) + } + } + + def compile(args: String*): Unit = + execute(("cc0" :: includeDirs.flatMap(dir => List("-L", dir)) ::: args.toList):_*) + + def execute(command: String*): Unit = { + val proc = new ProcessBuilder(command:_*) + .redirectError(Redirect.INHERIT) + .redirectOutput(Redirect.PIPE) + .start() + + val exit = proc.waitFor() + if (exit != 0) { + info("Output: " + Source.fromInputStream(proc.getInputStream()).mkString) + fail(s"Command '${command.mkString(" ")}' exited with code $exit") + } + } + + val inputFiles = + TestUtils.groupResources("verifier") ++ + TestUtils.groupResources("quant-study") ++ + TestUtils.groupResources("baseline") + + inputFiles.foreach { input => + test(s"baseline compile and run ${input.name}") { + val source = input(".c0").read() + val id = input.id + + val parsed = Parser.parseProgram(source) match { + case err: Failure => + fail(s"Parse error:\n${err.trace().longAggregateMsg}") + case Success(value, _) => value + } + val errors = new ErrorSink() + val resolved = Validator + .validateParsed(parsed, includeDirs, errors) + .getOrElse( + fail(errors.errors.map(_.toString()).mkString("\n")) + ) + val program = IRTransformer.transform(resolved) + + BaselineChecks.insert(program) + + val baselineSource = IRPrinter.print(program, false) + val baselineFile = output.resolve(id + ".baseline.c0") + Files.writeString(baselineFile, baselineSource) + + assertFile(input.get(".baseline.c0"), baselineSource) + + val outputExe = output.resolve(id) + compile(s"--output=${outputExe}", baselineFile.toString()) + + execute(outputExe.toString()) + } + } +} \ No newline at end of file diff --git a/src/test/scala/permutation/BaselineSpec.scala b/src/test/scala/permutation/BaselineSpec.scala index 2f64fd5f..9fbcd877 100644 --- a/src/test/scala/permutation/BaselineSpec.scala +++ b/src/test/scala/permutation/BaselineSpec.scala @@ -5,14 +5,14 @@ import org.scalatest.funsuite.AnyFunSuite import gvc.specs._ import gvc.specs.BaseFileSpec import gvc.transformer.IRPrinter -import gvc.benchmarking.BaselineChecker +import gvc.benchmarking.BaselineChecks class BaselineSpec extends AnyFunSuite with BaseFileSpec { for (input <- TestUtils.groupResources("baseline")) { test("test " + input.name) { val ir = TestUtils.program(input(".c0").read()).ir - BaselineChecker.check(ir) + BaselineChecks.insert(ir) val output = IRPrinter.print(ir, false) assertFile(input(".baseline.c0"), output) } diff --git a/src/test/scala/weaver/RuntimeSpec.scala b/src/test/scala/weaver/RuntimeSpec.scala index e7f5758a..cefd3133 100644 --- a/src/test/scala/weaver/RuntimeSpec.scala +++ b/src/test/scala/weaver/RuntimeSpec.scala @@ -13,7 +13,7 @@ class RuntimeSpec extends AnyFunSuite with BaseFileSpec { assert(IRPrinter.print(program, true).trim() == "#use ") - assert(program.method(CheckRuntime.Names.assertAcc) == runtime.assertAcc) + assert(program.method(CheckRuntime.Names.assert) == runtime.assert) assert( program.struct( CheckRuntime.Names.ownedFieldsStruct