Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDGPU] Optionally Use GCNRPTrackers during scheduling #93090

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7740991
[AMDGPU] NFC: Add BBLiveOutMap & LiveOut Cache
jrbyrnes May 21, 2024
d2f8020
[AMDGPU] NFC: Provide RPTracker interface for external iterators
jrbyrnes May 21, 2024
819fb01
[AMDGPU] Optionally Use AMDGPU RPTrackers during scheduling
jrbyrnes May 22, 2024
b538f21
Formatting
jrbyrnes Jun 14, 2024
653e153
Actually use the iterative trackers
jrbyrnes May 27, 2024
5d92149
Review Comments
jrbyrnes May 28, 2024
b1b81cc
Use DAG.MRI
jrbyrnes May 28, 2024
83fea0a
Formatting
jrbyrnes May 28, 2024
3947cbb
Review comments
jrbyrnes Jun 14, 2024
a8600c8
Allocate Pressure vector
jrbyrnes Jun 14, 2024
4d3a3ca
Remove flag from upward RPTracker
jrbyrnes Jun 18, 2024
d468ede
Review comments
jrbyrnes Jun 19, 2024
3d072b4
Dont modify existing PreRARematStage LiveIn handling
jrbyrnes Jun 20, 2024
bb1e241
Use GCNTracker RP speculation
jrbyrnes Aug 12, 2024
366b90f
Port changes from pull/93088
jrbyrnes Aug 20, 2024
0b3e08f
Port changes from pull/93088
jrbyrnes Aug 21, 2024
9de7cc2
Feed SIRegisterInfo to Trackers + Propagate unused AGPR speculative p…
jrbyrnes Aug 21, 2024
a97ee42
Review comments
jrbyrnes Sep 5, 2024
66c42b0
Avoid const_cast
jrbyrnes Sep 18, 2024
dbd6812
Fix shouldTrackVGPRs calculation
jrbyrnes Sep 23, 2024
2714af5
Add lit tests
jrbyrnes Sep 27, 2024
7096cb0
Remove CurrLIS
jrbyrnes Sep 27, 2024
9a6563e
Mark speculative query methods as const
jrbyrnes Oct 3, 2024
e976308
Fix lit tests
jrbyrnes Oct 6, 2024
b6e86d8
Remove bumpUpwardPressure
jrbyrnes Oct 7, 2024
35ab173
Changes from pull/111452 + use the new recede
jrbyrnes Oct 8, 2024
aa74786
Code / comment cleanup
jrbyrnes Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/GCNIterativeScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ void GCNIterativeScheduler::scheduleLegacyMaxOccupancy(
LLVM_DEBUG(dbgs() << "Scheduling using default scheduler, "
"target occupancy = "
<< TgtOcc << '\n');
GCNMaxOccupancySchedStrategy LStrgy(Context);
GCNMaxOccupancySchedStrategy LStrgy(Context, /*IsLegacyScheduler=*/true);
unsigned FinalOccupancy = std::min(Occ, MFI->getOccupancy());

for (int I = 0; I < NumPasses; ++I) {
Expand Down
199 changes: 176 additions & 23 deletions llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,63 @@ collectVirtualRegUses(SmallVectorImpl<RegisterMaskPair> &RegMaskPairs,
}
}

/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
static LaneBitmask getLanesWithProperty(
const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
LaneBitmask SafeDefault,
function_ref<bool(const LiveRange &LR, SlotIndex Pos)> Property) {
if (RegUnit.isVirtual()) {
const LiveInterval &LI = LIS.getInterval(RegUnit);
LaneBitmask Result;
if (TrackLaneMasks && LI.hasSubRanges()) {
for (const LiveInterval::SubRange &SR : LI.subranges()) {
if (Property(SR, Pos))
Result |= SR.LaneMask;
}
} else if (Property(LI, Pos)) {
Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(RegUnit)
: LaneBitmask::getAll();
}

return Result;
}

const LiveRange *LR = LIS.getCachedRegUnit(RegUnit);
if (LR == nullptr)
return SafeDefault;
return Property(*LR, Pos) ? LaneBitmask::getAll() : LaneBitmask::getNone();
}

/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
/// Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}.
/// The query starts with a lane bitmask which gets lanes/bits removed for every
/// use we find.
static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask,
SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
const MachineRegisterInfo &MRI,
const SIRegisterInfo *TRI,
const LiveIntervals *LIS,
bool Upward = false) {
for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) {
if (MO.isUndef())
continue;
Comment on lines +338 to +339
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No isUndef check, it is redundant with the readsReg case for uses and broken of the def of subregister case

const MachineInstr *MI = MO.getParent();
SlotIndex InstSlot = LIS->getInstructionIndex(*MI).getRegSlot();
bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx)
: (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx);
if (!InRange)
continue;

unsigned SubRegIdx = MO.getSubReg();
LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubRegIdx);
LastUseMask &= ~UseMask;
if (LastUseMask.none())
return LaneBitmask::getNone();
}
return LastUseMask;
}

///////////////////////////////////////////////////////////////////////////////
// GCNRPTracker

Expand Down Expand Up @@ -354,17 +411,28 @@ void GCNRPTracker::reset(const MachineInstr &MI,
MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
}

////////////////////////////////////////////////////////////////////////////////
// GCNUpwardRPTracker

void GCNUpwardRPTracker::reset(const MachineRegisterInfo &MRI_,
const LiveRegSet &LiveRegs_) {
void GCNRPTracker::reset(const MachineRegisterInfo &MRI_,
const LiveRegSet &LiveRegs_) {
MRI = &MRI_;
LiveRegs = LiveRegs_;
LastTrackedMI = nullptr;
MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_);
}

/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
LaneBitmask GCNRPTracker::getLastUsedLanes(Register RegUnit,
SlotIndex Pos) const {
return getLanesWithProperty(
LIS, *MRI, true, RegUnit, Pos.getBaseIndex(), LaneBitmask::getNone(),
[](const LiveRange &LR, SlotIndex Pos) {
const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
return S != nullptr && S->end == Pos.getRegSlot();
});
}

////////////////////////////////////////////////////////////////////////////////
// GCNUpwardRPTracker

void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
assert(MRI && "call reset first");

Expand Down Expand Up @@ -441,25 +509,37 @@ bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
return true;
}

bool GCNDownwardRPTracker::advanceBeforeNext() {
bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI,
bool UseInternalIterator) {
assert(MRI && "call reset first");
if (!LastTrackedMI)
return NextMI == MBBEnd;

assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
SlotIndex SI;
const MachineInstr *CurrMI;
if (UseInternalIterator) {
if (!LastTrackedMI)
return NextMI == MBBEnd;

assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
CurrMI = LastTrackedMI;

SI = NextMI == MBBEnd
? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
: LIS.getInstructionIndex(*NextMI).getBaseIndex();
} else { //! UseInternalIterator
SI = LIS.getInstructionIndex(*MI).getBaseIndex();
CurrMI = MI;
}

SlotIndex SI = NextMI == MBBEnd
? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
: LIS.getInstructionIndex(*NextMI).getBaseIndex();
assert(SI.isValid());

// Remove dead registers or mask bits.
SmallSet<Register, 8> SeenRegs;
for (auto &MO : LastTrackedMI->operands()) {
for (auto &MO : CurrMI->operands()) {
if (!MO.isReg() || !MO.getReg().isVirtual())
continue;
if (MO.isUse() && !MO.readsReg())
continue;
if (!UseInternalIterator && MO.isDef())
continue;
if (!SeenRegs.insert(MO.getReg()).second)
continue;
const LiveInterval &LI = LIS.getInterval(MO.getReg());
Expand Down Expand Up @@ -492,15 +572,22 @@ bool GCNDownwardRPTracker::advanceBeforeNext() {

LastTrackedMI = nullptr;

return NextMI == MBBEnd;
return UseInternalIterator && (NextMI == MBBEnd);
}

void GCNDownwardRPTracker::advanceToNext() {
LastTrackedMI = &*NextMI++;
NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI,
bool UseInternalIterator) {
if (UseInternalIterator) {
LastTrackedMI = &*NextMI++;
NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
} else {
LastTrackedMI = MI;
}

const MachineInstr *CurrMI = LastTrackedMI;

// Add new registers or mask bits.
for (const auto &MO : LastTrackedMI->all_defs()) {
for (const auto &MO : CurrMI->all_defs()) {
Register Reg = MO.getReg();
if (!Reg.isVirtual())
continue;
Expand All @@ -513,11 +600,16 @@ void GCNDownwardRPTracker::advanceToNext() {
MaxPressure = max(MaxPressure, CurPressure);
}

bool GCNDownwardRPTracker::advance() {
if (NextMI == MBBEnd)
bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator) {
if (UseInternalIterator && NextMI == MBBEnd)
return false;
advanceBeforeNext();
advanceToNext();

advanceBeforeNext(MI, UseInternalIterator);
advanceToNext(MI, UseInternalIterator);
if (!UseInternalIterator) {
// We must remove any dead def lanes from the current RP
advanceBeforeNext(MI, true);
}
return true;
}

Expand Down Expand Up @@ -559,6 +651,67 @@ Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
});
}

GCNRegPressure
GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI,
const SIRegisterInfo *TRI) const {
assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction.");

SlotIndex SlotIdx;
SlotIdx = LIS.getInstructionIndex(*MI).getRegSlot();

// Account for register pressure similar to RegPressureTracker::recede().
RegisterOperands RegOpers;
RegOpers.collect(*MI, *TRI, *MRI, true, /*IgnoreDead=*/false);
RegOpers.adjustLaneLiveness(LIS, *MRI, SlotIdx);
GCNRegPressure TempPressure = CurPressure;

for (const RegisterMaskPair &Use : RegOpers.Uses) {
Register Reg = Use.RegUnit;
if (!Reg.isVirtual())
continue;
LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
if (LastUseMask.none())
continue;
// The LastUseMask is queried from the liveness information of instruction
// which may be further down the schedule. Some lanes may actually not be
// last uses for the current position.
// FIXME: allow the caller to pass in the list of vreg uses that remain
// to be bottom-scheduled to avoid searching uses at each query.
SlotIndex CurrIdx;
const MachineBasicBlock *MBB = MI->getParent();
MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward(
LastTrackedMI ? LastTrackedMI : MBB->begin(), MBB->end());
if (IdxPos == MBB->end()) {
CurrIdx = LIS.getMBBEndIdx(MBB);
} else {
CurrIdx = LIS.getInstructionIndex(*IdxPos).getRegSlot();
}

LastUseMask =
findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, TRI, &LIS);
if (LastUseMask.none())
continue;

LaneBitmask LiveMask =
LiveRegs.contains(Reg) ? LiveRegs.at(Reg) : LaneBitmask(0);
LaneBitmask NewMask = LiveMask & ~LastUseMask;
TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
}

// Generate liveness for defs.
for (const RegisterMaskPair &Def : RegOpers.Defs) {
Register Reg = Def.RegUnit;
if (!Reg.isVirtual())
continue;
LaneBitmask LiveMask =
LiveRegs.contains(Reg) ? LiveRegs.at(Reg) : LaneBitmask(0);
LaneBitmask NewMask = LiveMask | Def.LaneMask;
TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
}

return TempPressure;
}

bool GCNUpwardRPTracker::isValid() const {
const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
Expand Down
Loading