Skip to content

Commit

Permalink
Implement Metal parallel shader compilation (google#7205)
Browse files Browse the repository at this point in the history
  • Loading branch information
bejado authored and plepers committed Dec 9, 2023
1 parent 9a45afb commit 1535978
Show file tree
Hide file tree
Showing 8 changed files with 391 additions and 86 deletions.
1 change: 1 addition & 0 deletions filament/backend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ if (FILAMENT_SUPPORTS_METAL)
src/metal/MetalExternalImage.mm
src/metal/MetalHandles.mm
src/metal/MetalPlatform.mm
src/metal/MetalShaderCompiler.mm
src/metal/MetalState.mm
src/metal/MetalTimerQuery.mm
src/metal/MetalUtils.mm
Expand Down
3 changes: 3 additions & 0 deletions filament/backend/src/metal/MetalContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#define TNT_METALCONTEXT_H

#include "MetalResourceTracker.h"
#include "MetalShaderCompiler.h"
#include "MetalState.h"

#include <CoreVideo/CVMetalTextureCache.h>
Expand Down Expand Up @@ -149,6 +150,8 @@ struct MetalContext {

MTLViewport currentViewport;

MetalShaderCompiler* shaderCompiler = nullptr;

#if defined(FILAMENT_METAL_PROFILING)
// Logging and profiling.
os_log_t log;
Expand Down
2 changes: 1 addition & 1 deletion filament/backend/src/metal/MetalDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ namespace backend {
class MetalPlatform;

class MetalBuffer;
class MetalProgram;
class MetalSamplerGroup;
class MetalTexture;
struct MetalUniformBuffer;
struct MetalContext;
struct MetalProgram;
struct BufferState;

#ifndef FILAMENT_METAL_HANDLE_ARENA_SIZE_IN_MB
Expand Down
37 changes: 26 additions & 11 deletions filament/backend/src/metal/MetalDriver.mm
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@
mContext->eventListener = [[MTLSharedEventListener alloc] initWithDispatchQueue:queue];
}

mContext->shaderCompiler = new MetalShaderCompiler(mContext->device, *this);
mContext->shaderCompiler->init();

#if defined(FILAMENT_METAL_PROFILING)
mContext->log = os_log_create("com.google.filament", "Metal");
mContext->signpostId = os_signpost_id_generate(mContext->log);
Expand All @@ -158,6 +161,7 @@
delete mContext->bufferPool;
delete mContext->blitter;
delete mContext->timerQueryImpl;
delete mContext->shaderCompiler;
delete mContext;
}

Expand Down Expand Up @@ -319,7 +323,7 @@
}

void MetalDriver::createProgramR(Handle<HwProgram> rph, Program&& program) {
construct_handle<MetalProgram>(rph, mContext->device, program);
construct_handle<MetalProgram>(rph, *mContext, std::move(program));
}

void MetalDriver::createDefaultRenderTargetR(Handle<HwRenderTarget> rth, int dummy) {
Expand Down Expand Up @@ -584,6 +588,7 @@

MetalExternalImage::shutdown(*mContext);
mContext->blitter->shutdown();
mContext->shaderCompiler->terminate();
}

ShaderModel MetalDriver::getShaderModel() const noexcept {
Expand Down Expand Up @@ -719,7 +724,7 @@
}

bool MetalDriver::isParallelShaderCompileSupported() {
return false;
return true;
}

bool MetalDriver::isWorkaroundNeeded(Workaround workaround) {
Expand Down Expand Up @@ -957,7 +962,7 @@
void MetalDriver::compilePrograms(CompilerPriorityQueue priority,
CallbackHandler* handler, CallbackHandler::Callback callback, void* user) {
if (callback) {
scheduleCallback(handler, user, callback);
mContext->shaderCompiler->notifyWhenAllProgramsAreReady(handler, callback, user);
}
}

Expand Down Expand Up @@ -1473,14 +1478,19 @@
auto program = handle_cast<MetalProgram>(ps.program);
const auto& rs = ps.rasterState;

// This might block until the shader compilation has finished.
auto functions = program->getFunctions();

// If the material debugger is enabled, avoid fatal (or cascading) errors and that can occur
// during the draw call when the program is invalid. The shader compile error has already been
// dumped to the console at this point, so it's fine to simply return early.
if (FILAMENT_ENABLE_MATDBG && UTILS_UNLIKELY(!program->isValid)) {
if (FILAMENT_ENABLE_MATDBG && UTILS_UNLIKELY(!functions)) {
return;
}

ASSERT_PRECONDITION(program->isValid, "Attempting to draw with an invalid Metal program.");
ASSERT_PRECONDITION(bool(functions), "Attempting to draw with an invalid Metal program.");

auto [fragment, vertex] = functions.getRasterFunctions();

// Pipeline state
MTLPixelFormat colorPixelFormat[MRT::MAX_SUPPORTED_RENDER_TARGET_COUNT] = { MTLPixelFormatInvalid };
Expand All @@ -1503,8 +1513,8 @@
assert_invariant(isMetalFormatStencil(stencilPixelFormat));
}
MetalPipelineState pipelineState {
.vertexFunction = program->vertexFunction,
.fragmentFunction = program->fragmentFunction,
.vertexFunction = vertex,
.fragmentFunction = fragment,
.vertexDescription = primitive->vertexDescription,
.colorAttachmentPixelFormat = {
colorPixelFormat[0],
Expand Down Expand Up @@ -1649,7 +1659,7 @@
if (!samplerGroup) {
continue;
}
const auto& stageFlags = program->samplerGroupInfo[s].stageFlags;
const auto& stageFlags = program->getSamplerGroupInfo()[s].stageFlags;
if (stageFlags == ShaderStageFlags::NONE) {
continue;
}
Expand Down Expand Up @@ -1722,21 +1732,26 @@

auto mtlProgram = handle_cast<MetalProgram>(program);

// This might block until the shader compilation has finished.
auto functions = mtlProgram->getFunctions();

// If the material debugger is enabled, avoid fatal (or cascading) errors and that can occur
// during the draw call when the program is invalid. The shader compile error has already been
// dumped to the console at this point, so it's fine to simply return early.
if (FILAMENT_ENABLE_MATDBG && UTILS_UNLIKELY(!mtlProgram->isValid)) {
if (FILAMENT_ENABLE_MATDBG && UTILS_UNLIKELY(!functions)) {
return;
}

assert_invariant(mtlProgram->isValid && mtlProgram->computeFunction);
auto compute = functions.getComputeFunction();

assert_invariant(bool(functions) && compute);

id<MTLComputeCommandEncoder> computeEncoder =
[getPendingCommandBuffer(mContext) computeCommandEncoder];

NSError* error = nil;
id<MTLComputePipelineState> computePipelineState =
[mContext->device newComputePipelineStateWithFunction:mtlProgram->computeFunction
[mContext->device newComputePipelineStateWithFunction:compute
error:&error];
if (error) {
auto description = [error.localizedDescription cStringUsingEncoding:NSUTF8StringEncoding];
Expand Down
18 changes: 11 additions & 7 deletions filament/backend/src/metal/MetalHandles.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,20 @@ struct MetalRenderPrimitive : public HwRenderPrimitive {
VertexDescription vertexDescription = {};
};

struct MetalProgram : public HwProgram {
MetalProgram(id<MTLDevice> device, const Program& program) noexcept;
class MetalProgram : public HwProgram {
public:
MetalProgram(MetalContext& context, Program&& program) noexcept;

id<MTLFunction> vertexFunction;
id<MTLFunction> fragmentFunction;
id<MTLFunction> computeFunction;
const MetalShaderCompiler::MetalFunctionBundle& getFunctions();
const Program::SamplerGroupInfo& getSamplerGroupInfo() { return samplerGroupInfo; }

Program::SamplerGroupInfo samplerGroupInfo;
private:
void initialize();

bool isValid = false;
Program::SamplerGroupInfo samplerGroupInfo;
MetalContext& mContext;
MetalShaderCompiler::MetalFunctionBundle mFunctionBundle;
MetalShaderCompiler::program_token_t mToken;
};

struct PixelBufferShape {
Expand Down
84 changes: 17 additions & 67 deletions filament/backend/src/metal/MetalHandles.mm
Original file line number Diff line number Diff line change
Expand Up @@ -346,78 +346,28 @@ void presentDrawable(bool presentFrame, void* user) {
};
}

MetalProgram::MetalProgram(id<MTLDevice> device, const Program& program) noexcept
: HwProgram(program.getName()), vertexFunction(nil), fragmentFunction(nil),
computeFunction(nil), isValid(false) {
MetalProgram::MetalProgram(MetalContext& context, Program&& program) noexcept
: HwProgram(program.getName()), mContext(context) {

using MetalFunctionPtr = __strong id<MTLFunction>*;

static_assert(Program::SHADER_TYPE_COUNT == 3, "Only vertex, fragment, and/or compute shaders expected.");
MetalFunctionPtr shaderFunctions[3] = { &vertexFunction, &fragmentFunction, &computeFunction };

const auto& sources = program.getShadersSource();
for (size_t i = 0; i < Program::SHADER_TYPE_COUNT; i++) {
const auto& source = sources[i];
// It's okay for some shaders to be empty, they shouldn't be used in any draw calls.
if (source.empty()) {
continue;
}
// Save this program's SamplerGroupInfo, it's used during draw calls to bind sampler groups to
// the appropriate stage(s).
samplerGroupInfo = program.getSamplerGroupInfo();

assert_invariant( source[source.size() - 1] == '\0' );

// the shader string is null terminated and the length includes the null character
NSString* objcSource = [[NSString alloc] initWithBytes:source.data()
length:source.size() - 1
encoding:NSUTF8StringEncoding];
NSError* error = nil;
// When options is nil, Metal uses the most recent language version available.
id<MTLLibrary> library = [device newLibraryWithSource:objcSource
options:nil
error:&error];
if (library == nil) {
if (error) {
auto description =
[error.localizedDescription cStringUsingEncoding:NSUTF8StringEncoding];
utils::slog.w << description << utils::io::endl;
}
PANIC_LOG("Failed to compile Metal program.");
return;
}
mToken = context.shaderCompiler->createProgram(program.getName(), std::move(program));
assert_invariant(mToken);
}

MTLFunctionConstantValues* constants = [MTLFunctionConstantValues new];
auto const& specializationConstants = program.getSpecializationConstants();
for (auto const& sc : specializationConstants) {
const std::array<MTLDataType, 3> types{
MTLDataTypeInt, MTLDataTypeFloat, MTLDataTypeBool };
std::visit([&sc, constants, type = types[sc.value.index()]](auto&& arg) {
[constants setConstantValue:&arg
type:type
atIndex:sc.id];
}, sc.value);
}
const MetalShaderCompiler::MetalFunctionBundle& MetalProgram::getFunctions() {
initialize();
return mFunctionBundle;
}

id<MTLFunction> function = [library newFunctionWithName:@"main0"
constantValues:constants
error:&error];
if (!program.getName().empty()) {
function.label = @(program.getName().c_str());
}
assert_invariant(function);
*shaderFunctions[i] = function;
void MetalProgram::initialize() {
if (!mToken) {
return;
}

UTILS_UNUSED_IN_RELEASE const bool isRasterizationProgram =
vertexFunction != nil && fragmentFunction != nil;
UTILS_UNUSED_IN_RELEASE const bool isComputeProgram = computeFunction != nil;
// The program must be either a rasterization program XOR a compute program.
assert_invariant(isRasterizationProgram != isComputeProgram);

// All stages of the program have compiled successfully, this is a valid program.
isValid = true;

// Save this program's SamplerGroupInfo, it's used during draw calls to bind sampler groups to
// the appropriate stage(s).
samplerGroupInfo = program.getSamplerGroupInfo();
mFunctionBundle = mContext.shaderCompiler->getProgram(mToken);
assert_invariant(!mToken);
}

MetalTexture::MetalTexture(MetalContext& context, SamplerType target, uint8_t levels,
Expand Down
110 changes: 110 additions & 0 deletions filament/backend/src/metal/MetalShaderCompiler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef TNT_FILAMENT_BACKEND_METAL_METALSHADERCOMPILER_H
#define TNT_FILAMENT_BACKEND_METAL_METALSHADERCOMPILER_H

#include "CompilerThreadPool.h"

#include "CallbackManager.h"

#include <backend/CallbackHandler.h>
#include <backend/Program.h>

#include <utils/CString.h>

#include <Metal/Metal.h>

#include <array>
#include <memory>

namespace filament::backend {

class MetalDriver;

class MetalShaderCompiler {
struct MetalProgramToken;

public:
class MetalFunctionBundle {
public:
MetalFunctionBundle() = default;
MetalFunctionBundle(id<MTLFunction> fragment, id<MTLFunction> vertex)
: functions{fragment, vertex} {
assert_invariant(fragment && vertex);
assert_invariant(fragment.functionType == MTLFunctionTypeFragment);
assert_invariant(vertex.functionType == MTLFunctionTypeVertex);
}
explicit MetalFunctionBundle(id<MTLFunction> compute) : functions{compute, nil} {
assert_invariant(compute);
assert_invariant(compute.functionType == MTLFunctionTypeKernel);
}

std::pair<id<MTLFunction>, id<MTLFunction>> getRasterFunctions() const noexcept {
assert_invariant(functions[0].functionType == MTLFunctionTypeFragment);
assert_invariant(functions[1].functionType == MTLFunctionTypeVertex);
return {functions[0], functions[1]};
}

id<MTLFunction> getComputeFunction() const noexcept {
assert_invariant(functions[0].functionType == MTLFunctionTypeKernel);
return functions[0];
}

explicit operator bool() const { return functions[0] != nil; }

private:
// Can hold two functions, either:
// - fragment and vertex (for rasterization pipelines)
// - compute (for compute pipelines)
id<MTLFunction> functions[2] = {nil, nil};
};

using program_token_t = std::shared_ptr<MetalProgramToken>;

explicit MetalShaderCompiler(id<MTLDevice> device, MetalDriver& driver);

MetalShaderCompiler(MetalShaderCompiler const& rhs) = delete;
MetalShaderCompiler(MetalShaderCompiler&& rhs) = delete;
MetalShaderCompiler& operator=(MetalShaderCompiler const& rhs) = delete;
MetalShaderCompiler& operator=(MetalShaderCompiler&& rhs) = delete;

void init() noexcept;
void terminate() noexcept;

// Creates a program asynchronously
program_token_t createProgram(utils::CString const& name, Program&& program);

// Returns the functions, blocking if necessary. The Token is destroyed and becomes invalid.
MetalFunctionBundle getProgram(program_token_t& token);

// Destroys a valid token and all associated resources. Used to "cancel" a program compilation.
static void terminate(program_token_t& token);

void notifyWhenAllProgramsAreReady(
CallbackHandler* handler, CallbackHandler::Callback callback, void* user);

private:
static MetalFunctionBundle compileProgram(const Program& program, id<MTLDevice> device);

CompilerThreadPool mCompilerThreadPool;
id<MTLDevice> mDevice;
CallbackManager mCallbackManager;
};

} // namespace filament::backend

#endif // TNT_FILAMENT_BACKEND_METAL_METALSHADERCOMPILER_H
Loading

0 comments on commit 1535978

Please sign in to comment.