diff --git a/src/modules/graphics/vulkan/Shader.cpp b/src/modules/graphics/vulkan/Shader.cpp index f2cfa3e3c..491700b63 100644 --- a/src/modules/graphics/vulkan/Shader.cpp +++ b/src/modules/graphics/vulkan/Shader.cpp @@ -113,6 +113,18 @@ static VkShaderStageFlagBits getStageBit(ShaderStageType type) } } +static VkShaderStageFlags getStageFlags(ShaderStageMask mask) +{ + VkShaderStageFlags flags = 0; + if (mask & SHADERSTAGEMASK_VERTEX) + flags |= VK_SHADER_STAGE_VERTEX_BIT; + if (mask & SHADERSTAGEMASK_PIXEL) + flags |= VK_SHADER_STAGE_FRAGMENT_BIT; + if (mask & SHADERSTAGEMASK_COMPUTE) + flags |= VK_SHADER_STAGE_COMPUTE_BIT; + return flags; +} + static EShLanguage getGlslShaderType(ShaderStageType stage) { switch (stage) @@ -876,12 +888,6 @@ void Shader::createDescriptorSetLayout() { std::vector bindings; - VkShaderStageFlags stageFlags; - if (isCompute) - stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; - else - stageFlags = VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT; - for (auto const &entry : reflection.allUniforms) { if (!entry.second->active) @@ -895,7 +901,7 @@ void Shader::createDescriptorSetLayout() layoutBinding.binding = entry.second->location; layoutBinding.descriptorType = type; layoutBinding.descriptorCount = entry.second->count; - layoutBinding.stageFlags = stageFlags; + layoutBinding.stageFlags = getStageFlags((ShaderStageMask)entry.second->stageMask); bindings.push_back(layoutBinding); } @@ -907,7 +913,10 @@ void Shader::createDescriptorSetLayout() uniformBinding.binding = localUniformLocation; uniformBinding.descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; uniformBinding.descriptorCount = 1; - uniformBinding.stageFlags = stageFlags; + if (isCompute) + uniformBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + else + uniformBinding.stageFlags = VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT; bindings.push_back(uniformBinding); }