Skip to content

Commit

Permalink
[rtl] refactor mask unit.
Browse files Browse the repository at this point in the history
  • Loading branch information
qinjun-li committed Sep 25, 2024
1 parent d7395a2 commit 4be97be
Show file tree
Hide file tree
Showing 11 changed files with 940 additions and 838 deletions.
67 changes: 67 additions & 0 deletions t1/src/Bundles.scala
Original file line number Diff line number Diff line change
Expand Up @@ -698,3 +698,70 @@ class T1Retire(xLen: Int) extends Bundle {
val csr: ValidIO[T1CSRRetire] = Valid(new T1CSRRetire)
val mem: ValidIO[EmptyBundle] = Valid(new EmptyBundle)
}

class MaskUnitGroupState(parameter: T1Parameter) extends Bundle {
val executeIndex: UInt = UInt(2.W)
val groupReadState: UInt = UInt(parameter.laneNumber.W)
val needRead: UInt = UInt(parameter.laneNumber.W)
}

class MaskUnitInstReq(parameter: T1Parameter) extends Bundle {
val instructionIndex: UInt = UInt(parameter.instructionIndexBits.W)
val decodeResult: DecodeBundle = Decoder.bundle(parameter.decoderParam)
val readFromScala: UInt = UInt(parameter.datapathWidth.W)
val eew: UInt = UInt(2.W)
val vm: Bool = Bool()
val vxrm: UInt = UInt(3.W)
}

class MaskUnitExeReq(parameter: LaneParameter) extends Bundle {
// source1, read vs
val source1: UInt = UInt(parameter.datapathWidth.W)
// source2, read offset
val source2: UInt = UInt(parameter.datapathWidth.W)
val readOverlap = UInt(4.W)
val readLane: Vec[UInt] = Vec(4, UInt(log2Ceil(parameter.laneNumber).W))
val groupCounter: UInt = UInt(parameter.groupNumberBits.W)
}

class MaskUnitExeResponse(parameter: LaneParameter) extends Bundle {
val ffoSuccess: Bool = Bool()
val writeData = new MaskUnitWriteBundle(parameter)
}

class MaskUnitReadReq(parameter: T1Parameter) extends Bundle {
val vs: UInt = UInt(5.W)
// source2, read offset
val offset: UInt = UInt(parameter.laneParam.vrfOffsetBits.W)
// Read which lane
val readLane: UInt = UInt(log2Ceil(parameter.laneNumber).W)
}

class MaskUnitReadQueue(parameter: T1Parameter) extends Bundle {
val vs: UInt = UInt(5.W)
// source2, read offset
val offset: UInt = UInt(parameter.laneParam.vrfOffsetBits.W)
// Which channel will this read request be written to?
val writeIndex: UInt = UInt(log2Ceil(parameter.laneNumber).W)
}

class MaskUnitWaitReadQueue(parameter: T1Parameter) extends Bundle {
// source1
val source1: Vec[UInt] = Vec(parameter.laneNumber, UInt(parameter.datapathWidth.W))
// source2
val source2: Vec[UInt] = Vec(parameter.laneNumber, UInt(parameter.datapathWidth.W))

val groupCounter: UInt = UInt(parameter.laneParam.groupNumberBits.W)
val executeIndex: UInt = UInt(2.W)
val lastGroup: Bool = Bool()
val sourceValid: UInt = UInt(parameter.laneNumber.W)
val writeMask: Vec[UInt] = Vec(parameter.laneNumber, UInt((parameter.datapathWidth / 4).W))

val needRead: UInt = UInt(parameter.laneNumber.W)
}

class MaskUnitWriteBundle(parameter: LaneParameter) extends Bundle {
val data: UInt = UInt(parameter.datapathWidth.W)
val mask: UInt = UInt((parameter.datapathWidth / 8).W)
val groupCounter: UInt = UInt(parameter.groupNumberBits.W)
}
56 changes: 36 additions & 20 deletions t1/src/Lane.scala
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
@public
val csrInterface: CSRInterface = IO(Input(new CSRInterface(parameter.vlMaxBits)))

@public
val maskUnitRequest: DecoupledIO[MaskUnitExeReq] = IO(Decoupled(new MaskUnitExeReq(parameter)))

@public
val maskUnitResponse: ValidIO[MaskUnitExeResponse] = IO(Flipped(Valid(new MaskUnitExeResponse(parameter))))

/** response to [[T1.lsu]] or mask unit in [[T1]] */
@public
val laneResponse: ValidIO[LaneResponse] = IO(Valid(new LaneResponse(parameter)))
Expand Down Expand Up @@ -318,6 +324,7 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[

// TODO: remove
dontTouch(writeBusPort)
maskUnitRequest <> DontCare // todo

/** VRF instantces. */
val vrf: Instance[VRF] = Instantiate(new VRF(parameter.vrfParam))
Expand Down Expand Up @@ -571,7 +578,16 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
val executionUnit: Instance[LaneExecutionBridge] = Instantiate(
new LaneExecutionBridge(parameter, isLastSlot, index)
)
val maskStage: Option[Instance[MaskExchangeUnit]] = Option.when(isLastSlot)(Instantiate(new MaskExchangeUnit(parameter)))
val stage3: Instance[LaneStage3] = Instantiate(new LaneStage3(parameter, isLastSlot))
val stage3EnqWire: DecoupledIO[LaneStage3Enqueue] = Wire(Decoupled(new LaneStage3Enqueue(parameter, isLastSlot)))
val stage3EnqSelect: DecoupledIO[LaneStage3Enqueue] = maskStage.map { mask =>
mask.enqueue <> stage3EnqWire
maskUnitRequest <> mask.maskReq
mask.maskUnitResponse := maskUnitResponse
mask.dequeue
}.getOrElse(stage3EnqWire)
stage3.enqueue <> stage3EnqSelect

// slot state
laneState.vSew1H := vSew1H
Expand Down Expand Up @@ -753,43 +769,43 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
0.U(parameter.chainingSize.W)
)
AssertProperty(BoolSequence(!executionUnit.dequeue.valid || stage2.dequeue.valid))
stage3.enqueue.valid := executionUnit.dequeue.valid
executionUnit.dequeue.ready := stage3.enqueue.ready
stage3EnqWire.valid := executionUnit.dequeue.valid
executionUnit.dequeue.ready := stage3EnqWire.ready
stage2.dequeue.ready := executionUnit.dequeue.fire

if (!isLastSlot) {
stage3.enqueue.bits := DontCare
stage3EnqWire.bits := DontCare
}

// pipe state from stage0
stage3.enqueue.bits.decodeResult := stage2.dequeue.bits.decodeResult
stage3.enqueue.bits.instructionIndex := stage2.dequeue.bits.instructionIndex
stage3.enqueue.bits.loadStore := stage2.dequeue.bits.loadStore
stage3.enqueue.bits.vd := stage2.dequeue.bits.vd
stage3.enqueue.bits.ffoByOtherLanes := ffoRecord.ffoByOtherLanes
stage3.enqueue.bits.groupCounter := stage2.dequeue.bits.groupCounter
stage3.enqueue.bits.mask := stage2.dequeue.bits.mask
stage3EnqWire.bits.decodeResult := stage2.dequeue.bits.decodeResult
stage3EnqWire.bits.instructionIndex := stage2.dequeue.bits.instructionIndex
stage3EnqWire.bits.loadStore := stage2.dequeue.bits.loadStore
stage3EnqWire.bits.vd := stage2.dequeue.bits.vd
stage3EnqWire.bits.ffoByOtherLanes := ffoRecord.ffoByOtherLanes
stage3EnqWire.bits.groupCounter := stage2.dequeue.bits.groupCounter
stage3EnqWire.bits.mask := stage2.dequeue.bits.mask
if (isLastSlot) {
stage3.enqueue.bits.sSendResponse := stage2.dequeue.bits.sSendResponse.get
stage3.enqueue.bits.ffoSuccess := executionUnit.dequeue.bits.ffoSuccess.get
stage3.enqueue.bits.fpReduceValid.zip(executionUnit.dequeue.bits.fpReduceValid).foreach { case (sink, source) =>
stage3EnqWire.bits.sSendResponse := stage2.dequeue.bits.sSendResponse.get
stage3EnqWire.bits.ffoSuccess := executionUnit.dequeue.bits.ffoSuccess.get
stage3EnqWire.bits.fpReduceValid.zip(executionUnit.dequeue.bits.fpReduceValid).foreach { case (sink, source) =>
sink := source
}
}
stage3.enqueue.bits.data := executionUnit.dequeue.bits.data
stage3.enqueue.bits.pipeData := stage2.dequeue.bits.pipeData.getOrElse(DontCare)
stage3.enqueue.bits.ffoIndex := executionUnit.dequeue.bits.ffoIndex
executionUnit.dequeue.bits.crossWriteData.foreach(data => stage3.enqueue.bits.crossWriteData := data)
stage2.dequeue.bits.sSendResponse.foreach(_ => stage3.enqueue.bits.sSendResponse := _)
executionUnit.dequeue.bits.ffoSuccess.foreach(_ => stage3.enqueue.bits.ffoSuccess := _)
stage3EnqWire.bits.data := executionUnit.dequeue.bits.data
stage3EnqWire.bits.pipeData := stage2.dequeue.bits.pipeData.getOrElse(DontCare)
stage3EnqWire.bits.ffoIndex := executionUnit.dequeue.bits.ffoIndex
executionUnit.dequeue.bits.crossWriteData.foreach(data => stage3EnqWire.bits.crossWriteData := data)
stage2.dequeue.bits.sSendResponse.foreach(_ => stage3EnqWire.bits.sSendResponse := _)
executionUnit.dequeue.bits.ffoSuccess.foreach(_ => stage3EnqWire.bits.ffoSuccess := _)

if (isLastSlot) {
when(laneResponseFeedback.valid) {
when(laneResponseFeedback.bits.complete) {
ffoRecord.ffoByOtherLanes := true.B
}
}
when(stage3.enqueue.fire) {
when(stage3EnqWire.fire) {
executionUnit.dequeue.bits.ffoSuccess.foreach(ffoRecord.selfCompleted := _)
// This group found means the next group ended early
ffoRecord.ffoByOtherLanes := ffoRecord.ffoByOtherLanes || ffoRecord.selfCompleted
Expand Down
Loading

0 comments on commit 4be97be

Please sign in to comment.