diff --git a/driver/builder.go b/driver/builder.go index 33627724..3f4dd7d1 100644 --- a/driver/builder.go +++ b/driver/builder.go @@ -87,7 +87,9 @@ func (b Builder) Build(name string) *Driver { driver.middlewares = append(driver.middlewares, globalStorageMemoryCopyMiddleware) } else { defaultMemoryCopyMiddleware := &defaultMemoryCopyMiddleware{ - driver: driver, + driver: driver, + cyclesPerD2H: 8500, + cyclesPerH2D: 14500, } driver.middlewares = append(driver.middlewares, defaultMemoryCopyMiddleware) } diff --git a/driver/memorycopy.go b/driver/memorycopy.go index a6fe27f2..2bbfd91e 100644 --- a/driver/memorycopy.go +++ b/driver/memorycopy.go @@ -12,6 +12,12 @@ import ( // communication. type defaultMemoryCopyMiddleware struct { driver *Driver + + cyclesPerH2D int + cyclesPerD2H int + cyclesLeft int + + awaitingReqs []sim.Msg } func (m *defaultMemoryCopyMiddleware) ProcessCommand( @@ -67,7 +73,8 @@ func (m *defaultMemoryCopyMiddleware) processMemCopyH2DCommand( rawBytes[offset:offset+sizeToCopy], pAddr) cmd.Reqs = append(cmd.Reqs, req) - m.driver.requestsToSend = append(m.driver.requestsToSend, req) + m.awaitingReqs = append(m.awaitingReqs, req) + // m.driver.requestsToSend = append(m.driver.requestsToSend, req) sizeLeft -= sizeToCopy addr += sizeToCopy @@ -76,6 +83,8 @@ func (m *defaultMemoryCopyMiddleware) processMemCopyH2DCommand( m.driver.logTaskToGPUInitiate(now, cmd, req) } + m.cyclesLeft = m.cyclesPerH2D + queue.IsRunning = true return true @@ -114,7 +123,8 @@ func (m *defaultMemoryCopyMiddleware) processMemCopyD2HCommand( m.driver.gpuPort, m.driver.GPUs[gpuID-1], pAddr, cmd.RawData[offset:offset+sizeToCopy]) cmd.Reqs = append(cmd.Reqs, req) - m.driver.requestsToSend = append(m.driver.requestsToSend, req) + m.awaitingReqs = append(m.awaitingReqs, req) + // m.driver.requestsToSend = append(m.driver.requestsToSend, req) sizeLeft -= sizeToCopy addr += sizeToCopy @@ -123,6 +133,8 @@ func (m *defaultMemoryCopyMiddleware) processMemCopyD2HCommand( m.driver.logTaskToGPUInitiate(now, cmd, req) } + m.cyclesLeft = m.cyclesPerD2H + queue.IsRunning = true return true } @@ -177,17 +189,29 @@ func (m *defaultMemoryCopyMiddleware) sendFlushRequest( func (m *defaultMemoryCopyMiddleware) Tick( now sim.VTimeInSec, ) (madeProgress bool) { + madeProgress = false + + if m.cyclesLeft > 0 { + m.cyclesLeft-- + madeProgress = true + } else if m.cyclesLeft == 0 { + m.driver.requestsToSend = append(m.driver.requestsToSend, m.awaitingReqs...) + m.awaitingReqs = nil + m.cyclesLeft = -1 + madeProgress = true + } + req := m.driver.gpuPort.Peek() if req == nil { - return false + return madeProgress } switch req := req.(type) { case *sim.GeneralRsp: - return m.processGeneralRsp(now, req) + madeProgress = m.processGeneralRsp(now, req) } - return false + return madeProgress } func (m *defaultMemoryCopyMiddleware) processGeneralRsp(