Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ValueTracking] Handle icmp pred (trunc X), C in computeKnownBitsFromCmp #82803

Merged
merged 2 commits into from
Mar 6, 2024

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Feb 23, 2024

This patch handles the pattern icmp pred (trunc X), C in computeKnownBitsFromCmp to infer low bits of X from dominating conditions.

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 23, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

Changes

This patch handles the pattern icmp pred (trunc X), C in computeKnownBitsFromCmp to infer low bits of X from dominating conditions.


Full diff: https://github.com/llvm/llvm-project/pull/82803.diff

3 Files Affected:

  • (modified) llvm/lib/Analysis/DomConditionCache.cpp (+2-1)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+16-4)
  • (modified) llvm/test/Transforms/InstCombine/known-bits.ll (+124)
diff --git a/llvm/lib/Analysis/DomConditionCache.cpp b/llvm/lib/Analysis/DomConditionCache.cpp
index 274f3ff44b2a6f..b1d9cdff943b80 100644
--- a/llvm/lib/Analysis/DomConditionCache.cpp
+++ b/llvm/lib/Analysis/DomConditionCache.cpp
@@ -27,7 +27,8 @@ static void findAffectedValues(Value *Cond,
 
       // Peek through unary operators to find the source of the condition.
       Value *Op;
-      if (match(I, m_PtrToInt(m_Value(Op)))) {
+      if (match(I,
+                m_CombineOr(m_PtrToInt(m_Value(Op)), m_Trunc(m_Value(Op))))) {
         if (isa<Instruction>(Op) || isa<Argument>(Op))
           Affected.push_back(Op);
       }
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 04f317228b3ea7..dbac5bc8aa7626 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -717,10 +717,22 @@ static void computeKnownBitsFromCond(const Value *V, Value *Cond,
     computeKnownBitsFromCond(V, B, Known, Depth + 1, SQ, Invert);
   }
 
-  if (auto *Cmp = dyn_cast<ICmpInst>(Cond))
-    computeKnownBitsFromCmp(
-        V, Invert ? Cmp->getInversePredicate() : Cmp->getPredicate(),
-        Cmp->getOperand(0), Cmp->getOperand(1), Known, SQ);
+  if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
+    ICmpInst::Predicate Pred =
+        Invert ? Cmp->getInversePredicate() : Cmp->getPredicate();
+    Value *LHS = Cmp->getOperand(0);
+    Value *RHS = Cmp->getOperand(1);
+
+    // Handle icmp pred (trunc V), C
+    if (match(LHS, m_Trunc(m_Specific(V)))) {
+      KnownBits DstKnown(LHS->getType()->getScalarSizeInBits());
+      computeKnownBitsFromCmp(LHS, Pred, LHS, RHS, DstKnown, SQ);
+      Known = Known.unionWith(DstKnown.anyext(Known.getBitWidth()));
+      return;
+    }
+
+    computeKnownBitsFromCmp(V, Pred, LHS, RHS, Known, SQ);
+  }
 }
 
 void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
diff --git a/llvm/test/Transforms/InstCombine/known-bits.ll b/llvm/test/Transforms/InstCombine/known-bits.ll
index 246579cc4cd0c0..3581256131a270 100644
--- a/llvm/test/Transforms/InstCombine/known-bits.ll
+++ b/llvm/test/Transforms/InstCombine/known-bits.ll
@@ -284,5 +284,129 @@ exit:
   ret i8 %or2
 }
 
+define i32 @test_icmp_trunc1(i32 %x){
+; CHECK-LABEL: @test_icmp_trunc1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[Y:%.*]] = trunc i32 [[X:%.*]] to i16
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i16 [[Y]], 7
+; CHECK-NEXT:    br i1 [[CMP]], label [[THEN:%.*]], label [[ELSE:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    ret i32 7
+; CHECK:       else:
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %y = trunc i32 %x to i16
+  %cmp = icmp eq i16 %y, 7
+  br i1 %cmp, label %then, label %else
+then:
+  %z = and i32 %x, 15
+  ret i32 %z
+else:
+  ret i32 0
+}
+
+define i64 @test_icmp_trunc2(i64 %x) {
+; CHECK-LABEL: @test_icmp_trunc2(
+; CHECK-NEXT:    [[CONV:%.*]] = trunc i64 [[X:%.*]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[CONV]], 12
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[SEXT:%.*]] = and i64 [[X]], 2147483647
+; CHECK-NEXT:    ret i64 [[SEXT]]
+; CHECK:       if.else:
+; CHECK-NEXT:    ret i64 0
+;
+  %conv = trunc i64 %x to i32
+  %cmp = icmp sgt i32 %conv, 12
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:
+  %sext = shl i64 %x, 32
+  %ret = ashr exact i64 %sext, 32
+  ret i64 %ret
+if.else:
+  ret i64 0
+}
+
+define i64 @test_icmp_trunc3(i64 %n) {
+; CHECK-LABEL: @test_icmp_trunc3(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CONV:%.*]] = trunc i64 [[N:%.*]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[CONV]], 96
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[RET:%.*]] = and i64 [[N]], 127
+; CHECK-NEXT:    ret i64 [[RET]]
+; CHECK:       if.else:
+; CHECK-NEXT:    ret i64 0
+;
+entry:
+  %conv = trunc i64 %n to i32
+  %cmp = icmp ult i32 %conv, 96
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:
+  %ret = and i64 %n, 4294967295
+  ret i64 %ret
+
+if.else:
+  ret i64 0
+}
+
+define i8 @test_icmp_trunc4(i64 %n) {
+; CHECK-LABEL: @test_icmp_trunc4(
+; CHECK-NEXT:    [[CONV:%.*]] = trunc i64 [[N:%.*]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[CONV]], 10
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[CONV2:%.*]] = trunc i64 [[N]] to i8
+; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i8 [[CONV2]], 48
+; CHECK-NEXT:    ret i8 [[ADD]]
+; CHECK:       if.else:
+; CHECK-NEXT:    ret i8 0
+;
+  %conv = trunc i64 %n to i32
+  %cmp = icmp ult i32 %conv, 10
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:
+  %conv2 = trunc i64 %n to i8
+  %add = add i8 %conv2, 48
+  ret i8 %add
+
+if.else:
+  ret i8 0
+}
+
+define i64 @test_icmp_trunc5(i64 %n) {
+; CHECK-LABEL: @test_icmp_trunc5(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[SHR:%.*]] = ashr i64 [[N:%.*]], 47
+; CHECK-NEXT:    [[CONV1:%.*]] = trunc i64 [[SHR]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i32 [[CONV1]], -13
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[TMP0:%.*]] = and i64 [[SHR]], 15
+; CHECK-NEXT:    [[NOT:%.*]] = xor i64 [[TMP0]], 15
+; CHECK-NEXT:    ret i64 [[NOT]]
+; CHECK:       if.else:
+; CHECK-NEXT:    ret i64 13
+;
+entry:
+  %shr = ashr i64 %n, 47
+  %conv1 = trunc i64 %shr to i32
+  %cmp = icmp ugt i32 %conv1, -13
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:
+  %and = and i64 %shr, 4294967295
+  %not = xor i64 %and, 4294967295
+  ret i64 %not
+
+if.else:
+  ret i64 13
+}
+
 declare void @use(i1)
 declare void @sink(i8)

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 23, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Yingwei Zheng (dtcxzyw)

Changes

This patch handles the pattern icmp pred (trunc X), C in computeKnownBitsFromCmp to infer low bits of X from dominating conditions.


Full diff: https://github.com/llvm/llvm-project/pull/82803.diff

3 Files Affected:

  • (modified) llvm/lib/Analysis/DomConditionCache.cpp (+2-1)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+16-4)
  • (modified) llvm/test/Transforms/InstCombine/known-bits.ll (+124)
diff --git a/llvm/lib/Analysis/DomConditionCache.cpp b/llvm/lib/Analysis/DomConditionCache.cpp
index 274f3ff44b2a6f..b1d9cdff943b80 100644
--- a/llvm/lib/Analysis/DomConditionCache.cpp
+++ b/llvm/lib/Analysis/DomConditionCache.cpp
@@ -27,7 +27,8 @@ static void findAffectedValues(Value *Cond,
 
       // Peek through unary operators to find the source of the condition.
       Value *Op;
-      if (match(I, m_PtrToInt(m_Value(Op)))) {
+      if (match(I,
+                m_CombineOr(m_PtrToInt(m_Value(Op)), m_Trunc(m_Value(Op))))) {
         if (isa<Instruction>(Op) || isa<Argument>(Op))
           Affected.push_back(Op);
       }
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 04f317228b3ea7..dbac5bc8aa7626 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -717,10 +717,22 @@ static void computeKnownBitsFromCond(const Value *V, Value *Cond,
     computeKnownBitsFromCond(V, B, Known, Depth + 1, SQ, Invert);
   }
 
-  if (auto *Cmp = dyn_cast<ICmpInst>(Cond))
-    computeKnownBitsFromCmp(
-        V, Invert ? Cmp->getInversePredicate() : Cmp->getPredicate(),
-        Cmp->getOperand(0), Cmp->getOperand(1), Known, SQ);
+  if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
+    ICmpInst::Predicate Pred =
+        Invert ? Cmp->getInversePredicate() : Cmp->getPredicate();
+    Value *LHS = Cmp->getOperand(0);
+    Value *RHS = Cmp->getOperand(1);
+
+    // Handle icmp pred (trunc V), C
+    if (match(LHS, m_Trunc(m_Specific(V)))) {
+      KnownBits DstKnown(LHS->getType()->getScalarSizeInBits());
+      computeKnownBitsFromCmp(LHS, Pred, LHS, RHS, DstKnown, SQ);
+      Known = Known.unionWith(DstKnown.anyext(Known.getBitWidth()));
+      return;
+    }
+
+    computeKnownBitsFromCmp(V, Pred, LHS, RHS, Known, SQ);
+  }
 }
 
 void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
diff --git a/llvm/test/Transforms/InstCombine/known-bits.ll b/llvm/test/Transforms/InstCombine/known-bits.ll
index 246579cc4cd0c0..3581256131a270 100644
--- a/llvm/test/Transforms/InstCombine/known-bits.ll
+++ b/llvm/test/Transforms/InstCombine/known-bits.ll
@@ -284,5 +284,129 @@ exit:
   ret i8 %or2
 }
 
+define i32 @test_icmp_trunc1(i32 %x){
+; CHECK-LABEL: @test_icmp_trunc1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[Y:%.*]] = trunc i32 [[X:%.*]] to i16
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i16 [[Y]], 7
+; CHECK-NEXT:    br i1 [[CMP]], label [[THEN:%.*]], label [[ELSE:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    ret i32 7
+; CHECK:       else:
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %y = trunc i32 %x to i16
+  %cmp = icmp eq i16 %y, 7
+  br i1 %cmp, label %then, label %else
+then:
+  %z = and i32 %x, 15
+  ret i32 %z
+else:
+  ret i32 0
+}
+
+define i64 @test_icmp_trunc2(i64 %x) {
+; CHECK-LABEL: @test_icmp_trunc2(
+; CHECK-NEXT:    [[CONV:%.*]] = trunc i64 [[X:%.*]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[CONV]], 12
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[SEXT:%.*]] = and i64 [[X]], 2147483647
+; CHECK-NEXT:    ret i64 [[SEXT]]
+; CHECK:       if.else:
+; CHECK-NEXT:    ret i64 0
+;
+  %conv = trunc i64 %x to i32
+  %cmp = icmp sgt i32 %conv, 12
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:
+  %sext = shl i64 %x, 32
+  %ret = ashr exact i64 %sext, 32
+  ret i64 %ret
+if.else:
+  ret i64 0
+}
+
+define i64 @test_icmp_trunc3(i64 %n) {
+; CHECK-LABEL: @test_icmp_trunc3(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CONV:%.*]] = trunc i64 [[N:%.*]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[CONV]], 96
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[RET:%.*]] = and i64 [[N]], 127
+; CHECK-NEXT:    ret i64 [[RET]]
+; CHECK:       if.else:
+; CHECK-NEXT:    ret i64 0
+;
+entry:
+  %conv = trunc i64 %n to i32
+  %cmp = icmp ult i32 %conv, 96
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:
+  %ret = and i64 %n, 4294967295
+  ret i64 %ret
+
+if.else:
+  ret i64 0
+}
+
+define i8 @test_icmp_trunc4(i64 %n) {
+; CHECK-LABEL: @test_icmp_trunc4(
+; CHECK-NEXT:    [[CONV:%.*]] = trunc i64 [[N:%.*]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[CONV]], 10
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[CONV2:%.*]] = trunc i64 [[N]] to i8
+; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i8 [[CONV2]], 48
+; CHECK-NEXT:    ret i8 [[ADD]]
+; CHECK:       if.else:
+; CHECK-NEXT:    ret i8 0
+;
+  %conv = trunc i64 %n to i32
+  %cmp = icmp ult i32 %conv, 10
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:
+  %conv2 = trunc i64 %n to i8
+  %add = add i8 %conv2, 48
+  ret i8 %add
+
+if.else:
+  ret i8 0
+}
+
+define i64 @test_icmp_trunc5(i64 %n) {
+; CHECK-LABEL: @test_icmp_trunc5(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[SHR:%.*]] = ashr i64 [[N:%.*]], 47
+; CHECK-NEXT:    [[CONV1:%.*]] = trunc i64 [[SHR]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i32 [[CONV1]], -13
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[TMP0:%.*]] = and i64 [[SHR]], 15
+; CHECK-NEXT:    [[NOT:%.*]] = xor i64 [[TMP0]], 15
+; CHECK-NEXT:    ret i64 [[NOT]]
+; CHECK:       if.else:
+; CHECK-NEXT:    ret i64 13
+;
+entry:
+  %shr = ashr i64 %n, 47
+  %conv1 = trunc i64 %shr to i32
+  %cmp = icmp ugt i32 %conv1, -13
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:
+  %and = and i64 %shr, 4294967295
+  %not = xor i64 %and, 4294967295
+  ret i64 %not
+
+if.else:
+  ret i64 13
+}
+
 declare void @use(i1)
 declare void @sink(i8)

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Feb 23, 2024
Known = Known.unionWith(DstKnown.anyext(Known.getBitWidth()));
return;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code all looks correct, i'm wondering, however, if this may become bugprone given the code "distance" between adding affected ops and actually supporting them.

Think there is any convenient way to fix that?

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dtcxzyw dtcxzyw merged commit 3589cac into llvm:main Mar 6, 2024
4 checks passed
@dtcxzyw dtcxzyw deleted the perf/knownbits-icmp-trunc branch March 6, 2024 17:05
@nikic
Copy link
Contributor

nikic commented Mar 6, 2024

Hm, I wonder why this causes code size regressions, especially on mafft (https://llvm-compile-time-tracker.com/compare.php?from=0cbbcf1ef006ce13a1fa94960067723982ae955a&to=3589cacfa8da89b9b5051e4dba659caa575e6b3f&stat=size-text).

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Mar 7, 2024

Hm, I wonder why this causes code size regressions, especially on mafft (https://llvm-compile-time-tracker.com/compare.php?from=0cbbcf1ef006ce13a1fa94960067723982ae955a&to=3589cacfa8da89b9b5051e4dba659caa575e6b3f&stat=size-text).

I will add CTMark to my benchmark.

dtcxzyw added a commit to dtcxzyw/llvm-project that referenced this pull request Mar 7, 2024
@dtcxzyw
Copy link
Member Author

dtcxzyw commented Mar 7, 2024

Hm, I wonder why this causes code size regressions, especially on mafft (https://llvm-compile-time-tracker.com/compare.php?from=0cbbcf1ef006ce13a1fa94960067723982ae955a&to=3589cacfa8da89b9b5051e4dba659caa575e6b3f&stat=size-text).

Reverting this patch cannot improve the code size.
plctlab/llvm-ci#1107 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants