Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Scalarizer][DirectX] support structs return types #111569

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

farzonl
Copy link
Member

@farzonl farzonl commented Oct 8, 2024

Based on this RFC: https://discourse.llvm.org/t/rfc-allow-the-scalarizer-pass-to-scalarize-vectors-returned-in-structs/82306

LLVM intrinsics do not support out params. To get around this limitation implementers will make intrinsics return structs to capture a return type and an out param. This implementation detail should not impact scalarization since these cases should be elementwise operations.

Three changes are needed.

  • The CallInst visitor needs to be updated to handle Structs
  • A new visitor is needed for ExtractValue instructions
  • finsh needs to be update to handle structs so that insert elements are properly propogated.

Testing changes

  • Add support for llvm.frexp
  • Add support for llvm.dx.splitdouble

fixes #111437

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 8, 2024

@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-directx

Author: Farzon Lotfi (farzonl)

Changes

Based on this RFC: https://discourse.llvm.org/t/rfc-allow-the-scalarizer-pass-to-scalarize-vectors-returned-in-structs/82306

LLVM intrinsics do not support out params. To get around this limitation implementers will make intrinsics return structs to capture a return type and an out param. This implementation detail should not impact scalarization since these cases should be elementwise operations.

Three changes are needed.

  • The CallInst visitor needs to be updated to handle Structs
  • A new visitor is needed for ExtractValue instructions
  • finsh needs to be update to handle structs so that insert elements are properly propogated.

Testing changes

  • Add support for llvm.frexp
  • Add support for llvm.dx.splitdouble

Full diff: https://github.com/llvm/llvm-project/pull/111569.diff

5 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+2)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp (+1)
  • (modified) llvm/lib/Transforms/Scalar/Scalarizer.cpp (+97-2)
  • (added) llvm/test/CodeGen/DirectX/split-double.ll (+40)
  • (added) llvm/test/Transforms/Scalarizer/frexp.ll (+67)
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index f2b9e286ebb476..5f0f856df8e2b0 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -86,5 +86,7 @@ def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
 def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
 def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>], 
+    [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
 def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 }
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index be714b5c87895a..4ddf39a4337df6 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -28,6 +28,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
   switch (ID) {
   case Intrinsic::dx_frac:
   case Intrinsic::dx_rsqrt:
+  case Intrinsic::dx_splitdouble:
     return true;
   default:
     return false;
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 72728c0f839e5d..d8b052061c1ad5 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -197,6 +197,23 @@ struct VectorLayout {
   uint64_t SplitSize = 0;
 };
 
+static bool isStructAllVectors(Type *Ty) {
+  if (!isa<StructType>(Ty))
+    return false;
+  if (Ty->getNumContainedTypes() < 1)
+    return false;
+  FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(0));
+  if (!VecTy)
+    return false;
+  unsigned VecSize = VecTy->getNumElements();
+  for (unsigned I = 1; I < Ty->getNumContainedTypes(); I++) {
+    VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(I));
+    if (!VecTy || VecSize != VecTy->getNumElements())
+      return false;
+  }
+  return true;
+}
+
 /// Concatenate the given fragments to a single vector value of the type
 /// described in @p VS.
 static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
@@ -276,6 +293,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
   bool visitBitCastInst(BitCastInst &BCI);
   bool visitInsertElementInst(InsertElementInst &IEI);
   bool visitExtractElementInst(ExtractElementInst &EEI);
+  bool visitExtractValueInst(ExtractValueInst &EVI);
   bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
   bool visitPHINode(PHINode &PHI);
   bool visitLoadInst(LoadInst &LI);
@@ -667,6 +685,11 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
 bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
   if (isTriviallyVectorizable(ID))
     return true;
+  // TODO: investigate vectorizable frexp
+  switch (ID) {
+  case Intrinsic::frexp:
+    return true;
+  }
   return Intrinsic::isTargetIntrinsic(ID) &&
          TTI->isTargetIntrinsicTriviallyScalarizable(ID);
 }
@@ -674,7 +697,13 @@ bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
 /// If a call to a vector typed intrinsic function, split into a scalar call per
 /// element if possible for the intrinsic.
 bool ScalarizerVisitor::splitCall(CallInst &CI) {
-  std::optional<VectorSplit> VS = getVectorSplit(CI.getType());
+  Type *CallType = CI.getType();
+  bool AreAllVectors = isStructAllVectors(CallType);
+  std::optional<VectorSplit> VS;
+  if (AreAllVectors)
+    VS = getVectorSplit(CallType->getContainedType(0));
+  else
+    VS = getVectorSplit(CallType);
   if (!VS)
     return false;
 
@@ -699,6 +728,18 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
   if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
     Tys.push_back(VS->SplitTy);
 
+  if (AreAllVectors) {
+    Type *PrevType = CallType->getContainedType(0);
+    Type *CallType = CI.getType();
+    for (unsigned I = 1; I < CallType->getNumContainedTypes(); I++) {
+      Type *CurrType = cast<FixedVectorType>(CallType->getContainedType(I));
+      if (PrevType != CurrType) {
+        std::optional<VectorSplit> CurrVS = getVectorSplit(CurrType);
+        Tys.push_back(CurrVS->SplitTy);
+        PrevType = CurrType;
+      }
+    }
+  }
   // Assumes that any vector type has the same number of elements as the return
   // vector type, which is true for all current intrinsics.
   for (unsigned I = 0; I != NumArgs; ++I) {
@@ -1029,6 +1070,31 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
   return true;
 }
 
+bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
+  Value *Op = EVI.getOperand(0);
+  Type *OpTy = Op->getType();
+  ValueVector Res;
+  if (!isStructAllVectors(OpTy))
+    return false;
+  Type *VecType = cast<FixedVectorType>(OpTy->getContainedType(0));
+  std::optional<VectorSplit> VS = getVectorSplit(VecType);
+  if (!VS)
+    return false;
+  IRBuilder<> Builder(&EVI);
+  Scatterer Op0 = scatter(&EVI, Op, *VS);
+  assert(!EVI.getIndices().empty() && "Make sure an index exists");
+  // Note for our use case we only care about the top level index.
+  unsigned Index = EVI.getIndices()[0];
+  for (unsigned OpIdx = 0; OpIdx < Op0.size(); ++OpIdx) {
+    Value *ResElem = Builder.CreateExtractValue(
+        Op0[OpIdx], Index, EVI.getName() + ".elem" + std::to_string(Index));
+    Res.push_back(ResElem);
+  }
+
+  gather(&EVI, Res, *VS);
+  return true;
+}
+
 bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
   std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType());
   if (!VS)
@@ -1195,7 +1261,7 @@ bool ScalarizerVisitor::finish() {
     if (!Op->use_empty()) {
       // The value is still needed, so recreate it using a series of
       // insertelements and/or shufflevectors.
-      Value *Res;
+      Value *Res = nullptr;
       if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) {
         BasicBlock *BB = Op->getParent();
         IRBuilder<> Builder(Op);
@@ -1208,6 +1274,35 @@ bool ScalarizerVisitor::finish() {
         Res = concatenate(Builder, CV, VS, Op->getName());
 
         Res->takeName(Op);
+      } else if (auto *Ty = dyn_cast<StructType>(Op->getType())) {
+        BasicBlock *BB = Op->getParent();
+        IRBuilder<> Builder(Op);
+        if (isa<PHINode>(Op))
+          Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
+
+        // Iterate over each element in the struct
+        unsigned NumOfStructElements = Ty->getNumElements();
+        SmallVector<ValueVector, 4> ElemCV(NumOfStructElements);
+        for (unsigned I = 0; I < NumOfStructElements; ++I) {
+          for (auto *CVelem : CV) {
+            Value *Elem = Builder.CreateExtractValue(
+                CVelem, I, Op->getName() + ".elem" + std::to_string(I));
+            ElemCV[I].push_back(Elem);
+          }
+        }
+        Res = PoisonValue::get(Ty);
+        for (unsigned I = 0; I < NumOfStructElements; ++I) {
+          Type *ElemTy = Ty->getElementType(I);
+          assert(isa<FixedVectorType>(ElemTy) &&
+                 "Only Structs of all FixedVectorType supported");
+          VectorSplit VS = *getVectorSplit(ElemTy);
+          assert(VS.NumFragments == CV.size());
+
+          Value *ConcatenatedVector =
+              concatenate(Builder, ElemCV[I], VS, Op->getName());
+          Res = Builder.CreateInsertValue(Res, ConcatenatedVector, I,
+                                          Op->getName() + ".insert");
+        }
       } else {
         assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
         Res = CV[0];
diff --git a/llvm/test/CodeGen/DirectX/split-double.ll b/llvm/test/CodeGen/DirectX/split-double.ll
new file mode 100644
index 00000000000000..9b70e87ba4794e
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/split-double.ll
@@ -0,0 +1,40 @@
+; RUN: opt -passes='function(scalarizer<load-store>)' -S -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; CHECK-LABEL: @test_vector_double_split_void
+define void @test_vector_double_split_void(<2 x double> noundef %d) {
+  ; CHECK: [[ee0:%.*]] = extractelement <2 x double> %d, i64 0
+  ; CHECK: [[ie0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee0]])
+  ; CHECK: [[ee1:%.*]] = extractelement <2 x double> %d, i64 1
+  ; CHECK: [[ie1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee1]])
+  ; CHECK-NOT: extractvalue { i32, i32 } {{.*}}, 0
+  ; CHECK-NOT: insertelement <2 x i32> {{.*}}, i32 {{.*}}, i64 0
+  %hlsl.asuint = call { <2 x i32>, <2 x i32> }  @llvm.dx.splitdouble.v2i32(<2 x double> %d)
+  ret void
+}
+
+; CHECK-LABEL: @test_vector_double_split
+define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) {
+  ; CHECK: [[ee0:%.*]] = extractelement <3 x double> %d, i64 0
+  ; CHECK: [[ie0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee0]])
+  ; CHECK: [[ee1:%.*]] = extractelement <3 x double> %d, i64 1
+  ; CHECK: [[ie1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee1]])
+  ; CHECK: [[ee2:%.*]] = extractelement <3 x double> %d, i64 2
+  ; CHECK: [[ie2:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee2]])
+  ; CHECK: [[ev00:%.*]] = extractvalue { i32, i32 } [[ie0]], 0
+  ; CHECK: [[ev01:%.*]] = extractvalue { i32, i32 } [[ie1]], 0
+  ; CHECK: [[ev02:%.*]] = extractvalue { i32, i32 } [[ie2]], 0
+  ; CHECK: [[ev10:%.*]] = extractvalue { i32, i32 } [[ie0]], 1
+  ; CHECK: [[ev11:%.*]] = extractvalue { i32, i32 } [[ie1]], 1
+  ; CHECK: [[ev12:%.*]] = extractvalue { i32, i32 } [[ie2]], 1
+  ; CHECK: [[add1:%.*]] = add i32 [[ev00]], [[ev10]]
+  ; CHECK: [[add2:%.*]] = add i32 [[ev01]], [[ev11]]
+  ; CHECK: [[add3:%.*]] = add i32 [[ev02]], [[ev12]]
+  ; CHECK: insertelement <3 x i32> poison, i32 [[add1]], i64 0
+  ; CHECK: insertelement <3 x i32> %{{.*}}, i32 [[add2]], i64 1
+  ; CHECK: insertelement <3 x i32> %{{.*}}, i32 [[add3]], i64 2
+  %hlsl.asuint = call { <3 x i32>, <3 x i32> }  @llvm.dx.splitdouble.v3i32(<3 x double> %d)
+  %1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
+  %2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
+  %3 = add <3 x i32> %1, %2
+  ret <3 x i32> %3
+}
diff --git a/llvm/test/Transforms/Scalarizer/frexp.ll b/llvm/test/Transforms/Scalarizer/frexp.ll
new file mode 100644
index 00000000000000..48159b45c18960
--- /dev/null
+++ b/llvm/test/Transforms/Scalarizer/frexp.ll
@@ -0,0 +1,67 @@
+; RUN: opt %s -passes='function(scalarizer<load-store>)' -S | FileCheck %s
+
+; CHECK-LABEL: @test_vector_half_frexp_half
+define noundef <2 x half> @test_vector_half_frexp_half(<2 x half> noundef %h) {
+  ; CHECK: [[ee0:%.*]] = extractelement <2 x half> %h, i64 0
+  ; CHECK-NEXT: [[ie0:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee0]])
+  ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x half> %h, i64 1
+  ; CHECK-NEXT: [[ie1:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee1]])
+  ; CHECK-NEXT: [[ev00:%.*]] = extractvalue { half, i32 } [[ie0]], 0
+  ; CHECK-NEXT: [[ev01:%.*]] = extractvalue { half, i32 } [[ie1]], 0
+  ; CHECK-NEXT: insertelement <2 x half> poison, half [[ev00]], i64 0
+  ; CHECK-NEXT: insertelement <2 x half> %{{.*}}, half [[ev01]], i64 1
+  %r =  call { <2 x half>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x half> %h)
+  %e0 = extractvalue { <2 x half>, <2 x i32> } %r, 0
+  ret <2 x half> %e0
+}
+
+; CHECK-LABEL: @test_vector_half_frexp_int
+define noundef <2 x i32> @test_vector_half_frexp_int(<2 x half> noundef %h) {
+  ; CHECK: [[ee0:%.*]] = extractelement <2 x half> %h, i64 0
+  ; CHECK-NEXT: [[ie0:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee0]])
+  ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x half> %h, i64 1
+  ; CHECK-NEXT: [[ie1:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee1]])
+  ; CHECK-NEXT: [[ev10:%.*]] = extractvalue { half, i32 } [[ie0]], 1
+  ; CHECK-NEXT: [[ev11:%.*]] = extractvalue { half, i32 } [[ie1]], 1
+  ; CHECK-NEXT: insertelement <2 x i32> poison, i32 [[ev10]], i64 0
+  ; CHECK-NEXT: insertelement <2 x i32> %{{.*}}, i32 [[ev11]], i64 1
+  %r =  call { <2 x half>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x half> %h)
+  %e1 = extractvalue { <2 x half>, <2 x i32> } %r, 1
+  ret <2 x i32> %e1
+}
+
+; CHECK-LABEL: @test_vector_float_frexp_int
+define noundef <2 x float> @test_vector_float_frexp_int(<2 x float> noundef %f) {
+  ; CHECK: [[ee0:%.*]] = extractelement <2 x float> %f, i64 0
+  ; CHECK-NEXT: [[ie0:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[ee0]])
+  ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x float> %f, i64 1
+  ; CHECK-NEXT: [[ie1:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[ee1]])
+  ; CHECK-NEXT: [[ev00:%.*]] = extractvalue { float, i32 } [[ie0]], 0
+  ; CHECK-NEXT: [[ev01:%.*]] = extractvalue { float, i32 } [[ie1]], 0
+  ; CHECK-NEXT: insertelement <2 x float> poison, float [[ev00]], i64 0
+  ; CHECK-NEXT: insertelement <2 x float> %{{.*}}, float [[ev01]], i64 1
+  ; CHECK-NEXT: extractvalue { float, i32 } [[ie0]], 1
+  ; CHECK-NEXT: extractvalue { float, i32 } [[ie1]], 1
+  %1 =  call { <2 x float>, <2 x i32> } @llvm.frexp.v2f16.v2i32(<2 x float> %f)
+  %2 = extractvalue { <2 x float>, <2 x i32> } %1, 0
+  %3 = extractvalue { <2 x float>, <2 x i32> } %1, 1
+  ret <2 x float> %2
+}
+
+; CHECK-LABEL: @test_vector_double_frexp_int
+define noundef <2 x double> @test_vector_double_frexp_int(<2 x double> noundef %d) {
+  ; CHECK: [[ee0:%.*]] = extractelement <2 x double> %d, i64 0
+  ; CHECK-NEXT: [[ie0:%.*]] = call { double, i32 } @llvm.frexp.f64.i32(double [[ee0]])
+  ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x double> %d, i64 1
+  ; CHECK-NEXT: [[ie1:%.*]] = call { double, i32 } @llvm.frexp.f64.i32(double [[ee1]])
+  ; CHECK-NEXT: [[ev00:%.*]] = extractvalue { double, i32 } [[ie0]], 0
+  ; CHECK-NEXT: [[ev01:%.*]] = extractvalue { double, i32 } [[ie1]], 0
+  ; CHECK-NEXT: insertelement <2 x double> poison, double [[ev00]], i64 0
+  ; CHECK-NEXT: insertelement <2 x double> %{{.*}}, double [[ev01]], i64 1
+  ; CHECK-NEXT: extractvalue { double, i32 } [[ie0]], 1
+  ; CHECK-NEXT: extractvalue { double, i32 } [[ie1]], 1
+  %1 =  call { <2 x double>, <2 x i32> } @llvm.frexp.v2f64.v2i32(<2 x double> %d)
+  %2 = extractvalue { <2 x double>, <2 x i32> } %1, 0
+  %3 = extractvalue { <2 x double>, <2 x i32> } %1, 1
+  ret <2 x double> %2
+}

@farzonl farzonl self-assigned this Oct 8, 2024
@@ -667,14 +685,25 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
if (isTriviallyVectorizable(ID))
return true;
// TODO: investigate vectorizable frexp
switch (ID) {
case Intrinsic::frexp:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making this vectorizable was out of scope for this work so did not add frexp to isTriviallyVectorizable. We should do a follow up task for this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

[DirectX][Scalarizer] Add support to scalarize struct of vector return types
2 participants