Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Metal parallel shader compilation #7205

Merged
merged 4 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions filament/backend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ if (FILAMENT_SUPPORTS_METAL)
src/metal/MetalExternalImage.mm
src/metal/MetalHandles.mm
src/metal/MetalPlatform.mm
src/metal/MetalShaderCompiler.mm
src/metal/MetalState.mm
src/metal/MetalTimerQuery.mm
src/metal/MetalUtils.mm
Expand Down
3 changes: 3 additions & 0 deletions filament/backend/src/metal/MetalContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#define TNT_METALCONTEXT_H

#include "MetalResourceTracker.h"
#include "MetalShaderCompiler.h"
#include "MetalState.h"

#include <CoreVideo/CVMetalTextureCache.h>
Expand Down Expand Up @@ -136,6 +137,8 @@ struct MetalContext {

MTLViewport currentViewport;

MetalShaderCompiler* shaderCompiler = nullptr;

#if defined(FILAMENT_METAL_PROFILING)
// Logging and profiling.
os_log_t log;
Expand Down
2 changes: 1 addition & 1 deletion filament/backend/src/metal/MetalDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ namespace backend {
class MetalPlatform;

class MetalBuffer;
class MetalProgram;
class MetalSamplerGroup;
class MetalTexture;
struct MetalUniformBuffer;
struct MetalContext;
struct MetalProgram;
struct BufferState;

#ifndef FILAMENT_METAL_HANDLE_ARENA_SIZE_IN_MB
Expand Down
37 changes: 26 additions & 11 deletions filament/backend/src/metal/MetalDriver.mm
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@
mContext->eventListener = [[MTLSharedEventListener alloc] initWithDispatchQueue:queue];
}

mContext->shaderCompiler = new MetalShaderCompiler(mContext->device, *this);
mContext->shaderCompiler->init();

#if defined(FILAMENT_METAL_PROFILING)
mContext->log = os_log_create("com.google.filament", "Metal");
mContext->signpostId = os_signpost_id_generate(mContext->log);
Expand All @@ -157,6 +160,7 @@
delete mContext->bufferPool;
delete mContext->blitter;
delete mContext->timerQueryImpl;
delete mContext->shaderCompiler;
delete mContext;
}

Expand Down Expand Up @@ -317,7 +321,7 @@
}

void MetalDriver::createProgramR(Handle<HwProgram> rph, Program&& program) {
construct_handle<MetalProgram>(rph, mContext->device, program);
construct_handle<MetalProgram>(rph, *mContext, std::move(program));
}

void MetalDriver::createDefaultRenderTargetR(Handle<HwRenderTarget> rth, int dummy) {
Expand Down Expand Up @@ -566,6 +570,7 @@

MetalExternalImage::shutdown(*mContext);
mContext->blitter->shutdown();
mContext->shaderCompiler->terminate();
}

ShaderModel MetalDriver::getShaderModel() const noexcept {
Expand Down Expand Up @@ -701,7 +706,7 @@
}

bool MetalDriver::isParallelShaderCompileSupported() {
return false;
return true;
}

bool MetalDriver::isWorkaroundNeeded(Workaround workaround) {
Expand Down Expand Up @@ -935,7 +940,7 @@
void MetalDriver::compilePrograms(CompilerPriorityQueue priority,
CallbackHandler* handler, CallbackHandler::Callback callback, void* user) {
if (callback) {
scheduleCallback(handler, user, callback);
mContext->shaderCompiler->notifyWhenAllProgramsAreReady(handler, callback, user);
}
}

Expand Down Expand Up @@ -1447,14 +1452,19 @@
auto program = handle_cast<MetalProgram>(ps.program);
const auto& rs = ps.rasterState;

// This might block until the shader compilation has finished.
auto functions = program->getFunctions();

// If the material debugger is enabled, avoid fatal (or cascading) errors and that can occur
// during the draw call when the program is invalid. The shader compile error has already been
// dumped to the console at this point, so it's fine to simply return early.
if (FILAMENT_ENABLE_MATDBG && UTILS_UNLIKELY(!program->isValid)) {
if (FILAMENT_ENABLE_MATDBG && UTILS_UNLIKELY(!functions)) {
return;
}

ASSERT_PRECONDITION(program->isValid, "Attempting to draw with an invalid Metal program.");
ASSERT_PRECONDITION(bool(functions), "Attempting to draw with an invalid Metal program.");

auto [fragment, vertex] = functions.getRasterFunctions();

// Pipeline state
MTLPixelFormat colorPixelFormat[MRT::MAX_SUPPORTED_RENDER_TARGET_COUNT] = { MTLPixelFormatInvalid };
Expand All @@ -1477,8 +1487,8 @@
assert_invariant(isMetalFormatStencil(stencilPixelFormat));
}
MetalPipelineState pipelineState {
.vertexFunction = program->vertexFunction,
.fragmentFunction = program->fragmentFunction,
.vertexFunction = vertex,
.fragmentFunction = fragment,
.vertexDescription = primitive->vertexDescription,
.colorAttachmentPixelFormat = {
colorPixelFormat[0],
Expand Down Expand Up @@ -1623,7 +1633,7 @@
if (!samplerGroup) {
continue;
}
const auto& stageFlags = program->samplerGroupInfo[s].stageFlags;
const auto& stageFlags = program->getSamplerGroupInfo()[s].stageFlags;
if (stageFlags == ShaderStageFlags::NONE) {
continue;
}
Expand Down Expand Up @@ -1696,21 +1706,26 @@

auto mtlProgram = handle_cast<MetalProgram>(program);

// This might block until the shader compilation has finished.
auto functions = mtlProgram->getFunctions();

// If the material debugger is enabled, avoid fatal (or cascading) errors and that can occur
// during the draw call when the program is invalid. The shader compile error has already been
// dumped to the console at this point, so it's fine to simply return early.
if (FILAMENT_ENABLE_MATDBG && UTILS_UNLIKELY(!mtlProgram->isValid)) {
if (FILAMENT_ENABLE_MATDBG && UTILS_UNLIKELY(!functions)) {
return;
}

assert_invariant(mtlProgram->isValid && mtlProgram->computeFunction);
auto compute = functions.getComputeFunction();

assert_invariant(bool(functions) && compute);

id<MTLComputeCommandEncoder> computeEncoder =
[getPendingCommandBuffer(mContext) computeCommandEncoder];

NSError* error = nil;
id<MTLComputePipelineState> computePipelineState =
[mContext->device newComputePipelineStateWithFunction:mtlProgram->computeFunction
[mContext->device newComputePipelineStateWithFunction:compute
error:&error];
if (error) {
auto description = [error.localizedDescription cStringUsingEncoding:NSUTF8StringEncoding];
Expand Down
18 changes: 11 additions & 7 deletions filament/backend/src/metal/MetalHandles.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,20 @@ struct MetalRenderPrimitive : public HwRenderPrimitive {
VertexDescription vertexDescription = {};
};

struct MetalProgram : public HwProgram {
MetalProgram(id<MTLDevice> device, const Program& program) noexcept;
class MetalProgram : public HwProgram {
public:
MetalProgram(MetalContext& context, Program&& program) noexcept;

id<MTLFunction> vertexFunction;
id<MTLFunction> fragmentFunction;
id<MTLFunction> computeFunction;
const MetalShaderCompiler::MetalFunctionBundle& getFunctions();
const Program::SamplerGroupInfo& getSamplerGroupInfo() { return samplerGroupInfo; }

Program::SamplerGroupInfo samplerGroupInfo;
private:
void initialize();

bool isValid = false;
Program::SamplerGroupInfo samplerGroupInfo;
MetalContext& mContext;
MetalShaderCompiler::MetalFunctionBundle mFunctionBundle;
MetalShaderCompiler::program_token_t mToken;
};

struct PixelBufferShape {
Expand Down
84 changes: 17 additions & 67 deletions filament/backend/src/metal/MetalHandles.mm
Original file line number Diff line number Diff line change
Expand Up @@ -327,78 +327,28 @@ void presentDrawable(bool presentFrame, void* user) {
};
}

MetalProgram::MetalProgram(id<MTLDevice> device, const Program& program) noexcept
: HwProgram(program.getName()), vertexFunction(nil), fragmentFunction(nil),
computeFunction(nil), isValid(false) {
MetalProgram::MetalProgram(MetalContext& context, Program&& program) noexcept
: HwProgram(program.getName()), mContext(context) {

using MetalFunctionPtr = __strong id<MTLFunction>*;

static_assert(Program::SHADER_TYPE_COUNT == 3, "Only vertex, fragment, and/or compute shaders expected.");
MetalFunctionPtr shaderFunctions[3] = { &vertexFunction, &fragmentFunction, &computeFunction };

const auto& sources = program.getShadersSource();
for (size_t i = 0; i < Program::SHADER_TYPE_COUNT; i++) {
const auto& source = sources[i];
// It's okay for some shaders to be empty, they shouldn't be used in any draw calls.
if (source.empty()) {
continue;
}
// Save this program's SamplerGroupInfo, it's used during draw calls to bind sampler groups to
// the appropriate stage(s).
samplerGroupInfo = program.getSamplerGroupInfo();

assert_invariant( source[source.size() - 1] == '\0' );

// the shader string is null terminated and the length includes the null character
NSString* objcSource = [[NSString alloc] initWithBytes:source.data()
length:source.size() - 1
encoding:NSUTF8StringEncoding];
NSError* error = nil;
// When options is nil, Metal uses the most recent language version available.
id<MTLLibrary> library = [device newLibraryWithSource:objcSource
options:nil
error:&error];
if (library == nil) {
if (error) {
auto description =
[error.localizedDescription cStringUsingEncoding:NSUTF8StringEncoding];
utils::slog.w << description << utils::io::endl;
}
PANIC_LOG("Failed to compile Metal program.");
return;
}
mToken = context.shaderCompiler->createProgram(program.getName(), std::move(program));
assert_invariant(mToken);
}

MTLFunctionConstantValues* constants = [MTLFunctionConstantValues new];
auto const& specializationConstants = program.getSpecializationConstants();
for (auto const& sc : specializationConstants) {
const std::array<MTLDataType, 3> types{
MTLDataTypeInt, MTLDataTypeFloat, MTLDataTypeBool };
std::visit([&sc, constants, type = types[sc.value.index()]](auto&& arg) {
[constants setConstantValue:&arg
type:type
atIndex:sc.id];
}, sc.value);
}
const MetalShaderCompiler::MetalFunctionBundle& MetalProgram::getFunctions() {
initialize();
return mFunctionBundle;
}

id<MTLFunction> function = [library newFunctionWithName:@"main0"
constantValues:constants
error:&error];
if (!program.getName().empty()) {
function.label = @(program.getName().c_str());
}
assert_invariant(function);
*shaderFunctions[i] = function;
void MetalProgram::initialize() {
if (!mToken) {
return;
}

UTILS_UNUSED_IN_RELEASE const bool isRasterizationProgram =
vertexFunction != nil && fragmentFunction != nil;
UTILS_UNUSED_IN_RELEASE const bool isComputeProgram = computeFunction != nil;
// The program must be either a rasterization program XOR a compute program.
assert_invariant(isRasterizationProgram != isComputeProgram);

// All stages of the program have compiled successfully, this is a valid program.
isValid = true;

// Save this program's SamplerGroupInfo, it's used during draw calls to bind sampler groups to
// the appropriate stage(s).
samplerGroupInfo = program.getSamplerGroupInfo();
mFunctionBundle = mContext.shaderCompiler->getProgram(mToken);
assert_invariant(!mToken);
}

MetalTexture::MetalTexture(MetalContext& context, SamplerType target, uint8_t levels,
Expand Down
110 changes: 110 additions & 0 deletions filament/backend/src/metal/MetalShaderCompiler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef TNT_FILAMENT_BACKEND_METAL_METALSHADERCOMPILER_H
#define TNT_FILAMENT_BACKEND_METAL_METALSHADERCOMPILER_H

#include "CompilerThreadPool.h"

#include "CallbackManager.h"

#include <backend/CallbackHandler.h>
#include <backend/Program.h>

#include <utils/CString.h>

#include <Metal/Metal.h>

#include <array>
#include <memory>

namespace filament::backend {

class MetalDriver;

class MetalShaderCompiler {
struct MetalProgramToken;

public:
class MetalFunctionBundle {
public:
MetalFunctionBundle() = default;
MetalFunctionBundle(id<MTLFunction> fragment, id<MTLFunction> vertex)
: functions{fragment, vertex} {
assert_invariant(fragment && vertex);
assert_invariant(fragment.functionType == MTLFunctionTypeFragment);
assert_invariant(vertex.functionType == MTLFunctionTypeVertex);
}
bejado marked this conversation as resolved.
Show resolved Hide resolved
explicit MetalFunctionBundle(id<MTLFunction> compute) : functions{compute, nil} {
assert_invariant(compute);
assert_invariant(compute.functionType == MTLFunctionTypeKernel);
}

std::pair<id<MTLFunction>, id<MTLFunction>> getRasterFunctions() const noexcept {
assert_invariant(functions[0].functionType == MTLFunctionTypeFragment);
assert_invariant(functions[1].functionType == MTLFunctionTypeVertex);
return {functions[0], functions[1]};
}

id<MTLFunction> getComputeFunction() const noexcept {
assert_invariant(functions[0].functionType == MTLFunctionTypeKernel);
return functions[0];
}

explicit operator bool() const { return functions[0] != nil; }

private:
// Can hold two functions, either:
// - fragment and vertex (for rasterization pipelines)
// - compute (for compute pipelines)
id<MTLFunction> functions[2] = {nil, nil};
};

using program_token_t = std::shared_ptr<MetalProgramToken>;

explicit MetalShaderCompiler(id<MTLDevice> device, MetalDriver& driver);

MetalShaderCompiler(MetalShaderCompiler const& rhs) = delete;
MetalShaderCompiler(MetalShaderCompiler&& rhs) = delete;
MetalShaderCompiler& operator=(MetalShaderCompiler const& rhs) = delete;
MetalShaderCompiler& operator=(MetalShaderCompiler&& rhs) = delete;

void init() noexcept;
void terminate() noexcept;

// Creates a program asynchronously
program_token_t createProgram(utils::CString const& name, Program&& program);

// Returns the functions, blocking if necessary. The Token is destroyed and becomes invalid.
MetalFunctionBundle getProgram(program_token_t& token);

// Destroys a valid token and all associated resources. Used to "cancel" a program compilation.
static void terminate(program_token_t& token);

void notifyWhenAllProgramsAreReady(
CallbackHandler* handler, CallbackHandler::Callback callback, void* user);

private:
static MetalFunctionBundle compileProgram(const Program& program, id<MTLDevice> device);

CompilerThreadPool mCompilerThreadPool;
id<MTLDevice> mDevice;
CallbackManager mCallbackManager;
};

} // namespace filament::backend

#endif // TNT_FILAMENT_BACKEND_METAL_METALSHADERCOMPILER_H
Loading