diff --git a/llvm/lib/IR/ConstantFPRange.cpp b/llvm/lib/IR/ConstantFPRange.cpp index 3f63ca8e62c258..74c9797d969f9d 100644 --- a/llvm/lib/IR/ConstantFPRange.cpp +++ b/llvm/lib/IR/ConstantFPRange.cpp @@ -108,11 +108,114 @@ ConstantFPRange ConstantFPRange::getNonNaN(const fltSemantics &Sem) { /*MayBeQNaN=*/false, /*MayBeSNaN=*/false); } +/// Return true for ULT/UGT/OLT/OGT +static bool fcmpPredExcludesEqual(FCmpInst::Predicate Pred) { + return !(Pred & FCmpInst::FCMP_OEQ); +} + +/// Return [-inf, V) or [-inf, V] +static ConstantFPRange makeLessThan(APFloat V, FCmpInst::Predicate Pred) { + const fltSemantics &Sem = V.getSemantics(); + if (fcmpPredExcludesEqual(Pred)) { + if (V.isNegInfinity()) + return ConstantFPRange::getEmpty(Sem); + V.next(/*nextDown=*/true); + } + return ConstantFPRange::getNonNaN(APFloat::getInf(Sem, /*Negative=*/true), + std::move(V)); +} + +/// Return (V, +inf] or [V, +inf] +static ConstantFPRange makeGreaterThan(APFloat V, FCmpInst::Predicate Pred) { + const fltSemantics &Sem = V.getSemantics(); + if (fcmpPredExcludesEqual(Pred)) { + if (V.isPosInfinity()) + return ConstantFPRange::getEmpty(Sem); + V.next(/*nextDown=*/false); + } + return ConstantFPRange::getNonNaN(std::move(V), + APFloat::getInf(Sem, /*Negative=*/false)); +} + +/// Make sure that +0/-0 are both included in the range. +static ConstantFPRange extendZeroIfEqual(const ConstantFPRange &CR, + FCmpInst::Predicate Pred) { + if (fcmpPredExcludesEqual(Pred)) + return CR; + + APFloat Lower = CR.getLower(); + APFloat Upper = CR.getUpper(); + if (Lower.isPosZero()) + Lower = APFloat::getZero(Lower.getSemantics(), /*Negative=*/true); + if (Upper.isNegZero()) + Upper = APFloat::getZero(Upper.getSemantics(), /*Negative=*/false); + return ConstantFPRange(std::move(Lower), std::move(Upper), CR.containsQNaN(), + CR.containsSNaN()); +} + +static ConstantFPRange setNaNField(const ConstantFPRange &CR, + FCmpInst::Predicate Pred) { + bool ContainsNaN = FCmpInst::isUnordered(Pred); + return ConstantFPRange(CR.getLower(), CR.getUpper(), + /*MayBeQNaN=*/ContainsNaN, /*MayBeSNaN=*/ContainsNaN); +} + ConstantFPRange ConstantFPRange::makeAllowedFCmpRegion(FCmpInst::Predicate Pred, const ConstantFPRange &Other) { - // TODO - return getFull(Other.getSemantics()); + if (Other.isEmptySet()) + return Other; + if (Other.containsNaN() && FCmpInst::isUnordered(Pred)) + return getFull(Other.getSemantics()); + if (Other.isNaNOnly() && FCmpInst::isOrdered(Pred)) + return getEmpty(Other.getSemantics()); + + switch (Pred) { + case FCmpInst::FCMP_TRUE: + return getFull(Other.getSemantics()); + case FCmpInst::FCMP_FALSE: + return getEmpty(Other.getSemantics()); + case FCmpInst::FCMP_ORD: + return getNonNaN(Other.getSemantics()); + case FCmpInst::FCMP_UNO: + return getNaNOnly(Other.getSemantics(), /*MayBeQNaN=*/true, + /*MayBeSNaN=*/true); + case FCmpInst::FCMP_OEQ: + case FCmpInst::FCMP_UEQ: + return setNaNField(extendZeroIfEqual(Other, Pred), Pred); + case FCmpInst::FCMP_ONE: + case FCmpInst::FCMP_UNE: + if (const APFloat *SingleElement = + Other.getSingleElement(/*ExcludesNaN=*/true)) { + const fltSemantics &Sem = SingleElement->getSemantics(); + if (SingleElement->isPosInfinity()) + return setNaNField( + getNonNaN(APFloat::getInf(Sem, /*Negative=*/true), + APFloat::getLargest(Sem, /*Negative=*/false)), + Pred); + if (SingleElement->isNegInfinity()) + return setNaNField( + getNonNaN(APFloat::getLargest(Sem, /*Negative=*/true), + APFloat::getInf(Sem, /*Negative=*/false)), + Pred); + } + return Pred == FCmpInst::FCMP_ONE ? getNonNaN(Other.getSemantics()) + : getFull(Other.getSemantics()); + case FCmpInst::FCMP_OLT: + case FCmpInst::FCMP_OLE: + case FCmpInst::FCMP_ULT: + case FCmpInst::FCMP_ULE: + return setNaNField( + extendZeroIfEqual(makeLessThan(Other.getUpper(), Pred), Pred), Pred); + case FCmpInst::FCMP_OGT: + case FCmpInst::FCMP_OGE: + case FCmpInst::FCMP_UGT: + case FCmpInst::FCMP_UGE: + return setNaNField( + extendZeroIfEqual(makeGreaterThan(Other.getLower(), Pred), Pred), Pred); + default: + llvm_unreachable("Unexpected predicate"); + } } ConstantFPRange diff --git a/llvm/unittests/IR/ConstantFPRangeTest.cpp b/llvm/unittests/IR/ConstantFPRangeTest.cpp index d228651b5129cc..17a08207fe1ba0 100644 --- a/llvm/unittests/IR/ConstantFPRangeTest.cpp +++ b/llvm/unittests/IR/ConstantFPRangeTest.cpp @@ -161,6 +161,19 @@ static void EnumerateValuesInConstantFPRange(const ConstantFPRange &CR, } } +template +static bool AnyOfValueInConstantFPRange(const ConstantFPRange &CR, Fn TestFn) { + const fltSemantics &Sem = CR.getSemantics(); + unsigned Bits = APFloat::semanticsSizeInBits(Sem); + assert(Bits < 32 && "Too many bits"); + for (unsigned I = 0, E = (1U << Bits) - 1; I != E; ++I) { + APFloat V(Sem, APInt(Bits, I)); + if (CR.contains(V) && TestFn(V)) + return true; + } + return false; +} + TEST_F(ConstantFPRangeTest, Basics) { EXPECT_TRUE(Full.isFullSet()); EXPECT_FALSE(Full.isEmptySet()); @@ -429,4 +442,32 @@ TEST_F(ConstantFPRangeTest, MismatchedSemantics) { #endif #endif +TEST_F(ConstantFPRangeTest, makeAllowedFCmpRegion) { + for (auto Pred : FCmpInst::predicates()) { + EnumerateConstantFPRanges( + [Pred](const ConstantFPRange &CR) { + ConstantFPRange Res = + ConstantFPRange::makeAllowedFCmpRegion(Pred, CR); + ConstantFPRange Optimal = + ConstantFPRange::getEmpty(CR.getSemantics()); + EnumerateValuesInConstantFPRange( + ConstantFPRange::getFull(CR.getSemantics()), + [&](const APFloat &V) { + if (AnyOfValueInConstantFPRange(CR, [&](const APFloat &U) { + return FCmpInst::compare(V, U, Pred); + })) + Optimal = Optimal.unionWith(ConstantFPRange(V)); + }); + + EXPECT_TRUE(Res.contains(Optimal)) + << "Wrong result for makeAllowedFCmpRegion(" << Pred << ", " << CR + << "). Expected " << Optimal << ", but got " << Res; + EXPECT_EQ(Res, Optimal) + << "Suboptimal result for makeAllowedFCmpRegion(" << Pred << ", " + << CR << ")"; + }, + /*Exhaustive=*/false); + } +} + } // anonymous namespace