Skip to content

Commit

Permalink
spirv: only set LocalSizeId mode when necessary
Browse files Browse the repository at this point in the history
SPIR-V 1.6 added the LocalSizeId execution mode that allows using
spec constants for setting the work-group size, however it does not
deprecate the LocalSize mode. This change causes the LocalSizeId mode to
only be used when at least one of the workgroup size is actually
specified with a spec constant.

Fixes #3200
  • Loading branch information
arcady-lunarg committed Oct 12, 2023
1 parent 4ce1a1a commit 48f9ed8
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 64 deletions.
37 changes: 23 additions & 14 deletions SPIRV/GlslangToSpv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1741,23 +1741,31 @@ TGlslangToSpvTraverser::TGlslangToSpvTraverser(unsigned int spvVersion,
}
break;

case EShLangCompute:
case EShLangCompute: {
builder.addCapability(spv::CapabilityShader);
if (glslangIntermediate->getSpv().spv >= glslang::EShTargetSpv_1_6) {
std::vector<spv::Id> dimConstId;
for (int dim = 0; dim < 3; ++dim) {
bool specConst = (glslangIntermediate->getLocalSizeSpecId(dim) != glslang::TQualifier::layoutNotSet);
dimConstId.push_back(builder.makeUintConstant(glslangIntermediate->getLocalSize(dim), specConst));
if (specConst) {
builder.addDecoration(dimConstId.back(), spv::DecorationSpecId,
glslangIntermediate->getLocalSizeSpecId(dim));
bool needSizeId = false;
for (int dim = 0; dim < 3; ++dim) {
if ((glslangIntermediate->getLocalSizeSpecId(dim) != glslang::TQualifier::layoutNotSet)) {
needSizeId = true;
break;
}
}
builder.addExecutionModeId(shaderEntry, spv::ExecutionModeLocalSizeId, dimConstId);
}
if (glslangIntermediate->getSpv().spv >= glslang::EShTargetSpv_1_6 && needSizeId) {
std::vector<spv::Id> dimConstId;
for (int dim = 0; dim < 3; ++dim) {
bool specConst = (glslangIntermediate->getLocalSizeSpecId(dim) != glslang::TQualifier::layoutNotSet);
dimConstId.push_back(builder.makeUintConstant(glslangIntermediate->getLocalSize(dim), specConst));
if (specConst) {
builder.addDecoration(dimConstId.back(), spv::DecorationSpecId,
glslangIntermediate->getLocalSizeSpecId(dim));
needSizeId = true;
}
}
builder.addExecutionModeId(shaderEntry, spv::ExecutionModeLocalSizeId, dimConstId);
} else {
builder.addExecutionMode(shaderEntry, spv::ExecutionModeLocalSize, glslangIntermediate->getLocalSize(0),
glslangIntermediate->getLocalSize(1),
glslangIntermediate->getLocalSize(2));
builder.addExecutionMode(shaderEntry, spv::ExecutionModeLocalSize, glslangIntermediate->getLocalSize(0),
glslangIntermediate->getLocalSize(1),
glslangIntermediate->getLocalSize(2));
}
if (glslangIntermediate->getLayoutDerivativeModeNone() == glslang::LayoutDerivativeGroupQuads) {
builder.addCapability(spv::CapabilityComputeDerivativeGroupQuadsNV);
Expand All @@ -1769,6 +1777,7 @@ TGlslangToSpvTraverser::TGlslangToSpvTraverser(unsigned int spvVersion,
builder.addExtension(spv::E_SPV_NV_compute_shader_derivatives);
}
break;
}
case EShLangTessEvaluation:
case EShLangTessControl:
builder.addCapability(spv::CapabilityTessellation);
Expand Down
100 changes: 50 additions & 50 deletions Test/baseResults/hlsl.structcopylogical.comp.out
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,17 @@ local_size = (128, 1, 1)
Capability Shader
1: ExtInstImport "GLSL.std.450"
MemoryModel Logical GLSL450
EntryPoint GLCompute 4 "main" 17 32 57 74
ExecutionModeId 4 LocalSizeId 7 8 8
EntryPoint GLCompute 4 "main" 16 32 57 74
ExecutionMode 4 LocalSize 128 1 1
Source HLSL 500
Name 4 "main"
Name 12 "@main(u1;"
Name 11 "id"
Name 14 "MyStruct"
MemberName 14(MyStruct) 0 "a"
MemberName 14(MyStruct) 1 "b"
MemberName 14(MyStruct) 2 "c"
Name 17 "s"
Name 10 "@main(u1;"
Name 9 "id"
Name 12 "MyStruct"
MemberName 12(MyStruct) 0 "a"
MemberName 12(MyStruct) 1 "b"
MemberName 12(MyStruct) 2 "c"
Name 16 "s"
Name 25 "count"
Name 26 "MyStruct"
MemberName 26(MyStruct) 0 "a"
Expand Down Expand Up @@ -300,20 +300,20 @@ local_size = (128, 1, 1)
2: TypeVoid
3: TypeFunction 2
6: TypeInt 32 0
7: 6(int) Constant 128
8: 6(int) Constant 1
9: TypePointer Function 6(int)
10: TypeFunction 2 9(ptr)
14(MyStruct): TypeStruct 6(int) 6(int) 6(int)
15: TypeArray 14(MyStruct) 7
16: TypePointer Workgroup 15
17(s): 16(ptr) Variable Workgroup
18: TypeInt 32 1
19: 18(int) Constant 0
7: TypePointer Function 6(int)
8: TypeFunction 2 7(ptr)
12(MyStruct): TypeStruct 6(int) 6(int) 6(int)
13: 6(int) Constant 128
14: TypeArray 12(MyStruct) 13
15: TypePointer Workgroup 14
16(s): 15(ptr) Variable Workgroup
17: TypeInt 32 1
18: 17(int) Constant 0
19: 6(int) Constant 1
20: 6(int) Constant 2
21: 6(int) Constant 3
22:14(MyStruct) ConstantComposite 8 20 21
23: TypePointer Workgroup 14(MyStruct)
22:12(MyStruct) ConstantComposite 19 20 21
23: TypePointer Workgroup 12(MyStruct)
26(MyStruct): TypeStruct 6(int) 6(int) 6(int)
27: TypeRuntimeArray 26(MyStruct)
28(MyStructs): TypeStruct 6(int) 27
Expand All @@ -322,64 +322,64 @@ local_size = (128, 1, 1)
31: TypePointer StorageBuffer 30(sb)
32(sb): 31(ptr) Variable StorageBuffer
33: TypePointer StorageBuffer 6(int)
36: TypePointer Function 14(MyStruct)
36: TypePointer Function 12(MyStruct)
40: TypeBool
47: 18(int) Constant 1
47: 17(int) Constant 1
49: TypePointer StorageBuffer 26(MyStruct)
54: TypeRuntimeArray 26(MyStruct)
55(o): TypeStruct 54
56: TypePointer StorageBuffer 55(o)
57(o): 56(ptr) Variable StorageBuffer
61: 6(int) Constant 0
67: 18(int) Constant 2
67: 17(int) Constant 2
73: TypePointer Input 6(int)
74(id): 73(ptr) Variable Input
4(main): 2 Function None 3
5: Label
72(id): 9(ptr) Variable Function
76(param): 9(ptr) Variable Function
72(id): 7(ptr) Variable Function
76(param): 7(ptr) Variable Function
75: 6(int) Load 74(id)
Store 72(id) 75
77: 6(int) Load 72(id)
Store 76(param) 77
78: 2 FunctionCall 12(@main(u1;) 76(param)
78: 2 FunctionCall 10(@main(u1;) 76(param)
Return
FunctionEnd
12(@main(u1;): 2 Function None 10
11(id): 9(ptr) FunctionParameter
13: Label
25(count): 9(ptr) Variable Function
10(@main(u1;): 2 Function None 8
9(id): 7(ptr) FunctionParameter
11: Label
25(count): 7(ptr) Variable Function
37(ms): 36(ptr) Variable Function
24: 23(ptr) AccessChain 17(s) 19
24: 23(ptr) AccessChain 16(s) 18
Store 24 22
34: 33(ptr) AccessChain 32(sb) 19 19 19
34: 33(ptr) AccessChain 32(sb) 18 18 18
35: 6(int) Load 34
Store 25(count) 35
38: 6(int) Load 11(id)
38: 6(int) Load 9(id)
39: 6(int) Load 25(count)
41: 40(bool) UGreaterThan 38 39
42: 6(int) Load 11(id)
42: 6(int) Load 9(id)
43: 6(int) Load 25(count)
44: 6(int) ISub 42 43
45: 23(ptr) AccessChain 17(s) 44
46:14(MyStruct) Load 45
48: 6(int) Load 11(id)
50: 49(ptr) AccessChain 32(sb) 19 19 47 48
45: 23(ptr) AccessChain 16(s) 44
46:12(MyStruct) Load 45
48: 6(int) Load 9(id)
50: 49(ptr) AccessChain 32(sb) 18 18 47 48
51:26(MyStruct) Load 50
52:14(MyStruct) CopyLogical 51
53:14(MyStruct) Select 41 46 52
52:12(MyStruct) CopyLogical 51
53:12(MyStruct) Select 41 46 52
Store 37(ms) 53
58: 33(ptr) AccessChain 57(o) 19 19 19
59: 9(ptr) AccessChain 37(ms) 19
58: 33(ptr) AccessChain 57(o) 18 18 18
59: 7(ptr) AccessChain 37(ms) 18
60: 6(int) Load 59
62: 6(int) AtomicIAdd 58 8 61 60
63: 33(ptr) AccessChain 57(o) 19 19 47
64: 9(ptr) AccessChain 37(ms) 47
62: 6(int) AtomicIAdd 58 19 61 60
63: 33(ptr) AccessChain 57(o) 18 18 47
64: 7(ptr) AccessChain 37(ms) 47
65: 6(int) Load 64
66: 6(int) AtomicIAdd 63 8 61 65
68: 33(ptr) AccessChain 57(o) 19 19 67
69: 9(ptr) AccessChain 37(ms) 67
66: 6(int) AtomicIAdd 63 19 61 65
68: 33(ptr) AccessChain 57(o) 18 18 67
69: 7(ptr) AccessChain 37(ms) 67
70: 6(int) Load 69
71: 6(int) AtomicIAdd 68 8 61 70
71: 6(int) AtomicIAdd 68 19 61 70
Return
FunctionEnd

0 comments on commit 48f9ed8

Please sign in to comment.