From 0e81af7d4c88b0d1ac33b42b561801c402ecd355 Mon Sep 17 00:00:00 2001 From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com> Date: Tue, 20 Dec 2022 13:02:25 +0000 Subject: [PATCH] partialPoke/Expect: add more tests, improve implementation (#589) (#590) * partialPoke/Expect: add more tests, improve implementation * fix error reporting for partialExpect (cherry picked from commit 8f692f2f7102c4983adfac615a6080cf146fb7a4) Co-authored-by: Kevin Laeufer --- src/main/scala/chiseltest/exceptions.scala | 11 + .../internal/TestEnvInterface.scala | 10 +- src/main/scala/chiseltest/package.scala | 169 ++++++------ .../tests/PokeAndExpectPartialTests.scala | 250 ++++++++++++++++++ 4 files changed, 345 insertions(+), 95 deletions(-) create mode 100644 src/test/scala/chiseltest/tests/PokeAndExpectPartialTests.scala diff --git a/src/main/scala/chiseltest/exceptions.scala b/src/main/scala/chiseltest/exceptions.scala index b4c09d8d1..fff4f58c3 100644 --- a/src/main/scala/chiseltest/exceptions.scala +++ b/src/main/scala/chiseltest/exceptions.scala @@ -24,3 +24,14 @@ class StopException(message: String) extends Exception(message) /** Indicates that a Chisel `assert(...)` or `assume(...)` statement has failed. */ class ChiselAssertionError(message: String) extends Exception(message) + +/** Indicates that a value used in a poke/expect is not a literal. + * It could be hardware or a DontCare which is only allowed when using pokePartial/expectPartial. + */ +class NonLiteralValueError(val value: chisel3.Data, val signal: chisel3.Data, op: String) + extends Exception( + s"""Value $value for entry $signal is not a literal value! + |You need to fully specify all fields/entries when using $op. + |Maybe try using `${op}Partial` if you only want to use incomplete Vec/Bundle literals. + |""".stripMargin + ) diff --git a/src/main/scala/chiseltest/internal/TestEnvInterface.scala b/src/main/scala/chiseltest/internal/TestEnvInterface.scala index fcc1fa1d9..3b472993b 100644 --- a/src/main/scala/chiseltest/internal/TestEnvInterface.scala +++ b/src/main/scala/chiseltest/internal/TestEnvInterface.scala @@ -45,17 +45,19 @@ trait TestEnvInterface { def signalExpectFailure(message: String): Unit = { val trace = new Throwable val expectStackDepth = trace.getStackTrace.indexWhere(ste => - ste.getClassName.startsWith("chiseltest.package$") && ste.getMethodName == "expect" + ste.getClassName.startsWith( + "chiseltest.package$" + ) && (ste.getMethodName == "expect" || ste.getMethodName == "expectPartial") ) require( expectStackDepth != -1, s"Failed to find expect in stack trace:\r\n${trace.getStackTrace.mkString("\r\n")}" ) - val trimmedTrace = trace.getStackTrace.drop(expectStackDepth + 2) - val detailedTrace = topFileName.map(getExpectDetailedTrace(trimmedTrace.toSeq, _)).getOrElse("") + val trimmedTrace = trace.getStackTrace.drop(expectStackDepth) + val failureLocation: String = topFileName.map(getExpectDetailedTrace(trimmedTrace.toSeq, _)).getOrElse("") val stackIndex = expectStackDepth + 1 - batchedFailures += new FailedExpectException(message + detailedTrace, stackIndex) + batchedFailures += new FailedExpectException(message + failureLocation, stackIndex) } /** If there are any failures, reports them and end the test now. diff --git a/src/main/scala/chiseltest/package.scala b/src/main/scala/chiseltest/package.scala index 2d4fcd626..9617f52c3 100644 --- a/src/main/scala/chiseltest/package.scala +++ b/src/main/scala/chiseltest/package.scala @@ -196,19 +196,7 @@ package object chiseltest { */ def pokePartial(value: T): Unit = { require(DataMirror.checkTypeEquivalence(x, value), s"Record type mismatch") - x.elements.filter { case (k, v) => - DataMirror.directionOf(v) != ActualDirection.Output && { - value.elements(k) match { - case _: Record => true - case data: Data => data.isLit - } - } - }.foreach { case (k, v) => - v match { - case record: Record => record.pokePartial(value.elements(k).asInstanceOf[Record]) - case data: Data => data.poke(value.elements(k)) - } - } + x.pokeInternal(value, allowPartial = true) } /** Check the given signal with a [[Record.litValue()]]; @@ -216,17 +204,7 @@ package object chiseltest { */ def expectPartial(value: T): Unit = { require(DataMirror.checkTypeEquivalence(x, value), s"Record type mismatch") - x.elements.filter { case (k, _) => - value.elements(k) match { - case _: Record => true - case d: Data => d.isLit - } - }.foreach { case (k, v) => - v match { - case record: Record => record.expectPartial(value.elements(k).asInstanceOf[Record]) - case data: Data => data.expect(value.elements(k)) - } - } + x.expectInternal(value, None, allowPartial = true) } } @@ -242,19 +220,7 @@ package object chiseltest { */ def pokePartial(value: T): Unit = { require(DataMirror.checkTypeEquivalence(x, value), s"Vec type mismatch") - x.getElements.zipWithIndex.filter { case (v, index) => - DataMirror.directionOf(v) != ActualDirection.Output && { - value.getElements(index) match { - case _: T => true - case data: Data => data.isLit - } - } - }.foreach { case (v, index) => - v match { - case vec: T => vec.pokePartial(value.getElements(index).asInstanceOf[T]) - case data: Data => data.poke(value.getElements(index)) - } - } + x.pokeInternal(value, allowPartial = true) } /** Check the given signal with a [[Vec.litValue()]]; @@ -262,44 +228,55 @@ package object chiseltest { */ def expectPartial(value: T): Unit = { require(DataMirror.checkTypeEquivalence(x, value), s"Vec type mismatch") - x.getElements.zipWithIndex.filter { case (_, index) => - value.getElements(index) match { - case _: T => true - case d: Data => d.isLit - } - }.foreach { case (v, index) => - v match { - case vec: T => vec.expectPartial(value.getElements(index).asInstanceOf[T]) - case data: Data => data.expect(value.getElements(index)) - } - } + x.expectInternal(value, None, allowPartial = true) } } implicit class testableData[T <: Data](x: T) { import Utils._ - def poke(value: T): Unit = (x, value) match { - case (x: Bool, value: Bool) => x.poke(value) - case (x: UInt, value: UInt) => x.poke(value) - case (x: SInt, value: SInt) => x.poke(value) - case (x: FixedPoint, value: FixedPoint) => x.poke(value) - case (x: Interval, value: Interval) => x.poke(value) - case (x: Record, value: Record) => - require(DataMirror.checkTypeEquivalence(x, value), s"Record type mismatch") - x.elements.zip(value.elements).foreach { case ((_, x), (_, value)) => - x.poke(value) - } - case (x: Vec[_], value: Vec[_]) => - require(DataMirror.checkTypeEquivalence(x, value), s"Vec type mismatch") - x.getElements.zip(value.getElements).foreach { case (x, value) => - x.poke(value) + def poke(value: T): Unit = pokeInternal(value, allowPartial = false) + + private def isAllowedNonLitGround(value: T, allowPartial: Boolean, op: String): Boolean = { + val isGroundType = value match { + case _: Vec[_] | _: Record => false + case _ => true + } + // if we are dealing with a ground type non-literal, this is only allowed if we are doing a partial poke/expect + if (isGroundType && !value.isLit) { + if (allowPartial) { true } + else { + throw new NonLiteralValueError(value, x, op) } - case (x: EnumType, value: EnumType) => - require(DataMirror.checkTypeEquivalence(x, value), s"EnumType mismatch") - pokeBits(x, value.litValue) - case x => throw new LiteralTypeException(s"don't know how to poke $x") - // TODO: aggregate types + } else { + false + } + } + + private[chiseltest] def pokeInternal(value: T, allowPartial: Boolean): Unit = { + if (isAllowedNonLitGround(value, allowPartial, "poke")) return + (x, value) match { + case (x: Bool, value: Bool) => x.poke(value) + case (x: UInt, value: UInt) => x.poke(value) + case (x: SInt, value: SInt) => x.poke(value) + case (x: FixedPoint, value: FixedPoint) => x.poke(value) + case (x: Interval, value: Interval) => x.poke(value) + case (x: Record, value: Record) => + require(DataMirror.checkTypeEquivalence(x, value), s"Record type mismatch") + x.elements.zip(value.elements).foreach { case ((_, x), (_, value)) => + x.pokeInternal(value, allowPartial) + } + case (x: Vec[_], value: Vec[_]) => + require(DataMirror.checkTypeEquivalence(x, value), s"Vec type mismatch") + x.getElements.zip(value.getElements).foreach { case (x, value) => + x.pokeInternal(value, allowPartial) + } + case (x: EnumType, value: EnumType) => + require(DataMirror.checkTypeEquivalence(x, value), s"EnumType mismatch") + pokeBits(x, value.litValue) + case x => throw new LiteralTypeException(s"don't know how to poke $x") + // TODO: aggregate types + } } def peek(): T = x match { @@ -322,31 +299,41 @@ package object chiseltest { case x => throw new LiteralTypeException(s"don't know how to peek $x") } - protected def expectInternal(value: T, message: Option[() => String]): Unit = (x, value) match { - case (x: Bool, value: Bool) => x.expectInternal(value.litValue, message) - case (x: UInt, value: UInt) => x.expectInternal(value.litValue, message) - case (x: SInt, value: SInt) => x.expectInternal(value.litValue, message) - case (x: FixedPoint, value: FixedPoint) => x.expectInternal(value, message) - case (x: Interval, value: Interval) => x.expectInternal(value, message) - case (x: Record, value: Record) => - require(DataMirror.checkTypeEquivalence(x, value), s"Record type mismatch") - x.elements.zip(value.elements).foreach { case ((_, x), (_, value)) => - x.expectInternal(value, message) - } - case (x: Vec[_], value: Vec[_]) => - require(DataMirror.checkTypeEquivalence(x, value), s"Vec type mismatch") - x.getElements.zip(value.getElements).foreach { case (x, value) => - x.expectInternal(value, message) - } - case (x: EnumType, value: EnumType) => - require(DataMirror.checkTypeEquivalence(x, value), s"EnumType mismatch") - Utils.expectBits(x, value.litValue, message, Some(enumToString(x))) - case x => throw new LiteralTypeException(s"don't know how to expect $x") - // TODO: aggregate types + private[chiseltest] def expectInternal(value: T, message: Option[() => String], allowPartial: Boolean): Unit = { + if (isAllowedNonLitGround(value, allowPartial, "expect")) return + (x, value) match { + case (x: Bool, value: Bool) => x.expectInternal(value.litValue, message) + case (x: UInt, value: UInt) => x.expectInternal(value.litValue, message) + case (x: SInt, value: SInt) => x.expectInternal(value.litValue, message) + case (x: FixedPoint, value: FixedPoint) => x.expectInternal(value, message) + case (x: Interval, value: Interval) => x.expectInternal(value, message) + case (x: Record, value: Record) => + require(DataMirror.checkTypeEquivalence(x, value), s"Record type mismatch") + x.elements.zip(value.elements).foreach { case ((_, x), (_, value)) => + x.expectInternal(value, message, allowPartial) + } + case (x: Vec[_], value: Vec[_]) => + require(DataMirror.checkTypeEquivalence(x, value), s"Vec type mismatch") + x.getElements.zip(value.getElements).zipWithIndex.foreach { case ((subX, value), index) => + value match { + case DontCare => + throw new RuntimeException( + s"Vec $x needs to be fully specified when using expect. Index $index is missing." + + "Maybe try using `expectPartial` if you only want to check for some elements." + ) + case other => subX.expectInternal(other, message, allowPartial) + } + } + case (x: EnumType, value: EnumType) => + require(DataMirror.checkTypeEquivalence(x, value), s"EnumType mismatch") + Utils.expectBits(x, value.litValue, message, Some(enumToString(x))) + case x => throw new LiteralTypeException(s"don't know how to expect $x") + // TODO: aggregate types + } } - def expect(value: T): Unit = expectInternal(value, None) - def expect(value: T, message: => String): Unit = expectInternal(value, Some(() => message)) + def expect(value: T): Unit = expectInternal(value, None, allowPartial = false) + def expect(value: T, message: => String): Unit = expectInternal(value, Some(() => message), allowPartial = false) /** @return the single clock that drives the source of this signal. * @throws ClockResolutionException if sources of this signal have more than one, or zero clocks diff --git a/src/test/scala/chiseltest/tests/PokeAndExpectPartialTests.scala b/src/test/scala/chiseltest/tests/PokeAndExpectPartialTests.scala new file mode 100644 index 000000000..a8de08588 --- /dev/null +++ b/src/test/scala/chiseltest/tests/PokeAndExpectPartialTests.scala @@ -0,0 +1,250 @@ +package chiseltest.tests + +import chisel3._ +import chisel3.experimental.BundleLiterals.AddBundleLiteralConstructor +import chisel3.experimental.VecLiterals.AddVecLiteralConstructor +import chiseltest._ +import org.scalatest.flatspec.AnyFlatSpec + +class PokeAndExpectPartialTests extends AnyFlatSpec with ChiselScalatestTester{ + behavior of "pokePartial" + + it should "work with a bundle of uint" in { + val typ = new CustomBundle("foo" -> UInt(32.W), "bar" -> UInt(32.W)) + test(new PassthroughModule(typ)) { c => + c.in.pokePartial(typ.Lit( + _.elements("foo") -> 4.U + )) + c.out.expectPartial(typ.Lit( + _.elements("foo") -> 4.U + )) + c.clock.step() + c.in.pokePartial(typ.Lit( + _.elements("bar") -> 5.U + )) + c.out.expect(typ.Lit( + _.elements("foo") -> 4.U, + _.elements("bar") -> 5.U + )) + } + } + + it should "work with a bundle of bundle" in { + val innerTyp = new CustomBundle("0" -> UInt(8.W), "1" -> UInt(17.W), "2" -> UInt(100.W)) + val typ = new CustomBundle("0" -> innerTyp, "1" -> innerTyp) + test(new PassthroughModule(typ)) { c => + c.in.pokePartial(typ.Lit( + _.elements("0") -> innerTyp.Lit( + // full inner bundle + _.elements("0") -> 4.U, + _.elements("1") -> 4.U, + _.elements("2") -> 4.U, + ))) + c.out.expectPartial(typ.Lit( + _.elements("0") -> innerTyp.Lit( + // full inner bundle + _.elements("0") -> 4.U, + _.elements("1") -> 4.U, + _.elements("2") -> 4.U, + ))) + c.clock.step() + c.in.pokePartial(typ.Lit( + _.elements("1") -> innerTyp.Lit( + // partial inner bundle + _.elements("0") -> 3.U, + _.elements("2") -> 3.U, + ))) + c.out.expectPartial(typ.Lit( + _.elements("1") -> innerTyp.Lit( + // partial inner bundle + _.elements("0") -> 3.U, + _.elements("2") -> 3.U, + ))) + c.clock.step() + c.in.pokePartial(typ.Lit( + _.elements("1") -> innerTyp.Lit( + // partial inner bundle + _.elements("0") -> 7.U, // partial overwrite! + _.elements("1") -> 7.U, + ))) + c.out.expect(typ.Lit( + _.elements("0") -> innerTyp.Lit( + _.elements("0") -> 4.U, + _.elements("1") -> 4.U, + _.elements("2") -> 4.U, + ), + _.elements("1") -> innerTyp.Lit( + _.elements("0") -> 7.U, + _.elements("1") -> 7.U, + _.elements("2") -> 3.U, + ), + )) + } + } + + it should "work with a vector of uint" in { + val typ = Vec(4, UInt(32.W)) + test(new PassthroughModule(typ)) { c => + c.in.pokePartial(typ.Lit( + 0 -> 4.U, + )) + c.out.expectPartial(typ.Lit( + 0 -> 4.U, + )) + c.clock.step() + c.in.pokePartial(typ.Lit( + 3 -> 5.U, + 2 -> 321.U, + 1 -> 123.U, + )) + c.out.expect(typ.Lit( + 0 -> 4.U, + 1 -> 123.U, + 2 -> 321.U, + 3 -> 5.U, + )) + c.clock.step() + c.in.pokePartial(typ.Lit( + 2 -> 444.U, + )) + c.out.expect(typ.Lit( + 0 -> 4.U, + 1 -> 123.U, + 2 -> 444.U, + 3 -> 5.U, + )) + } + } + + it should "work with a vector of vector" in { + val innerTyp = Vec(3, UInt(32.W)) + val typ = Vec(2, innerTyp) + test(new PassthroughModule(typ)) { c => + c.in.pokePartial(typ.Lit( + 0 -> innerTyp.Lit( + // full inner vector + 0 -> 4.U, + 1 -> 4.U, + 2 -> 4.U, + ))) + c.out.expectPartial(typ.Lit( + 0 -> innerTyp.Lit( + // full inner vector + 0 -> 4.U, + 1 -> 4.U, + 2 -> 4.U, + ))) + c.clock.step() + c.in.pokePartial(typ.Lit( + 1 -> innerTyp.Lit( + // partial inner vector + 0 -> 3.U, + 2 -> 3.U, + ))) + c.out.expectPartial(typ.Lit( + 1 -> innerTyp.Lit( + // partial inner vector + 0 -> 3.U, + 2 -> 3.U, + ))) + c.clock.step() + c.in.pokePartial(typ.Lit( + 1 -> innerTyp.Lit( + // partial inner vector + 0 -> 7.U, // partial overwrite! + 1 -> 7.U, + ))) + c.out.expect(typ.Lit( + 0 -> innerTyp.Lit( + 0 -> 4.U, + 1 -> 4.U, + 2 -> 4.U, + ), + 1 -> innerTyp.Lit( + 0 -> 7.U, + 1 -> 7.U, + 2 -> 3.U, + ), + )) + } + } + + it should "work with a vector of bundle" in { + val innerTyp = new CustomBundle("0" -> UInt(8.W), "1" -> UInt(17.W), "2" -> UInt(100.W)) + val typ = Vec(2, innerTyp) + test(new PassthroughModule(typ)) { c => + c.in.pokePartial(typ.Lit( + 0 -> innerTyp.Lit( + // full inner bundle + _.elements("0") -> 4.U, + _.elements("1") -> 4.U, + _.elements("2") -> 4.U, + ))) + c.out.expectPartial(typ.Lit( + 0 -> innerTyp.Lit( + // full inner bundle + _.elements("0") -> 4.U, + _.elements("1") -> 4.U, + _.elements("2") -> 4.U, + ))) + c.clock.step() + c.in.pokePartial(typ.Lit( + 1 -> innerTyp.Lit( + // partial inner bundle + _.elements("0") -> 3.U, + _.elements("2") -> 3.U, + ))) + c.out.expectPartial(typ.Lit( + 1 -> innerTyp.Lit( + // partial inner bundle + _.elements("0") -> 3.U, + _.elements("2") -> 3.U, + ))) + c.clock.step() + c.in.pokePartial(typ.Lit( + 1 -> innerTyp.Lit( + // partial inner bundle + _.elements("0") -> 7.U, // partial overwrite! + _.elements("1") -> 7.U, + ))) + c.out.expect(typ.Lit( + 0 -> innerTyp.Lit( + _.elements("0") -> 4.U, + _.elements("1") -> 4.U, + _.elements("2") -> 4.U, + ), + 1 -> innerTyp.Lit( + _.elements("0") -> 7.U, + _.elements("1") -> 7.U, + _.elements("2") -> 3.U, + ), + )) + } + } + + behavior of "poke" + + it should "provide a good error message when used with partial bundle literals" in { + val typ = new CustomBundle("foo" -> UInt(32.W), "bar" -> UInt(32.W)) + assertThrows[NonLiteralValueError] { + test(new PassthroughModule(typ)) { c => + c.in.poke(typ.Lit( + _.elements("foo") -> 123.U + )) + } + } + } + + it should "provide a good error message when used with partial vector literals" in { + val typ = Vec(4, UInt(32.W)) + assertThrows[NonLiteralValueError] { + test(new PassthroughModule(typ)) { c => + c.in.poke(typ.Lit( + 0 -> 4.U, + 3 -> 5.U + )) + } + } + } + +}