-
Notifications
You must be signed in to change notification settings - Fork 11.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Scalarizer][DirectX] support structs return types #111569
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-backend-directx Author: Farzon Lotfi (farzonl) ChangesBased on this RFC: https://discourse.llvm.org/t/rfc-allow-the-scalarizer-pass-to-scalarize-vectors-returned-in-structs/82306 LLVM intrinsics do not support out params. To get around this limitation implementers will make intrinsics return structs to capture a return type and an out param. This implementation detail should not impact scalarization since these cases should be elementwise operations. Three changes are needed.
Testing changes
Full diff: https://github.com/llvm/llvm-project/pull/111569.diff 5 Files Affected:
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index f2b9e286ebb476..5f0f856df8e2b0 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -86,5 +86,7 @@ def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>],
+ [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
}
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index be714b5c87895a..4ddf39a4337df6 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -28,6 +28,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
switch (ID) {
case Intrinsic::dx_frac:
case Intrinsic::dx_rsqrt:
+ case Intrinsic::dx_splitdouble:
return true;
default:
return false;
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 72728c0f839e5d..d8b052061c1ad5 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -197,6 +197,23 @@ struct VectorLayout {
uint64_t SplitSize = 0;
};
+static bool isStructAllVectors(Type *Ty) {
+ if (!isa<StructType>(Ty))
+ return false;
+ if (Ty->getNumContainedTypes() < 1)
+ return false;
+ FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(0));
+ if (!VecTy)
+ return false;
+ unsigned VecSize = VecTy->getNumElements();
+ for (unsigned I = 1; I < Ty->getNumContainedTypes(); I++) {
+ VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(I));
+ if (!VecTy || VecSize != VecTy->getNumElements())
+ return false;
+ }
+ return true;
+}
+
/// Concatenate the given fragments to a single vector value of the type
/// described in @p VS.
static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
@@ -276,6 +293,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
bool visitBitCastInst(BitCastInst &BCI);
bool visitInsertElementInst(InsertElementInst &IEI);
bool visitExtractElementInst(ExtractElementInst &EEI);
+ bool visitExtractValueInst(ExtractValueInst &EVI);
bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
bool visitPHINode(PHINode &PHI);
bool visitLoadInst(LoadInst &LI);
@@ -667,6 +685,11 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
if (isTriviallyVectorizable(ID))
return true;
+ // TODO: investigate vectorizable frexp
+ switch (ID) {
+ case Intrinsic::frexp:
+ return true;
+ }
return Intrinsic::isTargetIntrinsic(ID) &&
TTI->isTargetIntrinsicTriviallyScalarizable(ID);
}
@@ -674,7 +697,13 @@ bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
/// If a call to a vector typed intrinsic function, split into a scalar call per
/// element if possible for the intrinsic.
bool ScalarizerVisitor::splitCall(CallInst &CI) {
- std::optional<VectorSplit> VS = getVectorSplit(CI.getType());
+ Type *CallType = CI.getType();
+ bool AreAllVectors = isStructAllVectors(CallType);
+ std::optional<VectorSplit> VS;
+ if (AreAllVectors)
+ VS = getVectorSplit(CallType->getContainedType(0));
+ else
+ VS = getVectorSplit(CallType);
if (!VS)
return false;
@@ -699,6 +728,18 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
Tys.push_back(VS->SplitTy);
+ if (AreAllVectors) {
+ Type *PrevType = CallType->getContainedType(0);
+ Type *CallType = CI.getType();
+ for (unsigned I = 1; I < CallType->getNumContainedTypes(); I++) {
+ Type *CurrType = cast<FixedVectorType>(CallType->getContainedType(I));
+ if (PrevType != CurrType) {
+ std::optional<VectorSplit> CurrVS = getVectorSplit(CurrType);
+ Tys.push_back(CurrVS->SplitTy);
+ PrevType = CurrType;
+ }
+ }
+ }
// Assumes that any vector type has the same number of elements as the return
// vector type, which is true for all current intrinsics.
for (unsigned I = 0; I != NumArgs; ++I) {
@@ -1029,6 +1070,31 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
return true;
}
+bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
+ Value *Op = EVI.getOperand(0);
+ Type *OpTy = Op->getType();
+ ValueVector Res;
+ if (!isStructAllVectors(OpTy))
+ return false;
+ Type *VecType = cast<FixedVectorType>(OpTy->getContainedType(0));
+ std::optional<VectorSplit> VS = getVectorSplit(VecType);
+ if (!VS)
+ return false;
+ IRBuilder<> Builder(&EVI);
+ Scatterer Op0 = scatter(&EVI, Op, *VS);
+ assert(!EVI.getIndices().empty() && "Make sure an index exists");
+ // Note for our use case we only care about the top level index.
+ unsigned Index = EVI.getIndices()[0];
+ for (unsigned OpIdx = 0; OpIdx < Op0.size(); ++OpIdx) {
+ Value *ResElem = Builder.CreateExtractValue(
+ Op0[OpIdx], Index, EVI.getName() + ".elem" + std::to_string(Index));
+ Res.push_back(ResElem);
+ }
+
+ gather(&EVI, Res, *VS);
+ return true;
+}
+
bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType());
if (!VS)
@@ -1195,7 +1261,7 @@ bool ScalarizerVisitor::finish() {
if (!Op->use_empty()) {
// The value is still needed, so recreate it using a series of
// insertelements and/or shufflevectors.
- Value *Res;
+ Value *Res = nullptr;
if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) {
BasicBlock *BB = Op->getParent();
IRBuilder<> Builder(Op);
@@ -1208,6 +1274,35 @@ bool ScalarizerVisitor::finish() {
Res = concatenate(Builder, CV, VS, Op->getName());
Res->takeName(Op);
+ } else if (auto *Ty = dyn_cast<StructType>(Op->getType())) {
+ BasicBlock *BB = Op->getParent();
+ IRBuilder<> Builder(Op);
+ if (isa<PHINode>(Op))
+ Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
+
+ // Iterate over each element in the struct
+ unsigned NumOfStructElements = Ty->getNumElements();
+ SmallVector<ValueVector, 4> ElemCV(NumOfStructElements);
+ for (unsigned I = 0; I < NumOfStructElements; ++I) {
+ for (auto *CVelem : CV) {
+ Value *Elem = Builder.CreateExtractValue(
+ CVelem, I, Op->getName() + ".elem" + std::to_string(I));
+ ElemCV[I].push_back(Elem);
+ }
+ }
+ Res = PoisonValue::get(Ty);
+ for (unsigned I = 0; I < NumOfStructElements; ++I) {
+ Type *ElemTy = Ty->getElementType(I);
+ assert(isa<FixedVectorType>(ElemTy) &&
+ "Only Structs of all FixedVectorType supported");
+ VectorSplit VS = *getVectorSplit(ElemTy);
+ assert(VS.NumFragments == CV.size());
+
+ Value *ConcatenatedVector =
+ concatenate(Builder, ElemCV[I], VS, Op->getName());
+ Res = Builder.CreateInsertValue(Res, ConcatenatedVector, I,
+ Op->getName() + ".insert");
+ }
} else {
assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
Res = CV[0];
diff --git a/llvm/test/CodeGen/DirectX/split-double.ll b/llvm/test/CodeGen/DirectX/split-double.ll
new file mode 100644
index 00000000000000..9b70e87ba4794e
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/split-double.ll
@@ -0,0 +1,40 @@
+; RUN: opt -passes='function(scalarizer<load-store>)' -S -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; CHECK-LABEL: @test_vector_double_split_void
+define void @test_vector_double_split_void(<2 x double> noundef %d) {
+ ; CHECK: [[ee0:%.*]] = extractelement <2 x double> %d, i64 0
+ ; CHECK: [[ie0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee0]])
+ ; CHECK: [[ee1:%.*]] = extractelement <2 x double> %d, i64 1
+ ; CHECK: [[ie1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee1]])
+ ; CHECK-NOT: extractvalue { i32, i32 } {{.*}}, 0
+ ; CHECK-NOT: insertelement <2 x i32> {{.*}}, i32 {{.*}}, i64 0
+ %hlsl.asuint = call { <2 x i32>, <2 x i32> } @llvm.dx.splitdouble.v2i32(<2 x double> %d)
+ ret void
+}
+
+; CHECK-LABEL: @test_vector_double_split
+define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) {
+ ; CHECK: [[ee0:%.*]] = extractelement <3 x double> %d, i64 0
+ ; CHECK: [[ie0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee0]])
+ ; CHECK: [[ee1:%.*]] = extractelement <3 x double> %d, i64 1
+ ; CHECK: [[ie1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee1]])
+ ; CHECK: [[ee2:%.*]] = extractelement <3 x double> %d, i64 2
+ ; CHECK: [[ie2:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee2]])
+ ; CHECK: [[ev00:%.*]] = extractvalue { i32, i32 } [[ie0]], 0
+ ; CHECK: [[ev01:%.*]] = extractvalue { i32, i32 } [[ie1]], 0
+ ; CHECK: [[ev02:%.*]] = extractvalue { i32, i32 } [[ie2]], 0
+ ; CHECK: [[ev10:%.*]] = extractvalue { i32, i32 } [[ie0]], 1
+ ; CHECK: [[ev11:%.*]] = extractvalue { i32, i32 } [[ie1]], 1
+ ; CHECK: [[ev12:%.*]] = extractvalue { i32, i32 } [[ie2]], 1
+ ; CHECK: [[add1:%.*]] = add i32 [[ev00]], [[ev10]]
+ ; CHECK: [[add2:%.*]] = add i32 [[ev01]], [[ev11]]
+ ; CHECK: [[add3:%.*]] = add i32 [[ev02]], [[ev12]]
+ ; CHECK: insertelement <3 x i32> poison, i32 [[add1]], i64 0
+ ; CHECK: insertelement <3 x i32> %{{.*}}, i32 [[add2]], i64 1
+ ; CHECK: insertelement <3 x i32> %{{.*}}, i32 [[add3]], i64 2
+ %hlsl.asuint = call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %d)
+ %1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
+ %2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
+ %3 = add <3 x i32> %1, %2
+ ret <3 x i32> %3
+}
diff --git a/llvm/test/Transforms/Scalarizer/frexp.ll b/llvm/test/Transforms/Scalarizer/frexp.ll
new file mode 100644
index 00000000000000..48159b45c18960
--- /dev/null
+++ b/llvm/test/Transforms/Scalarizer/frexp.ll
@@ -0,0 +1,67 @@
+; RUN: opt %s -passes='function(scalarizer<load-store>)' -S | FileCheck %s
+
+; CHECK-LABEL: @test_vector_half_frexp_half
+define noundef <2 x half> @test_vector_half_frexp_half(<2 x half> noundef %h) {
+ ; CHECK: [[ee0:%.*]] = extractelement <2 x half> %h, i64 0
+ ; CHECK-NEXT: [[ie0:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee0]])
+ ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x half> %h, i64 1
+ ; CHECK-NEXT: [[ie1:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee1]])
+ ; CHECK-NEXT: [[ev00:%.*]] = extractvalue { half, i32 } [[ie0]], 0
+ ; CHECK-NEXT: [[ev01:%.*]] = extractvalue { half, i32 } [[ie1]], 0
+ ; CHECK-NEXT: insertelement <2 x half> poison, half [[ev00]], i64 0
+ ; CHECK-NEXT: insertelement <2 x half> %{{.*}}, half [[ev01]], i64 1
+ %r = call { <2 x half>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x half> %h)
+ %e0 = extractvalue { <2 x half>, <2 x i32> } %r, 0
+ ret <2 x half> %e0
+}
+
+; CHECK-LABEL: @test_vector_half_frexp_int
+define noundef <2 x i32> @test_vector_half_frexp_int(<2 x half> noundef %h) {
+ ; CHECK: [[ee0:%.*]] = extractelement <2 x half> %h, i64 0
+ ; CHECK-NEXT: [[ie0:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee0]])
+ ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x half> %h, i64 1
+ ; CHECK-NEXT: [[ie1:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee1]])
+ ; CHECK-NEXT: [[ev10:%.*]] = extractvalue { half, i32 } [[ie0]], 1
+ ; CHECK-NEXT: [[ev11:%.*]] = extractvalue { half, i32 } [[ie1]], 1
+ ; CHECK-NEXT: insertelement <2 x i32> poison, i32 [[ev10]], i64 0
+ ; CHECK-NEXT: insertelement <2 x i32> %{{.*}}, i32 [[ev11]], i64 1
+ %r = call { <2 x half>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x half> %h)
+ %e1 = extractvalue { <2 x half>, <2 x i32> } %r, 1
+ ret <2 x i32> %e1
+}
+
+; CHECK-LABEL: @test_vector_float_frexp_int
+define noundef <2 x float> @test_vector_float_frexp_int(<2 x float> noundef %f) {
+ ; CHECK: [[ee0:%.*]] = extractelement <2 x float> %f, i64 0
+ ; CHECK-NEXT: [[ie0:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[ee0]])
+ ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x float> %f, i64 1
+ ; CHECK-NEXT: [[ie1:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[ee1]])
+ ; CHECK-NEXT: [[ev00:%.*]] = extractvalue { float, i32 } [[ie0]], 0
+ ; CHECK-NEXT: [[ev01:%.*]] = extractvalue { float, i32 } [[ie1]], 0
+ ; CHECK-NEXT: insertelement <2 x float> poison, float [[ev00]], i64 0
+ ; CHECK-NEXT: insertelement <2 x float> %{{.*}}, float [[ev01]], i64 1
+ ; CHECK-NEXT: extractvalue { float, i32 } [[ie0]], 1
+ ; CHECK-NEXT: extractvalue { float, i32 } [[ie1]], 1
+ %1 = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f16.v2i32(<2 x float> %f)
+ %2 = extractvalue { <2 x float>, <2 x i32> } %1, 0
+ %3 = extractvalue { <2 x float>, <2 x i32> } %1, 1
+ ret <2 x float> %2
+}
+
+; CHECK-LABEL: @test_vector_double_frexp_int
+define noundef <2 x double> @test_vector_double_frexp_int(<2 x double> noundef %d) {
+ ; CHECK: [[ee0:%.*]] = extractelement <2 x double> %d, i64 0
+ ; CHECK-NEXT: [[ie0:%.*]] = call { double, i32 } @llvm.frexp.f64.i32(double [[ee0]])
+ ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x double> %d, i64 1
+ ; CHECK-NEXT: [[ie1:%.*]] = call { double, i32 } @llvm.frexp.f64.i32(double [[ee1]])
+ ; CHECK-NEXT: [[ev00:%.*]] = extractvalue { double, i32 } [[ie0]], 0
+ ; CHECK-NEXT: [[ev01:%.*]] = extractvalue { double, i32 } [[ie1]], 0
+ ; CHECK-NEXT: insertelement <2 x double> poison, double [[ev00]], i64 0
+ ; CHECK-NEXT: insertelement <2 x double> %{{.*}}, double [[ev01]], i64 1
+ ; CHECK-NEXT: extractvalue { double, i32 } [[ie0]], 1
+ ; CHECK-NEXT: extractvalue { double, i32 } [[ie1]], 1
+ %1 = call { <2 x double>, <2 x i32> } @llvm.frexp.v2f64.v2i32(<2 x double> %d)
+ %2 = extractvalue { <2 x double>, <2 x i32> } %1, 0
+ %3 = extractvalue { <2 x double>, <2 x i32> } %1, 1
+ ret <2 x double> %2
+}
|
@@ -667,14 +685,25 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) { | |||
bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) { | |||
if (isTriviallyVectorizable(ID)) | |||
return true; | |||
// TODO: investigate vectorizable frexp | |||
switch (ID) { | |||
case Intrinsic::frexp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making this vectorizable was out of scope for this work so did not add frexp
to isTriviallyVectorizable
. We should do a follow up task for this.
Based on this RFC: https://discourse.llvm.org/t/rfc-allow-the-scalarizer-pass-to-scalarize-vectors-returned-in-structs/82306
LLVM intrinsics do not support out params. To get around this limitation implementers will make intrinsics return structs to capture a return type and an out param. This implementation detail should not impact scalarization since these cases should be elementwise operations.
Three changes are needed.
ExtractValue
instructionsTesting changes
llvm.frexp
llvm.dx.splitdouble
fixes #111437