Skip to content

Commit

Permalink
[RISCV][TTI] Avoid an infinite recursion issue in getCastInstrCost (l…
Browse files Browse the repository at this point in the history
…lvm#110164)

Calling into BasicTTI is not always safe. In particular, BasicTTI does
not have a full legalization implementation (vector widening is
missing), and falls back on scalarization. The problem is that
scalarization for <N x i1> vectors is cost in terms of the cast API and
we can end up in an infinite recursive cycle.

The "right" fix for this would be teach BasicTTI how to model the full
legalization state machine, but several attempts at doing so have
resulted in dead ends or undesirable cost changes for targets I don't
understand.

This patch instead papers over the issue by avoiding the call to the
base class when dealing with an i1 source or dest. This doesn't
necessarily produce correct costs, but it should at least return
something semi-sensible and not crash.

Fixes llvm#108708
  • Loading branch information
preames authored Sep 27, 2024
1 parent 296901f commit 1a9569c
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 115 deletions.
59 changes: 38 additions & 21 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1163,9 +1163,47 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
Dst->getScalarSizeInBits() > ST->getELen())
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);

int ISD = TLI->InstructionOpcodeToISD(Opcode);
assert(ISD && "Invalid opcode");
std::pair<InstructionCost, MVT> SrcLT = getTypeLegalizationCost(Src);
std::pair<InstructionCost, MVT> DstLT = getTypeLegalizationCost(Dst);

// Handle i1 source and dest cases *before* calling logic in BasicTTI.
// The shared implementation doesn't model vector widening during legalization
// and instead assumes scalarization. In order to scalarize an <N x i1>
// vector, we need to extend/trunc to/from i8. If we don't special case
// this, we can get an infinite recursion cycle.
switch (ISD) {
default:
break;
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
if (Src->getScalarSizeInBits() == 1) {
// We do not use vsext/vzext to extend from mask vector.
// Instead we use the following instructions to extend from mask vector:
// vmv.v.i v8, 0
// vmerge.vim v8, v8, -1, v0
return DstLT.first *
getRISCVInstructionCost({RISCV::VMV_V_I, RISCV::VMERGE_VIM},
DstLT.second, CostKind) +
DstLT.first - 1;
}
break;
case ISD::TRUNCATE:
if (Dst->getScalarSizeInBits() == 1) {
// We do not use several vncvt to truncate to mask vector. So we could
// not use PowDiff to calculate it.
// Instead we use the following instructions to truncate to mask vector:
// vand.vi v8, v8, 1
// vmsne.vi v0, v8, 0
return SrcLT.first *
getRISCVInstructionCost({RISCV::VAND_VI, RISCV::VMSNE_VI},
SrcLT.second, CostKind) +
SrcLT.first - 1;
}
break;
};

// Our actual lowering for the case where a wider legal type is available
// uses promotion to the wider type. This is reflected in the result of
// getTypeLegalizationCost, but BasicTTI assumes the widened cases are
Expand All @@ -1181,22 +1219,11 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
// The split cost is handled by the base getCastInstrCost
assert((SrcLT.first == 1) && (DstLT.first == 1) && "Illegal type");

int ISD = TLI->InstructionOpcodeToISD(Opcode);
assert(ISD && "Invalid opcode");

int PowDiff = (int)Log2_32(DstLT.second.getScalarSizeInBits()) -
(int)Log2_32(SrcLT.second.getScalarSizeInBits());
switch (ISD) {
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND: {
if (Src->getScalarSizeInBits() == 1) {
// We do not use vsext/vzext to extend from mask vector.
// Instead we use the following instructions to extend from mask vector:
// vmv.v.i v8, 0
// vmerge.vim v8, v8, -1, v0
return getRISCVInstructionCost({RISCV::VMV_V_I, RISCV::VMERGE_VIM},
DstLT.second, CostKind);
}
if ((PowDiff < 1) || (PowDiff > 3))
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
unsigned SExtOp[] = {RISCV::VSEXT_VF2, RISCV::VSEXT_VF4, RISCV::VSEXT_VF8};
Expand All @@ -1206,16 +1233,6 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
return getRISCVInstructionCost(Op, DstLT.second, CostKind);
}
case ISD::TRUNCATE:
if (Dst->getScalarSizeInBits() == 1) {
// We do not use several vncvt to truncate to mask vector. So we could
// not use PowDiff to calculate it.
// Instead we use the following instructions to truncate to mask vector:
// vand.vi v8, v8, 1
// vmsne.vi v0, v8, 0
return getRISCVInstructionCost({RISCV::VAND_VI, RISCV::VMSNE_VI},
SrcLT.second, CostKind);
}
[[fallthrough]];
case ISD::FP_EXTEND:
case ISD::FP_ROUND: {
// Counts of narrow/widen instructions.
Expand Down
Loading

0 comments on commit 1a9569c

Please sign in to comment.