From 5d62e9bba5e5959ed8b841ec5d004f776c35540d Mon Sep 17 00:00:00 2001 From: Archermmt Date: Thu, 26 Oct 2023 11:47:22 +0800 Subject: [PATCH] [Unity][MSC][M1.3] Add translate && codegen for tensorrt (#15950) * add tensorrt test * add runtime * remove useless * destroy runtime --- cmake/modules/contrib/MSC.cmake | 4 + python/tvm/contrib/msc/core/ir/translate.py | 1 + .../contrib/msc/core/transform/transform.py | 18 + .../msc/framework/tensorrt/__init__.py | 17 + .../msc/framework/tensorrt/_ffi_api.py | 21 + .../framework/tensorrt/codegen/__init__.py | 19 + .../msc/framework/tensorrt/codegen/codegen.py | 162 ++++ .../msc/framework/tensorrt/codegen/sources.py | 318 +++++++ .../msc/framework/tensorrt/codegen/utils.py | 98 +++ .../framework/tensorrt/frontend/__init__.py | 17 + .../framework/tensorrt/frontend/translate.py | 66 ++ .../framework/tensorrt/transform/__init__.py | 20 + .../framework/tensorrt/transform/pattern.py | 352 ++++++++ .../framework/tensorrt/transform/transform.py | 42 + src/contrib/msc/core/codegen/codegen_json.cc | 66 ++ src/contrib/msc/core/codegen/codegen_json.h | 108 +++ src/contrib/msc/core/codegen/cpp_codegen.h | 109 +++ src/contrib/msc/core/printer/cpp_printer.cc | 243 ++++++ src/contrib/msc/core/printer/cpp_printer.h | 140 +++ src/contrib/msc/core/transform/fuse_tuple.cc | 206 +++++ src/contrib/msc/framework/tensorrt/codegen.cc | 515 +++++++++++ src/contrib/msc/framework/tensorrt/codegen.h | 87 ++ .../msc/framework/tensorrt/codegen_utils.h | 118 +++ .../msc/framework/tensorrt/tensorrt_opcode.cc | 808 +++++++++++++++++ .../msc/framework/tensorrt/tensorrt_opcode.h | 125 +++ .../framework/tensorrt/transform_tensorrt.cc | 748 ++++++++++++++++ src/runtime/contrib/msc/tensorrt_runtime.cc | 293 +++++++ .../test_msc/test_translate_tensorrt.py | 815 ++++++++++++++++++ 28 files changed, 5536 insertions(+) create mode 100644 python/tvm/contrib/msc/framework/tensorrt/__init__.py create mode 100644 python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py create mode 100644 python/tvm/contrib/msc/framework/tensorrt/codegen/__init__.py create mode 100644 python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py create mode 100644 python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py create mode 100644 python/tvm/contrib/msc/framework/tensorrt/codegen/utils.py create mode 100644 python/tvm/contrib/msc/framework/tensorrt/frontend/__init__.py create mode 100644 python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py create mode 100644 python/tvm/contrib/msc/framework/tensorrt/transform/__init__.py create mode 100644 python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py create mode 100644 python/tvm/contrib/msc/framework/tensorrt/transform/transform.py create mode 100644 src/contrib/msc/core/codegen/codegen_json.cc create mode 100644 src/contrib/msc/core/codegen/codegen_json.h create mode 100644 src/contrib/msc/core/codegen/cpp_codegen.h create mode 100644 src/contrib/msc/core/printer/cpp_printer.cc create mode 100644 src/contrib/msc/core/printer/cpp_printer.h create mode 100644 src/contrib/msc/core/transform/fuse_tuple.cc create mode 100644 src/contrib/msc/framework/tensorrt/codegen.cc create mode 100644 src/contrib/msc/framework/tensorrt/codegen.h create mode 100644 src/contrib/msc/framework/tensorrt/codegen_utils.h create mode 100644 src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc create mode 100644 src/contrib/msc/framework/tensorrt/tensorrt_opcode.h create mode 100644 src/contrib/msc/framework/tensorrt/transform_tensorrt.cc create mode 100644 src/runtime/contrib/msc/tensorrt_runtime.cc create mode 100644 tests/python/contrib/test_msc/test_translate_tensorrt.py diff --git a/cmake/modules/contrib/MSC.cmake b/cmake/modules/contrib/MSC.cmake index 45ce776a0864..d2dd6fc14fb1 100644 --- a/cmake/modules/contrib/MSC.cmake +++ b/cmake/modules/contrib/MSC.cmake @@ -22,5 +22,9 @@ if(USE_MSC) tvm_file_glob(GLOB_RECURSE MSC_RUNTIME_SOURCE "src/runtime/contrib/msc/*.cc") list(APPEND RUNTIME_SRCS ${MSC_RUNTIME_SOURCE}) + if(USE_TENSORRT_RUNTIME) + add_definitions("-DTENSORRT_ROOT_DIR=\"${TENSORRT_ROOT_DIR}\"") + endif() + message(STATUS "Build with MSC support...") endif() diff --git a/python/tvm/contrib/msc/core/ir/translate.py b/python/tvm/contrib/msc/core/ir/translate.py index f59b9f7ce888..b5bfa12b677a 100644 --- a/python/tvm/contrib/msc/core/ir/translate.py +++ b/python/tvm/contrib/msc/core/ir/translate.py @@ -314,6 +314,7 @@ def _partition_mod(mod, as_msc=True): passes.extend( [ msc_transform.BindShape(), + msc_transform.FuseTuple(target), tvm.relax.transform.MergeCompositeFunctions(), msc_transform.SetExprName(target=target), msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)), diff --git a/python/tvm/contrib/msc/core/transform/transform.py b/python/tvm/contrib/msc/core/transform/transform.py index 991e1bbf7cc6..24f7d38426f3 100644 --- a/python/tvm/contrib/msc/core/transform/transform.py +++ b/python/tvm/contrib/msc/core/transform/transform.py @@ -100,3 +100,21 @@ def BindShape(entry_name: str = "main") -> tvm.ir.transform.Pass: """ return relax_api.BindShape(entry_name) # type: ignore + + +def FuseTuple(target, entry_name: str = "main") -> tvm.ir.transform.Pass: + """Fuse Tuple and TupleGetItem to target + + Parameters + ---------- + target: str + The byoc target name + entry_name: str + The entry name + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + + return relax_api.FuseTuple(target, entry_name) # type: ignore diff --git a/python/tvm/contrib/msc/framework/tensorrt/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/__init__.py new file mode 100644 index 000000000000..a1c3d532efc6 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorrt""" diff --git a/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py b/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py new file mode 100644 index 000000000000..c0fa9c2c0559 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorrt._ffi_api""" + +import tvm._ffi + +tvm._ffi._init_api("msc.framework.tensorrt", __name__) diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/codegen/__init__.py new file mode 100644 index 000000000000..618a178a2d5b --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorrt.codegen""" + +from .codegen import * diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py new file mode 100644 index 000000000000..574c2cc31b0a --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorrt.codegen.codegen""" + +import os +import subprocess +from typing import Dict, Optional, Tuple, List +import numpy as np + +import tvm +from tvm.contrib.msc.core.ir import MSCGraph +from tvm.contrib.msc.core.codegen import CodeGen +from tvm.contrib.msc.core import utils as msc_utils +from tvm.contrib.msc.core.utils import MSCFramework +from tvm.contrib.msc.framework.tensorrt import _ffi_api +from .sources import get_trt_sources +from .utils import write_weight + + +def to_sub_tensorrt( + graph: MSCGraph, + weights: Optional[Dict[str, tvm.nd.array]] = None, + codegen_config: Optional[Dict[str, str]] = None, + print_config: Optional[Dict[str, str]] = None, + build_folder: msc_utils.MSCDirectory = None, + output_folder: msc_utils.MSCDirectory = None, +) -> str: + """Change MSCGraph to TensorRT engine file. + + Parameters + ---------- + graph: tvm.contrib.msc.core.ir.MSCGraph + The translated graph. + weights: dict of + The parameters of the IRModule. + codegen_config: dict + The config for codegen. + print_config: dict + The config for print. + build_folder: MSCDirectory + The folder for saving sources and datas. + export_folder: MSCDirectory + The folder for saving outputs. + + Returns + ------- + engine: str + The engine file. + """ + + codegen_config = codegen_config or {} + codegen_config["version"] = msc_utils.get_version(MSCFramework.TENSORRT) + if "tensorrt_root" not in codegen_config: + codegen_config["tensorrt_root"] = _ffi_api.GetTensorRTRoot() + build_folder = build_folder or msc_utils.msc_dir(keep_history=False, cleanup=True) + output_folder = output_folder or msc_utils.msc_dir("msc_output") + + def _create_depends(folder: msc_utils.MSCDirectory) -> str: + if weights: + # fill fake weights + runtime_weights = weights + for node in graph.get_nodes(): + if node.optype in ("nn.conv2d", "msc.linear"): + weight = node.weight_at("weight") + bias = np.zeros([weight.dim_at("O")], dtype=weight.dtype_name) + runtime_weights[node.name + ".bias"] = bias + # write weights file + with open(folder.relpath(graph.name + ".wts"), "w") as f: + f.write("{}\n".format(len(runtime_weights))) + for name, data in runtime_weights.items(): + if isinstance(data, np.ndarray): + write_weight(name, data, f) + else: + write_weight(name, data.asnumpy(), f) + # save utils sources + with folder.create_dir("utils") as utils_folder: + for name, source in get_trt_sources().items(): + utils_folder.add_file(name, source) + + def _build_engine(engine_name: str, folder: msc_utils.MSCDirectory) -> str: + with open("engine.log", "w") as log_f: + process = subprocess.Popen("./" + engine_name, stdout=log_f, stderr=log_f, shell=True) + process.wait() + assert ( + process.returncode == 0 + ), "Failed to test engine {} under {}, check engine.log for detail".format( + engine_name, os.getcwd() + ) + return folder.move_file(engine_name + ".trt", output_folder.create_dir(graph.name)) + + codegen = CodeGen( + graph, + _ffi_api.GetTensorRTSources, + codegen_config, + print_config, + build_folder.create_dir(graph.name), + code_format="cpp", + ) + engine_file = codegen.load([], pre_load=_create_depends, post_load=_build_engine) + return { + "graph_json": graph.to_json(), + "engine": engine_file, + } + + +def to_tensorrt( + mod: tvm.IRModule, + graph_infos: List[Tuple[str, MSCGraph, Dict[str, tvm.nd.array]]], + codegen_config: Optional[Dict[str, str]] = None, + print_config: Optional[Dict[str, str]] = None, + build_folder: msc_utils.MSCDirectory = None, + output_folder: msc_utils.MSCDirectory = None, +) -> Dict[str, str]: + """Change all MSCGraphs to TensorRT engine files. + + Parameters + ---------- + mod: IRModule + The IRModule of relax. + graph_infos: list + The translated graph. + codegen_config: dict + The config for codegen. + print_config: dict + The config for print. + build_folder: MSCDirectory + The folder for saving sources and datas. + export_folder: MSCDirectory + The folder for saving outputs. + + Returns + ------- + mod: IRModule + The translated mod with target func. + """ + + target_options = {} + for name, graph, weights in graph_infos: + options = to_sub_tensorrt( + graph, weights, codegen_config, print_config, build_folder, output_folder + ) + target_options[name] = msc_utils.dump_dict(options) + mod = tvm.transform.Sequential( + [ + tvm.relax.transform.RunCodegen({"msc_tensorrt": target_options}), + ] + )(mod) + return mod diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py b/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py new file mode 100644 index 000000000000..b6497e9258b7 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py @@ -0,0 +1,318 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorrt.codegen.sources""" + +from typing import Dict + +from tvm.contrib.msc.core.codegen import get_base_sources + + +def get_trt_common_h_code() -> str: + """Create trt_common header file codes + + Returns + ------- + source: str + The trt_common header source. + """ + + return """#ifndef TVM_CONTRIB_MSC_UTILS_TRT_COMMON_H_ +#define TVM_CONTRIB_MSC_UTILS_TRT_COMMON_H_ + +#include +#include +#include +#include +#include +#include + +#include "NvInfer.h" + +namespace tvm { +namespace contrib { +namespace msc { + +using namespace nvinfer1; + +#ifndef TRT_VERSION_GE +#define TRT_VERSION_GE(major, minor, patch) \\ + ((TRT_MAJOR > major) || (TRT_MAJOR == major && TRT_MINOR > minor) || \\ + (TRT_MAJOR == major && TRT_MINOR == minor && TRT_PATCH >= patch)) +#endif + +#if TRT_VERSION_GE(8, 0, 0) +#define TRT_NOEXCEPT noexcept +#else +#define TRT_NOEXCEPT +#endif + +#define CHECK(status) \\ + do { \\ + auto ret = (status); \\ + if (ret != 0) { \\ + std::cout << "Cuda failure: " << ret << std::endl; \\ + abort(); \\ + } \\ + } while (0) + +class TRTLogger : public ILogger { + public: + TRTLogger() : TRTLogger(Severity::kINFO) {} + explicit TRTLogger(Severity severity) { severity_ = severity; } + void log(Severity severity, const char* msg) noexcept override { + if (severity > severity_) return; + + switch (severity) { + case Severity::kINTERNAL_ERROR: + std::cout << "[MSC.INTERNAL_ERROR]: " << msg << std::endl; + break; + case Severity::kERROR: + std::cout << "[MSC.ERROR]: " << msg << std::endl; + break; + case Severity::kWARNING: + std::cout << "[MSC.WARNING]: " << msg << std::endl; + break; + case Severity::kINFO: + std::cout << "[MSC.INFO]: " << msg << std::endl; + break; + case Severity::kVERBOSE: + std::cout << "[MSC.VERBOSE]: " << msg << std::endl; + break; + default: + std::cout << "[MSC.UNKNOWN]: " << msg << std::endl; + break; + } + } + + void setLogSeverity(Severity severity) { severity_ = severity; } + + private: + Severity severity_; +}; + +struct InferDeleter { + template + void operator()(T* obj) const { + if (obj) { +#if TRT_VERSION_GE(8, 0, 0) + delete obj; +#else + obj->destroy(); +#endif + } + } +}; + +template +using TRTPtr = std::unique_ptr; + +class TRTUtils { + public: + static const std::string TensorInfo(ILayer* layer, size_t id = 0); + + static std::map LoadWeights(const std::string& file); + +#if TRT_VERSION_GE(6, 0, 0) + static bool SerializeEngineToFile(const std::string& file, TRTPtr& builder, + TRTPtr& network, + TRTPtr& config, TRTLogger& logger); +#else + static bool SerializeEngineToFile(const std::string& file, TRTPtr& builder, + TRTPtr& network, TRTLogger& logger); + +#endif + + static bool DeserializeEngineFromFile(const std::string& file, + std::shared_ptr& engine, TRTLogger& logger); +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm + +#endif // TVM_CONTRIB_MSC_UTILS_TRT_COMMON_H_ +""" + + +def get_trt_common_cc_code() -> str: + """Create trt_common cc file codes + + Returns + ------- + source: str + The trt_common cc source. + """ + + return """#include "trt_common.h" + +namespace tvm { +namespace contrib { +namespace msc { + +const std::string TRTUtils::TensorInfo(ILayer* layer, size_t id) { + std::string info = "S:"; + Dims dims = layer->getOutput(id)->getDimensions(); + for (int i = 0; i < dims.nbDims; i++) { + info += std::to_string(dims.d[i]) + ';'; + } + DataType dtype = layer->getOutput(id)->getType(); + info += " D:"; + if (dtype == DataType::kFLOAT) { + info += "float32"; + } else if (dtype == DataType::kHALF) { + info += "float16"; + } else if (dtype == DataType::kINT32) { + info += "int32"; + } else if (dtype == DataType::kINT8) { + info += "int8"; + } else if (dtype == DataType::kBOOL) { + info += "bool"; + } else { + info += "unknown"; + } + return info; +} + +std::map TRTUtils::LoadWeights(const std::string& file) { + std::map weightMap; + // Open weights file + std::ifstream input(file, std::ios::binary); + assert(input.is_open() && ("Failed to open file " + file).c_str()); + + // Read number of weight blobs + int32_t count; + input >> count; + assert(count > 0 && "Invalid weight map file."); + std::cout << "Find " << count << " weigths in the file : " << file << std::endl; + + while (count--) { + Weights wt{DataType::kFLOAT, nullptr, 0}; + uint32_t type, size; + // Read name and type of blob + std::string name; + input >> name >> std::dec >> type >> size; + wt.type = static_cast(type); + + // Load blob + if (wt.type == DataType::kFLOAT) { + uint32_t* val = reinterpret_cast(malloc(sizeof(val) * size)); + for (uint32_t x = 0; x < size; ++x) { + input >> std::hex >> val[x]; + } + wt.values = val; + } else if (wt.type == DataType::kHALF) { + uint16_t* val = reinterpret_cast(malloc(sizeof(val) * size)); + for (uint32_t x = 0; x < size; ++x) { + input >> std::hex >> val[x]; + } + wt.values = val; + } + wt.count = size; + weightMap[name] = wt; + } + input.close(); + return weightMap; +} + +#if TRT_VERSION_GE(6, 0, 0) +bool TRTUtils::SerializeEngineToFile(const std::string& file, TRTPtr& builder, + TRTPtr& network, + TRTPtr& config, TRTLogger& logger) { +#if TRT_VERSION_GE(8, 0, 0) + auto plan = TRTPtr(builder->buildSerializedNetwork(*network, *config)); +#else + auto engine = TRTPtr(builder->buildEngineWithConfig(*network, *config)); + if (!engine) { + logger.log(ILogger::Severity::kERROR, "Failed to build engine"); + return false; + } + auto plan = TRTPtr(engine->serialize()); +#endif + if (!plan) { + logger.log(ILogger::Severity::kERROR, "Failed to serialize network"); + return false; + } + std::ofstream ofs(file, std::ios::out | std::ios::binary); + assert(ofs.is_open() && ("Failed to open file " + file).c_str()); + ofs.write((char*)(plan->data()), plan->size()); + ofs.close(); + return true; +} +#else +bool TRTUtils::SerializeEngineToFile(const std::string& file, TRTPtr& builder, + TRTPtr& network, TRTLogger& logger) { + auto engine = TRTPtr(builder->buildCudaEngine(*network)); + if (!engine) { + logger.log(ILogger::Severity::kERROR, "Failed to build engine"); + return false; + } + auto plan = TRTPtr(engine->serialize()); + if (!plan) { + logger.log(ILogger::Severity::kERROR, "Failed to serialize network"); + return false; + } + std::ofstream ofs(file, std::ios::out | std::ios::binary); + assert(ofs.is_open() && ("Failed to open file " + file).c_str()); + ofs.write((char*)(plan->data()), plan->size()); + ofs.close(); + return true; +} +#endif + +bool TRTUtils::DeserializeEngineFromFile(const std::string& file, + std::shared_ptr& engine, TRTLogger& logger) { + std::vector stream; + size_t size{0}; + std::ifstream input(file, std::ifstream::binary); + assert(input.is_open() && ("Failed to open file " + file).c_str()); + if (input.good()) { + input.seekg(0, input.end); + size = input.tellg(); + input.seekg(0, input.beg); + stream.resize(size); + input.read(stream.data(), size); + input.close(); + } + logger.log(ILogger::Severity::kINFO, + ("size of engine from " + file + " is " + std::to_string(size)).c_str()); + auto runtime = TRTPtr(createInferRuntime(logger)); + engine = std::shared_ptr( + runtime->deserializeCudaEngine(stream.data(), size, nullptr), InferDeleter()); + input.close(); + return true; +} + +} // namespace msc +} // namespace contrib +} // namespace tvm +""" + + +def get_trt_sources() -> Dict[str, str]: + """Create trt sources for cpp codegen + + Returns + ------- + sources: dict + The trt utils sources. + """ + + sources = get_base_sources() + sources.update( + {"trt_common.h": get_trt_common_h_code(), "trt_common.cc": get_trt_common_cc_code()} + ) + return sources diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/utils.py b/python/tvm/contrib/msc/framework/tensorrt/codegen/utils.py new file mode 100644 index 000000000000..803d5a2afa23 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/utils.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorrt.codegen.utils""" + +import io +import struct +import numpy as np + + +def enum_dtype(array: np.ndarray) -> int: + """Get TensorRT DType enum from array. + + Parameters + ---------- + array: np.ndarray + The source array. + + Returns + ------- + dtype: int + The dtype enum. + """ + + if array.dtype == np.float32: + return 0 + if array.dtype == np.float16: + return 1 + if array.dtype == np.int8: + return 2 + if array.dtype == np.int32: + return 3 + raise Exception("Unexpected dtype {}, no matching tensorrt dtype".format(array.dtype)) + + +def float_to_hex(value: float) -> str: + """Change float to hex. + + Parameters + ---------- + value: float + The float value. + + Returns + ------- + hex: str + The hex format string. + """ + + return hex(struct.unpack(" str: + """Change array to hex. + + Parameters + ---------- + array: np.ndarray + The source array. + + Returns + ------- + hex: str + The hex format string. + """ + + return " ".join([float_to_hex(float(f))[2:] for f in array.flatten()]) + + +def write_weight(name: str, weight: np.ndarray, f_handler: io.TextIOWrapper): + """Write array to file in TensorRT format. + + Parameters + ---------- + name: str + The array name + weight: np.ndarray + The weight data. + f_handler: io.TextIOWrapper + The file handler + """ + + f_handler.write( + "{} {} {} {}\n".format(name, enum_dtype(weight), weight.size, array_to_hex(weight)) + ) diff --git a/python/tvm/contrib/msc/framework/tensorrt/frontend/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/frontend/__init__.py new file mode 100644 index 000000000000..85b163d7c667 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/frontend/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorrt.frontend""" diff --git a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py new file mode 100644 index 000000000000..845a66139645 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.torch.frontend.translate""" + +from typing import Dict, Optional, Tuple, List + +import tvm +from tvm import relax +from tvm.contrib.msc.core import transform as msc_transform +from tvm.contrib.msc.core.ir import MSCGraph, byoc_partition +from tvm.contrib.msc.framework.tensorrt import transform as trt_transform + + +def partition_for_tensorrt( + mod: tvm.IRModule, + params: Optional[Dict[str, tvm.nd.array]] = None, + trans_config: Optional[Dict[str, str]] = None, + build_config: Optional[Dict[str, str]] = None, + allow_incomplete: bool = True, +) -> Tuple[tvm.IRModule, List[Tuple[str, MSCGraph, Dict[str, tvm.nd.array]]]]: + """Partition module to tensorrt sub functions. + + Parameters + ---------- + mod: IRModule + The IRModule of relax. + trans_config: dict + The config for transform IRModule. + params: dict of + The parameters of the IRModule. + build_config: dict + The config for build MSCGraph. + allow_incomplete: bool + Whether allow some ops not on tensorrt + + Returns + ------- + mod: IRModule + The IRModule of partitioned relax. + graphs_info: list<> + The func list, each element for a sub graph. + """ + + trans_config = trans_config or {} + mod = tvm.transform.Sequential( + [ + msc_transform.SetExprName(), + trt_transform.TransformTensorRT(trans_config.get("version")), + relax.transform.FoldConstant(), + ] + )(mod) + return byoc_partition("msc_tensorrt", mod, params, trans_config, build_config, allow_incomplete) diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/transform/__init__.py new file mode 100644 index 000000000000..7bf054f5461b --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/transform/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorrt.transform""" + +from .pattern import * +from .transform import * diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py new file mode 100644 index 000000000000..fe42ee2ae7c7 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py @@ -0,0 +1,352 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""tvm.contrib.msc.framework.tensorrt.transform.pattern""" + +from typing import Mapping, Tuple, List, Union, Callable +from functools import wraps + +from tvm import relax +from tvm.relax.dpl import pattern +from tvm.relax.transform import PatternCheckContext, FusionPattern +from tvm.relax.backend.pattern_registry import register_patterns +from tvm.contrib.msc.core.transform import pattern as msc_pattern + + +def basic_pattern( + op_name: str, input_types: List[str] = None +) -> Tuple[pattern.DFPattern, Mapping[str, pattern.DFPattern]]: + """create basic pattern for tensorrt support ops. + + Parameters + ---------- + op_name: str + The name of a Relax op, such as "relax.nn.conv2d" + input_types: list + The input types, elach element can be input| constant + + Returns + ------- + out: tvm.relax.dpl.pattern.DFPattern + The resulting pattern describing the operation. + + annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. + """ + + input_types = input_types or ["input"] + inputs = [] + for i_type in input_types: + if i_type == "input": + inputs.append(pattern.wildcard()) + elif i_type == "constant": + inputs.append(pattern.is_const()) + else: + raise Exception("Unexpected input type " + str(i_type)) + out = pattern.is_op(op_name)(*inputs) + annotations = {"input_" + str(idx): arg for idx, arg in enumerate(inputs)} + annotations["out"] = out + return out, annotations + + +def elemwise_pattern(op_name: str) -> Tuple[pattern.DFPattern, Mapping[str, pattern.DFPattern]]: + """create elemwise pattern for tensorrt support ops. + + Parameters + ---------- + op_name: str + The name of a Relax op, such as "relax.add" + + Returns + ------- + out: tvm.relax.dpl.pattern.DFPattern + The resulting pattern describing the operation. + + annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. + """ + + return basic_pattern(op_name, ["input", "input"]) + + +def argmaxmin_pattern(op_name: str) -> Tuple[pattern.DFPattern, Mapping[str, pattern.DFPattern]]: + """create argmaxmin pattern for tensorrt support ops. + + Parameters + ---------- + op_name: str + The name of a Relax op, such as "relax.argmax" + + Returns + ------- + out: tvm.relax.dpl.pattern.DFPattern + The resulting pattern describing the operation. + + annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. + """ + + data = pattern.wildcard() + argmaxmin = pattern.is_op(op_name)(data) + out = pattern.is_op("relax.astype")(argmaxmin) + return out, {"input": data, "argmaxmin": argmaxmin, "out": out} + + +def _check_expr(expr: relax.Expr, dtypes: Tuple[str] = None) -> bool: + """Check if the expr can be fused on tensorrt. + + Parameters + ---------- + expr: relax.Expr + The expr to be check + dtype: tuple + The accept dtypes + + Returns + ------- + pass: bool + Whether the expr is correct. + """ + + if isinstance(expr, relax.ShapeExpr): + return True + if isinstance(expr, relax.PrimValue): + return True + if isinstance(expr, relax.Tuple): + return all(_check_expr(field) for field in expr.fields) + if any(i < 0 for i in expr.struct_info.shape.values): + return False + dtypes = dtypes or ("float32", "float16") + if expr.struct_info.dtype not in dtypes: + return False + return True + + +def _basic_check(context: PatternCheckContext) -> bool: + """Check if the basic pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + + for _, expr in context.annotated_expr.items(): + if not _check_expr(expr): + return False + return True + + +def _argmaxmin_check(context: PatternCheckContext) -> bool: + """Check if the argmaxmin pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + + if not _check_expr(context.annotated_expr["input"]): + return False + return _check_expr(context.annotated_expr["out"], ("int32")) + + +def _compare_check(context: PatternCheckContext) -> bool: + """Check if the compare pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + + if any(not _check_expr(context.annotated_expr[key]) for key in ["input_0", "input_1"]): + return False + if not _check_expr(context.annotated_expr["out"], ("bool")): + return False + ndim_a = len(context.annotated_expr["input_0"].struct_info.shape.values) + ndim_b = len(context.annotated_expr["input_1"].struct_info.shape.values) + return ndim_a == ndim_b + + +def _elemwise_check(context: PatternCheckContext) -> bool: + """Check if the elemwise pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + + if not _basic_check(context): + return False + ndim_a = len(context.annotated_expr["input_0"].struct_info.shape.values) + ndim_b = len(context.annotated_expr["input_1"].struct_info.shape.values) + return ndim_a == ndim_b + + +def _reshape_check(context: PatternCheckContext) -> bool: + """Check if the reshape pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + + dtypes = ("float32", "float16", "int32") + if any(not _check_expr(context.annotated_expr[key], dtypes) for key in ["input_0", "out"]): + return False + return True + + +def _take_check(context: PatternCheckContext) -> bool: + """Check if the take pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + + if any(not _check_expr(context.annotated_expr[key]) for key in ["input_0", "out"]): + return False + return _check_expr(context.annotated_expr["input_1"], ("int32")) + + +def wrap_basic_check( + func: Callable[[PatternCheckContext], bool] +) -> Callable[[PatternCheckContext], bool]: + """Wrapper a checker with basic check + + Returns + ------- + checker: PatternCheckContext + The wrapped checker. + """ + + @wraps(func) + def wrapper(context): + if not _basic_check(context): + return False + return func(context) + + return wrapper + + +CheckFunc = Callable[[Mapping[pattern.DFPattern, relax.Expr], relax.Expr], bool] +Pattern = Union[ + FusionPattern, + Tuple[str, pattern.DFPattern], + Tuple[str, pattern.DFPattern, Mapping[str, pattern.DFPattern]], + Tuple[str, pattern.DFPattern, Mapping[str, pattern.DFPattern], CheckFunc], +] + + +def get_patterns(target) -> List[Pattern]: + """Get all the tensorrt patterns. + + Parameters + ---------- + target: str + The target name for tensorrt patterns. + + Returns + ------- + patterns: list + The patterns + """ + + basic_ops = { + "nn.adaptive_avg_pool2d": ["input"], + "nn.avg_pool2d": ["input"], + "nn.conv2d": ["input", "constant"], + "nn.max_pool2d": ["input"], + "concat": ["input"], + "clip": ["input", "input", "input"], + "image.resize2d": ["input", "input"], + "matmul": ["input", "input"], + "permute_dims": ["input"], + "strided_slice": ["input"], + } + activation_ops = ["nn.relu", "nn.softmax", "sigmoid", "tanh"] + reduce_ops = ["max", "min", "mean", "sum"] + unary_ops = ["cos", "exp", "negative", "round", "sin", "square", "sqrt", "tan"] + elemwise_ops = [ + "add", + "divide", + "floor_divide", + "maximum", + "minimum", + "multiply", + "power", + "subtract", + ] + compare_ops = ["greater", "less"] + patterns = [] + # basic ops + for op, in_types in basic_ops.items(): + patterns.append((target + "." + op, *basic_pattern("relax." + op, in_types), _basic_check)) + # activation ops + for op in activation_ops: + patterns.append((target + "." + op, *basic_pattern("relax." + op, ["input"]), _basic_check)) + # reduce ops + for op in reduce_ops: + patterns.append((target + "." + op, *basic_pattern("relax." + op, ["input"]), _basic_check)) + # unary ops + for op in unary_ops: + patterns.append((target + "." + op, *basic_pattern("relax." + op, ["input"]), _basic_check)) + # elemwise ops + for op in elemwise_ops: + patterns.append((target + "." + op, *elemwise_pattern("relax." + op), _elemwise_check)) + # compare ops + for op in compare_ops: + patterns.append((target + "." + op, *elemwise_pattern("relax." + op), _compare_check)) + + # special ops + patterns.extend( + [ + (target + ".take", *basic_pattern("relax.take", ["input", "input"]), _take_check), + (target + ".argmax", *argmaxmin_pattern("relax.argmax"), _argmaxmin_check), + (target + ".argmin", *argmaxmin_pattern("relax.argmin"), _argmaxmin_check), + ( + target + ".reshape", + *basic_pattern("relax.reshape", ["input", "input"]), + _reshape_check, + ), + ] + ) + # fusable ops + patterns.extend( + [ + ( + target + ".msc.conv2d_bias", + *msc_pattern.make_opt_relax_conv_bias_pattern("relax.nn.conv2d"), + wrap_basic_check(msc_pattern._check_opt_relax_conv_bias), + ), + ] + ) + return patterns + + +register_patterns(get_patterns("msc_tensorrt")) diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py b/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py new file mode 100644 index 000000000000..d6f15c43dacd --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""tvm.contrib.msc.framework.tensorrt.transform.transform""" + +from typing import List + +import tvm +from tvm.relax.transform import _ffi_api as relax_api +from tvm.contrib.msc.core.utils import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + +def TransformTensorRT(version: List[int] = None) -> tvm.ir.transform.Pass: + """Transform the Function to fit TensorRT. + + Parameters + ---------- + version: list + The tensorrt version. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + + version = version or msc_utils.get_version(MSCFramework.TENSORRT) + return relax_api.TransformTensorRT(version) # type: ignore diff --git a/src/contrib/msc/core/codegen/codegen_json.cc b/src/contrib/msc/core/codegen/codegen_json.cc new file mode 100644 index 000000000000..7bbe576b6bfe --- /dev/null +++ b/src/contrib/msc/core/codegen/codegen_json.cc @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/codegen/codegen_json.cc + */ + +#include "codegen_json.h" + +#include + +namespace tvm { +namespace contrib { +namespace msc { + +std::vector MSCJSONSerializer::VisitExpr_(const CallNode* call_node) { + const auto& ref_node = graph_->FindNode(SpanUtils::GetAttr(call_node->span, "name")); + std::vector inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = + std::make_shared(ref_node->name, "kernel", inputs, ref_node->outputs.size()); + // add attributes + AddNodeAttr(node, "optype", ref_node->optype); + for (const auto& pair : ref_node->attrs) { + AddNodeAttr(node, pair.first, pair.second); + } + if (!global_options_set_) { + AddNodeAttr(node, "msc_global_options_num", std::to_string(options_.size())); + for (const auto& pair : options_) { + AddNodeAttr(node, "msc_global_" + pair.first, pair.second); + } + global_options_set_ = true; + } + return AddNode(node, GetRef(call_node)); +} + +void MSCJSONSerializer::AddNodeAttr(JSONGraphObjectPtr node, const String& key, + const String& value) { + std::vector array_value{std::string(value)}; + std::vector dmlc_value; + dmlc_value.emplace_back(array_value); + node->SetAttr(std::string(key), dmlc_value); +} + +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/msc/core/codegen/codegen_json.h b/src/contrib/msc/core/codegen/codegen_json.h new file mode 100644 index 000000000000..dfc2d699a968 --- /dev/null +++ b/src/contrib/msc/core/codegen/codegen_json.h @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/codegen/codegen_json.h + * \brief Basic JSONSerializer for MSC runnable BYOC. + */ +#ifndef TVM_CONTRIB_MSC_CORE_CODEGEN_CODEGEN_JSON_H_ +#define TVM_CONTRIB_MSC_CORE_CODEGEN_CODEGEN_JSON_H_ + +#include +#include +#include + +#include "../../../../relax/backend/contrib/codegen_json/codegen_json.h" +#include "../ir/graph.h" + +namespace tvm { +namespace contrib { +namespace msc { + +using namespace tvm::relax; + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; +using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr; +using JSONSerializer = backend::contrib::JSONSerializer; + +/*! + * \brief MSCCompileConfig defines config for all BYOC + */ +struct MSCCompileConfig { + std::string graph_json; + std::unordered_map options; + void Load(dmlc::JSONReader* reader) { + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "graph_json") { + reader->Read(&graph_json); + } else { + std::string value; + reader->Read(&value); + options.insert({key, value}); + } + } + } +}; + +class MSCJSONSerializer : public JSONSerializer { + public: + /*! + * \brief Constructor + * \param constant_names The names of all constants in the original module. + */ + explicit MSCJSONSerializer(const Map& constant_names, + const std::string& options) + : JSONSerializer(constant_names) { + MSCCompileConfig config; + std::istringstream is(options); + dmlc::JSONReader reader(&is); + reader.Read(&config); + ICHECK(config.graph_json.size() > 0) << "graph_json is needed to init MSCGraph"; + graph_ = MSCGraph(config.graph_json); + for (const auto& pair : config.options) { + options_.Set(pair.first, pair.second); + } + global_options_set_ = false; + } + + std::vector VisitExpr_(const CallNode* call_node) final; + + const String GetOption(const String& key) { + ICHECK(options_.count(key)) << "Can not find option " << key; + return options_[key]; + } + + const Map GetOptions() { return options_; } + + protected: + void AddNodeAttr(JSONGraphObjectPtr node, const String& key, const String& value); + + private: + MSCGraph graph_; + Map options_; + bool global_options_set_; +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_CORE_CODEGEN_CODEGEN_JSON_H_ diff --git a/src/contrib/msc/core/codegen/cpp_codegen.h b/src/contrib/msc/core/codegen/cpp_codegen.h new file mode 100644 index 000000000000..0f4f68c63669 --- /dev/null +++ b/src/contrib/msc/core/codegen/cpp_codegen.h @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/codegen/cpp_codegen.h + * \brief CPP codegen for MSCGraph. + */ +#ifndef TVM_CONTRIB_MSC_CORE_CODEGEN_CPP_CODEGEN_H_ +#define TVM_CONTRIB_MSC_CORE_CODEGEN_CPP_CODEGEN_H_ + +#include +#include + +#include + +#include "../printer/cpp_printer.h" +#include "base_codegen.h" +#include "code_stack.h" +#include "codegen_utils.h" + +namespace tvm { +namespace contrib { +namespace msc { + +using namespace tvm::script::printer; + +template +class CppCodeGen : public BaseCodeGen { + public: + /*! + * \brief The constructor of PyCodeGen + * \param graph the graph to be generated. + * \param config the options for codegen. + */ + explicit CppCodeGen(const MSCGraph& graph, const std::string& config = "") + : BaseCodeGen(graph, config) {} + + /*! \brief Stack the docs for the class declare*/ + virtual void CodeGenClassDeclare() = 0; + + /*! \brief Stack the docs for the class define*/ + virtual void CodeGenClassDefine() = 0; + + /*! \brief Stack the docs for the main func*/ + virtual void CodeGenMain() = 0; + + /*! \brief Stack the docs for the class define*/ + virtual void CodeGenCmake() = 0; + + /*! \brief Get sources*/ + virtual const Map GetSources(const std::string& print_options = "") { + Map sources; + auto add_source = [&print_options, &sources, this](const String& file) { + CppPrinter printer(print_options); + for (const auto& d : this->stack_.GetDocs()) { + printer.Append(d); + } + sources.Set(file, printer.GetString()); + this->stack_.Reset(); + }; + // class declare + CodeGenClassDeclare(); + add_source(this->graph()->name + ".h"); + // class define + CodeGenClassDefine(); + add_source(this->graph()->name + ".cc"); + // main func + CodeGenMain(); + add_source("main.cc"); + // cmakelists + CodeGenCmake(); + add_source("CMakeLists.txt"); + return sources; + } + + protected: + void StartNamespace() { + this->stack_.line("namespace tvm {").line("namespace contrib {").line("namespace msc {").line(); + } + + void EndNamespace() { + this->stack_.line() + .line("} // namespace tvm") + .line("} // namespace contrib") + .line("} // namespace msc") + .line(); + } +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_CORE_CODEGEN_CPP_CODEGEN_H_ diff --git a/src/contrib/msc/core/printer/cpp_printer.cc b/src/contrib/msc/core/printer/cpp_printer.cc new file mode 100644 index 000000000000..728f609c00df --- /dev/null +++ b/src/contrib/msc/core/printer/cpp_printer.cc @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/printer/cpp_printer.cc + */ + +#include "cpp_printer.h" + +namespace tvm { +namespace contrib { +namespace msc { + +void CppPrinter::PrintTypedDoc(const LiteralDoc& doc) { + const ObjectRef& value = doc->value; + bool defined = false; + if (!value.defined()) { + output_ << "nullptr"; + defined = true; + } else if (const auto* int_imm = value.as()) { + if (int_imm->dtype.is_bool()) { + output_ << (int_imm->value ? "true" : "false"); + defined = true; + } + } + if (!defined) { + MSCBasePrinter::PrintTypedDoc(doc); + } +} + +void CppPrinter::PrintTypedDoc(const IndexDoc& doc) { + ICHECK(doc->indices.size() == 1) << "CppPrinter only support 1 size indices"; + PrintDoc(doc->value, false); + output_ << "["; + PrintDoc(doc->indices[0], false); + output_ << "]"; +} + +void CppPrinter::PrintTypedDoc(const AttrAccessDoc& doc) { + PrintDoc(doc->value, false); + if (!doc->value->IsInstance()) { + output_ << "."; + } + output_ << doc->name; +} + +void CppPrinter::PrintTypedDoc(const CallDoc& doc) { + EnterEndlineScope(false); + PrintDoc(doc->callee, false); + output_ << "("; + PrintJoinedDocs(doc->args); + ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size()) + << "CallDoc should have equal number of elements in kwargs_keys and kwargs_values."; + if (doc->args.size() > 0 && doc->kwargs_keys.size() > 0) { + output_ << ", "; + } + PrintJoinedDocs(doc->kwargs_values); + output_ << ")"; + ExitEndlineScope(); + Endline(); +} + +void CppPrinter::PrintTypedDoc(const AssignDoc& doc) { + ICHECK(doc->lhs.defined()) << "lhs should be given for assign"; + if (doc->annotation.defined()) { + PrintDoc(doc->annotation.value(), false); + output_ << " "; + } + PrintDoc(doc->lhs, false); + if (doc->rhs.defined()) { + output_ << " = "; + EnterEndlineScope(false); + PrintDoc(doc->rhs.value(), false); + ExitEndlineScope(); + Endline(); + } +} + +void CppPrinter::PrintTypedDoc(const IfDoc& doc) { + MaybePrintComment(doc, true); + output_ << "if ("; + PrintDoc(doc->predicate, false); + output_ << ") {"; + PrintIndentedBlock(doc->then_branch); + if (!doc->else_branch.empty()) { + NewLine(); + output_ << "} else {"; + PrintIndentedBlock(doc->else_branch); + } + NewLine(); + output_ << "}"; +} + +void CppPrinter::PrintTypedDoc(const WhileDoc& doc) { + MaybePrintComment(doc, true); + output_ << "while ("; + PrintDoc(doc->predicate, false); + output_ << ") {"; + PrintIndentedBlock(doc->body); + NewLine(); + output_ << "}"; +} + +void CppPrinter::PrintTypedDoc(const ForDoc& doc) { + MaybePrintComment(doc, true); + if (doc->rhs->IsInstance()) { + const auto& tuple = Downcast(doc->rhs); + ICHECK_EQ(tuple->elements.size(), 2) << "For with tuple should has 2 elements"; + output_ << "for (size_t "; + PrintDoc(doc->lhs, false); + output_ << " = "; + PrintDoc(tuple->elements[0], false); + output_ << "; "; + PrintDoc(doc->lhs, false); + output_ << " < "; + PrintDoc(tuple->elements[1], false); + output_ << "; "; + PrintDoc(doc->lhs, false); + output_ << "++"; + } else { + output_ << "for (const auto& "; + PrintDoc(doc->lhs, false); + output_ << " : "; + PrintDoc(doc->rhs, false); + } + output_ << ") {"; + PrintIndentedBlock(doc->body); + NewLine(); + output_ << "}"; +} + +void CppPrinter::PrintTypedDoc(const ScopeDoc& doc) { + MaybePrintComment(doc, true); + ICHECK(doc->rhs.defined()) << "rhs should be given for scope"; + PrintDoc(doc->rhs); + PrintIndentedBlock(doc->body); +} + +void CppPrinter::PrintTypedDoc(const FunctionDoc& doc) { + MaybePrintComment(doc, true); + for (const AssignDoc& arg_doc : doc->args) { + ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment attached to them."; + } + if (doc->return_type.defined()) { + PrintDoc(doc->return_type.value(), false); + } else { + output_ << "void"; + } + output_ << " "; + PrintDoc(doc->name, false); + output_ << "("; + PrintJoinedDocs(doc->args, ", "); + output_ << ")"; + if (doc->body.size() > 0) { + output_ << " {"; + PrintIndentedBlock(doc->body); + if (doc->return_type.defined()) { + Endline(); + } + NewLine(); + output_ << "}"; + } else { + Endline(); + } + NewLine(false); +} + +void CppPrinter::PrintTypedDoc(const ClassDoc& doc) { + MaybePrintComment(doc, true); + output_ << "class "; + PrintDoc(doc->name, false); + output_ << " {"; + for (const StmtDoc& d : doc->body) { + PrintDoc(d); + } + NewLine(false); + output_ << "}"; + Endline(); +} + +void CppPrinter::PrintTypedDoc(const CommentDoc& doc) { + if (doc->comment.defined()) { + output_ << "// " << doc->comment.value(); + } +} + +void CppPrinter::PrintTypedDoc(const DeclareDoc& doc) { + if (doc->type.defined()) { + PrintDoc(doc->type.value(), false); + output_ << " "; + } + PrintDoc(doc->variable, false); + if (doc->init_args.size() > 0) { + if (doc->use_constructor) { + output_ << "("; + PrintJoinedDocs(doc->init_args, ", "); + output_ << ")"; + } else { + output_ << "{"; + PrintJoinedDocs(doc->init_args, ", "); + output_ << "}"; + } + } + Endline(); +} + +void CppPrinter::PrintTypedDoc(const PointerDoc& doc) { output_ << doc->name << "->"; } + +void CppPrinter::PrintTypedDoc(const StrictListDoc& doc) { + if (doc->allow_empty || doc->list->elements.size() > 0) { + PrintDoc(doc->list, false); + } else { + output_ << "{}"; + } +} + +void CppPrinter::PrintIndentedBlock(const Array& docs) { + IncreaseIndent(); + for (const StmtDoc& d : docs) { + PrintDoc(d); + } + DecreaseIndent(); +} + +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/msc/core/printer/cpp_printer.h b/src/contrib/msc/core/printer/cpp_printer.h new file mode 100644 index 000000000000..870ff517f61a --- /dev/null +++ b/src/contrib/msc/core/printer/cpp_printer.h @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/printer/cpp_printer.h + * \brief Cpp Printer. + */ + +#ifndef TVM_CONTRIB_MSC_CORE_PRINTER_CPP_PRINTER_H_ +#define TVM_CONTRIB_MSC_CORE_PRINTER_CPP_PRINTER_H_ + +#include +#include + +#include "msc_base_printer.h" + +namespace tvm { +namespace contrib { +namespace msc { + +using namespace tvm::script::printer; + +/*! + * \brief CppPrinter change list of docs to cpp format + * \sa Doc + */ +class CppPrinter : public MSCBasePrinter { + public: + /*! + * \brief The constructor of PythonPrinter + * \param options the options for printer. + */ + explicit CppPrinter(const std::string& options = "") : MSCBasePrinter(options) { + endlines_.push_back(true); + } + + protected: + /*! * \brief Print a LiteralDoc to python format*/ + void PrintTypedDoc(const LiteralDoc& doc) final; + + /*! \brief Virtual method to print an IndexDoc*/ + void PrintTypedDoc(const IndexDoc& doc) final; + + /*! * \brief Print a AttrAccessDoc to python format*/ + void PrintTypedDoc(const AttrAccessDoc& doc) final; + + /*! * \brief Print a CallDoc to python format*/ + void PrintTypedDoc(const CallDoc& doc) final; + + /*! * \brief Print a AssignDoc to python format*/ + void PrintTypedDoc(const AssignDoc& doc) final; + + /*! * \brief Print a IfDoc to python format*/ + void PrintTypedDoc(const IfDoc& doc) final; + + /*! * \brief Print a WhileDoc to python format*/ + void PrintTypedDoc(const WhileDoc& doc) final; + + /*! \brief Virtual method to print a ForDoc*/ + void PrintTypedDoc(const ForDoc& doc) final; + + /*! * \brief Print a ScopeDoc to python format*/ + void PrintTypedDoc(const ScopeDoc& doc) final; + + /*! * \brief Print a FunctionDoc to python format*/ + void PrintTypedDoc(const FunctionDoc& doc) final; + + /*! * \brief Print a ClassDoc to python format*/ + void PrintTypedDoc(const ClassDoc& doc) final; + + /*! * \brief Print a CommentDoc to python format*/ + void PrintTypedDoc(const CommentDoc& doc) final; + + /*! \brief Virtual method to print a DeclareDoc*/ + void PrintTypedDoc(const DeclareDoc& doc) final; + + /*! \brief Virtual method to print a PointerDoc*/ + void PrintTypedDoc(const PointerDoc& doc) final; + + /*! \brief Virtual method to print a StrictListDoc*/ + void PrintTypedDoc(const StrictListDoc& doc) final; + + private: + /*! \brief endline scopes*/ + std::vector endlines_; + + /*! \brief Enter a endline scope*/ + void EnterEndlineScope(bool endline = false) { endlines_.push_back(endline); } + + /*! \brief Exit a endline scope*/ + void ExitEndlineScope() { + ICHECK(endlines_.size() > 1) << "No endline scope found"; + endlines_.pop_back(); + } + + /*! \brief enable enbline*/ + void EnableEndline() { + ICHECK(endlines_.size() > 0) << "No endline scope found"; + endlines_[endlines_.size() - 1] = true; + } + + /*! \brief disable enbline*/ + void DisableEndline() { + ICHECK(endlines_.size() > 0) << "No endline scope found"; + endlines_[endlines_.size() - 1] = false; + } + + /*! \brief Print endline*/ + void Endline() { + ICHECK(endlines_.size() > 0) << "No endline scope found"; + if (endlines_[endlines_.size() - 1]) { + output_ << ";"; + } + } + + /*! \brief Print block with indent*/ + void PrintIndentedBlock(const Array& docs); +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm + +#endif // TVM_CONTRIB_MSC_CORE_PRINTER_CPP_PRINTER_H_ diff --git a/src/contrib/msc/core/transform/fuse_tuple.cc b/src/contrib/msc/core/transform/fuse_tuple.cc new file mode 100644 index 000000000000..e18d2cc35fe1 --- /dev/null +++ b/src/contrib/msc/core/transform/fuse_tuple.cc @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/transform/fuse_tuple.cc + * \brief Pass for fuse ShapeExpr. + */ + +#include +#include +#include +#include + +#include "../../../../relax/transform/utils.h" +#include "../utils.h" + +namespace tvm { +namespace relax { + +using namespace tvm::contrib::msc; + +/*! + * \brief Fuse Tuple and TupleGetItem to BYOC + */ +class TupleFuser : public ExprMutator { + public: + explicit TupleFuser(IRModule ctx_module, const String& target, const String& entry_name) + : ExprMutator(ctx_module) { + mod_ = ctx_module; + target_ = target + "."; + entry_name_ = entry_name; + } + + IRModule Fuse() { + GlobalVar main_var; + for (const auto& [gv, func] : mod_->functions) { + if (gv->name_hint == entry_name_) { + main_var = gv; + } else { + const auto& name_opt = func->GetAttr(attr::kComposite); + if (name_opt.defined() && StringUtils::StartsWith(name_opt.value(), target_)) { + target_funcs_.Set(gv, Downcast(func)); + } + } + } + // update main + ICHECK(main_var.defined()) << "Can not find entry func " << entry_name_; + const auto& new_func = Downcast(VisitExpr(mod_->Lookup(entry_name_))); + builder_->UpdateFunction(main_var, new_func); + return builder_->GetContextIRModule(); + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* val) final { + bool is_tuple_call = false; + if (target_funcs_.count(val->op)) { + if (val->args.size() == 1 && val->args[0]->IsInstance()) { + const auto& func_call = AddFunc(val->args[0]); + const auto& tuple_out = builder_->Emit(func_call); + ICHECK(target_funcs_.count(func_call->op)) << "Can not find target func " << func_call->op; + target_funcs_.Set(tuple_out, target_funcs_[func_call->op]); + const auto& new_call = Call(val->op, {tuple_out}, val->attrs, val->sinfo_args, val->span); + ReEmitBinding(binding, builder_->Normalize(new_call)); + is_tuple_call = true; + } + target_funcs_.Set(binding->var, target_funcs_[val->op]); + } + if (!is_tuple_call) { + ExprMutator::VisitBinding_(binding, val); + } + } + + void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) final { + bool on_target = true; + for (const auto& f : val->fields) { + if (!target_funcs_.count(f)) { + on_target = false; + break; + } + } + if (on_target) { + ReEmitFunc(binding, GetRef(val)); + } else { + ExprMutator::VisitBinding_(binding, val); + } + } + + void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final { + if (target_funcs_.count(val->tuple)) { + ReEmitFunc(binding, GetRef(val)); + } else { + ExprMutator::VisitBinding_(binding, val); + } + } + + private: + Call AddFunc(const Expr& expr) { + builder_->BeginDataflowBlock(); + Array inputs; + if (const auto* v_node = expr.as()) { + inputs = v_node->fields; + } else if (const auto* g_node = expr.as()) { + inputs = {g_node->tuple}; + } else { + LOG_FATAL << "Unexpceted expr " << expr; + } + Array func_inputs; + Array call_inputs; + Array params; + Map added_params; + for (size_t i = 0; i < inputs.size(); i++) { + if (!added_params.count(inputs[i])) { + const auto& name = String("param_" + std::to_string(i)); + const auto& var = Var(std::move(name), GetStructInfo(inputs[i])); + added_params.Set(inputs[i], var); + } + call_inputs.push_back(inputs[i]); + func_inputs.push_back(added_params[inputs[i]]); + params.push_back(added_params[inputs[i]]); + } + + Expr out_expr; + String func_name; + if (expr->IsInstance()) { + out_expr = Tuple(func_inputs, expr->span); + func_name = "tuple"; + } else if (const auto* g_node = expr.as()) { + out_expr = TupleGetItem(func_inputs[0], g_node->index, expr->span); + func_name = "get_item"; + } else { + LOG_FATAL << "Unexpceted expr " << expr; + } + + const auto& output = builder_->EmitOutput(out_expr); + BindingBlock new_block = builder_->EndBlock(); + Expr body = builder_->Normalize(output); + body = builder_->Normalize(SeqExpr({new_block}, body)); + Map func_attrs; + func_attrs.Set(tvm::relax::attr::kComposite, target_ + func_name); + func_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1)); + Function function = Function(/*params=*/params, // + /*body=*/body, // + /*ret_struct_info=*/NullOpt, // + /*is_pure=*/true, // + /*attrs=*/DictAttrs(func_attrs)); + Array free_vars = + FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; }); + if (!free_vars.empty()) { + params.push_back(Var("tir_vars", ShapeStructInfo(free_vars))); + function = Function(/*params=*/params, // + /*body=*/body, // + /*ret_struct_info=*/NullOpt, // + /*is_pure=*/true, // + /*attrs=*/DictAttrs(func_attrs)); + } + function = SymbolicVarRenewMutator::Renew(function); + GlobalVar gv = builder_->AddFunction(function, "fused_" + func_name); + target_funcs_.Set(gv, function); + return Call(gv, call_inputs); + } + + void ReEmitFunc(const VarBindingNode* binding, const Expr& expr) { + const auto& func_call = AddFunc(expr); + ReEmitBinding(binding, builder_->Normalize(func_call)); + ICHECK(target_funcs_.count(func_call->op)) << "Can not find target func " << func_call->op; + target_funcs_.Set(binding->var, target_funcs_[func_call->op]); + } + + IRModule mod_; + String target_; + String entry_name_; + Map target_funcs_; +}; + +IRModule FuseTuple(IRModule mod, const String& target, const String& entry_name) { + return TupleFuser(mod, target, entry_name).Fuse(); +} + +namespace transform { + +Pass FuseTuple(const String& target, const String& entry_name) { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relax::FuseTuple(m, target, entry_name); }; + return CreateModulePass(pass_func, 0, "FuseTuple", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FuseTuple").set_body_typed(FuseTuple); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc new file mode 100644 index 000000000000..b8b2335da16f --- /dev/null +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -0,0 +1,515 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/framework/tensorrt/codegen.cc + * \brief Codegen related classes. + */ + +#include "codegen.h" + +#include +#include + +#include "../../core/codegen/codegen_json.h" + +namespace tvm { +namespace contrib { +namespace msc { + +using namespace tvm::relax; + +void TensorRTCodeGen::CodeGenClassDeclare() { + stack_.line("#include \"NvInfer.h\"") + .line("#include \"NvInferRuntimeCommon.h\"") + .line("#include \"utils/base.h\"") + .line("#include \"utils/trt_common.h\"") + .line() + .line("using namespace nvinfer1;") + .line(); + StartNamespace(); + // start class declare + stack_.class_def(graph()->name).class_start().scope_start("public:"); + // declare build method + stack_.func_def("Build", "bool") + .func_arg("builder", "TRTPtr&") + .func_arg("network", "TRTPtr&"); + if (CompareVersion(6, 0, 0) >= 0) { + stack_.func_arg("config", "TRTPtr&"); + } + stack_.func_arg("logger", "TRTLogger&").func_start().func_end(); + // define cleanup method + stack_.func_def("CleanUp", "bool") + .func_start() + .for_start("mem", "mWeights") + .func_call("free") + .call_arg("(void*) (mem.second.values)") + .for_end() + .func_end("true"); + // end public scope + stack_.scope_end(); + // private scope + stack_.scope_start("private:").declare("std::map", "mWeights").scope_end(); + // end class declare + stack_.class_end().line(); + // declare test function + stack_.func_def("test_" + graph()->name, "bool") + .func_arg("engine", "std::shared_ptr&") + .func_arg("reader", "DatasetReader&") + .func_arg("logger", "TRTLogger&") + .func_start() + .func_end(); + EndNamespace(); +} + +void TensorRTCodeGen::CodeGenClassDefine() { + auto malloc_buffer = [this](const MSCTensor& tensor) { + const String& idx_var = "idx_" + IdxTensor(tensor); + this->stack_ + .func_call("getBindingIndex", DocUtils::ToDeclareDoc("int", idx_var), + DocUtils::ToPtrDoc("engine")) + .call_arg(DocUtils::ToStrDoc(tensor->name)) + .func_call("CHECK") + .func_call("cudaMalloc") + .call_arg("&gpu_buffers[" + idx_var + "]") + .call_arg(GetTensorBytes(tensor)) + .pop_nest() + .func_call("malloc", "cpu_buffers[" + idx_var + "]") + .call_arg(GetTensorBytes(tensor)); + }; + + stack_.line("#include \"" + graph()->name + ".h\"").line(); + StartNamespace(); + // start define build method + stack_.func_def(graph()->name + "::Build", "bool") + .func_arg("builder", "TRTPtr&") + .func_arg("network", "TRTPtr&"); + if (CompareVersion(6, 0, 0) >= 0) { + stack_.func_arg("config", "TRTPtr&"); + } + stack_.func_arg("logger", "TRTLogger&").func_start(); + if (graph()->weight_holders.size() > 0) { + stack_.assign("mWeights", "TRTUtils::LoadWeights(\"" + graph()->name + ".wts\")"); + } + // build layers + for (const auto& n : graph()->node_names) { + const auto& node = graph()->FindNode(n); + for (const auto& d : GetOpCodes(node)) { + stack_.line(d); + } + } + // mark outputs + stack_.comment("Mark outputs"); + for (const auto& o : graph()->GetOutputs()) { + const auto& pair = graph()->FindProducerAndIdx(o); + stack_.func_call("markOutput", NullOpt, DocUtils::ToPtrDoc("network")) + .call_arg("*" + IdxOutputBase(pair.first, pair.second)); + } + // mark batch_size + stack_.comment("Mark batch size"); + stack_.func_call("createOptimizationProfile", DocUtils::ToDeclareDoc("auto", "profile"), + DocUtils::ToPtrDoc("builder")); + Array batch_flags{"MIN", "MAX", "OPT"}; + for (const auto& i : graph()->GetInputs()) { + for (const auto& f : batch_flags) { + stack_.func_call("setDimensions", NullOpt, DocUtils::ToPtrDoc("profile")) + .call_arg(DocUtils::ToStrDoc(i->name)) + .call_arg("OptProfileSelector::k" + f) + .call_arg(ToDims(i->shape)); + } + } + // set max workspace + stack_.comment("Set max worksapce"); + if (CompareVersion(6, 0, 0) >= 0) { + stack_.func_call("setMaxWorkspaceSize", NullOpt, DocUtils::ToPtrDoc("config")) + .call_arg(config()->max_workspace); + } else { + stack_.func_call("setMaxWorkspaceSize", NullOpt, DocUtils::ToPtrDoc("builder")) + .call_arg(config()->max_workspace); + } + // end define build method + stack_.func_end("true"); + // start define test function + stack_.func_def("test_" + graph()->name, "bool") + .func_arg("engine", "std::shared_ptr&") + .func_arg("reader", "DatasetReader&") + .func_arg("logger", "TRTLogger&") + .func_start(); + stack_.comment("Create context") + .func_call("TRTPtr", DocUtils::ToDeclareDoc("auto", "context")) + .func_call("createExecutionContext", NullOpt, DocUtils::ToPtrDoc("engine")) + .pop_nest(); + ReturnOnFail("context", "Failed to create the context"); + // prepare variables + stack_.declare("bool", "pass", 0, false) + .declare_arg("true") + .declare("cudaStream_t", "stream") + .func_call("CHECK") + .func_call("cudaStreamCreate") + .call_arg("&stream") + .pop_nest(); + // malloc buffers + size_t binding_num = graph()->input_names.size() + graph()->output_names.size(); + stack_.comment("Malloc and copy the buffers") + .declare("void*", "cpu_buffers", binding_num) + .declare("void*", "gpu_buffers", binding_num); + for (const auto& i : graph()->GetInputs()) { + malloc_buffer(i); + } + for (const auto& o : graph()->GetOutputs()) { + malloc_buffer(o); + stack_.declare(CppDType(o->dtype), "output_" + IdxTensor(o), + static_cast(o->GetSize()->value)); + } + // read and test datas + stack_.comment("Read and test datas") + .while_start("reader.ReadNext(cpu_buffers)") + .comment("Memcopy inputs host to device"); + // copy inputs + for (const auto& i : graph()->GetInputs()) { + stack_.func_call("CHECK") + .func_call("cudaMemcpyAsync") + .call_arg("gpu_buffers[idx_" + IdxTensor(i) + "]") + .call_arg("cpu_buffers[idx_" + IdxTensor(i) + "]") + .call_arg(GetTensorBytes(i)) + .call_arg("cudaMemcpyHostToDevice") + .call_arg("stream") + .pop_nest(); + } + // enqueue + stack_.func_call("cudaStreamSynchronize") + .call_arg("stream") + .comment("enquque with gpu buffers") + .func_call("enqueueV2", NullOpt, DocUtils::ToPtrDoc("context")) + .call_arg("gpu_buffers") + .call_arg("stream") + .call_arg("nullptr") + .comment("Memcopy outputs device to host"); + // copy outputs + for (const auto& o : graph()->GetOutputs()) { + stack_.func_call("CHECK") + .func_call("cudaMemcpyAsync") + .call_arg("output_" + IdxTensor(o)) + .call_arg("gpu_buffers[idx_" + IdxTensor(o) + "]") + .call_arg(GetTensorBytes(o)) + .call_arg("cudaMemcpyDeviceToHost") + .call_arg("stream") + .pop_nest(); + } + stack_.func_call("cudaStreamSynchronize").call_arg("stream"); + // compare outputs + for (const auto& o : graph()->GetOutputs()) { + stack_.func_call("CommonUtils::CompareBuffers", "pass") + .call_arg("(" + CppDType(o->dtype) + "*)cpu_buffers[idx_" + IdxTensor(o) + "]") + .call_arg("output_" + IdxTensor(o)) + .call_arg(o->GetSize()); + ReturnOnFail("pass", "Failed to test the output " + o->name); + } + stack_.while_end(); + // clean up + stack_.comment("Clean up the buffers and stream") + .func_call("cudaStreamDestroy") + .call_arg("stream") + .for_start("i", 0, binding_num) + .func_call("CHECK") + .func_call("cudaFree") + .call_arg("gpu_buffers[i]") + .pop_nest() + .func_call("free") + .call_arg("cpu_buffers[i]") + .for_end(); + // end define test method + stack_.func_end("true"); + EndNamespace(); +} + +void TensorRTCodeGen::CodeGenMain() { + stack_.line("#include \"" + graph()->name + ".h\"") + .line() + .line("using namespace nvinfer1;") + .line("using namespace tvm::contrib::msc;") + .line() + .func_def("main", "int") + .func_arg("argc", "int") + .func_arg("argv", "char**") + .func_start() + .declare("TRTLogger", "logger") + .func_call("logger.setLogSeverity"); + if (config()->log_level == 0) { + stack_.call_arg("ILogger::Severity::kINFO"); + } else if (config()->log_level == 1) { + stack_.call_arg("ILogger::Severity::kVERBOSE"); + } else { + stack_.call_arg("ILogger::Severity::kWARNING"); + } + // prepare for build + stack_.comment("Define arguments") + .assign("pass", "true", "bool") + .assign("repeat_num", "1000", "int") + .assign("profile_level", std::to_string(config()->profile_level), "int") + .cond_if("argc > 1") + .assign("profile_level", "atoi(argv[1])") + .cond_end(); + + // start build the engine + stack_.comment("Build engine if not exist") + .cond_if("!FileUtils::FileExist(\"" + graph()->name + ".trt\")"); + // create builder + stack_.comment("Create TensorRT tools") + .func_call("TRTPtr", DocUtils::ToDeclareDoc("auto", "builder")) + .func_call("createInferBuilder") + .call_arg("logger") + .pop_nest(); + ReturnOnFail("builder", "Failed to create builder"); + // create network + if (CompareVersion(6, 0, 0) >= 0) { + stack_ + .assign("flags", + "1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)", + "uint32_t") + .func_call("TRTPtr", DocUtils::ToDeclareDoc("auto", "network")) + .func_call("createNetworkV2", NullOpt, DocUtils::ToPtrDoc("builder")) + .call_arg("flags") + .pop_nest(); + } else { + stack_.func_call("TRTPtr", DocUtils::ToDeclareDoc("auto", "network")) + .func_call("createNetwork", NullOpt, DocUtils::ToPtrDoc("builder")) + .pop_nest(); + } + ReturnOnFail("network", "Failed to create network"); + // create config + stack_.func_call("TRTPtr", DocUtils::ToDeclareDoc("auto", "config")) + .func_call("createBuilderConfig", NullOpt, DocUtils::ToPtrDoc("builder")) + .pop_nest(); + ReturnOnFail("config", "Failed to create config"); + // build model + stack_.comment("Build model") + .declare(graph()->name, "model") + .func_call("model.Build", "pass") + .call_arg("builder") + .call_arg("network"); + if (CompareVersion(6, 0, 0) >= 0) { + stack_.call_arg("config"); + } + stack_.call_arg("logger"); + ReturnOnFail("pass", "Failed to build model"); + // Set profile flag + stack_.comment("Set profile flag") + .declare("ProfilingVerbosity", "profile_verbose") + .cond_if("profile_level == 2") + .assign("profile_verbose", "ProfilingVerbosity::kDETAILED") + .cond_else() + .cond_if("profile_level == 1") + .assign("profile_verbose", "ProfilingVerbosity::kLAYER_NAMES_ONLY") + .cond_else() + .assign("profile_verbose", "ProfilingVerbosity::kNONE") + .cond_end() + .cond_end() + .func_call("setProfilingVerbosity", NullOpt, DocUtils::ToPtrDoc("config")) + .call_arg("profile_verbose"); + // Serialize engine + stack_.comment("Serialize engine") + .func_call("TRTUtils::SerializeEngineToFile", "pass") + .call_arg(DocUtils::ToStrDoc(graph()->name + ".trt")) + .call_arg("builder") + .call_arg("network"); + if (CompareVersion(6, 0, 0) >= 0) { + stack_.call_arg("config"); + } + stack_.call_arg("logger"); + ReturnOnFail("pass", "Failed to serialize the engine"); + // end build the engine + stack_.cond_end(); + // start deserialize engine + stack_.comment("Deserialize engine") + .declare("std::shared_ptr", "engine") + .func_call("TRTUtils::DeserializeEngineFromFile", "pass") + .call_arg(DocUtils::ToStrDoc(graph()->name + ".trt")) + .call_arg("engine") + .call_arg("logger"); + ReturnOnFail("pass", "Failed to deserialize the engine"); + // dump info by inspector + stack_.comment("Dump info by inspector") + .cond_if("profile_level > 0") + .func_call("TRTPtr", DocUtils::ToDeclareDoc("auto", "inspector")) + .func_call("createEngineInspector", NullOpt, DocUtils::ToPtrDoc("engine")) + .pop_nest() + .func_call("getEngineInformation", DocUtils::ToDeclareDoc("std::string", "result"), + DocUtils::ToPtrDoc("inspector")) + .call_arg("LayerInformationFormat::kJSON") + .declare("std::ofstream", "os") + .declare_arg(DocUtils::ToStrDoc(graph()->name + "_info.json")) + .declare_arg("std::ofstream::trunc") + .line("os << result << std::flush;") + .cond_end(); + // test engine + if (config()->test_iter > 0) { + stack_.comment("Prepare dataset") + .declare("DatasetReader", "reader") + .declare_arg(DocUtils::ToStrDoc(config()->dataset)) + .declare_arg(config()->test_iter); + stack_.comment("Test engine by datas") + .func_call("test_" + graph()->name, "pass") + .call_arg("engine") + .call_arg("reader") + .call_arg("logger"); + } + ReturnOnFail("pass", "Failed to test the engine"); + stack_.func_end("pass ? 0 : 1"); +} + +void TensorRTCodeGen::CodeGenCmake() { + stack_.line("cmake_minimum_required(VERSION " + config()->cmake_version + " FATAL_ERROR)") + .line("project(" + graph()->name + ")") + .line("find_package(CUDA)") + .line("find_path(TENSORRT_INCLUDE_DIR NvInfer.h HINTS " + config()->tensorrt_root + + " PATH_SUFFIXES include)") + .line("find_library(TENSORRT_LIB_DIR nvinfer HINTS " + config()->tensorrt_root + + " PATH_SUFFIXES lib)") + .line( + "message(STATUS \"Build project with TENSORRT_INCLUDE_DIR ${TENSORRT_INCLUDE_DIR} and " + "TENSORRT_LIB_DIR " + "${TENSORRT_LIB_DIR}\")") + .line("add_definitions(-DTRT_MAJOR=" + std::to_string(config()->version[0]) + ")") + .line("add_definitions(-DTRT_MINOR=" + std::to_string(config()->version[1]) + ")") + .line("add_definitions(-DTRT_PATCH=" + std::to_string(config()->version[2]) + ")") + .line("file(GLOB_RECURSE TRT_SRCS *.cc)") + .line("cuda_add_executable(" + graph()->name + " ${TRT_SRCS})") + .line("target_include_directories(" + graph()->name + " PUBLIC ${TENSORRT_INCLUDE_DIR})") + .line("target_link_libraries(" + graph()->name + " ${TENSORRT_LIB_DIR})"); +} + +const String TensorRTCodeGen::IdxTensor(const MSCTensor& tensor) { + const auto& pair = graph()->FindProducerAndIdx(tensor); + const String& prefix = "tensor_" + std::to_string(pair.first->index); + if (pair.first->outputs.size() > 1) { + return prefix + "_" + std::to_string(pair.second); + } + return prefix; +} + +const String TensorRTCodeGen::CppDType(const DataType& dtype) { + const String& dtype_name = CppCodeGen::DType(dtype); + if (dtype_name == "int32") { + return "int"; + } + if (dtype_name == "int64") { + return "int64_t"; + } + if (dtype_name == "float32") { + return "float"; + } + if (dtype_name == "float64") { + return "double"; + } + return dtype_name; +} + +const String TensorRTCodeGen::GetTensorBytes(const MSCTensor& tensor) { + return std::to_string(tensor->GetSize()->value) + " * sizeof(" + CppDType(tensor->dtype) + ")"; +} + +void TensorRTCodeGen::ReturnOnFail(const String& flag, const String& err) { + stack_.cond_if("!" + flag) + .func_call("logger.log") + .call_arg("ILogger::Severity::kERROR") + .call_arg(DocUtils::ToStrDoc(err)) + .line("return -1;") + .cond_end(); +} + +template +const String TensorRTCodeGen::ToDims(const std::vector& dims, bool use_ndim) { + if (dims.size() == 2 && !use_ndim) { + return "DimsHW{" + std::to_string(dims[0]) + "," + std::to_string(dims[1]) + "}"; + } + String dims_str = "Dims({" + std::to_string(dims.size()) + ",{"; + for (size_t i = 0; i < dims.size(); i++) { + dims_str = dims_str + std::to_string(dims[i]) + (i < dims.size() - 1 ? "," : ""); + } + dims_str = dims_str + "}})"; + return dims_str; +} + +const String TensorRTCodeGen::ToDims(const Array& dims, bool use_ndim) { + std::vector int_dims; + for (const auto& d : dims) { + int_dims.push_back(d->value); + } + return ToDims(int_dims, use_ndim); +} + +const Array TensorRTCodeGen::GetOpCodes(const MSCJoint& node) { + const auto& ops_map = GetTensorRTOpCodes(); + auto it = ops_map->find(node->optype); + ICHECK(it != ops_map->end()) << "Unsupported tensorrt op(" << node->optype << "): " << node; + it->second->Config(node, config()); + try { + return it->second->GetDocs(); + } catch (runtime::InternalError& err) { + LOG(WARNING) << "Failed to get docs for " << node << " : " << err.message(); + throw err; + } +} + +TVM_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTSources") + .set_body_typed([](const MSCGraph& graph, const String& codegen_config, + const String print_config) -> Map { + TensorRTCodeGen codegen = TensorRTCodeGen(graph, codegen_config); + return codegen.GetSources(print_config); + }); + +TVM_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTRoot").set_body_typed([]() -> String { +#ifdef TENSORRT_ROOT_DIR + return TENSORRT_ROOT_DIR; +#else + return ""; +#endif +}); + +/*! + * \brief Create runtime modules for MSC TensorRT. + * \param functions The extern functions to be compiled via TensorRT + * \return Runtime modules. + */ +Array MSCTensorRTCompiler(Array functions, + Map target_option, + Map constant_names) { + Array compiled_functions; + for (const auto& func : functions) { + VLOG(1) << "MSC.TensorRT partition:" << std::endl << func; + std::string func_name = GetExtSymbol(func); + ICHECK(target_option.count(func_name)) << "Can not find target option for " << func_name; + const auto& options = Downcast(target_option[func_name]); + MSCJSONSerializer serializer(constant_names, options); + serializer.serialize(func); + std::string graph_json = serializer.GetJSON(); + const auto* pf = runtime::Registry::Get("runtime.msc_tensorrt_runtime_create"); + ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function."; + VLOG(1) << "Creating msc_tensorrt runtime::Module for '" << func_name << "'"; + compiled_functions.push_back((*pf)(func_name, graph_json, serializer.GetConstantNames())); + } + return compiled_functions; +} + +TVM_REGISTER_GLOBAL("relax.ext.msc_tensorrt").set_body_typed(MSCTensorRTCompiler); + +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/msc/framework/tensorrt/codegen.h b/src/contrib/msc/framework/tensorrt/codegen.h new file mode 100644 index 000000000000..28d69d3a4f5c --- /dev/null +++ b/src/contrib/msc/framework/tensorrt/codegen.h @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/framework/tensorrt/codegen.h + * \brief Relax codegen for MSCGraph. + */ +#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_CODEGEN_H_ +#define TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_CODEGEN_H_ + +#include +#include + +#include "../../core/codegen/base_codegen.h" +#include "../../core/codegen/cpp_codegen.h" +#include "codegen_utils.h" +#include "tensorrt_opcode.h" + +namespace tvm { +namespace contrib { +namespace msc { + +class TensorRTCodeGen : public CppCodeGen { + public: + /*! + * \brief The constructor of TensorRTCodeGen + * \param graph the graph to be generated. + * \param config the options for codegen. + */ + explicit TensorRTCodeGen(const MSCGraph& graph, const std::string& config = "") + : CppCodeGen(graph, config) {} + + /*! \brief Stack the docs for the class declare*/ + void CodeGenClassDeclare() final; + + /*! \brief Stack the docs for the class define*/ + void CodeGenClassDefine() final; + + /*! \brief Stack the docs for the main func*/ + void CodeGenMain() final; + + /*! \brief Stack the docs for the class define*/ + void CodeGenCmake() final; + + protected: + /*! \brief Get the docs for the op*/ + const Array GetOpCodes(const MSCJoint& node) final; + + /*! \brief Generate return on fail codes*/ + void ReturnOnFail(const String& flag, const String& err); + + /*! \brief Get the index tensor*/ + const String IdxTensor(const MSCTensor& tensor); + + /*! \brief Get the dtype from the datatype*/ + const String CppDType(const DataType& dtype); + + /*! \brief Generate describe for tensor bytes*/ + const String GetTensorBytes(const MSCTensor& tensor); + + /*! \brief Get the tensorrt dims from dims*/ + template + const String ToDims(const std::vector& dims, bool use_ndim = true); + const String ToDims(const Array& dims, bool use_ndim = true); +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm + +#endif // TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_CODEGEN_H_ diff --git a/src/contrib/msc/framework/tensorrt/codegen_utils.h b/src/contrib/msc/framework/tensorrt/codegen_utils.h new file mode 100644 index 000000000000..8249444d9d2d --- /dev/null +++ b/src/contrib/msc/framework/tensorrt/codegen_utils.h @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/framework/tensorrt/codegen_utils.h + * \brief TensorRT config for codegen. + */ +#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_CODEGEN_UTILS_H_ +#define TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_CODEGEN_UTILS_H_ + +#include + +#include "../../core/codegen/base_codegen.h" +#include "../../core/codegen/codegen_utils.h" + +namespace tvm { +namespace contrib { +namespace msc { + +/*! + * \brief CodeGen helper for tensorrt codegen + */ +class TensorRTCodeGenHelper : public BaseCodeGenHelper { + public: + /*! \brief Get describe for default node input*/ + const String IdxInputBase(const MSCJoint& node, const String& prefix = "", int idx = 0, + const String& suffix = "") final { + const auto& pair = node->ProducerAndIdxOf(idx); + if (pair.first->optype == "input") { + return "*" + IdxNodeBase(pair.first, prefix, suffix); + } + if (pair.first->optype == "tuple" || pair.first->optype == "get_item") { + return IdxNodeBase(pair.first, prefix, suffix); + } + return "*" + IdxOutputBase(pair.first, prefix, pair.second, suffix); + } + + /*! \brief Get describe for default node output*/ + const String IdxOutputBase(const MSCJoint& node, const String& prefix = "", int idx = 0, + const String& suffix = "") final { + if (node->optype == "argmax" || node->optype == "argmin") { + ICHECK_EQ(idx, 0) << "argmax and argmin only has 1 output, get " << idx; + return IdxNodeBase(node, prefix, suffix) + "->getOutput(1)"; + } + if (node->optype == "tuple") { + return IdxNodeBase(node, prefix, suffix) + "[" + std::to_string(idx) + "]"; + } + if (node->optype == "get_item") { + ICHECK_EQ(idx, 0) << "get item only has 1 output, get " << idx; + return IdxNodeBase(node, prefix, suffix); + } + return IdxNodeBase(node, prefix, suffix) + "->getOutput(" + std::to_string(idx) + ")"; + } + + /*! \brief Get describe for default node weight*/ + const String IdxWeightBase(const MSCJoint& node, const String& wtype, + const String& suffix = "") final { + return "mWeights[\"" + node->WeightAt(wtype)->name + "\"]"; + } +}; + +/*! + * \brief CodeGen config for tensorrt codegen + */ +struct TensorRTCodeGenConfig { + int log_level{0}; + int profile_level{0}; + int test_iter{0}; + size_t max_workspace{1 << 20}; + std::string cmake_version{"3.5"}; + std::string dataset{"Dataset"}; + std::string tensorrt_root{"/usr/local/cuda"}; + CODEGEN_CONFIG_MEMBERS + void Load(dmlc::JSONReader* reader) { + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "log_level") { + reader->Read(&log_level); + } else if (key == "profile_level") { + reader->Read(&profile_level); + } else if (key == "test_iter") { + reader->Read(&test_iter); + } else if (key == "max_workspace") { + reader->Read(&max_workspace); + } else if (key == "cmake_version") { + reader->Read(&cmake_version); + } else if (key == "dataset") { + reader->Read(&dataset); + } else if (key == "tensorrt_root") { + reader->Read(&tensorrt_root); + } else { + CODEGEN_CONFIG_PARSE + } + } + } +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_CODEGEN_UTILS_H_ diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc new file mode 100644 index 000000000000..df5d4f343c88 --- /dev/null +++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc @@ -0,0 +1,808 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc + */ +#include "tensorrt_opcode.h" + +#include +#include + +#include "../../core/utils.h" + +namespace tvm { +namespace contrib { +namespace msc { + +const Array TensorRTOpCode::GetDocs() { + stack_.Config(this); + CodeGenBuild(); + if (node()->optype == "tuple") { + for (size_t i = 0; i < node()->outputs.size(); i++) { + stack_.func_call("setName", NullOpt, DocUtils::ToPtrDoc(IdxOutput(i))) + .call_arg(DocUtils::ToStrDoc(node()->OutputAt(i)->name)); + } + } else if (node()->optype == "get_item") { + stack_.func_call("setName", NullOpt, DocUtils::ToPtrDoc(IdxNode())) + .call_arg(DocUtils::ToStrDoc(node()->OutputAt(0)->name)); + } else if (node()->optype != "input") { + SetLayerByValue("Name", DocUtils::ToStrDoc(node()->name)); + for (size_t i = 0; i < node()->outputs.size(); i++) { + stack_.func_call("setName", NullOpt, DocUtils::ToPtrDoc(IdxOutput(i))) + .call_arg(DocUtils::ToStrDoc(node()->OutputAt(i)->name)); + } + } + return stack_.GetDocs(); +} + +void TensorRTOpCode::SetPadding(const String& key) { + const auto& padding = node()->GetTypeArrayAttr("padding"); + if (padding.size() == 1) { + SetLayerByDimsValue("Padding", std::vector{padding[0], padding[0]}, false); + } else if (padding.size() == 2) { + SetLayerByDimsValue("PrePadding", padding, false); + SetLayerByDimsValue("PostPadding", padding, false); + } else if (padding.size() == 4) { + SetLayerByDimsValue("PrePadding", std::vector{padding[0], padding[1]}, false); + SetLayerByDimsValue("PostPadding", std::vector{padding[2], padding[3]}, false); + } else { + LOG_FATAL << "Unexpected padding size" << padding.size(); + } +} + +const String TensorRTOpCode::DeclareInputs(bool simplify) { + const String& inputs_ref = "inputs_" + std::to_string(node()->index); + if (node()->parents.size() == 1 && simplify) { + const auto& idx_input = StringUtils::Replace(IdxInput(), "*", ""); + stack_.declare("std::vector", inputs_ref + "_vec") + .declare_arg(node()->inputs.size()) + .declare_arg(idx_input) + .assign(inputs_ref, inputs_ref + "_vec.data()", "ITensor**"); + } else { + stack_.declare("std::vector", IdxNode(), 0, false); + for (size_t i = 0; i < node()->inputs.size(); i++) { + const auto& idx_input = StringUtils::Replace(IdxInput(i), "*", ""); + stack_.declare_arg(idx_input); + } + } + return inputs_ref; +} + +const String TensorRTOpCode::DType(const DataType& dtype) { + const String& dtype_name = BaseOpCode::DType(dtype); + String dtype_enum; + if (dtype_name == "int8") { + dtype_enum = "DataType::kINT8"; + } else if (dtype_name == "int32") { + dtype_enum = "DataType::kINT32"; + } else if (dtype_name == "float16") { + dtype_enum = "DataType::kHALF"; + } else if (dtype_name == "float32") { + dtype_enum = "DataType::kFLOAT"; + } else { + LOG_FATAL << "Unexpected dtype for TensorRT " << dtype_name; + } + return dtype_enum; +} + +template +const String TensorRTOpCode::ToDims(const std::vector& dims, bool use_ndim) { + if (dims.size() == 2 && !use_ndim) { + return "DimsHW{" + std::to_string(dims[0]) + "," + std::to_string(dims[1]) + "}"; + } + String dims_str = "Dims({" + std::to_string(dims.size()) + ",{"; + for (size_t i = 0; i < dims.size(); i++) { + dims_str = dims_str + std::to_string(dims[i]) + (i < dims.size() - 1 ? "," : ""); + } + dims_str = dims_str + "}})"; + return dims_str; +} + +const String TensorRTOpCode::ToDims(const Array& dims, bool use_ndim) { + std::vector int_dims; + for (const auto& d : dims) { + int_dims.push_back(d->value); + } + return ToDims(int_dims, use_ndim); +} + +const String TensorRTOpCode::AttrToDims(const String& key, bool use_ndim) { + const auto& dims = node()->GetTypeArrayAttr(key); + return ToDims(dims, use_ndim); +} + +const size_t TensorRTOpCode::ToReduceAxis(const std::vector& axes, size_t ndim) { + size_t valid_ndim = ndim == 0 ? node()->InputAt(0)->Ndim() : ndim; + size_t reduce_axis = 0; + for (const auto& a : axes) { + reduce_axis += 1 << CommonUtils::GetIndex(a, valid_ndim); + } + return reduce_axis; +} + +const size_t TensorRTOpCode::AttrToReduceAxis(const String& key, size_t ndim) { + std::vector axes; + if (node()->GetAttr(key, &axes)) { + return ToReduceAxis(axes, ndim); + } + int axis; + ICHECK(node()->GetAttr(key, &axis)) << "Can not get axes from attribute key " << key; + return ToReduceAxis(std::vector{axis}, ndim); +} + +const size_t TensorRTOpCode::AttrToAxis(const String& key, size_t ndim) { + size_t valid_ndim = ndim == 0 ? node()->InputAt(0)->Ndim() : ndim; + int axis = node()->GetTypeAttr(key); + return CommonUtils::GetIndex(axis, valid_ndim); +} + +template +void TensorRTOpCode::SetLayerByAttr(const String& method, const String& key) { + stack_.func_call("set" + method, NullOpt, DocUtils::ToPtrDoc(IdxNode())).op_arg(key, ""); +} + +template +void TensorRTOpCode::SetLayerByValue(const String& method, const T& value) { + stack_.func_call("set" + method, NullOpt, DocUtils::ToPtrDoc(IdxNode())).call_arg(value); +} + +void TensorRTOpCode::SetLayerByDimsAttr(const String& method, const String& key, bool use_ndim) { + stack_.func_call("set" + method, NullOpt, DocUtils::ToPtrDoc(IdxNode())) + .call_arg(AttrToDims(key, use_ndim)); +} + +template +void TensorRTOpCode::SetLayerByDimsValue(const String& method, const std::vector& value, + bool use_ndim) { + stack_.func_call("set" + method, NullOpt, DocUtils::ToPtrDoc(IdxNode())) + .call_arg(ToDims(value, use_ndim)); +} + +void TensorRTOpCode::SetLayerByDimsValue(const String& method, const Array& value, + bool use_ndim) { + stack_.func_call("set" + method, NullOpt, DocUtils::ToPtrDoc(IdxNode())) + .call_arg(ToDims(value, use_ndim)); +} + +#define TENSORRT_OP_CODEGEN_METHODS(TypeName) \ + public: \ + TypeName(const String& func_name) : TensorRTOpCode(func_name) {} + +#define TENSORRT_FLAG_OP_CODEGEN_METHODS(TypeName) \ + public: \ + TypeName(const String& func_name, const String& symbol) : TensorRTOpCode(func_name) { \ + symbol_ = symbol; \ + } \ + \ + private: \ + String symbol_; + +class TensorRTActivationCodeGen : public TensorRTOpCode { + public: + explicit TensorRTActivationCodeGen(const String& symbol) : TensorRTOpCode("Activation") { + symbol_ = symbol; + } + + protected: + void CodeGenBuild() final { + stack_.op_call().op_input_arg().call_arg("ActivationType::k" + symbol_); + if (node()->optype == "nn.leaky_relu") { + SetLayerByAttr("Alpha", "alpha"); + } else if (node()->optype == "clip") { + SetLayerByAttr("Alpha", "min"); + SetLayerByAttr("Beta", "max"); + } + } + + private: + String symbol_; +}; + +class TensorRTAdaptivePool2dCodeGen : public TensorRTOpCode { + public: + TENSORRT_FLAG_OP_CODEGEN_METHODS(TensorRTAdaptivePool2dCodeGen) + + protected: + void CodeGenBuild() final { + const auto& input = node()->InputAt(0); + const auto& output = node()->OutputAt(0); + std::vector in_sizes{input->DimAt("H")->value, input->DimAt("W")->value}; + std::vector out_sizes{output->DimAt("H")->value, output->DimAt("W")->value}; + std::vector stride, kernel; + for (size_t i = 0; i < 2; i++) { + stride.push_back(in_sizes[i] / out_sizes[i]); + kernel.push_back((in_sizes[i] - (out_sizes[i] - 1) * stride[i])); + } + stack_.op_call() + .op_input_arg() + .call_arg("PoolingType::k" + symbol_) + .call_arg(ToDims(kernel, false)); + SetLayerByDimsValue("Stride", stride, false); + } +}; + +class TensorRTArgmaxminCodeGen : public TensorRTOpCode { + public: + explicit TensorRTArgmaxminCodeGen(const String& symbol) : TensorRTOpCode("TopK") { + symbol_ = symbol; + } + + protected: + void CodeGenBuild() final { + ICHECK(node()->GetTypeAttr("keepdims")) << "Only support argsort with keepdims"; + stack_.op_call() + .op_input_arg() + .call_arg("TopKOperation::k" + symbol_) + .op_arg("keepdims", "") + .call_arg(AttrToReduceAxis()); + } + + private: + String symbol_; +}; + +class TensorRTAstypeCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTAstypeCodeGen) + + protected: + void CodeGenBuild() final { + stack_.op_call() + .op_input_arg() + .func_call("setOutput", NullOpt, DocUtils::ToPtrDoc(IdxNode())) + .call_arg(0) + .op_dtype_arg(node()->OutputAt(0)->dtype); + } +}; + +class TensorRTBatchMatmulCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTBatchMatmulCodeGen) + + protected: + void CodeGenBuild() final { + bool trans_a = node()->GetTypeAttr("transpose_a"); + bool trans_b = node()->GetTypeAttr("transpose_b"); + stack_.op_call() + .op_input_arg() + .call_arg(trans_a ? "MatrixOperation::kTRANSPOSE" : "MatrixOperation::kNONE") + .op_input_arg(1) + .call_arg(trans_b ? "MatrixOperation::kTRANSPOSE" : "MatrixOperation::kNONE"); + } +}; + +class TensorRTConcatCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTConcatCodeGen) + + protected: + void CodeGenBuild() final { + const auto& producer = node()->ProducerOf(0); + ICHECK(node()->parents.size() == 1 && producer->optype == "tuple") + << "Concat expect parent as tuple, get " << node(); + stack_.op_call().call_arg(IdxNodeBase(producer) + ".data()").call_arg(producer->inputs.size()); + SetLayerByValue("Axis", AttrToAxis()); + } +}; + +class TensorRTConstantCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTConstantCodeGen) + + protected: + void CodeGenBuild() final { + ICHECK(!node()->HasAttr("scalar")) << "Scalar constant is not supported"; + stack_.op_call().call_arg(ToDims(node()->OutputAt(0)->shape)).op_weight_arg("const"); + } +}; + +class TensorRTConvCodeGen : public TensorRTOpCode { + public: + TensorRTConvCodeGen(const String& func_name, bool use_bias) : TensorRTOpCode(func_name) { + use_bias_ = use_bias; + } + + protected: + void CodeGenBuild() final { + const auto& weight = node()->WeightAt("weight"); + std::vector kernel_size; + for (size_t i = 0; i < weight->Ndim(); i++) { + if (weight->layout[i].name() == "I" || weight->layout[i].name() == "O") { + continue; + } + kernel_size.push_back(weight->DimAt(i)->value); + } + stack_.op_call() + .op_input_arg() + .call_arg(weight->DimAt("O")) + .call_arg(ToDims(kernel_size, false)) + .op_weight_arg("weight"); + if (use_bias_) { + stack_.op_weight_arg("bias"); + } else { + stack_.call_arg("mWeights[\"" + node()->name + ".bias\"]"); + } + SetLayerByDimsAttr("Stride", "strides", false); + SetLayerByDimsAttr("Dilation", "dilation", false); + SetLayerByAttr("NbGroups", "groups"); + SetPadding(); + } + + private: + bool use_bias_; +}; + +class TensorRTElemwiseCodeGen : public TensorRTOpCode { + public: + explicit TensorRTElemwiseCodeGen(const String& symbol) : TensorRTOpCode("ElementWise") { + symbol_ = symbol; + } + + protected: + void CodeGenBuild() final { + stack_.op_call().op_inputs_arg(false).call_arg("ElementWiseOperation::k" + symbol_); + } + + private: + String symbol_; +}; + +class TensorRTGetItemCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTGetItemCodeGen) + + protected: + void CodeGenBuild() final { + int index = node()->GetTypeAttr("index"); + const auto& producer = node()->ProducerOf(0); + stack_.assign(IdxNode(), IdxOutputBase(producer, index), "auto"); + } +}; + +class TensorRTInputCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTInputCodeGen) + + protected: + void CodeGenBuild() final { + const auto& output = node()->OutputAt(0); + stack_.op_call() + .call_arg(DocUtils::ToStrDoc(output->name)) + .op_dtype_arg(output->dtype) + .call_arg(ToDims(output->shape)); + } +}; + +class TensorRTLinearCodeGen : public TensorRTOpCode { + public: + TensorRTLinearCodeGen(const String& func_name, bool use_bias) : TensorRTOpCode(func_name) { + use_bias_ = use_bias; + } + + protected: + void CodeGenBuild() final { + const auto& weight = node()->WeightAt("weight"); + stack_.op_call().op_input_arg().call_arg(weight->DimAt("O")).op_weight_arg("weight"); + if (use_bias_) { + stack_.op_weight_arg("bias"); + } else { + stack_.call_arg("mWeights[\"" + node()->name + ".bias\"]"); + } + } + + private: + bool use_bias_; +}; + +class TensorRTMatmulCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTMatmulCodeGen) + + protected: + void CodeGenBuild() final { + stack_.op_call() + .op_input_arg() + .call_arg("MatrixOperation::kNONE") + .op_input_arg(1) + .call_arg("MatrixOperation::kNONE"); + } +}; + +class TensorRTPadCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTPadCodeGen) + + protected: + void CodeGenBuild() final { + const auto& pad_width = node()->GetTypeArrayAttr("pad_width"); + ICHECK(pad_width.size() % 2 == 0) << "pad_width should be multiple of 2, get " << node(); + std::vector pre_padding{2, 0}, post_padding{2, 0}; + const auto& input = node()->InputAt(0); + for (size_t i = 0; i < input->Ndim(); i++) { + if (input->layout[i].name() == "H") { + pre_padding[0] = pad_width[i * 2]; + post_padding[0] = pad_width[i * 2 + 1]; + } else if (input->layout[i].name() == "W") { + pre_padding[1] = pad_width[i * 2]; + post_padding[1] = pad_width[i * 2 + 1]; + } + } + stack_.op_call().op_input_arg().call_arg(ToDims(pre_padding)).call_arg(ToDims(post_padding)); + } +}; + +class TensorRTPermuteDimsCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTPermuteDimsCodeGen) + + protected: + void CodeGenBuild() final { + std::vector axes; + if (!node()->GetAttr("axes", &axes)) { + for (size_t i = node()->InputAt(0)->Ndim(); i > 0; i--) { + axes.push_back(i - 1); + } + } + const String& perm_ref = "perm_" + std::to_string(node()->index); + stack_.op_call().op_input_arg().declare("Permutation", perm_ref); + for (size_t i = 0; i < axes.size(); i++) { + stack_.assign(perm_ref + ".order[" + std::to_string(i) + "]", + CommonUtils::GetIndex(axes[i], node()->InputAt(0)->Ndim())); + } + SetLayerByValue("FirstTranspose", perm_ref); + } +}; + +class TensorRTPool2dCodeGen : public TensorRTOpCode { + public: + explicit TensorRTPool2dCodeGen(const String& symbol) : TensorRTOpCode("Pooling") { + symbol_ = symbol; + } + + protected: + void CodeGenBuild() final { + stack_.op_call() + .op_input_arg() + .call_arg("PoolingType::k" + symbol_) + .call_arg(AttrToDims("pool_size", false)); + SetLayerByDimsAttr("Stride", "strides", false); + if (node()->GetTypeAttr("ceil_mode")) { + SetLayerByValue("PaddingMode", "PaddingMode::kEXPLICIT_ROUND_UP"); + } + if (node()->optype == "nn.avg_pool2d") { + SetLayerByValue("AverageCountExcludesPadding", false); + } + SetPadding(); + } + + private: + String symbol_; +}; + +class TensorRTReduceCodeGen : public TensorRTOpCode { + public: + explicit TensorRTReduceCodeGen(const String& symbol) : TensorRTOpCode("Reduce") { + symbol_ = symbol; + } + + protected: + void CodeGenBuild() final { + stack_.op_call() + .op_input_arg() + .call_arg("ReduceOperation::k" + symbol_) + .call_arg(AttrToReduceAxis()) + .op_arg("keepdims", ""); + } + + private: + String symbol_; +}; + +class TensorRTReshapeCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTReshapeCodeGen) + + protected: + void CodeGenBuild() final { + const auto& output = node()->OutputAt(0); + stack_.op_call().op_input_arg(); + SetLayerByDimsValue("ReshapeDimensions", output->shape); + } +}; + +class TensorRTResize2dCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTResize2dCodeGen) + + protected: + void CodeGenBuild() final { + stack_.op_call().op_input_arg(); + const auto& method = node()->GetTypeAttr("method"); + String resize_mode; + if (method == "linear") { + resize_mode = "LINEAR"; + } else if (method == "nearest_neighbor") { + resize_mode = "NEAREST"; + } else { + LOG_FATAL << "Unexpected resize method " << method; + } + SetLayerByValue("ResizeMode", "ResizeMode::k" + resize_mode); + SetLayerByValue("SelectorForSinglePixel", "ResizeSelector::kFORMULA"); + const auto& transformation_mode = + node()->GetTypeAttr("coordinate_transformation_mode"); + // set transformation + if (transformation_mode == "align_corners") { + SetLayerByValue("CoordinateTransformation", "ResizeCoordinateTransformation::kALIGN_CORNERS"); + } else if (transformation_mode == "asymmetric") { + SetLayerByValue("CoordinateTransformation", "ResizeCoordinateTransformation::kASYMMETRIC"); + } else if (transformation_mode == "tf_half_pixel_for_nn") { + SetLayerByValue("CoordinateTransformation", "ResizeCoordinateTransformation::kHALF_PIXEL"); + } else if (transformation_mode == "pytorch_half_pixel") { + SetLayerByValue("CoordinateTransformation", "ResizeCoordinateTransformation::kHALF_PIXEL"); + } else if (transformation_mode == "half_pixel") { + SetLayerByValue("CoordinateTransformation", "ResizeCoordinateTransformation::kHALF_PIXEL"); + } else { + LOG_FATAL << "Unexpected transformation_mode " << transformation_mode; + } + // set round + const auto& rounding_method = node()->GetTypeAttr("rounding_method"); + if (transformation_mode == "tf_half_pixel_for_nn") { + SetLayerByValue("NearestRounding", "ResizeRoundMode::kCEIL"); + } else if (rounding_method == "floor") { + SetLayerByValue("NearestRounding", "ResizeRoundMode::kFLOOR"); + } else if (rounding_method == "ceil") { + SetLayerByValue("NearestRounding", "ResizeRoundMode::kCEIL"); + } else if (rounding_method == "round_prefer_floor") { + SetLayerByValue("NearestRounding", "ResizeRoundMode::kHALF_DOWN"); + } else if (rounding_method == "round_prefer_ceil") { + SetLayerByValue("NearestRounding", "ResizeRoundMode::kHALF_UP"); + } else if (rounding_method == "round") { + SetLayerByValue("NearestRounding", "ResizeRoundMode::kHALF_UP"); + } else if (rounding_method == "") { + SetLayerByValue("NearestRounding", "ResizeRoundMode::kHALF_UP"); + } else { + LOG_FATAL << "Unexpected rounding_method " << rounding_method; + } + // set output dims + SetLayerByDimsValue("OutputDimensions", node()->OutputAt(0)->shape); + } +}; + +class TensorRTSoftmaxCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTSoftmaxCodeGen) + + protected: + void CodeGenBuild() final { + stack_.op_call().op_input_arg(); + SetLayerByValue("Axes", AttrToReduceAxis()); + } +}; + +class TensorRTSquareCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTSquareCodeGen) + + protected: + void CodeGenBuild() final { + stack_.op_call().op_input_arg().op_input_arg().call_arg("ElementWiseOperation::kPROD"); + } +}; + +class TensorRTStridedSliceCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTStridedSliceCodeGen) + + protected: + void CodeGenBuild() final { + std::vector axes; + if (!node()->GetAttr("axes", &axes)) { + for (size_t i = 0; i < node()->InputAt(0)->Ndim(); i++) { + axes.push_back(i); + } + } + std::vector begin(node()->InputAt(0)->Ndim(), 0); + std::vector strides(node()->InputAt(0)->Ndim(), 1); + const auto& attr_begin = node()->GetTypeArrayAttr("begin"); + for (size_t i = 0; i < axes.size(); i++) { + size_t max_dim = static_cast(node()->InputAt(0)->DimAt(axes[i])->value); + begin[axes[i]] = CommonUtils::GetIndex(attr_begin[i], max_dim); + } + std::vector attr_strides; + if (node()->GetAttr("strides", &attr_strides)) { + for (size_t i = 0; i < axes.size(); i++) { + strides[axes[i]] = static_cast(attr_strides[i]); + } + } + stack_.op_call() + .op_input_arg() + .call_arg(ToDims(begin)) + .call_arg(ToDims(node()->OutputAt(0)->shape)) + .call_arg(ToDims(strides)); + } +}; + +class TensorRTTakeCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTTakeCodeGen) + + protected: + void CodeGenBuild() final { + stack_.op_call().op_inputs_arg(false).call_arg(AttrToAxis()); + if (node()->InputAt(0)->Ndim() == node()->InputAt(1)->Ndim()) { + SetLayerByValue("Mode", "GatherMode::kELEMENT"); + } + } +}; + +class TensorRTTopkCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTTopkCodeGen) + + protected: + void CodeGenBuild() final { + const String& symbol = node()->GetTypeAttr("is_asend") ? "MIN" : "MAX"; + stack_.op_call() + .op_input_arg() + .call_arg("TopKOperation::k" + symbol) + .op_arg("k", "") + .call_arg(AttrToReduceAxis()); + } +}; + +class TensorRTTupleCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTTupleCodeGen) + + protected: + void CodeGenBuild() final { + stack_.declare("std::vector", IdxNode(), 0, false); + for (size_t i = 0; i < node()->inputs.size(); i++) { + const auto& idx_input = StringUtils::Replace(IdxInput(i), "*", ""); + stack_.declare_arg(idx_input); + } + } +}; + +class TensorRTUnaryCodeGen : public TensorRTOpCode { + public: + explicit TensorRTUnaryCodeGen(const String& symbol) : TensorRTOpCode("Unary") { + symbol_ = symbol; + } + + protected: + void CodeGenBuild() final { + stack_.op_call().op_input_arg().call_arg("UnaryOperation::k" + symbol_); + } + + private: + String symbol_; +}; + +class TensorRTWhereCodeGen : public TensorRTOpCode { + public: + TENSORRT_OP_CODEGEN_METHODS(TensorRTWhereCodeGen) + + protected: + void CodeGenBuild() final { stack_.op_call().op_inputs_arg(false); } +}; + +const std::shared_ptr>> +GetTensorRTOpCodes() { + static auto map = std::make_shared>>(); + if (!map->empty()) return map; + // unary ops + map->emplace("abs", std::make_shared("ABS")); + map->emplace("acos", std::make_shared("ACOS")); + map->emplace("acosh", std::make_shared("ACOSH")); + map->emplace("asin", std::make_shared("ASIN")); + map->emplace("asinh", std::make_shared("ASINH")); + map->emplace("atan", std::make_shared("ATAN")); + map->emplace("atanh", std::make_shared("ATANH")); + map->emplace("ceil", std::make_shared("CEIL")); + map->emplace("cos", std::make_shared("COS")); + map->emplace("cosh", std::make_shared("COSH")); + map->emplace("erf", std::make_shared("ERF")); + map->emplace("exp", std::make_shared("EXP")); + map->emplace("floor", std::make_shared("FLOOR")); + map->emplace("log", std::make_shared("LOG")); + map->emplace("negative", std::make_shared("NEG")); + map->emplace("round", std::make_shared("ROUND")); + map->emplace("sin", std::make_shared("SIN")); + map->emplace("sinh", std::make_shared("SINH")); + map->emplace("sqrt", std::make_shared("SQRT")); + map->emplace("tan", std::make_shared("TAN")); + + // elemwise ops + map->emplace("add", std::make_shared("SUM")); + map->emplace("divide", std::make_shared("DIV")); + map->emplace("equal", std::make_shared("EQUAL")); + map->emplace("floor_divide", std::make_shared("FLOOR_DIV")); + map->emplace("greater", std::make_shared("GREATER")); + map->emplace("less", std::make_shared("LESS")); + map->emplace("maximum", std::make_shared("MAX")); + map->emplace("minimum", std::make_shared("MIN")); + map->emplace("multiply", std::make_shared("PROD")); + map->emplace("power", std::make_shared("POW")); + map->emplace("subtract", std::make_shared("SUB")); + + // reduce ops + map->emplace("max", std::make_shared("MAX")); + map->emplace("mean", std::make_shared("AVG")); + map->emplace("min", std::make_shared("MIN")); + map->emplace("sum", std::make_shared("SUM")); + + // math ops + map->emplace("argmax", std::make_shared("MAX")); + map->emplace("argmin", std::make_shared("MIN")); + map->emplace("astype", std::make_shared("Identity")); + map->emplace("concat", std::make_shared("Concatenation")); + map->emplace("expand_dims", std::make_shared("Shuffle")); + map->emplace("matmul", std::make_shared("MatrixMultiply")); + map->emplace("permute_dims", std::make_shared("Shuffle")); + map->emplace("reshape", std::make_shared("Shuffle")); + map->emplace("square", std::make_shared("ElementWise")); + map->emplace("squeeze", std::make_shared("Shuffle")); + map->emplace("strided_slice", std::make_shared("Slice")); + map->emplace("take", std::make_shared("Gather")); + map->emplace("topk", std::make_shared("TopK")); + map->emplace("where", std::make_shared("Select")); + + // create ops + map->emplace("constant", std::make_shared("Constant")); + + // activation ops + map->emplace("clip", std::make_shared("CLIP")); + map->emplace("sigmoid", std::make_shared("SIGMOID")); + map->emplace("tanh", std::make_shared("TANH")); + map->emplace("nn.relu", std::make_shared("RELU")); + map->emplace("nn.leaky_relu", std::make_shared("LEAKY_RELU")); + + // nn ops + map->emplace("nn.adaptive_avg_pool2d", + std::make_shared("Pooling", "AVERAGE")); + map->emplace("nn.avg_pool2d", std::make_shared("AVERAGE")); + map->emplace("nn.batch_matmul", std::make_shared("MatrixMultiply")); + map->emplace("nn.conv2d", std::make_shared("ConvolutionNd", false)); + map->emplace("nn.max_pool2d", std::make_shared("MAX")); + map->emplace("nn.pad", std::make_shared("Padding")); + map->emplace("nn.softmax", std::make_shared("SoftMax")); + + // image ops + map->emplace("image.resize2d", std::make_shared("Resize")); + + // special op + map->emplace("input", std::make_shared("Input")); + + // msc ops + map->emplace("msc.conv2d_bias", std::make_shared("ConvolutionNd", true)); + map->emplace("msc.linear", std::make_shared("FullyConnected", false)); + map->emplace("msc.linear_bias", std::make_shared("FullyConnected", true)); + + // special op + map->emplace("get_item", std::make_shared("")); + map->emplace("tuple", std::make_shared("")); + + return map; +} + +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.h b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.h new file mode 100644 index 000000000000..89942930ac22 --- /dev/null +++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.h @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/framework/tensorrt/tensorrt_opcode.h + * \brief TensorRT codegen for MSCJoint. + */ +#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_TENSORRT_OPCODE_H_ +#define TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_TENSORRT_OPCODE_H_ + +#include +#include +#include +#include + +#include "../../core/codegen/base_codegen.h" +#include "codegen_utils.h" + +namespace tvm { +namespace contrib { +namespace msc { + +class TensorRTOpCode; +typedef OpCodeStack TensorRTOpCodeStack; + +/*! + * \brief CodeGen for relax op + */ +class TensorRTOpCode : public BaseOpCode { + public: + /*! + * \brief The constructor of BaseOpDocsifier + * \param func_name the function name for the node. + * \param config the config json for the node. + */ + explicit TensorRTOpCode(const String& func_name) + : BaseOpCode(func_name) {} + + /*! \brief Convert node to docs*/ + const Array GetDocs() final; + + /*! \brief Get func_name for the default node*/ + const String callee_name() final { + return "network->add" + BaseOpCode::callee_name(); + } + + /*! \brief Get valid return name for the default node*/ + const String ret_name() final { return "auto " + IdxNode(true); } + + /*! \brief Get the dtype from the datatype*/ + const String DType(const DataType& dtype) final; + + protected: + TensorRTOpCodeStack stack_; + + /*! \brief Convert op build*/ + virtual void CodeGenBuild() = 0; + + /*! \brief Set padding for the layer*/ + void SetPadding(const String& key = "padding"); + + /*! \brief Declare the inputs*/ + const String DeclareInputs(bool simplify = true); + + /*! \brief Get the tensorrt dims from dims*/ + template + const String ToDims(const std::vector& dims, bool use_ndim = true); + const String ToDims(const Array& dims, bool use_ndim = true); + + /*! \brief Get the tensorrt dims from attribute*/ + const String AttrToDims(const String& key, bool use_ndim = true); + + /*! \brief Get the tensorrt reduce axis from dims*/ + const size_t ToReduceAxis(const std::vector& axes, size_t ndim = 0); + + /*! \brief Get the tensorrt reduce axis from attribute*/ + const size_t AttrToReduceAxis(const String& key = "axis", size_t ndim = 0); + + /*! \brief Get the attribute axis from attribute*/ + const size_t AttrToAxis(const String& key = "axis", size_t ndim = 0); + + /*! \brief Set layer by attribute*/ + template + void SetLayerByAttr(const String& method, const String& key); + + /*! \brief Set layer by value*/ + template + void SetLayerByValue(const String& method, const T& value); + + /*! \brief Set layer by dims attribute*/ + void SetLayerByDimsAttr(const String& method, const String& key, bool use_ndim = true); + + /*! \brief Set layer by dims value*/ + template + void SetLayerByDimsValue(const String& method, const std::vector& value, bool use_ndim = true); + void SetLayerByDimsValue(const String& method, const Array& value, bool use_ndim = true); +}; + +/*! + * \brief Get the map of available TensorRTOpCode, use optype as key + * \return Map of + */ +const std::shared_ptr>> +GetTensorRTOpCodes(); + +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_TENSORRT_OPCODE_H_ diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc new file mode 100644 index 000000000000..ca01d5fbea3c --- /dev/null +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -0,0 +1,748 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/framework/tensorrt/transform_tensorrt.cc + * \brief Pass for transform the function to tensorrt. + */ + +#include +#include +#include + +#include "../../../../relax/transform/utils.h" +#include "../../../../support/scalars.h" +#include "../../core/utils.h" + +namespace tvm { +namespace relax { +using namespace tvm::contrib::msc; + +const Array GetShape(const Expr& var) { + const auto& shape_opt = Downcast(GetStructInfo(var))->GetShape(); + ICHECK(shape_opt.defined()) << "Shape is not defined for " << var; + return shape_opt.value(); +} + +Var EmitCall(BlockBuilder builder, const Expr& expr, const Span& src_span, const String& suffix) { + const auto& name = SpanUtils::GetAttr(src_span, "name") + "_" + suffix; + expr->span = SpanUtils::SetAttr(expr->span, "name", name); + return builder->Emit(expr, name); +} + +Var MakeCall(BlockBuilder builder, const Span& src_span, const String& suffix, Expr op, + Array args, Attrs attrs = Attrs()) { + const auto& call = Call(op, args, attrs); + return EmitCall(builder, call, src_span, suffix); +} + +Expr MakeConstant(double value, const DataType& dtype, const String& name) { + const auto& data = support::FloatImmToNDArray(FloatImm(dtype, value)); + const auto& span = SpanUtils::SetAttr(Span(), "name", name); + return Constant(data, NullOpt, span); +} + +using FRewriteTensorRT = + runtime::TypedPackedFunc& new_calls, const Array& version)>; + +Expr RewriteElemwise(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const Array& version) { + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + const auto& shape_a = GetShape(call->args[0]); + const auto& shape_b = GetShape(call->args[1]); + static const Op& reshape_op = Op::Get("relax.reshape"); + if (shape_a.size() > shape_b.size()) { + Array exp_shape(shape_a.size(), Integer(1)); + if (shape_b.size() == 1) { + exp_shape.Set(shape_a.size() - 1, shape_b[0]); + } else if (shape_b.size() == 0) { + LOG_DEBUG << "Expand scalar argument to " << exp_shape; + } else { + LOG_FATAL << "broadcast only support 1 dim and scalar, get " << shape_b; + } + const auto& expand_b = MakeCall(builder, call->span, "expand_b", reshape_op, + {call->args[1], ShapeExpr(exp_shape)}); + return Call(call->op, {call->args[0], expand_b}, call->attrs, call->sinfo_args, call->span); + } + if (shape_a.size() < shape_b.size()) { + Array exp_shape(shape_b.size(), Integer(1)); + if (shape_a.size() == 1) { + exp_shape.Set(shape_b.size() - 1, shape_a[0]); + } else if (shape_a.size() == 0) { + LOG_DEBUG << "Expand scalar argument to " << exp_shape; + } else { + LOG_FATAL << "broadcast only support 1 dim and scalar, get " << shape_a; + } + const auto& expand_a = MakeCall(builder, call->span, "expand_a", reshape_op, + {call->args[0], ShapeExpr(exp_shape)}); + return Call(call->op, {expand_a, call->args[1]}, call->attrs, call->sinfo_args, call->span); + } + return call; +} + +Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const Array& version) { + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + if (new_calls.count(call->args[0]) && + new_calls[call->args[0]]->op == Op::Get("relax.nn.conv1d")) { + const auto& reshape = Downcast(builder->LookupBinding(Downcast(call->args[0]))); + if (reshape->op != Op::Get("relax.reshape")) { + return call; + } + const auto& conv2d = Downcast(builder->LookupBinding(Downcast(reshape->args[0]))); + if (conv2d->op != Op::Get("relax.nn.conv2d")) { + return call; + } + const auto& input_shape = GetShape(call->args[0]); + const auto& bias_shape = GetShape(call->args[1]); + const auto* conv_attrs = conv2d->attrs.as(); + if (conv_attrs->data_layout == "NCHW") { + // expand bias reshape + Array exp_bias_shape{bias_shape[0], bias_shape[1], Integer(1), bias_shape[2]}; + static const Op& reshape_op = Op::Get("relax.reshape"); + const auto& exp_bias = MakeCall(builder, call->span, "exp_bias", reshape_op, + {call->args[1], ShapeExpr(exp_bias_shape)}); + // redirect to conv2d + static const Op& add_op = Op::Get("relax.add"); + const auto& exp_add = + MakeCall(builder, call->span, "exp_add", add_op, {reshape->args[0], exp_bias}); + // reduce output + return Call(reshape_op, {exp_add, ShapeExpr(input_shape)}, Attrs(), call->sinfo_args, + call->span); + } else { + LOG_FATAL << "Unexpected data layout " << conv_attrs->data_layout; + } + } + return RewriteElemwise(builder, var, call, new_calls, version); +} + +Expr RewriteArgmaxmin(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const Array& version) { + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + const auto& out_dtype = Downcast(GetStructInfo(var))->dtype; + const auto* src_attrs = src_call->attrs.as(); + Expr raw_var; + if (src_attrs->keepdims) { + raw_var = EmitCall(builder, call, call->span, "raw"); + } else { + auto new_attrs = make_object(); + new_attrs->axis = src_attrs->axis; + new_attrs->keepdims = true; + raw_var = + MakeCall(builder, call->span, "keepdims", call->op, {call->args[0]}, Attrs(new_attrs)); + } + static const Op& astype_op = Op::Get("relax.astype"); + auto cast_to_attrs = make_object(); + cast_to_attrs->dtype = DataType::Int(32); + Expr res = MakeCall(builder, call->span, "cast_to", astype_op, {raw_var}, Attrs(cast_to_attrs)); + // reshape back + if (!src_attrs->keepdims) { + const auto& output_shape = GetShape(var); + static const Op& reshape_op = Op::Get("relax.reshape"); + res = MakeCall(builder, call->span, "reshape", reshape_op, {res, ShapeExpr(output_shape)}); + } + auto cast_from_attrs = make_object(); + cast_from_attrs->dtype = out_dtype; + return Call(astype_op, {res}, Attrs(cast_from_attrs), call->sinfo_args, call->span); +} + +Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const Array& version) { + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; + const auto* src_attrs = src_call->attrs.as(); + + // define dims + const auto& in_q_shape = GetShape(call->args[0]); + const auto& in_v_shape = GetShape(call->args[2]); + const auto& batch_size = in_q_shape[0]; + const auto& seq_len = in_q_shape[1]; + const auto& num_head = in_q_shape[2]; + const auto& head_dim = in_q_shape[3]; + const auto& seq_len_kv = in_v_shape[1]; + const auto& head_dim_v = in_v_shape[3]; + + // create ops + static const Op& permute_dims_op = Op::Get("relax.permute_dims"); + static const Op& reshape_op = Op::Get("relax.reshape"); + static const Op& matmul_op = Op::Get("relax.matmul"); + static const Op& multiply_op = Op::Get("relax.multiply"); + static const Op& add_op = Op::Get("relax.add"); + static const Op& divide_op = Op::Get("relax.divide"); + static const Op& sqrt_op = Op::Get("relax.sqrt"); + static const Op& softmax_op = Op::Get("relax.nn.softmax"); + static const Op& tril_op = Op::Get("relax.tril"); + static const Op& max_op = Op::Get("relax.max"); + static const Op& sum_op = Op::Get("relax.sum"); + static const Op& subtract_op = Op::Get("relax.subtract"); + static const Op& exp_op = Op::Get("relax.exp"); + + // prepare q,k,v + auto permute_attrs = make_object(); + Array axes{Integer(0), Integer(2), Integer(1), Integer(3)}; + permute_attrs->axes = axes; + const auto& q_trans = MakeCall(builder, call->span, "q_trans", permute_dims_op, {call->args[0]}, + Attrs(permute_attrs)); + const auto& k_trans = MakeCall(builder, call->span, "k_trans", permute_dims_op, {call->args[1]}, + Attrs(permute_attrs)); + const auto& v_trans = MakeCall(builder, call->span, "v_trans", permute_dims_op, {call->args[2]}, + Attrs(permute_attrs)); + Array q_shape({batch_size * num_head, seq_len, head_dim}); + const auto& q_reshape = + MakeCall(builder, call->span, "q_reshape", reshape_op, {q_trans, ShapeExpr(q_shape)}); + Array k_shape({batch_size * num_head, seq_len_kv, head_dim}); + const auto& k_reshape = + MakeCall(builder, call->span, "k_reshape", reshape_op, {k_trans, ShapeExpr(k_shape)}); + Array v_shape({batch_size * num_head, seq_len_kv, head_dim_v}); + const auto& v_reshape = + MakeCall(builder, call->span, "v_reshape", reshape_op, {v_trans, ShapeExpr(v_shape)}); + auto reduce_permute_attrs = make_object(); + Array v_axes{Integer(0), Integer(2), Integer(1)}; + reduce_permute_attrs->axes = v_axes; + // transpose for batch_matmul + const auto& k_reshape_trans = MakeCall(builder, call->span, "k_reshape_trans", permute_dims_op, + {k_reshape}, Attrs(reduce_permute_attrs)); + + // calculate product + auto matmul_attrs = make_object(); + matmul_attrs->out_dtype = in_dtype; + const auto& qk_prod = MakeCall(builder, call->span, "qk_prod", matmul_op, + {q_reshape, k_reshape_trans}, Attrs(matmul_attrs)); + Expr p_scale; + if (src_attrs->scale.defined()) { + const auto& scale = MakeConstant(static_cast(src_attrs->scale.value()->value), in_dtype, + SpanUtils::GetAttr(call->span, "name") + "_scale"); + Array exp_shape(3, Integer(1)); + const auto& exp_scale = + MakeCall(builder, call->span, "exp_scale", reshape_op, {scale, ShapeExpr(exp_shape)}); + p_scale = MakeCall(builder, call->span, "p_scale", multiply_op, {qk_prod, exp_scale}); + } else { + const auto& scale = MakeConstant(static_cast(Downcast(head_dim)->value), + in_dtype, SpanUtils::GetAttr(call->span, "name") + "_scale"); + Array exp_shape(3, Integer(1)); + const auto& exp_scale = + MakeCall(builder, call->span, "exp_scale", reshape_op, {scale, ShapeExpr(exp_shape)}); + const auto& sqrt_scale = MakeCall(builder, call->span, "sqrt_scale", sqrt_op, {exp_scale}); + p_scale = MakeCall(builder, call->span, "p_scale", divide_op, {qk_prod, sqrt_scale}); + } + + // bias + Expr prod = p_scale; + if (call->args.size() == 4) { + Array exp_shape{batch_size, num_head, seq_len, seq_len_kv}; + Array reduce_shape{batch_size * num_head, seq_len, seq_len_kv}; + const auto& prod_exp = + MakeCall(builder, call->span, "prod_exp", reshape_op, {prod, ShapeExpr(exp_shape)}); + const auto& prod_add = + MakeCall(builder, call->span, "prod_add", add_op, {prod_exp, call->args[3]}); + prod = MakeCall(builder, call->span, "prod_reduce", reshape_op, + {prod_add, ShapeExpr(reduce_shape)}); + } + + // causal_mask + Expr s_value; + if (!src_attrs->causal_mask.defined()) { + auto softmax_attrs = make_object(); + softmax_attrs->axis = 2; + s_value = MakeCall(builder, call->span, "act", softmax_op, {prod}, Attrs(softmax_attrs)); + } else { + const auto& causal_mask = src_attrs->causal_mask.value(); + PrimValue tril_k; + if (causal_mask == "TopLeft") { + tril_k = PrimValue(Integer(0)); + } else if (causal_mask == "BottomRight") { + tril_k = PrimValue(seq_len - seq_len_kv); + } else { + LOG_FATAL << "Unexpected causal_mask " << causal_mask; + } + const auto& p_masked = MakeCall(builder, call->span, "p_masked", tril_op, {prod, tril_k}); + auto reduce_attrs = make_object(); + Array axis{Integer(2)}; + reduce_attrs->axis = axis; + reduce_attrs->keepdims = true; + const auto& p_max = MakeCall(builder, call->span, "p_max", max_op, {prod}, Attrs(reduce_attrs)); + const auto& p_diff = MakeCall(builder, call->span, "p_diff", subtract_op, {p_masked, p_max}); + const auto& p_exp = MakeCall(builder, call->span, "p_exp", exp_op, {p_diff}); + const auto& p_masked_exp = + MakeCall(builder, call->span, "p_masked_exp", tril_op, {p_exp, tril_k}); + const auto& p_masked_sum = + MakeCall(builder, call->span, "p_masked_sum", sum_op, {p_masked_exp}, Attrs(reduce_attrs)); + s_value = MakeCall(builder, call->span, "act", divide_op, {p_masked_exp, p_masked_sum}); + } + + // final calculation + const auto& o_prod = + MakeCall(builder, call->span, "o_prod", matmul_op, {s_value, v_reshape}, Attrs(matmul_attrs)); + Array o_shape{batch_size, num_head, seq_len, head_dim_v}; + return Call(reshape_op, {o_prod, ShapeExpr(o_shape)}, Attrs(), call->sinfo_args, call->span); +} + +Expr RewriteBatchNorm(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const Array& version) { + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + const auto& input_shape = GetShape(call->args[0]); + const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; + const auto* src_attrs = src_call->attrs.as(); + // define expand shape + Array exp_shape(input_shape.size(), Integer(1)); + exp_shape.Set(src_attrs->axis, input_shape[src_attrs->axis]); + + // create eps constant + const auto& eps = + MakeConstant(src_attrs->epsilon, in_dtype, SpanUtils::GetAttr(call->span, "name") + "_eps"); + + // create ops + static const Op& add_op = Op::Get("relax.add"); + static const Op& divide_op = Op::Get("relax.divide"); + static const Op& multiply_op = Op::Get("relax.multiply"); + static const Op& reshape_op = Op::Get("relax.reshape"); + static const Op& sqrt_op = Op::Get("relax.sqrt"); + static const Op& subtract_op = Op::Get("relax.subtract"); + + // scale factor: gamma/sqrt(var + epsilon) + const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op, {call->args[4], eps}); + const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add}); + const auto& scale_factor = + MakeCall(builder, call->span, "scale_factor", divide_op, {call->args[1], sqrt}); + Expr res = call->args[0]; + // scale + if (src_attrs->scale) { + const auto& exp_scale = MakeCall(builder, call->span, "exp_scale", reshape_op, + {scale_factor, ShapeExpr(exp_shape)}); + res = MakeCall(builder, call->span, "scale", multiply_op, {res, exp_scale}); + } + // offset + if (src_attrs->center) { + // offset factor: beta-mean*scale_factor + const auto& average = + MakeCall(builder, call->span, "average", multiply_op, {call->args[3], scale_factor}); + const auto& offset_factor = + MakeCall(builder, call->span, "offset_factor", subtract_op, {call->args[2], average}); + const auto& exp_offset = MakeCall(builder, call->span, "exp_offset", reshape_op, + {offset_factor, ShapeExpr(exp_shape)}); + res = MakeCall(builder, call->span, "offset", add_op, {res, exp_offset}); + } + return Tuple(Array{res}, call->span); +} + +Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const Array& version) { + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + const auto& input_shape = GetShape(call->args[0]); + const auto& output_shape = GetShape(var); + Expr concat_input = call->args[0]; + static const Op& concat_op = Op::Get("relax.concat"); + for (size_t i = 0; i < input_shape.size(); i++) { + int64_t in_dim = Downcast(input_shape[i])->value; + int64_t out_dim = Downcast(output_shape[i])->value; + if (in_dim != out_dim) { + Array concat_inputs(out_dim / in_dim, concat_input); + auto concat_attrs = make_object(); + concat_attrs->axis = Integer(i); + concat_input = MakeCall(builder, call->span, "concat_" + std::to_string(i), concat_op, + {Tuple(concat_inputs)}, Attrs(concat_attrs)); + } + } + return concat_input; +} + +Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const Array& version) { + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + const auto* src_attrs = src_call->attrs.as(); + const auto& input_shape = GetShape(call->args[0]); + const auto& weight_shape = GetShape(call->args[1]); + const auto& output_shape = GetShape(var); + if (src_attrs->data_layout == "NCW") { + Array new_args; + // expand inputs + Array exp_input_shape{input_shape[0], input_shape[1], Integer(1), input_shape[2]}; + Array exp_weight_shape{weight_shape[0], weight_shape[1], Integer(1), weight_shape[2]}; + static const Op& reshape_op = Op::Get("relax.reshape"); + new_args.push_back(MakeCall(builder, call->span, "exp_input", reshape_op, + {call->args[0], ShapeExpr(exp_input_shape)})); + new_args.push_back(MakeCall(builder, call->span, "exp_weight", reshape_op, + {call->args[1], ShapeExpr(exp_weight_shape)})); + // change to conv2d + static const Op& conv2d_op = Op::Get("relax.nn.conv2d"); + auto conv_attrs = make_object(); + conv_attrs->strides = Array{src_attrs->strides[0], Integer(1)}; + conv_attrs->padding = + Array{Integer(0), src_attrs->padding[0], Integer(0), src_attrs->padding[1]}; + conv_attrs->dilation = Array{src_attrs->dilation[0], Integer(1)}; + conv_attrs->groups = src_attrs->groups; + conv_attrs->data_layout = "NCHW"; + conv_attrs->kernel_layout = "OIHW"; + conv_attrs->out_layout = "NCHW"; + conv_attrs->out_dtype = src_attrs->out_dtype; + const auto& conv2d = + MakeCall(builder, call->span, "exp", conv2d_op, new_args, Attrs(conv_attrs)); + // reduce output + return Call(reshape_op, {conv2d, ShapeExpr(output_shape)}, Attrs(), call->sinfo_args, + call->span); + } else { + LOG_FATAL << "Unexpected data layout " << src_attrs->data_layout; + } + return call; +} + +Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const Array& version) { + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + const auto& input_shape = GetShape(call->args[0]); + const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; + const auto* src_attrs = src_call->attrs.as(); + Array group_shape = input_shape; + Array exp_shape(input_shape.size(), Integer(1)); + size_t axis = CommonUtils::GetIndex(src_attrs->channel_axis, input_shape.size()); + int64_t channel_dim = Downcast(input_shape[axis])->value * + Downcast(input_shape[axis + 1])->value / src_attrs->num_groups; + group_shape.Set(axis, Integer(src_attrs->num_groups)); + group_shape.Set(axis + 1, Integer(channel_dim)); + exp_shape.Set(axis, Integer(src_attrs->num_groups)); + + // create eps constant + const auto& eps = + MakeConstant(src_attrs->epsilon, in_dtype, SpanUtils::GetAttr(call->span, "name") + "_eps"); + + // create ops + static const Op& add_op = Op::Get("relax.add"); + static const Op& divide_op = Op::Get("relax.divide"); + static const Op& mean_op = Op::Get("relax.mean"); + static const Op& multiply_op = Op::Get("relax.multiply"); + static const Op& square_op = Op::Get("relax.square"); + static const Op& reshape_op = Op::Get("relax.reshape"); + static const Op& sqrt_op = Op::Get("relax.sqrt"); + static const Op& subtract_op = Op::Get("relax.subtract"); + + // reshape input + const auto& reshape_in = MakeCall(builder, call->span, "reshape_in", reshape_op, + {call->args[0], ShapeExpr(group_shape)}); + + // mean(input) + auto mean_attrs = make_object(); + mean_attrs->axis = src_attrs->axes; + mean_attrs->keepdims = true; + const auto& mean = + MakeCall(builder, call->span, "mean", mean_op, {reshape_in}, Attrs(mean_attrs)); + + // variance: mean((input-mean)*(input-mean)) + const auto& diff = MakeCall(builder, call->span, "diff", subtract_op, {reshape_in, mean}); + const auto& square = MakeCall(builder, call->span, "square", square_op, {diff}); + const auto& variance = + MakeCall(builder, call->span, "variance", mean_op, {square}, Attrs(mean_attrs)); + + // sqrt(var + epsilon) + Array exp_eps_shape(input_shape.size(), Integer(1)); + const auto& exp_eps = + MakeCall(builder, call->span, "exp_eps", reshape_op, {eps, ShapeExpr(exp_eps_shape)}); + const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op, {variance, exp_eps}); + const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add}); + + // diff/sqrt + Expr res = MakeCall(builder, call->span, "divide", divide_op, {diff, sqrt}); + + // scale + if (src_attrs->scale) { + const auto& exp_gamma = MakeCall(builder, call->span, "exp_gamma", reshape_op, + {call->args[1], ShapeExpr(exp_shape)}); + res = MakeCall(builder, call->span, "scale", multiply_op, {res, exp_gamma}); + } + // offset + if (src_attrs->center) { + const auto& exp_beta = MakeCall(builder, call->span, "exp_beta", reshape_op, + {call->args[2], ShapeExpr(exp_shape)}); + res = MakeCall(builder, call->span, "offset", add_op, {res, exp_beta}); + } + // reshape output + return Call(reshape_op, {res, ShapeExpr(input_shape)}, Attrs(), call->sinfo_args, call->span); +} + +Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const Array& version) { + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + const auto& input_shape = GetShape(call->args[0]); + const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; + const auto* src_attrs = src_call->attrs.as(); + Array exp_shape(input_shape.size(), Integer(1)); + for (const auto& a : src_attrs->axes) { + size_t index = CommonUtils::GetIndex(static_cast(a->value), input_shape.size()); + exp_shape.Set(index, input_shape[index]); + } + // create eps constant + const auto& eps = + MakeConstant(src_attrs->epsilon, in_dtype, SpanUtils::GetAttr(call->span, "name") + "_eps"); + + // create ops + static const Op& add_op = Op::Get("relax.add"); + static const Op& divide_op = Op::Get("relax.divide"); + static const Op& mean_op = Op::Get("relax.mean"); + static const Op& multiply_op = Op::Get("relax.multiply"); + static const Op& square_op = Op::Get("relax.square"); + static const Op& reshape_op = Op::Get("relax.reshape"); + static const Op& sqrt_op = Op::Get("relax.sqrt"); + static const Op& subtract_op = Op::Get("relax.subtract"); + + // mean(input) + auto mean_attrs = make_object(); + mean_attrs->axis = src_attrs->axes; + mean_attrs->keepdims = true; + const auto& mean = + MakeCall(builder, call->span, "mean", mean_op, {call->args[0]}, Attrs(mean_attrs)); + + // variance: mean((input-mean)*(input-mean)) + const auto& diff = MakeCall(builder, call->span, "diff", subtract_op, {call->args[0], mean}); + const auto& square = MakeCall(builder, call->span, "square", square_op, {diff}); + const auto& variance = + MakeCall(builder, call->span, "variance", mean_op, {square}, Attrs(mean_attrs)); + + // sqrt(var + epsilon) + Array exp_eps_shape(input_shape.size(), Integer(1)); + const auto& exp_eps = + MakeCall(builder, call->span, "exp_eps", reshape_op, {eps, ShapeExpr(exp_eps_shape)}); + const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op, {variance, exp_eps}); + const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add}); + + // diff/sqrt + Call res = Call(divide_op, {diff, sqrt}, Attrs(), call->sinfo_args, call->span); + + // scale + if (src_attrs->scale) { + const auto& exp_gamma = MakeCall(builder, call->span, "exp_gamma", reshape_op, + {call->args[1], ShapeExpr(exp_shape)}); + const auto& res_var = EmitCall(builder, res, call->span, "pre_scale"); + if (src_attrs->center) { + res = Call(multiply_op, {res_var, exp_gamma}); + } else { + res = Call(multiply_op, {res_var, exp_gamma}, Attrs(), call->sinfo_args, call->span); + } + } + // offset + if (src_attrs->center) { + const auto& exp_beta = MakeCall(builder, call->span, "exp_beta", reshape_op, + {call->args[2], ShapeExpr(exp_shape)}); + const auto& res_var = EmitCall(builder, res, call->span, "pre_offset"); + res = Call(add_op, {res_var, exp_beta}, Attrs(), call->sinfo_args, call->span); + } + return res; +} + +Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const Array& version) { + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + const auto& shape_a = GetShape(call->args[0]); + const auto& shape_b = GetShape(call->args[1]); + static const Op& reshape_op = Op::Get("relax.reshape"); + if (shape_a.size() > shape_b.size()) { + Array exp_shape(shape_a.size(), Integer(1)); + for (size_t i = shape_b.size(); i < shape_a.size(); i++) { + exp_shape.Set(i, shape_b[i - shape_b.size()]); + } + const auto& expand_b = MakeCall(builder, call->span, "expand_b", reshape_op, + {call->args[1], ShapeExpr(exp_shape)}); + return Call(call->op, {call->args[0], expand_b}, call->attrs, call->sinfo_args, call->span); + } + if (shape_a.size() < shape_b.size()) { + Array exp_shape(shape_b.size(), Integer(1)); + for (size_t i = shape_a.size(); i < shape_b.size(); i++) { + exp_shape.Set(i, shape_a[i - shape_a.size()]); + } + const auto& expand_a = MakeCall(builder, call->span, "expand_a", reshape_op, + {call->args[0], ShapeExpr(exp_shape)}); + return Call(call->op, {expand_a, call->args[1]}, call->attrs, call->sinfo_args, call->span); + } + return call; +} + +Expr RewriteRsqrt(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const Array& version) { + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + const auto& input_shape = GetShape(call->args[0]); + const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; + Array exp_shape(input_shape.size(), Integer(1)); + // create 1 constant + const auto& one = MakeConstant(1, in_dtype, SpanUtils::GetAttr(call->span, "name") + "_one"); + + // create ops + static const Op& reshape_op = Op::Get("relax.reshape"); + static const Op& divide_op = Op::Get("relax.divide"); + static const Op& sqrt_op = Op::Get("relax.sqrt"); + + // expand and divide + const auto& exp_one = + MakeCall(builder, call->span, "exp_one", reshape_op, {one, ShapeExpr(exp_shape)}); + const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {call->args[0]}); + return Call(divide_op, {exp_one, sqrt}, Attrs(), call->sinfo_args, call->span); +} + +Expr RewriteSilu(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const Array& version) { + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + // create ops + static const Op& multiply_op = Op::Get("relax.multiply"); + static const Op& sigmoid_op = Op::Get("relax.sigmoid"); + // silu=input*sigmoid(input) + const auto& sigmoid = MakeCall(builder, call->span, "sigmoid", sigmoid_op, {call->args[0]}); + return Call(multiply_op, {call->args[0], sigmoid}, Attrs(), call->sinfo_args, call->span); +} + +Expr RewriteShapeLike(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const Array& version) { + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + const auto& output_shape = GetShape(var); + static const Op& reshape_op = Op::Get("relax.reshape"); + return Call(reshape_op, {call->args[0], ShapeExpr(output_shape)}, Attrs(), call->sinfo_args, + call->span); +} + +Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const Array& version) { + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + const auto& input_shape = GetShape(call->args[0]); + const auto* src_attrs = src_call->attrs.as(); + size_t axis = CommonUtils::GetIndex(src_attrs->axis, input_shape.size()); + std::vector split_begins, split_ends; + // get split begins and ends + if (src_attrs->indices_or_sections->IsInstance()) { + int64_t sections = Downcast(src_attrs->indices_or_sections)->value; + int64_t size = Downcast(input_shape[axis])->value / sections; + for (int64_t i = 0; i < sections; i++) { + split_begins.push_back(i * size); + split_ends.push_back(i * size + size); + } + } else if (src_attrs->indices_or_sections->IsInstance()) { + const auto& indices = Downcast>(src_attrs->indices_or_sections); + int64_t last_index = 0; + for (size_t i = 0; i < indices.size(); ++i) { + split_begins.push_back(last_index); + last_index = indices[i]->value; + split_ends.push_back(last_index); + } + split_begins.push_back(last_index); + split_ends.push_back(Downcast(input_shape[axis])->value); + } else { + LOG_FATAL << "Unexpected indices_or_sections " << src_attrs->indices_or_sections << "(" + << src_attrs->indices_or_sections->GetTypeKey() << ")"; + } + // create strided_slices + static const Op& slice_op = Op::Get("relax.strided_slice"); + Array outputs; + for (size_t i = 0; i < split_begins.size(); i++) { + auto slice_attrs = make_object(); + slice_attrs->axes.push_back(Integer(axis)); + slice_attrs->begin.push_back(Integer(split_begins[i])); + slice_attrs->end.push_back(Integer(split_ends[i])); + const auto& slice = MakeCall(builder, call->span, "slice_" + std::to_string(i), slice_op, + {call->args[0]}, Attrs(slice_attrs)); + outputs.push_back(slice); + } + return Tuple(outputs, call->span); +} + +// nn ops +TVM_REGISTER_OP("relax.nn.attention") + .set_attr("FRewriteTensorRT", RewriteAttention); +TVM_REGISTER_OP("relax.nn.attention_bias") + .set_attr("FRewriteTensorRT", RewriteAttention); +TVM_REGISTER_OP("relax.nn.batch_norm") + .set_attr("FRewriteTensorRT", RewriteBatchNorm); +TVM_REGISTER_OP("relax.nn.conv1d").set_attr("FRewriteTensorRT", RewriteConv1d); +TVM_REGISTER_OP("relax.nn.group_norm") + .set_attr("FRewriteTensorRT", RewriteGroupNorm); +TVM_REGISTER_OP("relax.nn.layer_norm") + .set_attr("FRewriteTensorRT", RewriteLayerNorm); +TVM_REGISTER_OP("relax.nn.silu").set_attr("FRewriteTensorRT", RewriteSilu); + +// elemwise ops +TVM_REGISTER_OP("relax.add").set_attr("FRewriteTensorRT", RewriteAdd); +TVM_REGISTER_OP("relax.divide").set_attr("FRewriteTensorRT", RewriteElemwise); +TVM_REGISTER_OP("relax.floor_divide") + .set_attr("FRewriteTensorRT", RewriteElemwise); +TVM_REGISTER_OP("relax.greater").set_attr("FRewriteTensorRT", RewriteElemwise); +TVM_REGISTER_OP("relax.less").set_attr("FRewriteTensorRT", RewriteElemwise); +TVM_REGISTER_OP("relax.maximum").set_attr("FRewriteTensorRT", RewriteElemwise); +TVM_REGISTER_OP("relax.minimum").set_attr("FRewriteTensorRT", RewriteElemwise); +TVM_REGISTER_OP("relax.multiply").set_attr("FRewriteTensorRT", RewriteElemwise); +TVM_REGISTER_OP("relax.power").set_attr("FRewriteTensorRT", RewriteElemwise); +TVM_REGISTER_OP("relax.subtract").set_attr("FRewriteTensorRT", RewriteElemwise); + +// math ops +TVM_REGISTER_OP("relax.argmax").set_attr("FRewriteTensorRT", RewriteArgmaxmin); +TVM_REGISTER_OP("relax.argmin").set_attr("FRewriteTensorRT", RewriteArgmaxmin); +TVM_REGISTER_OP("relax.broadcast_to") + .set_attr("FRewriteTensorRT", RewriteBroadcastTo); +TVM_REGISTER_OP("relax.expand_dims") + .set_attr("FRewriteTensorRT", RewriteShapeLike); +TVM_REGISTER_OP("relax.matmul").set_attr("FRewriteTensorRT", RewriteMatmul); +TVM_REGISTER_OP("relax.rsqrt").set_attr("FRewriteTensorRT", RewriteRsqrt); +TVM_REGISTER_OP("relax.squeeze").set_attr("FRewriteTensorRT", RewriteShapeLike); +TVM_REGISTER_OP("relax.split").set_attr("FRewriteTensorRT", RewriteSplit); + +class TensorRTTransformer : public ExprMutator { + public: + explicit TensorRTTransformer(IRModule ctx_module, const Array& version) + : ExprMutator(ctx_module) { + version_ = version; + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { + if (const auto* op_node = call_node->op.as()) { + const auto& op = Downcast(GetRef(op_node)); + const auto& rewrite_map = Op::GetAttrMap("FRewriteTensorRT"); + if (rewrite_map.count(op)) { + const auto& call = GetRef(call_node); + FRewriteTensorRT f = rewrite_map[op]; + const auto& new_call = f(builder_, binding->var, call, new_calls_, version_); + if (new_call != call) { + ReEmitBinding(binding, builder_->Normalize(new_call)); + new_calls_.Set(binding->var, call); + } + } + } + if (!new_calls_.count(binding->var)) { + ExprMutator::VisitBinding_(binding, call_node); + } + } + + private: + Map new_calls_; + Array version_; +}; + +Function TransformTensorRT(const Function& func, const IRModule& module, + const Array& version) { + return Downcast(TensorRTTransformer(module, version).VisitExpr(func)); +} + +namespace transform { + +Pass TransformTensorRT(const Array& version) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return relax::TransformTensorRT(f, m, version); + }; + return CreateFunctionPass(pass_func, 0, "TransformTensorRT", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.TransformTensorRT").set_body_typed(TransformTensorRT); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc b/src/runtime/contrib/msc/tensorrt_runtime.cc new file mode 100644 index 000000000000..8a4f5fe4bae0 --- /dev/null +++ b/src/runtime/contrib/msc/tensorrt_runtime.cc @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/tensorrt/tensorrt_runtime.cc + * \brief JSON runtime implementation for TensorRT. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../json/json_runtime.h" + +#ifdef TVM_GRAPH_EXECUTOR_TENSORRT +#include "../../../runtime/cuda/cuda_common.h" +#include "../tensorrt/tensorrt_logger.h" +#include "../tensorrt/tensorrt_utils.h" +#endif + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime::json; + +#ifdef TVM_GRAPH_EXECUTOR_TENSORRT +using namespace nvinfer1; +#endif + +class MSCTensorRTRuntime : public JSONRuntimeBase { + public: + /*! + * \brief The MSC TensorRT runtime module. Deserialize the provided functions + * on creation and store in the layer cache. + * + * \param symbol_name The name of the function. + * \param graph_json serialized JSON representation of a sub-graph. + * \param const_names The names of each constant in the sub-graph. + */ + explicit MSCTensorRTRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array& const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + + ~MSCTensorRTRuntime() override { + VLOG(1) << "Destroying MSC TensorRT runtime"; + DestroyEngine(); + } + + /*! + * \brief The type key of the module. + * + * \return module type key. + */ + const char* type_key() const final { return "msc_tensorrt"; } + + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + } + + /*! + * \brief Initialize runtime. + * + * \param consts The constant params from compiled model. + */ + void Init(const Array& consts) override { + ICHECK_EQ(consts.size(), const_idx_.size()) + << "The number of input constants must match the number of required."; + LoadGlobalOptions(); + LoadEngine(engine_file_); + } + + void LoadGlobalOptions() { + // These settings are global to the entire subgraph. Codegen will add them as attributes to all + // op nodes. Read from first one. + for (size_t i = 0; i < nodes_.size(); ++i) { + if (nodes_[i].HasAttr("msc_global_options_num")) { + engine_file_ = nodes_[i].GetAttr>("msc_global_engine")[0]; + } + } + } + +#ifdef TVM_GRAPH_EXECUTOR_TENSORRT + void Run() override { + SetInputOutputBinds(); + auto tvm_stream = CUDAThreadEntry::ThreadLocal()->stream; +#if TRT_VERSION_GE(6, 0, 1) + ICHECK(context_->enqueueV2(bindings_.data(), tvm_stream, nullptr)) + << "Running TensorRT failed."; +#else + LOG_FATAL << "Only support tensorrt with version >=6.0.0"; +#endif + // Copy outputs from GPU buffers if needed. + for (size_t i = 0; i < outputs_.size(); ++i) { + auto nid = outputs_[i].id_; + uint32_t eid = EntryID(outputs_[i]); + const auto& name = nodes_[nid].GetOpName() + ":" + std::to_string(outputs_[i].index_); + int binding_index = engine_->getBindingIndex(name.c_str()); + ICHECK_NE(binding_index, -1); + if (data_entry_[eid]->device.device_type != kDLCUDA) { + auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); + device_buffer.CopyTo(const_cast(data_entry_[eid])); + } + } + } + + bool LoadEngine(const String& engine_file) { + IRuntime* runtime = createInferRuntime(logger_); + // build engine + std::ifstream input(engine_file_, std::ifstream::binary); + if (!input.is_open() || !input.good()) { + LOG_ERROR << "Failed to open engine file " << engine_file_; + return false; + } + std::vector stream; + size_t size = 0; + input.seekg(0, input.end); + size = input.tellg(); + input.seekg(0, input.beg); + stream.resize(size); + input.read(stream.data(), size); + input.close(); + +#if TRT_VERSION_GE(8, 0, 0) + engine_ = runtime->deserializeCudaEngine(stream.data(), size); +#else + engine_ = runtime->deserializeCudaEngine(stream.data(), size, nullptr); +#endif + if (!engine_) { + LOG_ERROR << "Failed to load engine"; + return false; + } + // create context + context_ = engine_->createExecutionContext(); + if (!context_) { + LOG_ERROR << "Failed to create context"; + return false; + } + // resize bindings + size_t num_binding = static_cast(engine_->getNbBindings()); + bindings_.resize(num_binding); + binding_sizes_.resize(num_binding); + for (size_t i = 0; i < num_binding; i++) { + bindings_[i] = nullptr; + binding_sizes_[i] = 0; + } + // destroy runtime +#if TRT_VERSION_GE(8, 0, 0) + delete runtime; +#else + runtime->destroy(); +#endif + return true; + } + + void DestroyEngine() { +#if TRT_VERSION_GE(8, 0, 0) + delete context_; + delete engine_; +#else + context_->destroy(); + engine_->destroy(); +#endif + engine_ = nullptr; + context_ = nullptr; + } + + void SetInputOutputBinds() { + // Setup input bindings + for (size_t i = 0; i < input_nodes_.size(); ++i) { + auto nid = input_nodes_[i]; + if (nodes_[nid].GetOpType() == "input") { + for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { + uint32_t eid = EntryID(nid, j); + const auto& name = nodes_[nid].GetOpName() + ":" + std::to_string(j); + int binding_index = engine_->getBindingIndex(name.c_str()); + ICHECK_NE(binding_index, -1); +#if TRT_VERSION_GE(6, 0, 1) + std::vector shape(data_entry_[eid]->shape, + data_entry_[eid]->shape + data_entry_[eid]->ndim); + ICHECK(context_->setBindingDimensions(binding_index, VectorToTrtDims(shape))); +#endif + if (data_entry_[eid]->device.device_type == kDLCUDA) { + bindings_[binding_index] = data_entry_[eid]->data; + } else { + auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); + device_buffer.CopyFrom(data_entry_[eid]); + bindings_[binding_index] = device_buffer->data; + } + auto dims = engine_->getBindingDimensions(binding_index); + int num_elements = 1; + for (int i = 0; i < dims.nbDims; ++i) num_elements *= dims.d[i]; + binding_sizes_[binding_index] = num_elements; + } + } + } + // Setup output bindings. + for (size_t i = 0; i < outputs_.size(); ++i) { + auto nid = outputs_[i].id_; + uint32_t eid = EntryID(outputs_[i]); + const auto& name = nodes_[nid].GetOpName() + ":" + std::to_string(outputs_[i].index_); + int binding_index = engine_->getBindingIndex(name.c_str()); + ICHECK_NE(binding_index, -1); + if (data_entry_[eid]->device.device_type == kDLCUDA) { + bindings_[binding_index] = data_entry_[eid]->data; + } else { + auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); + bindings_[binding_index] = device_buffer->data; + } + } + } + + NDArray GetOrAllocateDeviceBuffer(int entry_id, int binding_index) { + std::vector shape(data_entry_[entry_id]->shape, + data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); + if (device_buffers_.count(binding_index)) { + // Buffer is already initialized. + if (shape[0] > device_buffers_[binding_index]->shape[0]) { + // Buffer is too small. Need to allocate bigger buffer. + device_buffers_[binding_index] = + runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); + } else if (shape[0] < device_buffers_[binding_index]->shape[0]) { + // Buffer is too large. Create view. + return device_buffers_[binding_index].CreateView(shape, data_entry_[entry_id]->dtype); + } + } else { + // Buffer not initialized yet. + device_buffers_[binding_index] = + runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); + } + return device_buffers_.at(binding_index); + } + +#else // TVM_GRAPH_EXECUTOR_TENSORRT + void Run() override { + LOG(FATAL) << "TensorRT runtime is not enabled. " + << "Please build with USE_TENSORRT_RUNTIME."; + } + + bool LoadEngine(const String& engine_file) { return false; } + + void DestroyEngine() {} +#endif // TVM_GRAPH_EXECUTOR_TENSORRT + + private: + String engine_file_; +#ifdef TVM_GRAPH_EXECUTOR_TENSORRT + TensorRTLogger logger_; + ICudaEngine* engine_{nullptr}; + IExecutionContext* context_{nullptr}; + std::vector bindings_; + std::vector binding_sizes_; + std::unordered_map device_buffers_; +#endif +}; + +runtime::Module MSCTensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { + auto n = make_object(symbol_name, graph_json, const_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.msc_tensorrt_runtime_create").set_body_typed(MSCTensorRTRuntimeCreate); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_msc_tensorrt") + .set_body_typed(JSONRuntimeBase::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py b/tests/python/contrib/test_msc/test_translate_tensorrt.py new file mode 100644 index 000000000000..e9981237be78 --- /dev/null +++ b/tests/python/contrib/test_msc/test_translate_tensorrt.py @@ -0,0 +1,815 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Test translate for TensorrRT. """ + +import pytest +import numpy as np + +import torch +from torch import fx +from torch.nn import Module + +import tvm.testing +from tvm.relax.frontend.torch import from_fx +from tvm.contrib.msc.framework.tensorrt.frontend import translate +from tvm.contrib.msc.framework.tensorrt import codegen +from tvm.contrib.msc.core import utils as msc_utils + + +requires_tensorrt = pytest.mark.skipif( + tvm.get_global_func("relax.ext.tensorrt", True) is None, + reason="TENSORRT is not enabled", +) + + +def build_and_run(mod, inputs): + """Build and run the virtual machine""" + + target = tvm.target.Target("cuda") + mod = tvm.relax.transform.LegalizeOps()(mod) + with target: + mod = tvm.tir.transform.DefaultGPUSchedule()(mod) + with tvm.transform.PassContext(opt_level=3): + rt_mod = tvm.relax.build(mod, target) + runnable = tvm.relax.VirtualMachine(rt_mod, tvm.cuda()) + res = runnable["main"](*inputs) + if isinstance(res, tvm.runtime.NDArray): + return [res.asnumpy()] + return [e.asnumpy() for e in res] + + +def verify_model(torch_model, input_info, allow_incomplete=False): + """Build model and verify results""" + + graph_model = fx.symbolic_trace(torch_model) + datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info] + torch_datas = [torch.from_numpy(i) for i in datas] + with torch.no_grad(): + golden = torch_model(*torch_datas) + mod = from_fx(graph_model, input_info) + if not isinstance(golden, (list, tuple)): + golden = [golden] + golden = [g.detach().cpu().numpy() for g in golden] + # partition module for tensorrt + mod, graph_infos = translate.partition_for_tensorrt(mod, allow_incomplete=allow_incomplete) + output_folder = msc_utils.msc_dir() + # tranalte to tensorrt + mod = codegen.to_tensorrt(mod, graph_infos, output_folder=output_folder) + tvm_datas = [tvm.nd.array(i, device=tvm.cuda()) for i in datas] + results = build_and_run(mod, tvm_datas) + for gol, res in zip(golden, results): + tvm.testing.assert_allclose(gol, res, atol=1e-3, rtol=1e-3) + output_folder.destory() + + +@requires_tensorrt +def test_conv1d(): + """test tensorrt translator for conv1d""" + + class Conv1D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=True) + + def forward(self, data): + return self.conv(data) + + class Conv1D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=False) + + def forward(self, data): + return self.conv(data) + + input_info = [([1, 3, 10], "float32")] + verify_model(Conv1D1(), input_info) + verify_model(Conv1D2(), input_info) + + +@requires_tensorrt +def test_conv2d(): + """test tensorrt translator for conv2d""" + + class Conv2D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) + + def forward(self, data): + return self.conv(data) + + class Conv2D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) + + def forward(self, data): + return self.conv(data) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Conv2D1(), input_info) + verify_model(Conv2D2(), input_info) + + +@requires_tensorrt +def test_linear(): + """test tensorrt translator for linear""" + + class Dense1(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=True) + + def forward(self, data): + return self.linear(data) + + class Dense2(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=False) + + def forward(self, data): + return self.linear(data) + + class MatMul1(Module): + def forward(self, x, y): + return torch.matmul(x, y) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Dense1(), input_info) + verify_model(Dense2(), input_info) + verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) + + +@requires_tensorrt +def test_bmm(): + """test tensorrt translator for bmm""" + + class BMM(Module): + def forward(self, x, y): + return torch.bmm(x, y) + + input_info = [((4, 128, 256), "float32"), ((4, 256, 512), "float32")] + verify_model(BMM(), input_info) + + +@requires_tensorrt +def test_baddbmm(): + """test tensorrt translator for baddbmm""" + + class BAddBMM1(Module): + def forward(self, c, x, y): + return torch.baddbmm(c, x, y) + + class BAddBMM2(Module): + def forward(self, c, x, y): + return torch.baddbmm(c, x, y, alpha=2, beta=0) + + input_info = [ + ((4, 128, 512), "float32"), + ((4, 128, 256), "float32"), + ((4, 256, 512), "float32"), + ] + verify_model(BAddBMM1(), input_info) + verify_model(BAddBMM2(), input_info) + + +@requires_tensorrt +def test_relu(): + """test tensorrt translator for relu""" + + class ReLU(Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, data): + return self.relu(data) + + input_info = [([10, 10], "float32")] + verify_model(ReLU(), input_info) + + +@requires_tensorrt +def test_relu6(): + """test tensorrt translator for relu6""" + + class ReLU6(Module): + def __init__(self): + super().__init__() + self.relu6 = torch.nn.ReLU6() + + def forward(self, data): + return self.relu6(data) + + input_info = [([10, 10], "float32")] + verify_model(ReLU6(), input_info) + + +@requires_tensorrt +def test_maxpool2d(): + """test tensorrt translator for maxpool2d""" + + class MaxPool2d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1]) + + def forward(self, data): + return self.pool(data) + + class MaxPool2d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2) + + def forward(self, data): + return self.pool(data) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(MaxPool2d(), input_info) + verify_model(MaxPool2d2(), input_info) + + +@requires_tensorrt +def test_avgpool2d(): + """test tensorrt translator for avgpool2d""" + + class AvgPool2d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1]) + + def forward(self, data): + return self.pool(data) + + class AvgPool2d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True) + + def forward(self, data): + return self.pool(data) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(AvgPool2d(), input_info) + verify_model(AvgPool2d2(), input_info) + + +@requires_tensorrt +def test_adaptive_avgpool2d(): + """test tensorrt translator for adaptive_avgpool2d""" + + class AdaptiveAvgPool2d0(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d([10, 10]) + + def forward(self, data): + return self.pool(data) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(AdaptiveAvgPool2d0(), input_info) + + +@requires_tensorrt +def test_flatten(): + """test tensorrt translator for flatten""" + + class Flatten(Module): + def __init__(self): + super().__init__() + self.f = torch.nn.Flatten(2, -1) + + def forward(self, data): + return self.f(data) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Flatten(), input_info) + verify_model(torch.nn.Flatten(2, -1), input_info) + + +@requires_tensorrt +def test_batchnorm2d(): + """test tensorrt translator for batchnorm2d""" + + class BatchNorm2d(Module): + def __init__(self): + super().__init__() + self.batchnorm = torch.nn.BatchNorm2d(3) + + def forward(self, data): + return self.batchnorm(data) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(BatchNorm2d().eval(), input_info) + + +@requires_tensorrt +def test_embedding(): + """test tensorrt translator for embedding""" + + class Embedding(Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(10, 3) + + def forward(self, data): + return self.embedding(data) + + verify_model(Embedding(), [([4], "int64")], allow_incomplete=True) + verify_model(Embedding(), [([4, 5], "int64")], allow_incomplete=True) + + +@requires_tensorrt +def test_layernorm(): + """test tensorrt translator for layernorm""" + + class LayerNorm(Module): + def __init__(self): + super().__init__() + self.layernorm = torch.nn.LayerNorm((10, 10)) + + def forward(self, data): + return self.layernorm(data) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(LayerNorm(), input_info) + + +@requires_tensorrt +def test_silu(): + """test tensorrt translator for silu""" + + class SiLU(Module): + def __init__(self): + super().__init__() + self.silu = torch.nn.SiLU() + + def forward(self, data): + return self.silu(data) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(SiLU(), input_info) + + +@requires_tensorrt +def test_groupnorm(): + """test tensorrt translator for groupnorm""" + + class GroupNorm(Module): + def __init__(self): + super().__init__() + self.groupnorm = torch.nn.GroupNorm(3, 3) + + def forward(self, data): + return self.groupnorm(data) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(GroupNorm(), input_info) + + +@requires_tensorrt +def test_softmax(): + """test tensorrt translator for softmax""" + + class Softmax(Module): + def __init__(self): + super().__init__() + self.softmax = torch.nn.Softmax(dim=1) + + def forward(self, data): + return self.softmax(data) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Softmax(), input_info) + + +@requires_tensorrt +def test_binary(): + """test tensorrt translator for binary""" + + input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")] + input_info2 = [([1, 3, 10, 10], "float32")] + + # Add + class Add1(Module): + def forward(self, lhs, rhs): + return lhs + rhs + + class Add2(Module): + def forward(self, lhs): + return lhs + 1.0 + + verify_model(Add1(), input_info1) + verify_model(Add2(), input_info2) + + # Sub + class Sub1(Module): + def forward(self, lhs, rhs): + return lhs - rhs + + class Sub2(Module): + def forward(self, lhs): + return lhs - 1.0 + + verify_model(Sub1(), input_info1) + verify_model(Sub2(), input_info2) + + # Mul + class Mul1(Module): + def forward(self, lhs, rhs): + return lhs * rhs + + class Mul2(Module): + def forward(self, lhs): + return lhs * 1.0 + + verify_model(Mul1(), input_info1) + verify_model(Mul2(), input_info2) + + # True div + class TrueDiv1(Module): + def forward(self, lhs, rhs): + return lhs / rhs + + class TrueDiv2(Module): + def forward(self, lhs): + return lhs / 1.0 + + verify_model(TrueDiv1(), input_info1) + verify_model(TrueDiv2(), input_info2) + + # Floor div + class FloorDiv1(Module): + def forward(self, lhs, rhs): + return lhs // rhs + + class FloorDiv2(Module): + def forward(self, lhs): + return lhs // 1.0 + + verify_model(FloorDiv1(), input_info1) + verify_model(FloorDiv2(), input_info2) + + # Power + class Power1(Module): + def forward(self, lhs, rhs): + return lhs**rhs + + class Power2(Module): + def forward(self, lhs): + return lhs**1.0 + + verify_model(Power1(), input_info1) + verify_model(Power2(), input_info2) + + +@requires_tensorrt +def test_squeeze(): + """test tensorrt translator for squeeze""" + + class Squeeze1(Module): + def forward(self, data): + return data.squeeze(1) + + class Squeeze2(Module): + def forward(self, data): + return data.squeeze() + + input_info = [([3, 1, 4, 1], "float32")] + verify_model(Squeeze1(), input_info) + verify_model(Squeeze2(), input_info) + + +@requires_tensorrt +def test_unsqueeze(): + """test tensorrt translator for unsqueeze""" + + class Unsqueeze1(Module): + def forward(self, data): + return data.unsqueeze(1) + + class Unsqueeze2(Module): + def forward(self, data): + return data.unsqueeze(-1) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Unsqueeze1(), input_info) + verify_model(Unsqueeze2(), input_info) + + +@requires_tensorrt +def test_getitem(): + """test tensorrt translator for getitem""" + + class Slice1(Module): + def forward(self, x): + return x[0:1, 1::2, :, :3] + + class Slice2(Module): + def forward(self, x): + return x[:, None, None, :, None] + + verify_model(Slice1(), [([1, 3, 10, 10], "float32")]) + verify_model(Slice2(), [([8, 16], "float32")]) + + +@requires_tensorrt +def test_unary(): + """test tensorrt translator for unary""" + + input_info = [([1, 3, 10, 10], "float32")] + + # sin + class Sin(Module): + def forward(self, data): + return torch.sin(data) + + verify_model(Sin(), input_info) + + # cos + class Cos(Module): + def forward(self, data): + return torch.cos(data) + + verify_model(Cos(), input_info) + + # exp + class Exp(Module): + def forward(self, data): + return torch.exp(data) + + verify_model(Exp(), input_info) + + # sqrt + class Sqrt(Module): + def forward(self, data): + return torch.sqrt(data) + + verify_model(Sqrt(), input_info) + + # sigmoid + class Sigmoid(Module): + def forward(self, data): + return torch.sigmoid(data) + + verify_model(Sigmoid(), input_info) + + # round + class Round(Module): + def forward(self, data): + return torch.round(data) + + verify_model(Round(), input_info) + + +@requires_tensorrt +def test_tanh(): + """test tensorrt translator for tanh""" + + class Tanh(Module): + def forward(self, data): + return torch.tanh(data) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Tanh(), input_info) + + +@requires_tensorrt +def test_clamp(): + """test tensorrt translator for clamp""" + + class Clamp(Module): + def forward(self, data): + return torch.clamp(data, min=0.1, max=0.5) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Clamp(), input_info) + + +@requires_tensorrt +def test_interpolate(): + """test tensorrt translator for interpolate""" + + class Interpolate(Module): + def forward(self, data): + return torch.nn.functional.interpolate(data, (5, 5)) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Interpolate(), input_info) + + +@requires_tensorrt +def test_addmm(): + """test tensorrt translator for addmm""" + + class Addmm(Module): + def forward(self, x_1, x_2, x_3): + return torch.addmm(x_1, x_2, x_3) + + input_info = [ + ([10, 10], "float32"), + ([10, 10], "float32"), + ([10, 10], "float32"), + ] + verify_model(Addmm(), input_info) + + +@requires_tensorrt +def test_split(): + """test tensorrt translator for split""" + + class Split(Module): + def forward(self, data): + return torch.split(data, 1, dim=1) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Split(), input_info) + + +@requires_tensorrt +def test_chunk(): + """test tensorrt translator for chunk""" + + class Chunk(Module): + def forward(self, data): + return torch.chunk(data, 3, dim=1) + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Chunk(), input_info) + + +@requires_tensorrt +def test_expand(): + """test tensorrt translator for expand""" + + class Expand(Module): + def forward(self, x): + x = x + 1.0 + return x.expand(4, 2, 3, 4) + + input_info = [([1, 2, 3, 4], "float32")] + verify_model(Expand(), input_info) + + +@requires_tensorrt +def test_reduce(): + """test tensorrt translator for reduce""" + + # sum + class Sum(Module): + def forward(self, x): + return torch.sum(x, (2, 1)) + + input_info = [([1, 2, 3, 4], "float32")] + verify_model(Sum(), input_info) + + +@requires_tensorrt +def test_permute(): + """test tensorrt translator for permute""" + + class Permute(Module): + def forward(self, x): + return x.permute(0, 3, 2, 1) + + input_info = [([1, 2, 3, 4], "float32")] + verify_model(Permute(), input_info) + + +@requires_tensorrt +def test_reshape(): + """test tensorrt translator for reshape""" + + class Reshape(Module): + def forward(self, x): + return x.reshape(2, 12) + + input_info = [([1, 2, 3, 4], "float32")] + verify_model(Reshape(), input_info) + + +@requires_tensorrt +def test_transpose(): + """test tensorrt translator for transpose""" + + class Transpose(Module): + def forward(self, x): + return x.transpose(1, 3) + + input_info = [([1, 2, 3, 4], "float32")] + verify_model(Transpose(), input_info) + + +@requires_tensorrt +def test_view(): + """test tensorrt translator for view""" + + class View(Module): + def forward(self, x): + return x.view(2, 12) + + input_info = [([1, 2, 3, 4], "float32")] + verify_model(View(), input_info) + + +@requires_tensorrt +def test_argmax(): + """test tensorrt translator for argmax""" + + class Argmax1(Module): + def forward(self, data): + return torch.argmax(data, dim=-1) + + class Argmax2(Module): + def forward(self, data): + return torch.argmax(data, dim=-1, keepdim=True) + + verify_model(Argmax1(), [([256, 256], "float32")], allow_incomplete=True) + verify_model(Argmax2(), [([256, 256], "float32")], allow_incomplete=True) + + +@requires_tensorrt +def test_argmin(): + """test tensorrt translator for argmin""" + + class Argmin1(Module): + def forward(self, data): + return torch.argmin(data, dim=-1) + + class Argmin2(Module): + def forward(self, data): + return torch.argmin(data, dim=-1, keepdim=True) + + verify_model(Argmin1(), [([256, 256], "float32")], allow_incomplete=True) + verify_model(Argmin2(), [([256, 256], "float32")], allow_incomplete=True) + + +@requires_tensorrt +def test_mean(): + """test tensorrt translator for mean""" + + class Mean(Module): + def forward(self, data): + return data.mean(-1) + + class MeanKeepDim(Module): + def forward(self, data): + return data.mean(-1, keepdim=True) + + verify_model(Mean(), [([256, 256], "float32")]) + verify_model(MeanKeepDim(), [([256, 256], "float32")]) + + +@requires_tensorrt +def test_rsqrt(): + """test tensorrt translator for rsqrt""" + + class Rsqrt(Module): + def forward(self, data): + return torch.rsqrt(data) + + verify_model(Rsqrt(), [([256, 256], "float32")]) + + +@requires_tensorrt +def test_neg(): + """test tensorrt translator for neg""" + + class Neg(Module): + def forward(self, data): + return -data + + verify_model(Neg(), [([256, 256], "float32")]) + + +@requires_tensorrt +def test_max(): + """test tensorrt translator for max""" + + class Max(Module): + def forward(self, x, y): + return torch.max(x, y) + + verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) + + +if __name__ == "__main__": + tvm.testing.main()