Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDGPU] RFC: Add and optionally use GCNIterativeRPTrackers #88797

Closed
wants to merge 9 commits into from
163 changes: 163 additions & 0 deletions llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,48 @@ void GCNRPTracker::reset(const MachineInstr &MI,
MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
}

DenseMap<int, GCNRPTracker::LiveRegSet>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the key is here. I'm also concerned by building maps for liveness info, that's already encoded in the LiveIntervals. Most of this code is also replicating logic from the generic tracker, and all of this is really hard to get correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what to do with this comment.

I'm not sure what the key is here.

Currently, these maps are keyed on the initial first instructions of the regions. However, if we are going to requery, we need a different key as these first instructions are likely to change. I've used instead RegionIdx -- which tracks the index in the Regions array. This is similarly used to index all these https://github.com/llvm/llvm-project/blob/a5044e6d505deb79f1b00bb39d11096d29b9c910/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp#L623C3-L623C10 I guess we could restructure as vector.

I'm concerned by building maps for liveness info, that's already encoded in the LiveIntervals.

While this patch does introduce a LiveOutMap, it doesn't introduce the notion of using a Map to cache live-ins -- this already exists in the code. In this patch, these maps optimize compile time by avoiding recalculation of the live regs.

Most of this code is also replicating logic from the generic tracker, and all of this is really hard to get correct.

This is mostly copy-paste from existing code in GCNRegPressure, but, nonetheless, we're running a CQE cycle to root out any correctness / performance issues. Per discussion, I think it's better to slightly duplicate logic for time being, and potentially rework the generic tracker in a long term project.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the key is here. I'm also concerned by building maps for liveness info, that's already encoded in the LiveIntervals.

I've tried to address this in #93089

I've left getLiveRegMap unmodified and used the last MI per region as the key to the liveness info (as is similarly done for BBLiveInMap currently). The result is that we do not introduce any new way of building maps for liveness info. The maps are not strictly necessary, and we could recompute them each time we encounter the same region during scheduling, but they do eliminate redundant calculations.

llvm::getLiveRegMap(DenseMap<MachineInstr *, int> &R, bool After,
LiveIntervals &LIS) {
std::vector<SlotIndex> Indexes;
// Indexes.reserve(R.size());
auto &SII = *LIS.getSlotIndexes();
for (std::pair<MachineInstr *, int> &Entry : R) {
auto SI = SII.getInstructionIndex(*Entry.first);
Indexes.push_back(After ? SI.getDeadSlot() : SI.getBaseIndex());
}
llvm::sort(Indexes);

auto &MRI = (*R.begin()).first->getParent()->getParent()->getRegInfo();
DenseMap<int, GCNRPTracker::LiveRegSet> LiveRegMap;
SmallVector<SlotIndex, 32> LiveIdxs, SRLiveIdxs;
for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
auto Reg = Register::index2VirtReg(I);
if (!LIS.hasInterval(Reg))
continue;
auto &LI = LIS.getInterval(Reg);
LiveIdxs.clear();
if (!LI.findIndexesLiveAt(Indexes, std::back_inserter(LiveIdxs)))
continue;
if (!LI.hasSubRanges()) {
for (auto SI : LiveIdxs) {
auto Idx = R[SII.getInstructionFromIndex(SI)];
LiveRegMap[Idx][Reg] = MRI.getMaxLaneMaskForVReg(Reg);
}
} else
for (const auto &S : LI.subranges()) {
// constrain search for subranges by indexes live at main range
SRLiveIdxs.clear();
S.findIndexesLiveAt(LiveIdxs, std::back_inserter(SRLiveIdxs));
for (auto SI : SRLiveIdxs) {
auto Idx = R[SII.getInstructionFromIndex(SI)];
LiveRegMap[Idx][Reg] |= S.LaneMask;
}
}
}
return LiveRegMap;
}

////////////////////////////////////////////////////////////////////////////////
// GCNUpwardRPTracker

Expand Down Expand Up @@ -570,6 +612,127 @@ bool GCNUpwardRPTracker::isValid() const {
return true;
}

////////////////////////////////////////////////////////////////////////////////
// GCNIterativeRPTrackers

void GCNIterativeRPTracker::reset(const MachineRegisterInfo *MRI_,
const LiveRegSet *LiveRegsCopy) {

MRI = MRI_;
if (LiveRegsCopy && &LiveRegs != LiveRegsCopy)
LiveRegs = *LiveRegsCopy;
if (!LiveRegsCopy)
LiveRegs.clear();
MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
}

// Mostly copy+paste from GCNUpwardRPTracker::recede
void GCNIterativeUpwardRPTracker::recede(const MachineInstr &MI,
LiveIntervals *LIS) {
assert(MRI && "call reset first");

if (MI.isDebugInstr())
return;

SmallVector<RegisterMaskPair, 8> RegUses;
collectVirtualRegUses(RegUses, MI, *LIS, *MRI);

// calc pressure at the MI (defs + uses)
auto AtMIPressure = CurPressure;
for (const auto &U : RegUses) {
auto LiveMask = LiveRegs[U.RegUnit];
AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
}
// update max pressure
MaxPressure = max(AtMIPressure, MaxPressure);

for (const auto &MO : MI.all_defs()) {
if (!MO.getReg().isVirtual() || MO.isDead())
continue;

auto Reg = MO.getReg();
auto I = LiveRegs.find(Reg);
if (I == LiveRegs.end())
continue;
auto &LiveMask = I->second;
auto PrevMask = LiveMask;
LiveMask &= ~getDefRegMask(MO, *MRI);
CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
if (LiveMask.none())
LiveRegs.erase(I);
}
for (const auto &U : RegUses) {
auto &LiveMask = LiveRegs[U.RegUnit];
auto PrevMask = LiveMask;
LiveMask |= U.LaneMask;
CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
}
assert(CurPressure == getRegPressure(*MRI, LiveRegs));
}

// Mostly copy+paste from GCNDownwardRPTracker::(advanceBeforeNext +
// advanceToNext)
void GCNIterativeDownwardRPTracker::advance(const MachineInstr &MI,
LiveIntervals *LIS) {
assert(MRI && "call reset first");
// Add new registers or mask bits.
for (const auto &MO : MI.all_defs()) {
Register Reg = MO.getReg();
if (!Reg.isVirtual())
continue;
if (MO.isDead())
continue;
auto &LiveMask = LiveRegs[Reg];
auto PrevMask = LiveMask;
LiveMask |= getDefRegMask(MO, *MRI);
CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
}

SlotIndex SI = LIS->getInstructionIndex(MI).getBoundaryIndex();
assert(SI.isValid());

// Remove dead registers or mask bits.
SmallSet<Register, 8> SeenRegs;
for (auto &MO : MI.operands()) {
if (!MO.isReg() || !MO.getReg().isVirtual())
continue;
if (!MO.isUse())
continue;
if (!MO.readsReg())
continue;
jrbyrnes marked this conversation as resolved.
Show resolved Hide resolved
if (!SeenRegs.insert(MO.getReg()).second)
continue;

const LiveInterval &LI = LIS->getInterval(MO.getReg());
if (LI.hasSubRanges()) {
auto It = LiveRegs.end();
for (const auto &S : LI.subranges()) {
if (S.expiredAt(SI)) {
if (It == LiveRegs.end()) {
It = LiveRegs.find(MO.getReg());
if (It == LiveRegs.end())
llvm_unreachable("register isn't live");
}
auto PrevMask = It->second;
It->second &= ~S.LaneMask;
CurPressure.inc(MO.getReg(), PrevMask, It->second, *MRI);
}
}
if (It != LiveRegs.end() && It->second.none()) {
LiveRegs.erase(It);
}
} else if (LI.expiredAt(SI)) {
auto It = LiveRegs.find(MO.getReg());
if (It == LiveRegs.end())
llvm_unreachable("register isn't live");
CurPressure.inc(MO.getReg(), It->second, LaneBitmask::getNone(), *MRI);
LiveRegs.erase(It);
}
}

MaxPressure = max(MaxPressure, CurPressure);
}

Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs,
const MachineRegisterInfo &MRI) {
return Printable([&LiveRegs, &MRI](raw_ostream &OS) {
Expand Down
75 changes: 37 additions & 38 deletions llvm/lib/Target/AMDGPU/GCNRegPressure.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,41 @@ class GCNDownwardRPTracker : public GCNRPTracker {
const LiveRegSet *LiveRegsCopy = nullptr);
};

class GCNIterativeRPTracker {
public:
using LiveRegSet = DenseMap<unsigned, LaneBitmask>;

protected:
LiveRegSet LiveRegs;
GCNRegPressure CurPressure, MaxPressure;

mutable const MachineRegisterInfo *MRI = nullptr;

GCNIterativeRPTracker() {};

public:
void reset(const MachineRegisterInfo *MRI_, const LiveRegSet *LiveRegsCopy);

GCNRegPressure getPressure() const { return CurPressure; }
GCNRegPressure getMaxPressure() const { return MaxPressure; }
};

class GCNIterativeUpwardRPTracker : public GCNIterativeRPTracker {
public:
GCNIterativeUpwardRPTracker() {};

// Move to the state just before the MI.
void recede(const MachineInstr &MI, LiveIntervals *TheLIS);
};

class GCNIterativeDownwardRPTracker : public GCNIterativeRPTracker {
public:
GCNIterativeDownwardRPTracker() {};

// Move to the state just after the MI.
void advance(const MachineInstr &MI, LiveIntervals *TheLIS);
};

LaneBitmask getLiveLaneMask(unsigned Reg,
SlotIndex SI,
const LiveIntervals &LIS,
Expand All @@ -275,44 +310,8 @@ GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS,
/// After - upon entry or exit of every instruction
/// Note: there is no entry in the map for instructions with empty live reg set
/// Complexity = O(NumVirtRegs * averageLiveRangeSegmentsPerReg * lg(R))
template <typename Range>
DenseMap<MachineInstr*, GCNRPTracker::LiveRegSet>
getLiveRegMap(Range &&R, bool After, LiveIntervals &LIS) {
std::vector<SlotIndex> Indexes;
Indexes.reserve(std::distance(R.begin(), R.end()));
auto &SII = *LIS.getSlotIndexes();
for (MachineInstr *I : R) {
auto SI = SII.getInstructionIndex(*I);
Indexes.push_back(After ? SI.getDeadSlot() : SI.getBaseIndex());
}
llvm::sort(Indexes);

auto &MRI = (*R.begin())->getParent()->getParent()->getRegInfo();
DenseMap<MachineInstr *, GCNRPTracker::LiveRegSet> LiveRegMap;
SmallVector<SlotIndex, 32> LiveIdxs, SRLiveIdxs;
for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
auto Reg = Register::index2VirtReg(I);
if (!LIS.hasInterval(Reg))
continue;
auto &LI = LIS.getInterval(Reg);
LiveIdxs.clear();
if (!LI.findIndexesLiveAt(Indexes, std::back_inserter(LiveIdxs)))
continue;
if (!LI.hasSubRanges()) {
for (auto SI : LiveIdxs)
LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] =
MRI.getMaxLaneMaskForVReg(Reg);
} else
for (const auto &S : LI.subranges()) {
// constrain search for subranges by indexes live at main range
SRLiveIdxs.clear();
S.findIndexesLiveAt(LiveIdxs, std::back_inserter(SRLiveIdxs));
for (auto SI : SRLiveIdxs)
LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] |= S.LaneMask;
}
}
return LiveRegMap;
}
DenseMap<int, GCNRPTracker::LiveRegSet>
getLiveRegMap(DenseMap<MachineInstr *, int> &R, bool After, LiveIntervals &LIS);

inline GCNRPTracker::LiveRegSet getLiveRegsAfter(const MachineInstr &MI,
const LiveIntervals &LIS) {
Expand Down
Loading
Loading