From 054a2b87d9421d383d980adfba5488c952053c31 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sun, 4 Aug 2024 23:17:30 -0700 Subject: [PATCH 01/85] Added TRTWrapper Signed-off-by: Boris Fomitchev --- monai/utils/__init__.py | 2 + monai/utils/cast_utils.py | 101 +++++++ monai/utils/export_utils.py | 299 +++++++++++++++++++ monai/utils/trt_utils.py | 554 ++++++++++++++++++++++++++++++++++++ 4 files changed, 956 insertions(+) create mode 100644 monai/utils/cast_utils.py create mode 100644 monai/utils/export_utils.py create mode 100644 monai/utils/trt_utils.py diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 03fa1ceed1..b0919c2ce5 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -152,3 +152,5 @@ get_numpy_dtype_from_string, get_torch_dtype_from_string, ) + +from .trt_utils import TRTWrapper diff --git a/monai/utils/cast_utils.py b/monai/utils/cast_utils.py new file mode 100644 index 0000000000..329033ed42 --- /dev/null +++ b/monai/utils/cast_utils.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext + +import torch + + +def avoid_bfloat16_autocast_context(): + """ + If the current autocast context is bfloat16, + cast it to float32 + """ + + if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16: + return torch.cuda.amp.autocast(dtype=torch.float32) + else: + return nullcontext() + + +def avoid_float16_autocast_context(): + """ + If the current autocast context is float16, cast it to bfloat16 + if available (unless we're in jit) or float32 + """ + + if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return torch.cuda.amp.autocast(dtype=torch.float32) + + if torch.cuda.is_bf16_supported(): + return torch.cuda.amp.autocast(dtype=torch.bfloat16) + else: + return torch.cuda.amp.autocast(dtype=torch.float32) + else: + return nullcontext() + + +def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): + return x.to(dtype=to_dtype) if x.dtype == from_dtype else x + + +def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): + if isinstance(x, torch.Tensor): + return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) + else: + if isinstance(x, dict): + new_dict = {} + for k in x.keys(): + new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) + return new_dict + elif isinstance(x, tuple): + return tuple( + cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x + ) + + +class CastToFloat(torch.nn.Module): + def __init__(self, mod): + super(CastToFloat, self).__init__() + self.mod = mod + + def forward(self, x): + with torch.cuda.amp.autocast(enabled=False): + ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) + return ret + + +class CastToFloatAll(torch.nn.Module): + def __init__(self, mod): + super(CastToFloatAll, self).__init__() + self.mod = mod + + def forward(self, *args): + from_dtype = args[0].dtype + with torch.cuda.amp.autocast(enabled=False): + ret = self.mod.forward( + *cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32) + ) + return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py new file mode 100644 index 0000000000..d79cac3db0 --- /dev/null +++ b/monai/utils/export_utils.py @@ -0,0 +1,299 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Optional, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .cast_utils import CastToFloat + + +class LinearWithBiasSkip(nn.Module): + def __init__(self, weight, bias, skip_bias_add): + super(LinearWithBiasSkip, self).__init__() + self.bias = bias + self.weight = weight + self.skip_bias_add = skip_bias_add + + def forward(self, x): + if self.skip_bias_add: + return F.linear(x, self.weight), self.bias + return F.linear(x, self.weight, self.bias), None + +apex_available = True + +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm + from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax + from apex.transformer.tensor_parallel.layers import ( + ColumnParallelLinear, + RowParallelLinear, + ) + + def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: + """ + Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export. + Args: + n: the FusedLayerNorm pytorch module to replace + Returns: + Equivalent LayerNorm module + """ + + p = next(n.parameters()) + if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm): + shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine + elif isinstance(n, FastLayerNorm): + shape, eps, affine = n.weight.shape, n.epsilon, True + else: + return None + + mod = nn.LayerNorm( + shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype + ) + n_state = n.state_dict() + mod.load_state_dict(n_state) + return mod + + def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]: + """ + Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export. + Args: + n: the FusedLayerNorm pytorch module to replace + Returns: + Equivalent LayerNorm module + """ + if not isinstance(n, RowParallelLinear): + raise ValueError( + "This function can only change the RowParallelLinear module." + ) + + dev = next(n.parameters()).device + mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev) + + n_state = n.state_dict() + mod.load_state_dict(n_state) + return mod + + def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]: + """ + Replaces Apex's ColumnParallelLinear or RowParallelLinear with nn.Linear + Args: + n: the nn.Module pytorch module to replace + Returns: + Equivalent Linear module + """ + if not ( + isinstance(n, ColumnParallelLinear) or isinstance(n, RowParallelLinear) + ): + raise ValueError( + "This function can only change the ColumnParallelLinear or RowParallelLinear module." + ) + + dev = next(n.parameters()).device + mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev) + + n_state = n.state_dict() + mod.load_state_dict(n_state) + return mod + + def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: + """ + Replaces Apex's FusedScaleMaskSoftmax with nn.LayerNorm. This is required for ONNX export. + Args: + n: the FusedScaleMaskSoftmax module to replace + Returns: + Equivalent LayerNorm module + """ + if not isinstance(n, FusedScaleMaskSoftmax): + raise ValueError( + "This function can only change the FusedScaleMaskSoftmax module." + ) + + # disable the fusion only + mod = FusedScaleMaskSoftmax( + n.input_in_fp16, + n.input_in_bf16, + n.attn_mask_type, + False, + n.mask_func, + n.softmax_in_fp32, + n.scale, + ) + + return mod + + default_Apex_replacements = { + "FusedLayerNorm": replace_FusedLayerNorm, + "MixedFusedLayerNorm": replace_FusedLayerNorm, + "FastLayerNorm": replace_FusedLayerNorm, + "ESM1bLayerNorm": replace_FusedLayerNorm, + "RowParallelLinear": replace_ParallelLinear, + "ColumnParallelLinear": replace_ParallelLinear, + "FusedScaleMaskSoftmax": replace_FusedScaleMaskSoftmax, + } + +except Exception: + default_Apex_replacements = {} + apex_available = False + + +def simple_replace( + BaseT: Type[nn.Module], DestT: Type[nn.Module] +) -> Callable[[nn.Module], Optional[nn.Module]]: + """ + Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same atrributes. No weights are copied. + Args: + BaseT : module type to replace + DestT : destination module type + Returns: + swap function to replace BaseT module with DestT + """ + + def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: + if not isinstance(mod, BaseT): + return None + args = [getattr(mod, name, None) for name in mod.__constants__] + out = DestT(*args) + return out + + return expansion_fn + + +def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: + """ + Replaces MatchedScaleMaskSoftmax with exportable softmax layer + Args: + n: module to replace + Returns: + exportable module + """ + # including the import here to avoid circular imports + from nemo.collections.nlp.modules.common.megatron.fused_softmax import ( + MatchedScaleMaskSoftmax, + ) + + # disabling fusion for the MatchedScaleMaskSoftmax + mod = MatchedScaleMaskSoftmax( + n.input_in_fp16, + n.input_in_bf16, + n.attn_mask_type, + False, + n.mask_func, + n.softmax_in_fp32, + n.scale, + ) + return mod + + +def wrap_module( + BaseT: Type[nn.Module], DestT: Type[nn.Module] +) -> Callable[[nn.Module], Optional[nn.Module]]: + """ + Generic function generator to replace BaseT module with DestT wrapper. + Args: + BaseT : module type to replace + DestT : destination module type + Returns: + swap function to replace BaseT module with DestT + """ + + def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: + out = DestT(mod) + return out + + return expansion_fn + + +def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]): + """ + This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows + for swapping nested modules through arbitrary levels if children + + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + + """ + for path, new_mod in mapping.items(): + expanded_path = path.split(".") + parent_mod = model + for sub_path in expanded_path[:-1]: + parent_mod = parent_mod._modules[sub_path] + parent_mod._modules[expanded_path[-1]] = new_mod + + return model + + +def replace_modules( + model: nn.Module, + expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]] = None, +) -> nn.Module: + """ + Top-level function to replace modules in model, specified by class name with a desired replacement. + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + Args: + model : top level module + expansions : replacement dictionary: module class name -> replacement function generator + Returns: + model, possibly modified in-place + """ + mapping: Dict[str, nn.Module] = {} + for name, m in model.named_modules(): + m_type = type(m).__name__ + if m_type in expansions: + # print (f"Found {m_type} in expansions ...") + swapped = expansions[m_type](m) + if swapped: + mapping[name] = swapped + + print(f"Swapped {len(mapping)} modules") + swap_modules(model, mapping) + return model + +def replace_for_export(model: nn.Module, do_cast: bool = False) -> nn.Module: + """ + Top-level function to replace default set of modules in model + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + Args: + model : top level module + replace_1D_2D : include 1D -> 2D replacements + Returns: + model, possibly modified in-place + """ + if apex_available: + print("Replacing Apex layers ...") + replace_modules(model, default_Apex_replacements) + + if do_cast: + print("Adding casts around norms...") + cast_replacements = { + "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), + "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), + "BatchNorm3d": wrap_module(nn.BatchNorm2d, CastToFloat), + "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), + "InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat), + "InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat), + } + replace_modules(model, cast_replacements) diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py new file mode 100644 index 0000000000..3d3893c239 --- /dev/null +++ b/monai/utils/trt_utils.py @@ -0,0 +1,554 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# +# Copyright 2022 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import List +from copy import copy +import numpy as np +import os +import pickle +from PIL import Image +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.onnx import onnx_from_path, fold_constants, save_onnx +from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx +from polygraphy.backend.trt import TrtRunner, CreateConfig, ModifyNetworkOutputs, Profile +from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine +from polygraphy.logger import G_LOGGER as L_ + +import random +from scipy import integrate +import tensorrt as trt +import torch +import traceback + +from io import BytesIO +from cuda import cudart +from enum import Enum, auto + +import threading + +from monai.apps.utils import get_logger +LOGGER=get_logger("run_cmd") + +lock_sm = threading.Lock() + +# Map of torch dtype -> numpy dtype +trt_to_torch_dtype_dict = { + trt.int32: torch.int32, + trt.float32: torch.float32, + trt.float16: torch.float16, + trt.bfloat16: torch.float16, + trt.int64: torch.int64, + trt.int8: torch.int8, + trt.bool: torch.bool, +} + +def get_dynamic_axes(profiles, extra_axes={}): + dynamic_axes=extra_axes + for profile in profiles: + for key in profile: + axes=[] + vals=profile[key] + for i in range(len(vals[0])): + if vals[0][i] != vals[2][i]: + axes.append(i) + if len(axes) > 0: + dynamic_axes[key] = axes + return dynamic_axes + +def CUASSERT(cuda_ret): + err = cuda_ret[0] + if err != 0: + raise RuntimeError(f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t") + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + +class ShapeException(Exception): + pass + + +class Engine: + """ + An auxiliary class to implement running of TRT optimized engines + + """ + def __init__( + self, + engine_path, + ): + self.engine_path = engine_path + self.engine = None + self.context = None + self.tensors = OrderedDict() + self.cuda_graph_instance = None # cuda graph + + def build( + self, + onnx_path, + profiles=[], + fp16=False, + bf16=False, + tf32=True, + builder_optimization_level=3, + enable_all_tactics=True, + direct_io=False, + timing_cache=None, + update_output_names=None, + ): + L_.info(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + config_kwargs = { + "builder_optimization_level": builder_optimization_level, + "direct_io": direct_io, + } + if not enable_all_tactics: + config_kwargs["tactic_sources"] = [] + + network = network_from_onnx_path( + onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM] + ) + if update_output_names: + L_.info(f"Updating network outputs to {update_output_names}") + network = ModifyNetworkOutputs(network, update_output_names) + # with L.verbosity(0): + L_.info("Calling engine_from_network...") + + engine = engine_from_network( + network, + config=CreateConfig( + fp16=fp16, + bf16=bf16, + tf32=tf32, + profiles=profiles, + load_timing_cache=timing_cache, + **config_kwargs, + ), + save_timing_cache=timing_cache, + ) + self.engine = engine + + def save(self): + save_engine(self.engine, path=self.engine_path) + + def load(self): + L_.info(f"Loading TensorRT engine: {self.engine_path}") + self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + + def activate(self, profile_num=0, reuse_device_memory=None): + if reuse_device_memory: + self.context = self.engine.create_execution_context_without_device_memory() + self.context.device_memory = reuse_device_memory + else: + self.context = self.engine.create_execution_context() + self.input_names = [] + self.output_names = [] + self.dtypes = [] + for idx in range(self.engine.num_io_tensors): + binding = self.engine[idx] + if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT: + self.input_names.append(binding) + elif self.engine.get_tensor_mode(binding) == trt.TensorIOMode.OUTPUT: + self.output_names.append(binding) + dtype = trt_to_torch_dtype_dict[self.engine.get_tensor_dtype(binding)] + self.dtypes.append(dtype) + self.cur_profile = profile_num + # L_.info(self.input_names) + # L_.info(self.output_names) + + def allocate_buffers(self, device): + # allocate outputs + ctx = self.context + + for i, binding in enumerate(self.output_names): + shape = ctx.get_tensor_shape(binding) + t = torch.empty( + list(shape), dtype=self.dtypes[i], device=device + ).contiguous() + self.tensors[binding] = t + ctx.set_tensor_address(binding, t.data_ptr()) + + @staticmethod + def check_shape(shape, profile): + shape = list(shape) + minlist = profile[0] + maxlist = profile[2] + good = True + for i, s in enumerate(shape): + if s < minlist[i] or s > maxlist[i]: + good = False + return good + + def set_inputs(self, feed_dict, stream): + e = self.engine + ctx = self.context + last_profile = self.cur_profile + + def try_set_inputs(): + for binding, t in feed_dict.items(): + if t is not None: + t = t.contiguous() + shape = t.shape + # mincurmax = list(e.get_profile_shape(self.cur_profile, binding)) + # if not self.check_shape(shape, mincurmax): + # raise ShapeException(f"Input shape to be set is outside the bounds: {binding} -> {shape}, profile is {mincurmax}, trying another profile: {self.cur_profile}") + ctx.set_input_shape(binding, shape) + ctx.set_tensor_address(binding, t.data_ptr()) + + while True: + try: + try_set_inputs() + break + except ShapeException: + next_profile = (self.cur_profile + 1) % e.num_optimization_profiles + if next_profile == last_profile: + raise + self.cur_profile = next_profile + ctx.set_optimization_profile_async(self.cur_profile, stream) + # torch.cuda.synchronize() + + left = ctx.infer_shapes() + assert len(left) == 0 + + def infer(self, stream, use_cuda_graph=False): + if use_cuda_graph: + if self.cuda_graph_instance is not None: + CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + CUASSERT(cudart.cudaStreamSynchronize(stream)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + # capture cuda graph + CUASSERT( + cudart.cudaStreamBeginCapture( + stream, + cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal, + ) + ) + self.context.execute_async_v3(stream) + graph = CUASSERT(cudart.cudaStreamEndCapture(stream)) + self.cuda_graph_instance = CUASSERT( + cudart.cudaGraphInstantiate(graph, 0) + ) + LOGGER.info("CUDA Graph captured!") + else: + noerror = self.context.execute_async_v3(stream) + CUASSERT(cudart.cudaStreamSynchronize(stream)) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +class TRTWrapper(torch.nn.Module): + """ + This wrapper implements TRT, ONNX and Torchscript persistent export + and running with fallback to Torch (for TRT modules with limited profiles) + + """ + + def __init__(self, + path, + model=None, + input_names=None, + output_names=None, + use_cuda_graph=False, + timestamp=None): + super().__init__() + self.input_names = input_names + self.output_names = output_names + self.model = model + self.profiles = None + self.engine = None + self.jit_model = None + self.onnx_runner = None + self.path = path + self.use_cuda_graph = use_cuda_graph + + if os.path.exists(self.onnx_path): + ftime=os.path.getmtime(self.onnx_path) + if timestamp is not None and ftime < timestamp: + os.remove(self.onnx_path) + else: + timestamp = ftime + if timestamp is not None and os.path.exists(self.engine_path) and os.path.getmtime(self.engine_path) < timestamp: + os.remove(self.engine_path) + + + @property + def engine_path(self): + return self.path + ".plan" + + @property + def jit_path(self): + return self.path + ".ts" + + @property + def onnx_path(self): + return self.path + ".onnx" + + @property + def profiles_path(self): + return self.path + ".profiles.pkl" + + def has_engine(self): + return self.engine is not None + + def has_onnx(self): + return os.path.exists(self.onnx_path) + + def has_jit(self): + return os.path.exists(self.jit_path) + + def has_profiles(self): + return os.path.exists(self.profiles_path) + + def load_engine(self): + try: + engine = Engine(self.engine_path) + engine.load() + engine.activate() + self.engine = engine + except Exception as e: + LOGGER.debug(f"Exception while loading the engine:\n{e}") + pass + + def load_jit(self): + try: + self.jit_model = torch.jit.load(self.jit_path) + except Exception: + pass + + def load_onnx(self, providers=["CUDAExecutionProvider"]): + try: + onnx_runner = OnnxrtRunner( + session_from_onnx(self.onnx_path, providers=providers) + ) + onnx_runner.activate() + self.onnx_runner = onnx_runner + except Exception: + pass + + def load_profiles(self): + with open(self.profiles_path, "rb") as fp: + profiles = pickle.load(fp) + self.profiles = profiles + return profiles + + def save_profiles(self): + with open(self.profiles_path, "wb") as fp: + pickle.dump(self.profiles, fp) + + def inputs_to_dict(self, input_example): + trt_inputs = {} + for i, inp in enumerate(input_example): + input_name = self.engine.input_names[i] + trt_inputs[input_name] = inp + return trt_inputs + + def forward(self, **args): + try: + if self.engine is not None: + # forward_trt is not thread safe as we do not use per-thread execution contexts + with lock_sm: + return self.forward_trt(args) + elif self.jit_model is not None: + return self.jit_model.forward(**args) + elif self.onnx_runner is not None: + ret = self.onnx_runner.infer(args) + ret = list(ret.values()) + ret = [r.cuda() for r in ret] + if len(ret) == 1: + ret = ret[0] + return ret + except Exception as e: + LOGGER.info(f"Exception: {e}\nFalling back to Pytorch ...") + + return self.model.forward(**args) + + def forward_trt(self, trt_inputs): + stream = torch.cuda.Stream(device=torch.cuda.current_device()) + self.engine.set_inputs(trt_inputs, stream.cuda_stream) + self.engine.allocate_buffers(torch.device("cuda")) + # Need this to synchronize with Torch stream + stream.wait_stream(torch.cuda.current_stream()) + ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) + ret = list(ret.values()) + + if len(ret) == 1: + ret = ret[0] + return ret + + def forward_trt_runner(self, trt_inputs): + with TrtRunner(self.engine) as runner: + ret = runner.infer(trt_inputs) + ret = list(ret.values()) + ret = [r.cuda() for r in ret] + if len(ret) == 1: + ret = ret[0] + return ret + + def build_engine( + self, + input_profiles=[], + fp16=False, + bf16=False, + tf32=False, + builder_optimization_level=3, + direct_io=False, + enable_all_tactics=True, + ): + profiles = [] + if len(input_profiles) > 0: + for input_profile in input_profiles: + if isinstance(input_profile, Profile): + profiles.append(input_profile) + else: + p = Profile() + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + profiles.append(p) + self.profiles = profiles + self.save_profiles() + + engine = Engine(self.path + ".plan") + engine.build( + self.onnx_path, + profiles, + fp16=fp16, + bf16=bf16, + tf32=tf32, + direct_io=direct_io, + builder_optimization_level=builder_optimization_level, + enable_all_tactics=enable_all_tactics, + ) + engine.activate() + self.engine = engine + + def jit_export( + self, + input_example, + verbose=False, + ): + self.jit_model = torch.jit.trace( + self.model, + input_example, + ).eval() + self.jit_model = torch.jit.freeze(self.jit_model) + torch.jit.save(self.jit_model, self.jit_path) + + def onnx_export( + self, + input_example, + dynamo=False, + onnx_registry=None, + dynamic_shapes=None, + verbose=False, + opset_version=18, + ): + L_.info(f"Exporting to ONNX, dynamic shapes: {dynamic_shapes}") + model = self.model + from .export_utils import replace_for_export + + replace_for_export(model, do_cast=True) + + if dynamo: + torch.onnx.export( + model, + input_example, + self.onnx_path, + dynamo=dynamo, + verbose=verbose, + opset_version=opset_version, + do_constant_folding=True, + input_names=self.input_names, + output_names=self.output_names, + dynamic_shapes=dynamic_shapes, + ) + else: + torch.onnx.export( + model, + input_example, + self.onnx_path, + verbose=verbose, + opset_version=opset_version, + do_constant_folding=True, + input_names=self.input_names, + output_names=self.output_names, + dynamic_axes=dynamic_shapes, + ) + L_.info("Folding constants...") + model_onnx = onnx_from_path(self.onnx_path) + fold_constants(model_onnx, allow_onnxruntime_shape_inference=False) + L_.info("Done folding constants.") + + L_.info("Saving model...") + save_onnx( + model_onnx, + self.onnx_path, + ) + L_.info("Done saving model.") + + def build_and_save( + self, + input_example, + dynamo=False, + verbose=False, + input_profiles=[], + fp16=False, + bf16=False, + tf32=True, + builder_optimization_level=3, + direct_io=False, + enable_all_tactics=True, + ): + if not self.has_engine(): + try: + if not self.has_onnx(): + self.onnx_export( + input_example, + dynamo=dynamo, + dynamic_shapes=get_dynamic_axes(input_profiles), + verbose=verbose, + ) + self.build_engine( + input_profiles=input_profiles, + fp16=fp16, tf32=tf32, + direct_io=direct_io, + builder_optimization_level=5, + enable_all_tactics=enable_all_tactics) + self.engine.save() + os.remove(self.onnx_path) + except Exception as e: + raise e + pass + + + From 3ab9c83a0acbd6109b4b1275855457dcccfee77b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 06:21:00 +0000 Subject: [PATCH 02/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/utils/cast_utils.py | 4 ++-- monai/utils/export_utils.py | 3 +-- monai/utils/trt_utils.py | 26 ++++++-------------------- 3 files changed, 9 insertions(+), 24 deletions(-) diff --git a/monai/utils/cast_utils.py b/monai/utils/cast_utils.py index 329033ed42..43caf7644a 100644 --- a/monai/utils/cast_utils.py +++ b/monai/utils/cast_utils.py @@ -78,7 +78,7 @@ def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): class CastToFloat(torch.nn.Module): def __init__(self, mod): - super(CastToFloat, self).__init__() + super().__init__() self.mod = mod def forward(self, x): @@ -89,7 +89,7 @@ def forward(self, x): class CastToFloatAll(torch.nn.Module): def __init__(self, mod): - super(CastToFloatAll, self).__init__() + super().__init__() self.mod = mod def forward(self, *args): diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py index d79cac3db0..d1c2c37aee 100644 --- a/monai/utils/export_utils.py +++ b/monai/utils/export_utils.py @@ -24,7 +24,6 @@ from typing import Callable, Dict, Optional, Type -import torch import torch.nn as nn import torch.nn.functional as F @@ -33,7 +32,7 @@ class LinearWithBiasSkip(nn.Module): def __init__(self, weight, bias, skip_bias_add): - super(LinearWithBiasSkip, self).__init__() + super().__init__() self.bias = bias self.weight = weight self.skip_bias_add = skip_bias_add diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index 3d3893c239..a9034a81d6 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -27,12 +27,8 @@ # from collections import OrderedDict -from typing import List -from copy import copy -import numpy as np import os import pickle -from PIL import Image from polygraphy.backend.common import bytes_from_path from polygraphy.backend.onnx import onnx_from_path, fold_constants, save_onnx from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx @@ -40,15 +36,10 @@ from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine from polygraphy.logger import G_LOGGER as L_ -import random -from scipy import integrate import tensorrt as trt import torch -import traceback -from io import BytesIO from cuda import cudart -from enum import Enum, auto import threading @@ -151,7 +142,7 @@ def build( save_timing_cache=timing_cache, ) self.engine = engine - + def save(self): save_engine(self.engine, path=self.engine_path) @@ -176,11 +167,11 @@ def activate(self, profile_num=0, reuse_device_memory=None): self.output_names.append(binding) dtype = trt_to_torch_dtype_dict[self.engine.get_tensor_dtype(binding)] self.dtypes.append(dtype) - self.cur_profile = profile_num + self.cur_profile = profile_num # L_.info(self.input_names) # L_.info(self.output_names) - - def allocate_buffers(self, device): + + def allocate_buffers(self, device): # allocate outputs ctx = self.context @@ -269,7 +260,7 @@ def infer(self, stream, use_cuda_graph=False): class TRTWrapper(torch.nn.Module): """ This wrapper implements TRT, ONNX and Torchscript persistent export - and running with fallback to Torch (for TRT modules with limited profiles) + and running with fallback to Torch (for TRT modules with limited profiles) """ @@ -290,7 +281,7 @@ def __init__(self, self.onnx_runner = None self.path = path self.use_cuda_graph = use_cuda_graph - + if os.path.exists(self.onnx_path): ftime=os.path.getmtime(self.onnx_path) if timestamp is not None and ftime < timestamp: @@ -337,7 +328,6 @@ def load_engine(self): self.engine = engine except Exception as e: LOGGER.debug(f"Exception while loading the engine:\n{e}") - pass def load_jit(self): try: @@ -548,7 +538,3 @@ def build_and_save( os.remove(self.onnx_path) except Exception as e: raise e - pass - - - From fe7103055adcfd0595a101c03f3e17ebfc471785 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 5 Aug 2024 16:18:36 -0700 Subject: [PATCH 03/85] Addressing code review comments, adding docustrings, cleanup Signed-off-by: Boris Fomitchev --- monai/utils/__init__.py | 2 +- monai/utils/cast_utils.py | 14 +++ monai/utils/export_utils.py | 24 +---- monai/utils/trt_utils.py | 177 +++++++++++++++++++++--------------- requirements-dev.txt | 1 + 5 files changed, 126 insertions(+), 92 deletions(-) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index b0919c2ce5..144d7681b7 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -153,4 +153,4 @@ get_torch_dtype_from_string, ) -from .trt_utils import TRTWrapper +TRTWrapper, TRT_AVAILABLE = optional_import('monai.utils.trt_utils', name='TRTWrapper') diff --git a/monai/utils/cast_utils.py b/monai/utils/cast_utils.py index 43caf7644a..89f5a62873 100644 --- a/monai/utils/cast_utils.py +++ b/monai/utils/cast_utils.py @@ -58,10 +58,16 @@ def avoid_float16_autocast_context(): def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): + """ + Utility function to cast a single tensor from from_dtype to to_dtype + """ return x.to(dtype=to_dtype) if x.dtype == from_dtype else x def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): + """ + Utility function to cast all tensors in a tuple from from_dtype to to_dtype + """ if isinstance(x, torch.Tensor): return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) else: @@ -77,6 +83,10 @@ def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): class CastToFloat(torch.nn.Module): + """ + Class used to add autocast protection for ONNX export + for forward methods with single return vaue + """ def __init__(self, mod): super().__init__() self.mod = mod @@ -88,6 +98,10 @@ def forward(self, x): class CastToFloatAll(torch.nn.Module): + """ + Class used to add autocast protection for ONNX export + for forward methods with multiple return values + """ def __init__(self, mod): super().__init__() self.mod = mod diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py index d1c2c37aee..6e5844a306 100644 --- a/monai/utils/export_utils.py +++ b/monai/utils/export_utils.py @@ -31,6 +31,10 @@ class LinearWithBiasSkip(nn.Module): + """ + Class used to replace Apex's RowParallelLinear and ColumnParallelLinear + for ONNX export + """ def __init__(self, weight, bias, skip_bias_add): super().__init__() self.bias = bias @@ -77,26 +81,6 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: mod.load_state_dict(n_state) return mod - def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]: - """ - Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export. - Args: - n: the FusedLayerNorm pytorch module to replace - Returns: - Equivalent LayerNorm module - """ - if not isinstance(n, RowParallelLinear): - raise ValueError( - "This function can only change the RowParallelLinear module." - ) - - dev = next(n.parameters()).device - mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev) - - n_state = n.state_dict() - mod.load_state_dict(n_state) - return mod - def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]: """ Replaces Apex's ColumnParallelLinear or RowParallelLinear with nn.Linear diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index a9034a81d6..6bbac6575c 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -32,9 +32,8 @@ from polygraphy.backend.common import bytes_from_path from polygraphy.backend.onnx import onnx_from_path, fold_constants, save_onnx from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx -from polygraphy.backend.trt import TrtRunner, CreateConfig, ModifyNetworkOutputs, Profile +from polygraphy.backend.trt import CreateConfig, ModifyNetworkOutputs, Profile from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine -from polygraphy.logger import G_LOGGER as L_ import tensorrt as trt import torch @@ -48,7 +47,8 @@ lock_sm = threading.Lock() -# Map of torch dtype -> numpy dtype + +# Map of TRT dtype -> Torch dtype trt_to_torch_dtype_dict = { trt.int32: torch.int32, trt.float32: torch.float32, @@ -60,6 +60,10 @@ } def get_dynamic_axes(profiles, extra_axes={}): + """ + Given [[min,opt,max],...] list of profile dimensions, + this method calculates dynamic_axes to use in onnx.export() + """ dynamic_axes=extra_axes for profile in profiles: for key in profile: @@ -73,6 +77,9 @@ def get_dynamic_axes(profiles, extra_axes={}): return dynamic_axes def CUASSERT(cuda_ret): + """ + Error reporting method for CUDA calls + """ err = cuda_ret[0] if err != 0: raise RuntimeError(f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t") @@ -80,7 +87,11 @@ def CUASSERT(cuda_ret): return cuda_ret[1] return None + class ShapeException(Exception): + """ + Exception class to report errors from setting TRT plan input shapes + """ pass @@ -100,46 +111,36 @@ def __init__( self.cuda_graph_instance = None # cuda graph def build( - self, - onnx_path, - profiles=[], - fp16=False, - bf16=False, - tf32=True, - builder_optimization_level=3, - enable_all_tactics=True, - direct_io=False, - timing_cache=None, - update_output_names=None, + self, + onnx_path, + profiles=[], + update_output_names=False, + **config_kwargs ): - L_.info(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") - config_kwargs = { - "builder_optimization_level": builder_optimization_level, - "direct_io": direct_io, - } - if not enable_all_tactics: - config_kwargs["tactic_sources"] = [] + """ + Builds TRT engine from ONNX file at onnx_path and sets self.engine + Args: + update_output_names: if set, use update_output_names as output names + profiles, config_kwargs: passed to TRT's engine_from_network() + """ + + LOGGER.info(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") network = network_from_onnx_path( onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM] ) if update_output_names: - L_.info(f"Updating network outputs to {update_output_names}") + LOGGER.info(f"Updating network outputs to {update_output_names}") network = ModifyNetworkOutputs(network, update_output_names) - # with L.verbosity(0): - L_.info("Calling engine_from_network...") + + LOGGER.info("Calling engine_from_network...") engine = engine_from_network( network, config=CreateConfig( - fp16=fp16, - bf16=bf16, - tf32=tf32, profiles=profiles, - load_timing_cache=timing_cache, **config_kwargs, ), - save_timing_cache=timing_cache, ) self.engine = engine @@ -147,10 +148,13 @@ def save(self): save_engine(self.engine, path=self.engine_path) def load(self): - L_.info(f"Loading TensorRT engine: {self.engine_path}") + LOGGER.info(f"Loading TensorRT engine: {self.engine_path}") self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) def activate(self, profile_num=0, reuse_device_memory=None): + """ + Creates execution context for self.engine and activates it + """ if reuse_device_memory: self.context = self.engine.create_execution_context_without_device_memory() self.context.device_memory = reuse_device_memory @@ -168,11 +172,11 @@ def activate(self, profile_num=0, reuse_device_memory=None): dtype = trt_to_torch_dtype_dict[self.engine.get_tensor_dtype(binding)] self.dtypes.append(dtype) self.cur_profile = profile_num - # L_.info(self.input_names) - # L_.info(self.output_names) def allocate_buffers(self, device): - # allocate outputs + """ + Allocates outputs to run TRT engine + """ ctx = self.context for i, binding in enumerate(self.output_names): @@ -195,6 +199,9 @@ def check_shape(shape, profile): return good def set_inputs(self, feed_dict, stream): + """ + Sets input bindings for TRT engine according to feed_dict + """ e = self.engine ctx = self.context last_profile = self.cur_profile @@ -204,6 +211,7 @@ def try_set_inputs(): if t is not None: t = t.contiguous() shape = t.shape + # TODO: port to new TRT10 API # mincurmax = list(e.get_profile_shape(self.cur_profile, binding)) # if not self.check_shape(shape, mincurmax): # raise ShapeException(f"Input shape to be set is outside the bounds: {binding} -> {shape}, profile is {mincurmax}, trying another profile: {self.cur_profile}") @@ -220,12 +228,15 @@ def try_set_inputs(): raise self.cur_profile = next_profile ctx.set_optimization_profile_async(self.cur_profile, stream) - # torch.cuda.synchronize() left = ctx.infer_shapes() assert len(left) == 0 def infer(self, stream, use_cuda_graph=False): + """ + Runs TRT engine. + Note use_cuda_graph requires all inputs to be the same GPU memory between calls. + """ if use_cuda_graph: if self.cuda_graph_instance is not None: CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) @@ -291,7 +302,9 @@ def __init__(self, if timestamp is not None and os.path.exists(self.engine_path) and os.path.getmtime(self.engine_path) < timestamp: os.remove(self.engine_path) - + """ + Auxiliary getters/setters + """ @property def engine_path(self): return self.path + ".plan" @@ -321,6 +334,9 @@ def has_profiles(self): return os.path.exists(self.profiles_path) def load_engine(self): + """ + Loads TRT plan from disk and activates its execution context. + """ try: engine = Engine(self.engine_path) engine.load() @@ -330,12 +346,18 @@ def load_engine(self): LOGGER.debug(f"Exception while loading the engine:\n{e}") def load_jit(self): + """ + Loads Torchscript from disk + """ try: self.jit_model = torch.jit.load(self.jit_path) except Exception: pass def load_onnx(self, providers=["CUDAExecutionProvider"]): + """ + Loads ONNX from disk and creates/activates OnnxrtRunner runner for it. + """ try: onnx_runner = OnnxrtRunner( session_from_onnx(self.onnx_path, providers=providers) @@ -346,16 +368,25 @@ def load_onnx(self, providers=["CUDAExecutionProvider"]): pass def load_profiles(self): + """ + Loads saved optimization profiles from disk + """ with open(self.profiles_path, "rb") as fp: profiles = pickle.load(fp) self.profiles = profiles return profiles def save_profiles(self): + """ + Saves optimization profiles to disk using pickle + """ with open(self.profiles_path, "wb") as fp: pickle.dump(self.profiles, fp) def inputs_to_dict(self, input_example): + """ + Converts list of inputs indo a dict usable with TRT engine + """ trt_inputs = {} for i, inp in enumerate(input_example): input_name = self.engine.input_names[i] @@ -363,6 +394,10 @@ def inputs_to_dict(self, input_example): return trt_inputs def forward(self, **args): + """ + Main forward method: depending on TRT/Torchscript/ONNX representation available, + runs appropriate accelerated method. If exception thrown, falls back to original Pytorch + """ try: if self.engine is not None: # forward_trt is not thread safe as we do not use per-thread execution contexts @@ -383,6 +418,10 @@ def forward(self, **args): return self.model.forward(**args) def forward_trt(self, trt_inputs): + """ + Auxiliary method to run TRT engine. + Sets input bindings from trt_inputs, allocates memory and runs activated TRT engine + """ stream = torch.cuda.Stream(device=torch.cuda.current_device()) self.engine.set_inputs(trt_inputs, stream.cuda_stream) self.engine.allocate_buffers(torch.device("cuda")) @@ -395,25 +434,16 @@ def forward_trt(self, trt_inputs): ret = ret[0] return ret - def forward_trt_runner(self, trt_inputs): - with TrtRunner(self.engine) as runner: - ret = runner.infer(trt_inputs) - ret = list(ret.values()) - ret = [r.cuda() for r in ret] - if len(ret) == 1: - ret = ret[0] - return ret - def build_engine( self, input_profiles=[], - fp16=False, - bf16=False, - tf32=False, - builder_optimization_level=3, - direct_io=False, - enable_all_tactics=True, + **build_args ): + """ + Builds TRT engine from ONNX file at self.onnx_path and sets self.engine + Args: + input_profiles, build_args - passed to engine.build() + """ profiles = [] if len(input_profiles) > 0: for input_profile in input_profiles: @@ -432,12 +462,7 @@ def build_engine( engine.build( self.onnx_path, profiles, - fp16=fp16, - bf16=bf16, - tf32=tf32, - direct_io=direct_io, - builder_optimization_level=builder_optimization_level, - enable_all_tactics=enable_all_tactics, + **build_args ) engine.activate() self.engine = engine @@ -447,6 +472,9 @@ def jit_export( input_example, verbose=False, ): + """ + Exports self.model to Torchscript at self.jit_path and sets self.jit_model + """ self.jit_model = torch.jit.trace( self.model, input_example, @@ -463,7 +491,12 @@ def onnx_export( verbose=False, opset_version=18, ): - L_.info(f"Exporting to ONNX, dynamic shapes: {dynamic_shapes}") + """ + Exports self.model to ONNX file at self.onnx_path + Args: passed to onnx.export() + """ + + LOGGER.info(f"Exporting to ONNX, dynamic shapes: {dynamic_shapes}") model = self.model from .export_utils import replace_for_export @@ -494,17 +527,17 @@ def onnx_export( output_names=self.output_names, dynamic_axes=dynamic_shapes, ) - L_.info("Folding constants...") + LOGGER.info("Folding constants...") model_onnx = onnx_from_path(self.onnx_path) fold_constants(model_onnx, allow_onnxruntime_shape_inference=False) - L_.info("Done folding constants.") + LOGGER.info("Done folding constants.") - L_.info("Saving model...") + LOGGER.info("Saving model...") save_onnx( model_onnx, self.onnx_path, ) - L_.info("Done saving model.") + LOGGER.info("Done saving model.") def build_and_save( self, @@ -512,13 +545,18 @@ def build_and_save( dynamo=False, verbose=False, input_profiles=[], - fp16=False, - bf16=False, - tf32=True, - builder_optimization_level=3, - direct_io=False, - enable_all_tactics=True, + **build_args ): + """ + If serialized engine is not found, exports self.model to ONNX, + builds TRT engine and saves serialized TRT engine to the disk. + Args: + input_example, dynamo, verbose: passed to self.onnx_export() + input_profiles: used to get dynamic axes for onnx_export(), + passed to self.build_engine() + build_args : passed to self.build_engine() + enable_all_tactics=True, + """ if not self.has_engine(): try: if not self.has_onnx(): @@ -530,10 +568,7 @@ def build_and_save( ) self.build_engine( input_profiles=input_profiles, - fp16=fp16, tf32=tf32, - direct_io=direct_io, - builder_optimization_level=5, - enable_all_tactics=enable_all_tactics) + **build_args) self.engine.save() os.remove(self.onnx_path) except Exception as e: diff --git a/requirements-dev.txt b/requirements-dev.txt index 72ba210093..3a61fbeef2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -59,3 +59,4 @@ nvidia-ml-py huggingface_hub pyamg>=5.0.0 git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd +polygraphy From 6a9727fcc192128e3540c12100865cc92c72699a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 23:19:22 +0000 Subject: [PATCH 04/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/utils/trt_utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index 6bbac6575c..ffffb76b08 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -61,7 +61,7 @@ def get_dynamic_axes(profiles, extra_axes={}): """ - Given [[min,opt,max],...] list of profile dimensions, + Given [[min,opt,max],...] list of profile dimensions, this method calculates dynamic_axes to use in onnx.export() """ dynamic_axes=extra_axes @@ -120,7 +120,7 @@ def build( """ Builds TRT engine from ONNX file at onnx_path and sets self.engine Args: - update_output_names: if set, use update_output_names as output names + update_output_names: if set, use update_output_names as output names profiles, config_kwargs: passed to TRT's engine_from_network() """ @@ -234,8 +234,8 @@ def try_set_inputs(): def infer(self, stream, use_cuda_graph=False): """ - Runs TRT engine. - Note use_cuda_graph requires all inputs to be the same GPU memory between calls. + Runs TRT engine. + Note use_cuda_graph requires all inputs to be the same GPU memory between calls. """ if use_cuda_graph: if self.cuda_graph_instance is not None: @@ -396,7 +396,7 @@ def inputs_to_dict(self, input_example): def forward(self, **args): """ Main forward method: depending on TRT/Torchscript/ONNX representation available, - runs appropriate accelerated method. If exception thrown, falls back to original Pytorch + runs appropriate accelerated method. If exception thrown, falls back to original Pytorch """ try: if self.engine is not None: @@ -420,7 +420,7 @@ def forward(self, **args): def forward_trt(self, trt_inputs): """ Auxiliary method to run TRT engine. - Sets input bindings from trt_inputs, allocates memory and runs activated TRT engine + Sets input bindings from trt_inputs, allocates memory and runs activated TRT engine """ stream = torch.cuda.Stream(device=torch.cuda.current_device()) self.engine.set_inputs(trt_inputs, stream.cuda_stream) @@ -493,7 +493,7 @@ def onnx_export( ): """ Exports self.model to ONNX file at self.onnx_path - Args: passed to onnx.export() + Args: passed to onnx.export() """ LOGGER.info(f"Exporting to ONNX, dynamic shapes: {dynamic_shapes}") @@ -552,9 +552,9 @@ def build_and_save( builds TRT engine and saves serialized TRT engine to the disk. Args: input_example, dynamo, verbose: passed to self.onnx_export() - input_profiles: used to get dynamic axes for onnx_export(), + input_profiles: used to get dynamic axes for onnx_export(), passed to self.build_engine() - build_args : passed to self.build_engine() + build_args : passed to self.build_engine() enable_all_tactics=True, """ if not self.has_engine(): From 29d9725f61d83d1e62478bf1fd4823950d5534a4 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 5 Aug 2024 23:37:55 -0700 Subject: [PATCH 05/85] Added TRT 10.3RC to Dockerfile Signed-off-by: Boris Fomitchev --- Dockerfile | 10 ++++++++++ requirements-dev.txt | 2 ++ 2 files changed, 12 insertions(+) diff --git a/Dockerfile b/Dockerfile index 8e255597d1..2580d80681 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,6 +24,16 @@ RUN if [[ $(uname -m) =~ "aarch64" ]]; then \ pip install numcodecs; \ fi +ARG TRT_URL=http://cuda-repo.nvidia.com/release-candidates/Libraries/TensorRT/v10.3/10.3.0.25-4abf3f29/12.5-r555/Ubuntu22_04-x64-manylinux_2_17/deb/ + +RUN rm -fr /tmp/trt && mkdir -p /tmp/trt && cd /tmp/trt && \ + curl ${TRT_URL} -o index.html && \ + for package in $(grep -o '[^ >"]*\.deb' index.html | uniq); do wget -nv ${TRT_URL}${package} & done && wait \ + && rm -f *-dev_* tensorrt_10* *-samples* \ + && dpkg -i *.deb \ + && apt-get --fix-broken install -y \ + && rm -rf index.html *.deb + WORKDIR /opt/monai # install full deps diff --git a/requirements-dev.txt b/requirements-dev.txt index 3a61fbeef2..84e91b52dd 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -60,3 +60,5 @@ huggingface_hub pyamg>=5.0.0 git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd polygraphy +cuda-python + From 5b8b4f294440c26c30ce485a8ffeb0230a3b51ab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 06:38:28 +0000 Subject: [PATCH 06/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- requirements-dev.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 84e91b52dd..8cce6ac16b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -61,4 +61,3 @@ pyamg>=5.0.0 git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd polygraphy cuda-python - From f31d6dd88d1a0121ce3aa62947111e4d4bd82583 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 5 Aug 2024 23:49:56 -0700 Subject: [PATCH 07/85] Workaround for format check Signed-off-by: Boris Fomitchev --- monai/utils/trt_utils.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index ffffb76b08..6a0258d8ed 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -29,13 +29,19 @@ from collections import OrderedDict import os import pickle -from polygraphy.backend.common import bytes_from_path -from polygraphy.backend.onnx import onnx_from_path, fold_constants, save_onnx -from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx -from polygraphy.backend.trt import CreateConfig, ModifyNetworkOutputs, Profile -from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine -import tensorrt as trt +# To keep CI and format check happy +try: + from polygraphy.backend.common import bytes_from_path + from polygraphy.backend.onnx import onnx_from_path, fold_constants, save_onnx + from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx + from polygraphy.backend.trt import CreateConfig, ModifyNetworkOutputs, Profile + from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine + + import tensorrt as trt +except Exception: + pass + import torch from cuda import cudart From 9303c32bad1e67e333465fee38d2468eb5f71681 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 06:50:26 +0000 Subject: [PATCH 08/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/utils/trt_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index 6a0258d8ed..ebd4147054 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -37,7 +37,7 @@ from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx from polygraphy.backend.trt import CreateConfig, ModifyNetworkOutputs, Profile from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine - + import tensorrt as trt except Exception: pass From c1d0b19fc7801e079fc6648e441bf1dfbda2e389 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 5 Aug 2024 23:56:40 -0700 Subject: [PATCH 09/85] More format check workarounds Signed-off-by: Boris Fomitchev --- monai/utils/trt_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index ebd4147054..511f41512a 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -39,12 +39,12 @@ from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine import tensorrt as trt + from cuda import cudart except Exception: pass import torch -from cuda import cudart import threading From 63c4b70e7b375162163283b543fa7f6b99cef8b7 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 6 Aug 2024 00:00:35 -0700 Subject: [PATCH 10/85] More format check workarounds Signed-off-by: Boris Fomitchev --- monai/utils/trt_utils.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index 511f41512a..df9df91b00 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -40,6 +40,17 @@ import tensorrt as trt from cuda import cudart + + # Map of TRT dtype -> Torch dtype + trt_to_torch_dtype_dict = { + trt.int32: torch.int32, + trt.float32: torch.float32, + trt.float16: torch.float16, + trt.bfloat16: torch.float16, + trt.int64: torch.int64, + trt.int8: torch.int8, + trt.bool: torch.bool, + } except Exception: pass @@ -54,16 +65,6 @@ lock_sm = threading.Lock() -# Map of TRT dtype -> Torch dtype -trt_to_torch_dtype_dict = { - trt.int32: torch.int32, - trt.float32: torch.float32, - trt.float16: torch.float16, - trt.bfloat16: torch.float16, - trt.int64: torch.int64, - trt.int8: torch.int8, - trt.bool: torch.bool, -} def get_dynamic_axes(profiles, extra_axes={}): """ From 9a3d6a6519aac0b143909525d3f678add884eb8b Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 6 Aug 2024 00:03:37 -0700 Subject: [PATCH 11/85] More format check workarounds Signed-off-by: Boris Fomitchev --- monai/utils/trt_utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index df9df91b00..2ac0efafb4 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -29,6 +29,8 @@ from collections import OrderedDict import os import pickle +import torch +import threading # To keep CI and format check happy try: @@ -54,11 +56,6 @@ except Exception: pass -import torch - - -import threading - from monai.apps.utils import get_logger LOGGER=get_logger("run_cmd") From 8bf03006622b5764cf888e9dff2527a45876f8cf Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 6 Aug 2024 00:34:26 -0700 Subject: [PATCH 12/85] Using optional exports for trt_utils Signed-off-by: Boris Fomitchev --- monai/utils/export_utils.py | 122 +----------------------------------- monai/utils/trt_utils.py | 17 ++--- 2 files changed, 10 insertions(+), 129 deletions(-) diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py index 6e5844a306..b605028a48 100644 --- a/monai/utils/export_utils.py +++ b/monai/utils/export_utils.py @@ -29,121 +29,6 @@ from .cast_utils import CastToFloat - -class LinearWithBiasSkip(nn.Module): - """ - Class used to replace Apex's RowParallelLinear and ColumnParallelLinear - for ONNX export - """ - def __init__(self, weight, bias, skip_bias_add): - super().__init__() - self.bias = bias - self.weight = weight - self.skip_bias_add = skip_bias_add - - def forward(self, x): - if self.skip_bias_add: - return F.linear(x, self.weight), self.bias - return F.linear(x, self.weight, self.bias), None - -apex_available = True - -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm - from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax - from apex.transformer.tensor_parallel.layers import ( - ColumnParallelLinear, - RowParallelLinear, - ) - - def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: - """ - Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export. - Args: - n: the FusedLayerNorm pytorch module to replace - Returns: - Equivalent LayerNorm module - """ - - p = next(n.parameters()) - if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm): - shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine - elif isinstance(n, FastLayerNorm): - shape, eps, affine = n.weight.shape, n.epsilon, True - else: - return None - - mod = nn.LayerNorm( - shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype - ) - n_state = n.state_dict() - mod.load_state_dict(n_state) - return mod - - def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]: - """ - Replaces Apex's ColumnParallelLinear or RowParallelLinear with nn.Linear - Args: - n: the nn.Module pytorch module to replace - Returns: - Equivalent Linear module - """ - if not ( - isinstance(n, ColumnParallelLinear) or isinstance(n, RowParallelLinear) - ): - raise ValueError( - "This function can only change the ColumnParallelLinear or RowParallelLinear module." - ) - - dev = next(n.parameters()).device - mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev) - - n_state = n.state_dict() - mod.load_state_dict(n_state) - return mod - - def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: - """ - Replaces Apex's FusedScaleMaskSoftmax with nn.LayerNorm. This is required for ONNX export. - Args: - n: the FusedScaleMaskSoftmax module to replace - Returns: - Equivalent LayerNorm module - """ - if not isinstance(n, FusedScaleMaskSoftmax): - raise ValueError( - "This function can only change the FusedScaleMaskSoftmax module." - ) - - # disable the fusion only - mod = FusedScaleMaskSoftmax( - n.input_in_fp16, - n.input_in_bf16, - n.attn_mask_type, - False, - n.mask_func, - n.softmax_in_fp32, - n.scale, - ) - - return mod - - default_Apex_replacements = { - "FusedLayerNorm": replace_FusedLayerNorm, - "MixedFusedLayerNorm": replace_FusedLayerNorm, - "FastLayerNorm": replace_FusedLayerNorm, - "ESM1bLayerNorm": replace_FusedLayerNorm, - "RowParallelLinear": replace_ParallelLinear, - "ColumnParallelLinear": replace_ParallelLinear, - "FusedScaleMaskSoftmax": replace_FusedScaleMaskSoftmax, - } - -except Exception: - default_Apex_replacements = {} - apex_available = False - - def simple_replace( BaseT: Type[nn.Module], DestT: Type[nn.Module] ) -> Callable[[nn.Module], Optional[nn.Module]]: @@ -255,20 +140,15 @@ def replace_modules( swap_modules(model, mapping) return model -def replace_for_export(model: nn.Module, do_cast: bool = False) -> nn.Module: +def replace_for_export(model: nn.Module, do_cast: bool = True) -> nn.Module: """ Top-level function to replace default set of modules in model NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. Args: model : top level module - replace_1D_2D : include 1D -> 2D replacements Returns: model, possibly modified in-place """ - if apex_available: - print("Replacing Apex layers ...") - replace_modules(model, default_Apex_replacements) - if do_cast: print("Adding casts around norms...") cast_replacements = { diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index 2ac0efafb4..b3b4ef2008 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -31,20 +31,22 @@ import pickle import torch import threading +from .module import optional_import -# To keep CI and format check happy -try: +P, P_imported = optional_import("polygraphy") +if P_imported: from polygraphy.backend.common import bytes_from_path from polygraphy.backend.onnx import onnx_from_path, fold_constants, save_onnx from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx from polygraphy.backend.trt import CreateConfig, ModifyNetworkOutputs, Profile from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine - import tensorrt as trt - from cuda import cudart +trt, trt_imported = optional_import("tensorrt") +cudart, _ = optional_import("cuda" , name="cudart") # Map of TRT dtype -> Torch dtype - trt_to_torch_dtype_dict = { +def trt_to_torch_dtype_dict(): + return { trt.int32: torch.int32, trt.float32: torch.float32, trt.float16: torch.float16, @@ -53,8 +55,6 @@ trt.int8: torch.int8, trt.bool: torch.bool, } -except Exception: - pass from monai.apps.utils import get_logger LOGGER=get_logger("run_cmd") @@ -167,13 +167,14 @@ def activate(self, profile_num=0, reuse_device_memory=None): self.input_names = [] self.output_names = [] self.dtypes = [] + dtype_dict = trt_to_torch_dtype_dict() for idx in range(self.engine.num_io_tensors): binding = self.engine[idx] if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT: self.input_names.append(binding) elif self.engine.get_tensor_mode(binding) == trt.TensorIOMode.OUTPUT: self.output_names.append(binding) - dtype = trt_to_torch_dtype_dict[self.engine.get_tensor_dtype(binding)] + dtype = dtype_dict[self.engine.get_tensor_dtype(binding)] self.dtypes.append(dtype) self.cur_profile = profile_num From c03e49b587c3b9300de2514ac85da3ebe12e2adf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 07:35:07 +0000 Subject: [PATCH 13/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/utils/export_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py index b605028a48..31702db314 100644 --- a/monai/utils/export_utils.py +++ b/monai/utils/export_utils.py @@ -25,7 +25,6 @@ from typing import Callable, Dict, Optional, Type import torch.nn as nn -import torch.nn.functional as F from .cast_utils import CastToFloat From 39c94c2ecd51c36392e8054b360fb495c77396eb Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 6 Aug 2024 01:04:09 -0700 Subject: [PATCH 14/85] Fixing lint errors Signed-off-by: Boris Fomitchev --- monai/utils/export_utils.py | 31 +++------------------- monai/utils/trt_utils.py | 52 ++++++++++++++----------------------- 2 files changed, 22 insertions(+), 61 deletions(-) diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py index b605028a48..5a3c2a326b 100644 --- a/monai/utils/export_utils.py +++ b/monai/utils/export_utils.py @@ -51,32 +51,6 @@ def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: return expansion_fn -def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: - """ - Replaces MatchedScaleMaskSoftmax with exportable softmax layer - Args: - n: module to replace - Returns: - exportable module - """ - # including the import here to avoid circular imports - from nemo.collections.nlp.modules.common.megatron.fused_softmax import ( - MatchedScaleMaskSoftmax, - ) - - # disabling fusion for the MatchedScaleMaskSoftmax - mod = MatchedScaleMaskSoftmax( - n.input_in_fp16, - n.input_in_bf16, - n.attn_mask_type, - False, - n.mask_func, - n.softmax_in_fp32, - n.scale, - ) - return mod - - def wrap_module( BaseT: Type[nn.Module], DestT: Type[nn.Module] ) -> Callable[[nn.Module], Optional[nn.Module]]: @@ -96,7 +70,7 @@ def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: return expansion_fn -def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]): +def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]) -> nn.Module: """ This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows for swapping nested modules through arbitrary levels if children @@ -116,7 +90,7 @@ def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]): def replace_modules( model: nn.Module, - expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]] = None, + expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]], ) -> nn.Module: """ Top-level function to replace modules in model, specified by class name with a desired replacement. @@ -160,3 +134,4 @@ def replace_for_export(model: nn.Module, do_cast: bool = True) -> nn.Module: "InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat), } replace_modules(model, cast_replacements) + return model diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index b3b4ef2008..076ac2f3b5 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -42,9 +42,14 @@ from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine trt, trt_imported = optional_import("tensorrt") -cudart, _ = optional_import("cuda" , name="cudart") +cudart, _ = optional_import("cuda.cudart") - # Map of TRT dtype -> Torch dtype +from monai.apps.utils import get_logger +LOGGER=get_logger("run_cmd") + +lock_sm = threading.Lock() + +# Map of TRT dtype -> Torch dtype def trt_to_torch_dtype_dict(): return { trt.int32: torch.int32, @@ -56,12 +61,6 @@ def trt_to_torch_dtype_dict(): trt.bool: torch.bool, } -from monai.apps.utils import get_logger -LOGGER=get_logger("run_cmd") - -lock_sm = threading.Lock() - - def get_dynamic_axes(profiles, extra_axes={}): """ @@ -507,31 +506,18 @@ def onnx_export( replace_for_export(model, do_cast=True) - if dynamo: - torch.onnx.export( - model, - input_example, - self.onnx_path, - dynamo=dynamo, - verbose=verbose, - opset_version=opset_version, - do_constant_folding=True, - input_names=self.input_names, - output_names=self.output_names, - dynamic_shapes=dynamic_shapes, - ) - else: - torch.onnx.export( - model, - input_example, - self.onnx_path, - verbose=verbose, - opset_version=opset_version, - do_constant_folding=True, - input_names=self.input_names, - output_names=self.output_names, - dynamic_axes=dynamic_shapes, - ) + torch.onnx.export( + model, + input_example, + self.onnx_path, + dynamo=dynamo, + verbose=verbose, + opset_version=opset_version, + do_constant_folding=True, + input_names=self.input_names, + output_names=self.output_names, + dynamic_axes=dynamic_shapes, + ) LOGGER.info("Folding constants...") model_onnx = onnx_from_path(self.onnx_path) fold_constants(model_onnx, allow_onnxruntime_shape_inference=False) From 9d867a79de0b2b9dd44ae0dc2c90025f9ee5678a Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 6 Aug 2024 01:29:03 -0700 Subject: [PATCH 15/85] Format fixed Signed-off-by: Boris Fomitchev --- monai/utils/__init__.py | 2 +- monai/utils/cast_utils.py | 12 +-- monai/utils/export_utils.py | 17 ++-- monai/utils/trt_utils.py | 160 +++++++++++++----------------------- 4 files changed, 72 insertions(+), 119 deletions(-) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 144d7681b7..03745dd1bc 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -153,4 +153,4 @@ get_torch_dtype_from_string, ) -TRTWrapper, TRT_AVAILABLE = optional_import('monai.utils.trt_utils', name='TRTWrapper') +TRTWrapper, TRT_AVAILABLE = optional_import("monai.utils.trt_utils", name="TRTWrapper") diff --git a/monai/utils/cast_utils.py b/monai/utils/cast_utils.py index 89f5a62873..8f529f0f87 100644 --- a/monai/utils/cast_utils.py +++ b/monai/utils/cast_utils.py @@ -22,6 +22,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from contextlib import nullcontext import torch @@ -77,9 +79,7 @@ def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) return new_dict elif isinstance(x, tuple): - return tuple( - cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x - ) + return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) class CastToFloat(torch.nn.Module): @@ -87,6 +87,7 @@ class CastToFloat(torch.nn.Module): Class used to add autocast protection for ONNX export for forward methods with single return vaue """ + def __init__(self, mod): super().__init__() self.mod = mod @@ -102,6 +103,7 @@ class CastToFloatAll(torch.nn.Module): Class used to add autocast protection for ONNX export for forward methods with multiple return values """ + def __init__(self, mod): super().__init__() self.mod = mod @@ -109,7 +111,5 @@ def __init__(self, mod): def forward(self, *args): from_dtype = args[0].dtype with torch.cuda.amp.autocast(enabled=False): - ret = self.mod.forward( - *cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32) - ) + ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py index dc82c75eef..97fdedfc0a 100644 --- a/monai/utils/export_utils.py +++ b/monai/utils/export_utils.py @@ -22,15 +22,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import Callable, Dict, Optional, Type import torch.nn as nn from .cast_utils import CastToFloat -def simple_replace( - BaseT: Type[nn.Module], DestT: Type[nn.Module] -) -> Callable[[nn.Module], Optional[nn.Module]]: + +def simple_replace(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: """ Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same atrributes. No weights are copied. Args: @@ -50,9 +51,7 @@ def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: return expansion_fn -def wrap_module( - BaseT: Type[nn.Module], DestT: Type[nn.Module] -) -> Callable[[nn.Module], Optional[nn.Module]]: +def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: """ Generic function generator to replace BaseT module with DestT wrapper. Args: @@ -87,10 +86,7 @@ def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]) -> nn.Module: return model -def replace_modules( - model: nn.Module, - expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]], -) -> nn.Module: +def replace_modules(model: nn.Module, expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]]) -> nn.Module: """ Top-level function to replace modules in model, specified by class name with a desired replacement. NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. @@ -113,6 +109,7 @@ def replace_modules( swap_modules(model, mapping) return model + def replace_for_export(model: nn.Module, do_cast: bool = True) -> nn.Module: """ Top-level function to replace default set of modules in model diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index 076ac2f3b5..26e4cbf3cb 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -26,29 +26,43 @@ # limitations under the License. # -from collections import OrderedDict +from __future__ import annotations + import os import pickle -import torch import threading +from collections import OrderedDict + +import torch + +from monai.apps.utils import get_logger + from .module import optional_import P, P_imported = optional_import("polygraphy") if P_imported: from polygraphy.backend.common import bytes_from_path - from polygraphy.backend.onnx import onnx_from_path, fold_constants, save_onnx + from polygraphy.backend.onnx import fold_constants, onnx_from_path, save_onnx from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx - from polygraphy.backend.trt import CreateConfig, ModifyNetworkOutputs, Profile - from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine + from polygraphy.backend.trt import ( + CreateConfig, + ModifyNetworkOutputs, + Profile, + engine_from_bytes, + engine_from_network, + network_from_onnx_path, + save_engine, + ) trt, trt_imported = optional_import("tensorrt") -cudart, _ = optional_import("cuda.cudart") +cudart, _ = optional_import("cuda.cudart") -from monai.apps.utils import get_logger -LOGGER=get_logger("run_cmd") + +LOGGER = get_logger("run_cmd") lock_sm = threading.Lock() + # Map of TRT dtype -> Torch dtype def trt_to_torch_dtype_dict(): return { @@ -67,11 +81,11 @@ def get_dynamic_axes(profiles, extra_axes={}): Given [[min,opt,max],...] list of profile dimensions, this method calculates dynamic_axes to use in onnx.export() """ - dynamic_axes=extra_axes + dynamic_axes = extra_axes for profile in profiles: for key in profile: - axes=[] - vals=profile[key] + axes = [] + vals = profile[key] for i in range(len(vals[0])): if vals[0][i] != vals[2][i]: axes.append(i) @@ -79,13 +93,16 @@ def get_dynamic_axes(profiles, extra_axes={}): dynamic_axes[key] = axes return dynamic_axes + def CUASSERT(cuda_ret): """ Error reporting method for CUDA calls """ err = cuda_ret[0] if err != 0: - raise RuntimeError(f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t") + raise RuntimeError( + f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" + ) if len(cuda_ret) > 1: return cuda_ret[1] return None @@ -95,6 +112,7 @@ class ShapeException(Exception): """ Exception class to report errors from setting TRT plan input shapes """ + pass @@ -103,23 +121,15 @@ class Engine: An auxiliary class to implement running of TRT optimized engines """ - def __init__( - self, - engine_path, - ): + + def __init__(self, engine_path): self.engine_path = engine_path self.engine = None self.context = None self.tensors = OrderedDict() - self.cuda_graph_instance = None # cuda graph - - def build( - self, - onnx_path, - profiles=[], - update_output_names=False, - **config_kwargs - ): + self.cuda_graph_instance = None # cuda graph + + def build(self, onnx_path, profiles=[], update_output_names=False, **config_kwargs): """ Builds TRT engine from ONNX file at onnx_path and sets self.engine Args: @@ -129,22 +139,14 @@ def build( LOGGER.info(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") - network = network_from_onnx_path( - onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM] - ) + network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) if update_output_names: LOGGER.info(f"Updating network outputs to {update_output_names}") network = ModifyNetworkOutputs(network, update_output_names) LOGGER.info("Calling engine_from_network...") - engine = engine_from_network( - network, - config=CreateConfig( - profiles=profiles, - **config_kwargs, - ), - ) + engine = engine_from_network(network, config=CreateConfig(profiles=profiles, **config_kwargs)) self.engine = engine def save(self): @@ -185,9 +187,7 @@ def allocate_buffers(self, device): for i, binding in enumerate(self.output_names): shape = ctx.get_tensor_shape(binding) - t = torch.empty( - list(shape), dtype=self.dtypes[i], device=device - ).contiguous() + t = torch.empty(list(shape), dtype=self.dtypes[i], device=device).contiguous() self.tensors[binding] = t ctx.set_tensor_address(binding, t.data_ptr()) @@ -252,16 +252,11 @@ def infer(self, stream, use_cuda_graph=False): raise ValueError("ERROR: inference failed.") # capture cuda graph CUASSERT( - cudart.cudaStreamBeginCapture( - stream, - cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal, - ) + cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal) ) self.context.execute_async_v3(stream) graph = CUASSERT(cudart.cudaStreamEndCapture(stream)) - self.cuda_graph_instance = CUASSERT( - cudart.cudaGraphInstantiate(graph, 0) - ) + self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(graph, 0)) LOGGER.info("CUDA Graph captured!") else: noerror = self.context.execute_async_v3(stream) @@ -279,13 +274,7 @@ class TRTWrapper(torch.nn.Module): """ - def __init__(self, - path, - model=None, - input_names=None, - output_names=None, - use_cuda_graph=False, - timestamp=None): + def __init__(self, path, model=None, input_names=None, output_names=None, use_cuda_graph=False, timestamp=None): super().__init__() self.input_names = input_names self.output_names = output_names @@ -298,17 +287,22 @@ def __init__(self, self.use_cuda_graph = use_cuda_graph if os.path.exists(self.onnx_path): - ftime=os.path.getmtime(self.onnx_path) + ftime = os.path.getmtime(self.onnx_path) if timestamp is not None and ftime < timestamp: os.remove(self.onnx_path) else: timestamp = ftime - if timestamp is not None and os.path.exists(self.engine_path) and os.path.getmtime(self.engine_path) < timestamp: + if ( + timestamp is not None + and os.path.exists(self.engine_path) + and os.path.getmtime(self.engine_path) < timestamp + ): os.remove(self.engine_path) """ Auxiliary getters/setters """ + @property def engine_path(self): return self.path + ".plan" @@ -363,9 +357,7 @@ def load_onnx(self, providers=["CUDAExecutionProvider"]): Loads ONNX from disk and creates/activates OnnxrtRunner runner for it. """ try: - onnx_runner = OnnxrtRunner( - session_from_onnx(self.onnx_path, providers=providers) - ) + onnx_runner = OnnxrtRunner(session_from_onnx(self.onnx_path, providers=providers)) onnx_runner.activate() self.onnx_runner = onnx_runner except Exception: @@ -438,11 +430,7 @@ def forward_trt(self, trt_inputs): ret = ret[0] return ret - def build_engine( - self, - input_profiles=[], - **build_args - ): + def build_engine(self, input_profiles=[], **build_args): """ Builds TRT engine from ONNX file at self.onnx_path and sets self.engine Args: @@ -463,37 +451,20 @@ def build_engine( self.save_profiles() engine = Engine(self.path + ".plan") - engine.build( - self.onnx_path, - profiles, - **build_args - ) + engine.build(self.onnx_path, profiles, **build_args) engine.activate() self.engine = engine - def jit_export( - self, - input_example, - verbose=False, - ): + def jit_export(self, input_example, verbose=False): """ Exports self.model to Torchscript at self.jit_path and sets self.jit_model """ - self.jit_model = torch.jit.trace( - self.model, - input_example, - ).eval() + self.jit_model = torch.jit.trace(self.model, input_example).eval() self.jit_model = torch.jit.freeze(self.jit_model) torch.jit.save(self.jit_model, self.jit_path) def onnx_export( - self, - input_example, - dynamo=False, - onnx_registry=None, - dynamic_shapes=None, - verbose=False, - opset_version=18, + self, input_example, dynamo=False, onnx_registry=None, dynamic_shapes=None, verbose=False, opset_version=18 ): """ Exports self.model to ONNX file at self.onnx_path @@ -524,20 +495,10 @@ def onnx_export( LOGGER.info("Done folding constants.") LOGGER.info("Saving model...") - save_onnx( - model_onnx, - self.onnx_path, - ) + save_onnx(model_onnx, self.onnx_path) LOGGER.info("Done saving model.") - def build_and_save( - self, - input_example, - dynamo=False, - verbose=False, - input_profiles=[], - **build_args - ): + def build_and_save(self, input_example, dynamo=False, verbose=False, input_profiles=[], **build_args): """ If serialized engine is not found, exports self.model to ONNX, builds TRT engine and saves serialized TRT engine to the disk. @@ -552,14 +513,9 @@ def build_and_save( try: if not self.has_onnx(): self.onnx_export( - input_example, - dynamo=dynamo, - dynamic_shapes=get_dynamic_axes(input_profiles), - verbose=verbose, + input_example, dynamo=dynamo, dynamic_shapes=get_dynamic_axes(input_profiles), verbose=verbose ) - self.build_engine( - input_profiles=input_profiles, - **build_args) + self.build_engine(input_profiles=input_profiles, **build_args) self.engine.save() os.remove(self.onnx_path) except Exception as e: From 6e2733add60e4e9100960d091a4ed75938170f1b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 08:29:32 +0000 Subject: [PATCH 16/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/utils/export_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py index 97fdedfc0a..77a91dfff2 100644 --- a/monai/utils/export_utils.py +++ b/monai/utils/export_utils.py @@ -24,14 +24,14 @@ from __future__ import annotations -from typing import Callable, Dict, Optional, Type +from typing import Callable import torch.nn as nn from .cast_utils import CastToFloat -def simple_replace(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: +def simple_replace(BaseT: type[nn.Module], DestT: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: """ Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same atrributes. No weights are copied. Args: @@ -41,7 +41,7 @@ def simple_replace(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[ swap function to replace BaseT module with DestT """ - def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: + def expansion_fn(mod: nn.Module) -> nn.Module | None: if not isinstance(mod, BaseT): return None args = [getattr(mod, name, None) for name in mod.__constants__] @@ -51,7 +51,7 @@ def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: return expansion_fn -def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: +def wrap_module(BaseT: type[nn.Module], DestT: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: """ Generic function generator to replace BaseT module with DestT wrapper. Args: @@ -61,14 +61,14 @@ def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn. swap function to replace BaseT module with DestT """ - def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: + def expansion_fn(mod: nn.Module) -> nn.Module | None: out = DestT(mod) return out return expansion_fn -def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]) -> nn.Module: +def swap_modules(model: nn.Module, mapping: dict[str, nn.Module]) -> nn.Module: """ This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows for swapping nested modules through arbitrary levels if children @@ -86,7 +86,7 @@ def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]) -> nn.Module: return model -def replace_modules(model: nn.Module, expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]]) -> nn.Module: +def replace_modules(model: nn.Module, expansions: dict[str, Callable[[nn.Module], nn.Module | None]]) -> nn.Module: """ Top-level function to replace modules in model, specified by class name with a desired replacement. NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. @@ -96,7 +96,7 @@ def replace_modules(model: nn.Module, expansions: Dict[str, Callable[[nn.Module] Returns: model, possibly modified in-place """ - mapping: Dict[str, nn.Module] = {} + mapping: dict[str, nn.Module] = {} for name, m in model.named_modules(): m_type = type(m).__name__ if m_type in expansions: From 848a42d9f9352ab204c998370598cd11c04714a4 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 6 Aug 2024 02:08:29 -0700 Subject: [PATCH 17/85] Fixing flake errors Signed-off-by: Boris Fomitchev --- monai/utils/export_utils.py | 27 ++++++++++++------------ monai/utils/trt_utils.py | 42 ++++++++++++++++++++----------------- 2 files changed, 37 insertions(+), 32 deletions(-) diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py index 97fdedfc0a..fe628fc364 100644 --- a/monai/utils/export_utils.py +++ b/monai/utils/export_utils.py @@ -31,38 +31,39 @@ from .cast_utils import CastToFloat -def simple_replace(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: +def simple_replace(base_t: Type[nn.Module], dest_t: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: """ - Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same atrributes. No weights are copied. + Generic function generator to replace base_t module with dest_t. + base_t and dest_t should have same atrributes. No weights are copied. Args: - BaseT : module type to replace - DestT : destination module type + base_t : module type to replace + dest_t : destination module type Returns: - swap function to replace BaseT module with DestT + swap function to replace base_t module with dest_t """ def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: - if not isinstance(mod, BaseT): + if not isinstance(mod, base_t): return None args = [getattr(mod, name, None) for name in mod.__constants__] - out = DestT(*args) + out = dest_t(*args) return out return expansion_fn -def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: +def wrap_module(base_t: Type[nn.Module], dest_t: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: """ - Generic function generator to replace BaseT module with DestT wrapper. + Generic function generator to replace base_t module with dest_t wrapper. Args: - BaseT : module type to replace - DestT : destination module type + base_t : module type to replace + dest_t : destination module type Returns: - swap function to replace BaseT module with DestT + swap function to replace base_t module with dest_t """ def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: - out = DestT(mod) + out = dest_t(mod) return out return expansion_fn diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index 26e4cbf3cb..3c0dec0e6c 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -76,12 +76,14 @@ def trt_to_torch_dtype_dict(): } -def get_dynamic_axes(profiles, extra_axes={}): +def get_dynamic_axes(profiles): """ Given [[min,opt,max],...] list of profile dimensions, this method calculates dynamic_axes to use in onnx.export() """ - dynamic_axes = extra_axes + dynamic_axes = {} + if not profiles: + return dynamic_axes for profile in profiles: for key in profile: axes = [] @@ -91,24 +93,24 @@ def get_dynamic_axes(profiles, extra_axes={}): axes.append(i) if len(axes) > 0: dynamic_axes[key] = axes - return dynamic_axes + return dynamic_axes -def CUASSERT(cuda_ret): +def cuassert(cuda_ret): """ Error reporting method for CUDA calls """ err = cuda_ret[0] if err != 0: raise RuntimeError( - f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" + f"CUDA ERROR: {err}" ) if len(cuda_ret) > 1: return cuda_ret[1] return None -class ShapeException(Exception): +class ShapeError(Exception): """ Exception class to report errors from setting TRT plan input shapes """ @@ -129,7 +131,7 @@ def __init__(self, engine_path): self.tensors = OrderedDict() self.cuda_graph_instance = None # cuda graph - def build(self, onnx_path, profiles=[], update_output_names=False, **config_kwargs): + def build(self, onnx_path, profiles=None, update_output_names=False, **config_kwargs): """ Builds TRT engine from ONNX file at onnx_path and sets self.engine Args: @@ -218,7 +220,7 @@ def try_set_inputs(): # TODO: port to new TRT10 API # mincurmax = list(e.get_profile_shape(self.cur_profile, binding)) # if not self.check_shape(shape, mincurmax): - # raise ShapeException(f"Input shape to be set is outside the bounds: {binding} -> {shape}, profile is {mincurmax}, trying another profile: {self.cur_profile}") + # raise ShapeError(f"Input shape to be set is outside the bounds: {binding} -> {shape}") ctx.set_input_shape(binding, shape) ctx.set_tensor_address(binding, t.data_ptr()) @@ -226,7 +228,7 @@ def try_set_inputs(): try: try_set_inputs() break - except ShapeException: + except ShapeError: next_profile = (self.cur_profile + 1) % e.num_optimization_profiles if next_profile == last_profile: raise @@ -243,24 +245,24 @@ def infer(self, stream, use_cuda_graph=False): """ if use_cuda_graph: if self.cuda_graph_instance is not None: - CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) - CUASSERT(cudart.cudaStreamSynchronize(stream)) + cuassert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + cuassert(cudart.cudaStreamSynchronize(stream)) else: # do inference before CUDA graph capture noerror = self.context.execute_async_v3(stream) if not noerror: raise ValueError("ERROR: inference failed.") # capture cuda graph - CUASSERT( + cuassert( cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal) ) self.context.execute_async_v3(stream) - graph = CUASSERT(cudart.cudaStreamEndCapture(stream)) - self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(graph, 0)) + graph = cuassert(cudart.cudaStreamEndCapture(stream)) + self.cuda_graph_instance = cuassert(cudart.cudaGraphInstantiate(graph, 0)) LOGGER.info("CUDA Graph captured!") else: noerror = self.context.execute_async_v3(stream) - CUASSERT(cudart.cudaStreamSynchronize(stream)) + cuassert(cudart.cudaStreamSynchronize(stream)) if not noerror: raise ValueError("ERROR: inference failed.") @@ -352,10 +354,12 @@ def load_jit(self): except Exception: pass - def load_onnx(self, providers=["CUDAExecutionProvider"]): + def load_onnx(self, providers=None): """ Loads ONNX from disk and creates/activates OnnxrtRunner runner for it. """ + if providers is None: + providers = ["CUDAExecutionProvider"] try: onnx_runner = OnnxrtRunner(session_from_onnx(self.onnx_path, providers=providers)) onnx_runner.activate() @@ -430,14 +434,14 @@ def forward_trt(self, trt_inputs): ret = ret[0] return ret - def build_engine(self, input_profiles=[], **build_args): + def build_engine(self, input_profiles=None, **build_args): """ Builds TRT engine from ONNX file at self.onnx_path and sets self.engine Args: input_profiles, build_args - passed to engine.build() """ profiles = [] - if len(input_profiles) > 0: + if input_profiles: for input_profile in input_profiles: if isinstance(input_profile, Profile): profiles.append(input_profile) @@ -498,7 +502,7 @@ def onnx_export( save_onnx(model_onnx, self.onnx_path) LOGGER.info("Done saving model.") - def build_and_save(self, input_example, dynamo=False, verbose=False, input_profiles=[], **build_args): + def build_and_save(self, input_example, dynamo=False, verbose=False, input_profiles=None, **build_args): """ If serialized engine is not found, exports self.model to ONNX, builds TRT engine and saves serialized TRT engine to the disk. From cf2c3b1647560cb11114eac51da84f4990e6e9e8 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 6 Aug 2024 13:13:28 -0700 Subject: [PATCH 18/85] Fixing CI Signed-off-by: Boris Fomitchev --- monai/utils/export_utils.py | 2 +- monai/utils/trt_utils.py | 4 +--- requirements-dev.txt | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py index a1fd246d83..c52e93fa8b 100644 --- a/monai/utils/export_utils.py +++ b/monai/utils/export_utils.py @@ -24,7 +24,7 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, Type import torch.nn as nn diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index 3c0dec0e6c..8d3ddcf2c1 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -102,9 +102,7 @@ def cuassert(cuda_ret): """ err = cuda_ret[0] if err != 0: - raise RuntimeError( - f"CUDA ERROR: {err}" - ) + raise RuntimeError(f"CUDA ERROR: {err}") if len(cuda_ret) > 1: return cuda_ret[1] return None diff --git a/requirements-dev.txt b/requirements-dev.txt index 8cce6ac16b..ff374127ee 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -60,4 +60,4 @@ huggingface_hub pyamg>=5.0.0 git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd polygraphy -cuda-python + From e8b51f48d2a12c2a4ba9094f35c70fa8fa704705 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 20:13:59 +0000 Subject: [PATCH 19/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/utils/export_utils.py | 6 +++--- requirements-dev.txt | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py index c52e93fa8b..97bc8f45bc 100644 --- a/monai/utils/export_utils.py +++ b/monai/utils/export_utils.py @@ -24,14 +24,14 @@ from __future__ import annotations -from typing import Callable, Type +from typing import Callable import torch.nn as nn from .cast_utils import CastToFloat -def simple_replace(base_t: Type[nn.Module], dest_t: Type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: +def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: """ Generic function generator to replace base_t module with dest_t. base_t and dest_t should have same atrributes. No weights are copied. @@ -52,7 +52,7 @@ def expansion_fn(mod: nn.Module) -> nn.Module | None: return expansion_fn -def wrap_module(base_t: Type[nn.Module], dest_t: Type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: +def wrap_module(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: """ Generic function generator to replace base_t module with dest_t wrapper. Args: diff --git a/requirements-dev.txt b/requirements-dev.txt index ff374127ee..3a61fbeef2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -60,4 +60,3 @@ huggingface_hub pyamg>=5.0.0 git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd polygraphy - From ddb5bc89cd77ede9f68a0ea9d704bb1945a3c92e Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 6 Aug 2024 17:22:21 -0700 Subject: [PATCH 20/85] Fixed mypy, Engine refactor Signed-off-by: Boris Fomitchev --- monai/utils/export_utils.py | 6 +- monai/utils/trt_utils.py | 115 +++++++++++------------------------- 2 files changed, 41 insertions(+), 80 deletions(-) diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py index 97bc8f45bc..c1e8909611 100644 --- a/monai/utils/export_utils.py +++ b/monai/utils/export_utils.py @@ -81,7 +81,11 @@ def swap_modules(model: nn.Module, mapping: dict[str, nn.Module]) -> nn.Module: expanded_path = path.split(".") parent_mod = model for sub_path in expanded_path[:-1]: - parent_mod = parent_mod._modules[sub_path] + submod = parent_mod._modules[sub_path] + if submod is None: + break + else: + parent_mod = submod parent_mod._modules[expanded_path[-1]] = new_mod return model diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index 8d3ddcf2c1..d62b924fcb 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -81,7 +81,7 @@ def get_dynamic_axes(profiles): Given [[min,opt,max],...] list of profile dimensions, this method calculates dynamic_axes to use in onnx.export() """ - dynamic_axes = {} + dynamic_axes: dict[str, list[int]] = {} if not profiles: return dynamic_axes for profile in profiles: @@ -123,51 +123,19 @@ class Engine: """ def __init__(self, engine_path): - self.engine_path = engine_path - self.engine = None - self.context = None - self.tensors = OrderedDict() - self.cuda_graph_instance = None # cuda graph - - def build(self, onnx_path, profiles=None, update_output_names=False, **config_kwargs): """ - Builds TRT engine from ONNX file at onnx_path and sets self.engine - Args: - update_output_names: if set, use update_output_names as output names - profiles, config_kwargs: passed to TRT's engine_from_network() + Loads serialized engine, creates execution context and activates it """ - - LOGGER.info(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") - - network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) - if update_output_names: - LOGGER.info(f"Updating network outputs to {update_output_names}") - network = ModifyNetworkOutputs(network, update_output_names) - - LOGGER.info("Calling engine_from_network...") - - engine = engine_from_network(network, config=CreateConfig(profiles=profiles, **config_kwargs)) - self.engine = engine - - def save(self): - save_engine(self.engine, path=self.engine_path) - - def load(self): + self.engine_path = engine_path LOGGER.info(f"Loading TensorRT engine: {self.engine_path}") self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) - - def activate(self, profile_num=0, reuse_device_memory=None): - """ - Creates execution context for self.engine and activates it - """ - if reuse_device_memory: - self.context = self.engine.create_execution_context_without_device_memory() - self.context.device_memory = reuse_device_memory - else: - self.context = self.engine.create_execution_context() + self.tensors = OrderedDict() + self.cuda_graph_instance = None # cuda graph + self.context = self.engine.create_execution_context() self.input_names = [] self.output_names = [] self.dtypes = [] + self.cur_profile = 0 dtype_dict = trt_to_torch_dtype_dict() for idx in range(self.engine.num_io_tensors): binding = self.engine[idx] @@ -177,7 +145,6 @@ def activate(self, profile_num=0, reuse_device_memory=None): self.output_names.append(binding) dtype = dtype_dict[self.engine.get_tensor_dtype(binding)] self.dtypes.append(dtype) - self.cur_profile = profile_num def allocate_buffers(self, device): """ @@ -208,6 +175,7 @@ def set_inputs(self, feed_dict, stream): """ e = self.engine ctx = self.context + last_profile = self.cur_profile def try_set_inputs(): @@ -280,7 +248,7 @@ def __init__(self, path, model=None, input_names=None, output_names=None, use_cu self.output_names = output_names self.model = model self.profiles = None - self.engine = None + self.engine: Engine | None = None self.jit_model = None self.onnx_runner = None self.path = path @@ -336,10 +304,7 @@ def load_engine(self): Loads TRT plan from disk and activates its execution context. """ try: - engine = Engine(self.engine_path) - engine.load() - engine.activate() - self.engine = engine + self.engine = Engine(self.engine_path) except Exception as e: LOGGER.debug(f"Exception while loading the engine:\n{e}") @@ -381,16 +346,6 @@ def save_profiles(self): with open(self.profiles_path, "wb") as fp: pickle.dump(self.profiles, fp) - def inputs_to_dict(self, input_example): - """ - Converts list of inputs indo a dict usable with TRT engine - """ - trt_inputs = {} - for i, inp in enumerate(input_example): - input_name = self.engine.input_names[i] - trt_inputs[input_name] = inp - return trt_inputs - def forward(self, **args): """ Main forward method: depending on TRT/Torchscript/ONNX representation available, @@ -400,7 +355,18 @@ def forward(self, **args): if self.engine is not None: # forward_trt is not thread safe as we do not use per-thread execution contexts with lock_sm: - return self.forward_trt(args) + device = torch.cuda.current_device() + stream = torch.cuda.Stream(device=device) + self.engine.set_inputs(args, stream.cuda_stream) + self.engine.allocate_buffers(device=device) + # Need this to synchronize with Torch stream + stream.wait_stream(torch.cuda.current_stream()) + ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) + ret = list(ret.values()) + + if len(ret) == 1: + ret = ret[0] + return ret elif self.jit_model is not None: return self.jit_model.forward(**args) elif self.onnx_runner is not None: @@ -415,29 +381,13 @@ def forward(self, **args): return self.model.forward(**args) - def forward_trt(self, trt_inputs): - """ - Auxiliary method to run TRT engine. - Sets input bindings from trt_inputs, allocates memory and runs activated TRT engine - """ - stream = torch.cuda.Stream(device=torch.cuda.current_device()) - self.engine.set_inputs(trt_inputs, stream.cuda_stream) - self.engine.allocate_buffers(torch.device("cuda")) - # Need this to synchronize with Torch stream - stream.wait_stream(torch.cuda.current_stream()) - ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) - ret = list(ret.values()) - - if len(ret) == 1: - ret = ret[0] - return ret - def build_engine(self, input_profiles=None, **build_args): """ Builds TRT engine from ONNX file at self.onnx_path and sets self.engine Args: input_profiles, build_args - passed to engine.build() """ + profiles = [] if input_profiles: for input_profile in input_profiles: @@ -452,10 +402,17 @@ def build_engine(self, input_profiles=None, **build_args): self.profiles = profiles self.save_profiles() - engine = Engine(self.path + ".plan") - engine.build(self.onnx_path, profiles, **build_args) - engine.activate() - self.engine = engine + LOGGER.info(f"Building TensorRT engine for {self.onnx_path}: {self.engine_path}") + + network = network_from_onnx_path(self.onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + if self.output_names and False: + LOGGER.info(f"Updating network outputs to {self.output_names}") + network = ModifyNetworkOutputs(network, self.output_names) + + LOGGER.info("Calling engine_from_network...") + + engine = engine_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) + save_engine(engine, path=self.engine_path) def jit_export(self, input_example, verbose=False): """ @@ -518,7 +475,7 @@ def build_and_save(self, input_example, dynamo=False, verbose=False, input_profi input_example, dynamo=dynamo, dynamic_shapes=get_dynamic_axes(input_profiles), verbose=verbose ) self.build_engine(input_profiles=input_profiles, **build_args) - self.engine.save() + self.engine = Engine(self.engine_path) os.remove(self.onnx_path) except Exception as e: - raise e + LOGGER.info(f"Failed to build engine: {e}") From b188237d3288054eb4607095c15d05c7e39c88a3 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 7 Aug 2024 17:02:18 -0700 Subject: [PATCH 21/85] Merged cast_utils, copyrights fixed. Signed-off-by: Boris Fomitchev --- monai/utils/cast_utils.py | 115 ------------------------------------ monai/utils/export_utils.py | 73 ++++++++++++++++++----- monai/utils/trt_utils.py | 22 +------ 3 files changed, 60 insertions(+), 150 deletions(-) delete mode 100644 monai/utils/cast_utils.py diff --git a/monai/utils/cast_utils.py b/monai/utils/cast_utils.py deleted file mode 100644 index 8f529f0f87..0000000000 --- a/monai/utils/cast_utils.py +++ /dev/null @@ -1,115 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from contextlib import nullcontext - -import torch - - -def avoid_bfloat16_autocast_context(): - """ - If the current autocast context is bfloat16, - cast it to float32 - """ - - if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16: - return torch.cuda.amp.autocast(dtype=torch.float32) - else: - return nullcontext() - - -def avoid_float16_autocast_context(): - """ - If the current autocast context is float16, cast it to bfloat16 - if available (unless we're in jit) or float32 - """ - - if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return torch.cuda.amp.autocast(dtype=torch.float32) - - if torch.cuda.is_bf16_supported(): - return torch.cuda.amp.autocast(dtype=torch.bfloat16) - else: - return torch.cuda.amp.autocast(dtype=torch.float32) - else: - return nullcontext() - - -def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): - """ - Utility function to cast a single tensor from from_dtype to to_dtype - """ - return x.to(dtype=to_dtype) if x.dtype == from_dtype else x - - -def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): - """ - Utility function to cast all tensors in a tuple from from_dtype to to_dtype - """ - if isinstance(x, torch.Tensor): - return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) - else: - if isinstance(x, dict): - new_dict = {} - for k in x.keys(): - new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) - return new_dict - elif isinstance(x, tuple): - return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) - - -class CastToFloat(torch.nn.Module): - """ - Class used to add autocast protection for ONNX export - for forward methods with single return vaue - """ - - def __init__(self, mod): - super().__init__() - self.mod = mod - - def forward(self, x): - with torch.cuda.amp.autocast(enabled=False): - ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) - return ret - - -class CastToFloatAll(torch.nn.Module): - """ - Class used to add autocast protection for ONNX export - for forward methods with multiple return values - """ - - def __init__(self, mod): - super().__init__() - self.mod = mod - - def forward(self, *args): - from_dtype = args[0].dtype - with torch.cuda.amp.autocast(enabled=False): - ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) - return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py index c1e8909611..eafcf07251 100644 --- a/monai/utils/export_utils.py +++ b/monai/utils/export_utils.py @@ -1,21 +1,8 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# # http://www.apache.org/licenses/LICENSE-2.0 -# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -26,9 +13,65 @@ from typing import Callable +import torch import torch.nn as nn -from .cast_utils import CastToFloat + +def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): + """ + Utility function to cast a single tensor from from_dtype to to_dtype + """ + return x.to(dtype=to_dtype) if x.dtype == from_dtype else x + + +def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): + """ + Utility function to cast all tensors in a tuple from from_dtype to to_dtype + """ + if isinstance(x, torch.Tensor): + return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) + else: + if isinstance(x, dict): + new_dict = {} + for k in x.keys(): + new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) + return new_dict + elif isinstance(x, tuple): + return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) + + +class CastToFloat(torch.nn.Module): + """ + Class used to add autocast protection for ONNX export + for forward methods with single return vaue + """ + + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, x): + dtype = x.dtype + with torch.cuda.amp.autocast(enabled=False): + ret = self.mod.forward(x.to(torch.float32)).to(dtype) + return ret + + +class CastToFloatAll(torch.nn.Module): + """ + Class used to add autocast protection for ONNX export + for forward methods with multiple return values + """ + + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, *args): + from_dtype = args[0].dtype + with torch.cuda.amp.autocast(enabled=False): + ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) + return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index d62b924fcb..1147a94b8f 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -1,30 +1,13 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - -# -# Copyright 2022 The HuggingFace Inc. team. -# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# +# http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# from __future__ import annotations @@ -57,7 +40,6 @@ trt, trt_imported = optional_import("tensorrt") cudart, _ = optional_import("cuda.cudart") - LOGGER = get_logger("run_cmd") lock_sm = threading.Lock() From 60cdd745638ee434724a62f88308dfee633a5371 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 7 Aug 2024 21:06:28 -0700 Subject: [PATCH 22/85] Added unit test Signed-off-by: Boris Fomitchev --- monai/utils/export_utils.py | 28 +++++++++- monai/utils/trt_utils.py | 104 ++++++++++++++++-------------------- tests/test_trt_wrapper.py | 73 +++++++++++++++++++++++++ 3 files changed, 145 insertions(+), 60 deletions(-) create mode 100644 tests/test_trt_wrapper.py diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py index eafcf07251..b78484e72b 100644 --- a/monai/utils/export_utils.py +++ b/monai/utils/export_utils.py @@ -16,6 +16,12 @@ import torch import torch.nn as nn +from .module import optional_import + +P, P_imported = optional_import("polygraphy") +if P_imported: + from polygraphy.backend.onnx import fold_constants, onnx_from_path, save_onnx + def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): """ @@ -52,7 +58,7 @@ def __init__(self, mod): def forward(self, x): dtype = x.dtype - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): ret = self.mod.forward(x.to(torch.float32)).to(dtype) return ret @@ -69,7 +75,7 @@ def __init__(self, mod): def forward(self, *args): from_dtype = args[0].dtype - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) @@ -179,3 +185,21 @@ def replace_for_export(model: nn.Module, do_cast: bool = True) -> nn.Module: } replace_modules(model, cast_replacements) return model + + +def onnx_export(model, input_example, onnx_path, **export_args): + """ + Exports model to ONNX file at onnx_path + Args: input_example, onnx_path, export_args : passed to onnx.export() + """ + + for p in model.parameters(): + p.requires_grad = False + + replace_for_export(model, do_cast=True) + + torch.onnx.export(model, input_example, onnx_path, **export_args) + # onnx.export is not very good at folding constants so we use polygraphy + model_onnx = onnx_from_path(onnx_path) + fold_constants(model_onnx) # , allow_onnxruntime_shape_inference=False) + save_onnx(model_onnx, onnx_path) diff --git a/monai/utils/trt_utils.py b/monai/utils/trt_utils.py index 1147a94b8f..51b119b22c 100644 --- a/monai/utils/trt_utils.py +++ b/monai/utils/trt_utils.py @@ -20,12 +20,12 @@ from monai.apps.utils import get_logger +from .export_utils import onnx_export from .module import optional_import P, P_imported = optional_import("polygraphy") if P_imported: from polygraphy.backend.common import bytes_from_path - from polygraphy.backend.onnx import fold_constants, onnx_from_path, save_onnx from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx from polygraphy.backend.trt import ( CreateConfig, @@ -220,11 +220,19 @@ def infer(self, stream, use_cuda_graph=False): class TRTWrapper(torch.nn.Module): """ This wrapper implements TRT, ONNX and Torchscript persistent export - and running with fallback to Torch (for TRT modules with limited profiles) - + and running with optional fallback to Torch (for TRT modules with limited profiles) """ - def __init__(self, path, model=None, input_names=None, output_names=None, use_cuda_graph=False, timestamp=None): + def __init__( + self, + path, + model=None, + input_names=None, + output_names=None, + use_cuda_graph=False, + timestamp=None, + fallback=False, + ): super().__init__() self.input_names = input_names self.output_names = output_names @@ -235,6 +243,7 @@ def __init__(self, path, model=None, input_names=None, output_names=None, use_cu self.onnx_runner = None self.path = path self.use_cuda_graph = use_cuda_graph + self.fallback = fallback if os.path.exists(self.onnx_path): ftime = os.path.getmtime(self.onnx_path) @@ -281,12 +290,18 @@ def has_jit(self): def has_profiles(self): return os.path.exists(self.profiles_path) + def delete_model(self): + if self.fallback and self.model is not None: + del self.model + self.model = None + def load_engine(self): """ Loads TRT plan from disk and activates its execution context. """ try: self.engine = Engine(self.engine_path) + self.delete_model() except Exception as e: LOGGER.debug(f"Exception while loading the engine:\n{e}") @@ -296,6 +311,7 @@ def load_jit(self): """ try: self.jit_model = torch.jit.load(self.jit_path) + self.delete_model() except Exception: pass @@ -309,6 +325,7 @@ def load_onnx(self, providers=None): onnx_runner = OnnxrtRunner(session_from_onnx(self.onnx_path, providers=providers)) onnx_runner.activate() self.onnx_runner = onnx_runner + self.delete_model() except Exception: pass @@ -359,13 +376,15 @@ def forward(self, **args): ret = ret[0] return ret except Exception as e: - LOGGER.info(f"Exception: {e}\nFalling back to Pytorch ...") - + if self.model: + LOGGER.info(f"Exception: {e}\nFalling back to Pytorch ...") + else: + raise e return self.model.forward(**args) - def build_engine(self, input_profiles=None, **build_args): + def onnx_to_trt(self, input_profiles=None, **build_args): """ - Builds TRT engine from ONNX file at self.onnx_path and sets self.engine + Builds TRT engine from ONNX file at self.onnx_path and saves to self.trt_path Args: input_profiles, build_args - passed to engine.build() """ @@ -396,68 +415,37 @@ def build_engine(self, input_profiles=None, **build_args): engine = engine_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) save_engine(engine, path=self.engine_path) - def jit_export(self, input_example, verbose=False): - """ - Exports self.model to Torchscript at self.jit_path and sets self.jit_model - """ - self.jit_model = torch.jit.trace(self.model, input_example).eval() - self.jit_model = torch.jit.freeze(self.jit_model) - torch.jit.save(self.jit_model, self.jit_path) - - def onnx_export( - self, input_example, dynamo=False, onnx_registry=None, dynamic_shapes=None, verbose=False, opset_version=18 - ): - """ - Exports self.model to ONNX file at self.onnx_path - Args: passed to onnx.export() - """ - - LOGGER.info(f"Exporting to ONNX, dynamic shapes: {dynamic_shapes}") - model = self.model - from .export_utils import replace_for_export - - replace_for_export(model, do_cast=True) - - torch.onnx.export( - model, - input_example, - self.onnx_path, - dynamo=dynamo, - verbose=verbose, - opset_version=opset_version, - do_constant_folding=True, - input_names=self.input_names, - output_names=self.output_names, - dynamic_axes=dynamic_shapes, - ) - LOGGER.info("Folding constants...") - model_onnx = onnx_from_path(self.onnx_path) - fold_constants(model_onnx, allow_onnxruntime_shape_inference=False) - LOGGER.info("Done folding constants.") - - LOGGER.info("Saving model...") - save_onnx(model_onnx, self.onnx_path) - LOGGER.info("Done saving model.") - - def build_and_save(self, input_example, dynamo=False, verbose=False, input_profiles=None, **build_args): + def build_and_save(self, input_example, export_args=None, input_profiles=None, **build_args): """ If serialized engine is not found, exports self.model to ONNX, builds TRT engine and saves serialized TRT engine to the disk. Args: - input_example, dynamo, verbose: passed to self.onnx_export() + input_example, export_args: passed to self.onnx_export() input_profiles: used to get dynamic axes for onnx_export(), passed to self.build_engine() - build_args : passed to self.build_engine() + build_args : passed to onnx_to_trt() enable_all_tactics=True, """ + if not export_args: + export_args: dict = {} + if input_profiles: + export_args.update({"dynamic_axes": get_dynamic_axes(input_profiles)}) + if not self.has_engine(): try: if not self.has_onnx(): - self.onnx_export( - input_example, dynamo=dynamo, dynamic_shapes=get_dynamic_axes(input_profiles), verbose=verbose + LOGGER.info(f"Exporting to {self.onnx_path}, export args: {export_args}") + onnx_export( + self.model, + input_example, + self.onnx_path, + input_names=self.input_names, + output_names=self.output_names, + **export_args, ) - self.build_engine(input_profiles=input_profiles, **build_args) - self.engine = Engine(self.engine_path) + LOGGER.info("Export to ONNX successful.") + self.onnx_to_trt(input_profiles=input_profiles, **build_args) + self.load_engine() os.remove(self.onnx_path) except Exception as e: LOGGER.info(f"Failed to build engine: {e}") diff --git a/tests/test_trt_wrapper.py b/tests/test_trt_wrapper.py new file mode 100644 index 0000000000..c3223eda13 --- /dev/null +++ b/tests/test_trt_wrapper.py @@ -0,0 +1,73 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import tempfile +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.nets import UNet +from monai.utils import optional_import +from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows + +TRTWrapper, has_trtwrapper = optional_import( + "monai.utils", name="TRTWrapper", descriptor="TRT wrapper is not available - check your installation!" +) + +TEST_CASE_1 = ["fp32"] +TEST_CASE_2 = ["fp16"] + + +@skip_if_windows +@skip_if_no_cuda +@skip_if_quick +class TestConvertToTRT(unittest.TestCase): + + def setUp(self): + self.gpu_device = torch.cuda.current_device() + + def tearDown(self): + current_device = torch.cuda.current_device() + if current_device != self.gpu_device: + torch.cuda.set_device(self.gpu_device) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @unittest.skipUnless(has_trtwrapper, "TensorRT wrapper is required for convert!") + def test_value(self, precision): + model = UNet( + spatial_dims=3, + in_channels=1, + out_channels=2, + channels=(2, 2, 4, 8, 4), + strides=(2, 2, 2, 2), + num_res_units=2, + norm="batch", + ).cuda() + with torch.no_grad(), tempfile.TemporaryDirectory() as _: + model.eval() + input_example = torch.randn(1, 1, 96, 96, 96).cuda() + output_example = model(input_example) + args: dict = {"tf32": True} + if precision == "fp16": + args["fp16"] = True + args["precision_constraints"] = "obey" + + trt_wrapper = TRTWrapper("test_wrapper", model, input_names=["x"]) + trt_wrapper.build_and_save(input_example, **args, builder_optimization_level=1) + trt_output = trt_wrapper(x=input_example) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + +if __name__ == "__main__": + unittest.main() From 0ab5d26226a4a46a748c38a2453cb5e69c394f45 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 8 Aug 2024 19:09:04 -0700 Subject: [PATCH 23/85] TRTWrapper moved to networks Signed-off-by: Boris Fomitchev --- monai/networks/__init__.py | 1 + .../trt_utils.py => networks/trt_wrapper.py} | 17 +- monai/networks/utils.py | 211 ++++++++++++++++-- monai/utils/__init__.py | 2 - monai/utils/export_utils.py | 205 ----------------- tests/test_trt_wrapper.py | 4 +- 6 files changed, 205 insertions(+), 235 deletions(-) rename monai/{utils/trt_utils.py => networks/trt_wrapper.py} (97%) delete mode 100644 monai/utils/export_utils.py diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 4c429ae813..3cf3aa0cec 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .utils import ( + add_casts_around_norms, convert_to_onnx, convert_to_torchscript, convert_to_trt, diff --git a/monai/utils/trt_utils.py b/monai/networks/trt_wrapper.py similarity index 97% rename from monai/utils/trt_utils.py rename to monai/networks/trt_wrapper.py index 51b119b22c..04e81ebb19 100644 --- a/monai/utils/trt_utils.py +++ b/monai/networks/trt_wrapper.py @@ -19,9 +19,8 @@ import torch from monai.apps.utils import get_logger - -from .export_utils import onnx_export -from .module import optional_import +from monai.networks.utils import add_casts_around_norms, convert_to_onnx +from monai.utils.module import optional_import P, P_imported = optional_import("polygraphy") if P_imported: @@ -427,7 +426,7 @@ def build_and_save(self, input_example, export_args=None, input_profiles=None, * enable_all_tactics=True, """ if not export_args: - export_args: dict = {} + export_args = {} if input_profiles: export_args.update({"dynamic_axes": get_dynamic_axes(input_profiles)}) @@ -435,10 +434,11 @@ def build_and_save(self, input_example, export_args=None, input_profiles=None, * try: if not self.has_onnx(): LOGGER.info(f"Exporting to {self.onnx_path}, export args: {export_args}") - onnx_export( + add_casts_around_norms(self.model) + convert_to_onnx( self.model, input_example, - self.onnx_path, + filename=self.onnx_path, input_names=self.input_names, output_names=self.output_names, **export_args, @@ -448,4 +448,7 @@ def build_and_save(self, input_example, export_args=None, input_profiles=None, * self.load_engine() os.remove(self.onnx_path) except Exception as e: - LOGGER.info(f"Failed to build engine: {e}") + if self.fallback: + LOGGER.info(f"Failed to build engine: {e}") + else: + raise e diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 6a97434215..279900f56f 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -36,6 +36,7 @@ onnx, _ = optional_import("onnx") onnxreference, _ = optional_import("onnx.reference") onnxruntime, _ = optional_import("onnxruntime") +polygraphy, polygraphy_imported = optional_import("polygraphy") __all__ = [ "one_hot", @@ -606,6 +607,7 @@ def convert_to_onnx( rtol: float = 1e-4, atol: float = 0.0, use_trace: bool = True, + do_constant_folding: bool = True, **kwargs, ): """ @@ -642,6 +644,7 @@ def convert_to_onnx( if use_trace: # let torch.onnx.export to trace the model. mode_to_export = model + torch_versioned_kwargs = kwargs else: if not pytorch_after(1, 10): if "example_outputs" not in kwargs: @@ -654,32 +657,35 @@ def convert_to_onnx( del kwargs["example_outputs"] mode_to_export = torch.jit.script(model, **kwargs) + if torch.is_tensor(inputs): + inputs = (inputs,) + if filename is None: f = io.BytesIO() - torch.onnx.export( - mode_to_export, - tuple(inputs), - f=f, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=opset_version, - **torch_versioned_kwargs, - ) + else: + f = filename + + torch.onnx.export( + mode_to_export, + tuple(inputs), + f=f, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + **torch_versioned_kwargs, + ) + if filename is None: onnx_model = onnx.load_model_from_string(f.getvalue()) else: - torch.onnx.export( - mode_to_export, - tuple(inputs), - f=filename, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=opset_version, - **torch_versioned_kwargs, - ) onnx_model = onnx.load(filename) + if do_constant_folding and polygraphy_imported: + from polygraphy.backend.onnx import fold_constants + + fold_constants(onnx_model) + if verify: if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -1189,3 +1195,168 @@ def forward(self, x): if dtype == self.initial_type: x = x.to(self.initial_type) return x + + +def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): + """ + Utility function to cast a single tensor from from_dtype to to_dtype + """ + return x.to(dtype=to_dtype) if x.dtype == from_dtype else x + + +def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): + """ + Utility function to cast all tensors in a tuple from from_dtype to to_dtype + """ + if isinstance(x, torch.Tensor): + return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) + else: + if isinstance(x, dict): + new_dict = {} + for k in x.keys(): + new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) + return new_dict + elif isinstance(x, tuple): + return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) + + +class CastToFloat(torch.nn.Module): + """ + Class used to add autocast protection for ONNX export + for forward methods with single return vaue + """ + + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, x): + dtype = x.dtype + with torch.amp.autocast("cuda", enabled=False): + ret = self.mod.forward(x.to(torch.float32)).to(dtype) + return ret + + +class CastToFloatAll(torch.nn.Module): + """ + Class used to add autocast protection for ONNX export + for forward methods with multiple return values + """ + + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, *args): + from_dtype = args[0].dtype + with torch.amp.autocast("cuda", enabled=False): + ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) + return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) + + +def wrap_module(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: + """ + Generic function generator to replace base_t module with dest_t wrapper. + Args: + base_t : module type to replace + dest_t : destination module type + Returns: + swap function to replace base_t module with dest_t + """ + + def expansion_fn(mod: nn.Module) -> nn.Module | None: + out = dest_t(mod) + return out + + return expansion_fn + + +def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: + """ + Generic function generator to replace base_t module with dest_t. + base_t and dest_t should have same atrributes. No weights are copied. + Args: + base_t : module type to replace + dest_t : destination module type + Returns: + swap function to replace base_t module with dest_t + """ + + def expansion_fn(mod: nn.Module) -> nn.Module | None: + if not isinstance(mod, base_t): + return None + args = [getattr(mod, name, None) for name in mod.__constants__] + out = dest_t(*args) + return out + + return expansion_fn + + +def _swap_modules(model: nn.Module, mapping: dict[str, nn.Module]) -> nn.Module: + """ + This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows + for swapping nested modules through arbitrary levels if children + + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + + """ + for path, new_mod in mapping.items(): + expanded_path = path.split(".") + parent_mod = model + for sub_path in expanded_path[:-1]: + submod = parent_mod._modules[sub_path] + if submod is None: + break + else: + parent_mod = submod + parent_mod._modules[expanded_path[-1]] = new_mod + + return model + + +def replace_modules_by_type( + model: nn.Module, expansions: dict[str, Callable[[nn.Module], nn.Module | None]] +) -> nn.Module: + """ + Top-level function to replace modules in model, specified by class name with a desired replacement. + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + Args: + model : top level module + expansions : replacement dictionary: module class name -> replacement function generator + Returns: + model, possibly modified in-place + """ + mapping: dict[str, nn.Module] = {} + for name, m in model.named_modules(): + m_type = type(m).__name__ + if m_type in expansions: + # print (f"Found {m_type} in expansions ...") + swapped = expansions[m_type](m) + if swapped: + mapping[name] = swapped + + print(f"Swapped {len(mapping)} modules") + _swap_modules(model, mapping) + return model + + +def add_casts_around_norms(model: nn.Module) -> nn.Module: + """ + Top-level function to add cast wrappers around modules known to cause issues for FP16/autocast ONNX export + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + Args: + model : top level module + Returns: + model, possibly modified in-place + """ + print("Adding casts around norms...") + cast_replacements = { + "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), + "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), + "BatchNorm3d": wrap_module(nn.BatchNorm2d, CastToFloat), + "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), + "InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat), + "InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat), + } + replace_modules_by_type(model, cast_replacements) + return model diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 03745dd1bc..03fa1ceed1 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -152,5 +152,3 @@ get_numpy_dtype_from_string, get_torch_dtype_from_string, ) - -TRTWrapper, TRT_AVAILABLE = optional_import("monai.utils.trt_utils", name="TRTWrapper") diff --git a/monai/utils/export_utils.py b/monai/utils/export_utils.py deleted file mode 100644 index b78484e72b..0000000000 --- a/monai/utils/export_utils.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Callable - -import torch -import torch.nn as nn - -from .module import optional_import - -P, P_imported = optional_import("polygraphy") -if P_imported: - from polygraphy.backend.onnx import fold_constants, onnx_from_path, save_onnx - - -def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): - """ - Utility function to cast a single tensor from from_dtype to to_dtype - """ - return x.to(dtype=to_dtype) if x.dtype == from_dtype else x - - -def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): - """ - Utility function to cast all tensors in a tuple from from_dtype to to_dtype - """ - if isinstance(x, torch.Tensor): - return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) - else: - if isinstance(x, dict): - new_dict = {} - for k in x.keys(): - new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) - return new_dict - elif isinstance(x, tuple): - return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) - - -class CastToFloat(torch.nn.Module): - """ - Class used to add autocast protection for ONNX export - for forward methods with single return vaue - """ - - def __init__(self, mod): - super().__init__() - self.mod = mod - - def forward(self, x): - dtype = x.dtype - with torch.amp.autocast("cuda", enabled=False): - ret = self.mod.forward(x.to(torch.float32)).to(dtype) - return ret - - -class CastToFloatAll(torch.nn.Module): - """ - Class used to add autocast protection for ONNX export - for forward methods with multiple return values - """ - - def __init__(self, mod): - super().__init__() - self.mod = mod - - def forward(self, *args): - from_dtype = args[0].dtype - with torch.amp.autocast("cuda", enabled=False): - ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) - return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) - - -def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: - """ - Generic function generator to replace base_t module with dest_t. - base_t and dest_t should have same atrributes. No weights are copied. - Args: - base_t : module type to replace - dest_t : destination module type - Returns: - swap function to replace base_t module with dest_t - """ - - def expansion_fn(mod: nn.Module) -> nn.Module | None: - if not isinstance(mod, base_t): - return None - args = [getattr(mod, name, None) for name in mod.__constants__] - out = dest_t(*args) - return out - - return expansion_fn - - -def wrap_module(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: - """ - Generic function generator to replace base_t module with dest_t wrapper. - Args: - base_t : module type to replace - dest_t : destination module type - Returns: - swap function to replace base_t module with dest_t - """ - - def expansion_fn(mod: nn.Module) -> nn.Module | None: - out = dest_t(mod) - return out - - return expansion_fn - - -def swap_modules(model: nn.Module, mapping: dict[str, nn.Module]) -> nn.Module: - """ - This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows - for swapping nested modules through arbitrary levels if children - - NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. - - """ - for path, new_mod in mapping.items(): - expanded_path = path.split(".") - parent_mod = model - for sub_path in expanded_path[:-1]: - submod = parent_mod._modules[sub_path] - if submod is None: - break - else: - parent_mod = submod - parent_mod._modules[expanded_path[-1]] = new_mod - - return model - - -def replace_modules(model: nn.Module, expansions: dict[str, Callable[[nn.Module], nn.Module | None]]) -> nn.Module: - """ - Top-level function to replace modules in model, specified by class name with a desired replacement. - NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. - Args: - model : top level module - expansions : replacement dictionary: module class name -> replacement function generator - Returns: - model, possibly modified in-place - """ - mapping: dict[str, nn.Module] = {} - for name, m in model.named_modules(): - m_type = type(m).__name__ - if m_type in expansions: - # print (f"Found {m_type} in expansions ...") - swapped = expansions[m_type](m) - if swapped: - mapping[name] = swapped - - print(f"Swapped {len(mapping)} modules") - swap_modules(model, mapping) - return model - - -def replace_for_export(model: nn.Module, do_cast: bool = True) -> nn.Module: - """ - Top-level function to replace default set of modules in model - NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. - Args: - model : top level module - Returns: - model, possibly modified in-place - """ - if do_cast: - print("Adding casts around norms...") - cast_replacements = { - "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), - "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), - "BatchNorm3d": wrap_module(nn.BatchNorm2d, CastToFloat), - "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), - "InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat), - "InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat), - } - replace_modules(model, cast_replacements) - return model - - -def onnx_export(model, input_example, onnx_path, **export_args): - """ - Exports model to ONNX file at onnx_path - Args: input_example, onnx_path, export_args : passed to onnx.export() - """ - - for p in model.parameters(): - p.requires_grad = False - - replace_for_export(model, do_cast=True) - - torch.onnx.export(model, input_example, onnx_path, **export_args) - # onnx.export is not very good at folding constants so we use polygraphy - model_onnx = onnx_from_path(onnx_path) - fold_constants(model_onnx) # , allow_onnxruntime_shape_inference=False) - save_onnx(model_onnx, onnx_path) diff --git a/tests/test_trt_wrapper.py b/tests/test_trt_wrapper.py index c3223eda13..6f9ff772b8 100644 --- a/tests/test_trt_wrapper.py +++ b/tests/test_trt_wrapper.py @@ -22,7 +22,9 @@ from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows TRTWrapper, has_trtwrapper = optional_import( - "monai.utils", name="TRTWrapper", descriptor="TRT wrapper is not available - check your installation!" + "monai.networks.trt_wrapper", + name="TRTWrapper", + descriptor="TRT wrapper is not available - check your installation!", ) TEST_CASE_1 = ["fp32"] From 7d449f5a091da80be2ab5034e7573d20ca8d210d Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 9 Aug 2024 21:56:14 -0700 Subject: [PATCH 24/85] Refactored TRTWrapper args Signed-off-by: Boris Fomitchev --- monai/networks/trt_wrapper.py | 290 ++++++++++++++++------------------ tests/test_trt_wrapper.py | 21 ++- 2 files changed, 152 insertions(+), 159 deletions(-) diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 04e81ebb19..310cb47cd3 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -11,10 +11,11 @@ from __future__ import annotations +import inspect import os -import pickle import threading from collections import OrderedDict +from collections.abc import Mapping, Sequence import torch @@ -25,10 +26,8 @@ P, P_imported = optional_import("polygraphy") if P_imported: from polygraphy.backend.common import bytes_from_path - from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx from polygraphy.backend.trt import ( CreateConfig, - ModifyNetworkOutputs, Profile, engine_from_bytes, engine_from_network, @@ -97,7 +96,7 @@ class ShapeError(Exception): pass -class Engine: +class TRTEngine: """ An auxiliary class to implement running of TRT optimized engines @@ -218,38 +217,63 @@ def infer(self, stream, use_cuda_graph=False): class TRTWrapper(torch.nn.Module): """ - This wrapper implements TRT, ONNX and Torchscript persistent export - and running with optional fallback to Torch (for TRT modules with limited profiles) + This wrapper implements: + - TRT lazy persistent export + - Running TRT with optional fallback to Torch + (for TRT engines with limited profiles) """ def __init__( self, - path, - model=None, - input_names=None, - output_names=None, - use_cuda_graph=False, - timestamp=None, - fallback=False, + path: str, + model: torch.nn.Module | None = None, + precision: str = "tf32", + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + export_args: Mapping | None = None, + build_args: Mapping | None = None, + input_profiles: Mapping | None = None, + dynamic_batchsize: Sequence[int] | None = None, + use_cuda_graph: bool = False, + timestamp: int | None = None, + fallback: bool = False, ): + """ + Initialization method: + Tries to load persistent serialized TRT engine + Saves arguments for lazy TRT build on first forward() call + Args: + path : Path where to save persistent serialized TRT engine, + model: Model to "wrap". Can be None if TRT engine is supposed to exist. + input_names: Optional list of output names to pass to onnx.export() + output_names: Optional list of output names to pass to onnx.export() + export_args: Optional args to pass to onnx.export(). See onnx.export() for details. + build_args: Optional args to pass to TRT builder. See polygraphy.Config for details + input_profiles: Optional list of profiles for TRT builder and ONNX export. + Each profile is a map of "input name" -> [min,opt,max] values. + dynamic_batchsize: A sequence with three elements to define the batch size range of the input for the model to be + converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. + If both input_profiles and dynamic_batchsize are omitted, static shapes will be used to build TRT engine. + use_cuda_graph: Use CUDA Graph for inference(all inputs have to be the same GPU memory between calls!) + timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes) + fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile) + """ + super().__init__() - self.input_names = input_names - self.output_names = output_names - self.model = model - self.profiles = None - self.engine: Engine | None = None - self.jit_model = None - self.onnx_runner = None self.path = path + self.model = model + self.precision = precision + self.output_names = output_names or [] + self.profiles = input_profiles or [] + self.dynamic_batchsize = dynamic_batchsize + self.export_args = export_args or {} + self.build_args = build_args or {} + self.engine: TRTEngine | None = None self.use_cuda_graph = use_cuda_graph self.fallback = fallback + self.disabled = False - if os.path.exists(self.onnx_path): - ftime = os.path.getmtime(self.onnx_path) - if timestamp is not None and ftime < timestamp: - os.remove(self.onnx_path) - else: - timestamp = ftime + # Force engine rebuild if older than the timestamp passed if ( timestamp is not None and os.path.exists(self.engine_path) @@ -257,6 +281,14 @@ def __init__( ): os.remove(self.engine_path) + # Normally we read input_names from forward() but can be overridden + if input_names is None and self.model is not None: + argspec = inspect.getfullargspec(self.model.forward) + input_names = argspec.args[1:] + self.input_names = input_names + + self._load_engine() + """ Auxiliary getters/setters """ @@ -265,97 +297,54 @@ def __init__( def engine_path(self): return self.path + ".plan" - @property - def jit_path(self): - return self.path + ".ts" - @property def onnx_path(self): return self.path + ".onnx" - @property - def profiles_path(self): - return self.path + ".profiles.pkl" - - def has_engine(self): - return self.engine is not None - - def has_onnx(self): - return os.path.exists(self.onnx_path) - - def has_jit(self): - return os.path.exists(self.jit_path) + def _inputs_to_dict(self, input_example): + trt_inputs = {} + for i, inp in enumerate(input_example): + input_name = self.input_names[i] + trt_inputs[input_name] = inp + return trt_inputs - def has_profiles(self): - return os.path.exists(self.profiles_path) - - def delete_model(self): - if self.fallback and self.model is not None: - del self.model - self.model = None - - def load_engine(self): + def _load_engine(self): """ Loads TRT plan from disk and activates its execution context. """ try: - self.engine = Engine(self.engine_path) - self.delete_model() + self.engine = TRTEngine(self.engine_path) + self.input_names = self.engine.input_names + if self.fallback and self.model is not None: + del self.model + self.model = None except Exception as e: LOGGER.debug(f"Exception while loading the engine:\n{e}") - def load_jit(self): - """ - Loads Torchscript from disk + def forward(self, *argv, **kwargs): """ - try: - self.jit_model = torch.jit.load(self.jit_path) - self.delete_model() - except Exception: - pass + Main forward method: + Builds TRT engine if not available yet. + Tries to run TRT engine + If exception thrown and self.callback==True: falls back to original Pytorch - def load_onnx(self, providers=None): - """ - Loads ONNX from disk and creates/activates OnnxrtRunner runner for it. - """ - if providers is None: - providers = ["CUDAExecutionProvider"] - try: - onnx_runner = OnnxrtRunner(session_from_onnx(self.onnx_path, providers=providers)) - onnx_runner.activate() - self.onnx_runner = onnx_runner - self.delete_model() - except Exception: - pass - - def load_profiles(self): - """ - Loads saved optimization profiles from disk - """ - with open(self.profiles_path, "rb") as fp: - profiles = pickle.load(fp) - self.profiles = profiles - return profiles + Args: Passing through whatever args wrapped module's forward() has + Returns: Passing through wrapped module's forward() return value(s) - def save_profiles(self): - """ - Saves optimization profiles to disk using pickle - """ - with open(self.profiles_path, "wb") as fp: - pickle.dump(self.profiles, fp) - - def forward(self, **args): - """ - Main forward method: depending on TRT/Torchscript/ONNX representation available, - runs appropriate accelerated method. If exception thrown, falls back to original Pytorch """ try: + if len(argv) > 0: + kwargs.update(self._inputs_to_dict(argv)) + argv = [] + + if self.engine is None and not self.disabled: + self._build_and_save(kwargs) if self.engine is not None: # forward_trt is not thread safe as we do not use per-thread execution contexts with lock_sm: device = torch.cuda.current_device() stream = torch.cuda.Stream(device=device) - self.engine.set_inputs(args, stream.cuda_stream) + self.engine.set_inputs(kwargs, stream.cuda_stream) self.engine.allocate_buffers(device=device) # Need this to synchronize with Torch stream stream.wait_stream(torch.cuda.current_stream()) @@ -365,32 +354,21 @@ def forward(self, **args): if len(ret) == 1: ret = ret[0] return ret - elif self.jit_model is not None: - return self.jit_model.forward(**args) - elif self.onnx_runner is not None: - ret = self.onnx_runner.infer(args) - ret = list(ret.values()) - ret = [r.cuda() for r in ret] - if len(ret) == 1: - ret = ret[0] - return ret except Exception as e: if self.model: LOGGER.info(f"Exception: {e}\nFalling back to Pytorch ...") else: raise e - return self.model.forward(**args) + return self.model.forward(*argv, **kwargs) - def onnx_to_trt(self, input_profiles=None, **build_args): + def _onnx_to_trt(self): """ Builds TRT engine from ONNX file at self.onnx_path and saves to self.trt_path - Args: - input_profiles, build_args - passed to engine.build() """ profiles = [] - if input_profiles: - for input_profile in input_profiles: + if self.profiles: + for input_profile in self.profiles: if isinstance(input_profile, Profile): profiles.append(input_profile) else: @@ -399,56 +377,64 @@ def onnx_to_trt(self, input_profiles=None, **build_args): assert len(dims) == 3 p.add(name, min=dims[0], opt=dims[1], max=dims[2]) profiles.append(p) - self.profiles = profiles - self.save_profiles() - LOGGER.info(f"Building TensorRT engine for {self.onnx_path}: {self.engine_path}") + build_args = self.build_args.copy() + build_args["tf32"] = self.precision != "fp32" + build_args["fp16"] = self.precision == "fp16" + build_args["bf16"] = self.precision == "bf16" + LOGGER.info(f"Building TensorRT engine for {self.onnx_path}: {self.engine_path}") network = network_from_onnx_path(self.onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) - if self.output_names and False: - LOGGER.info(f"Updating network outputs to {self.output_names}") - network = ModifyNetworkOutputs(network, self.output_names) - - LOGGER.info("Calling engine_from_network...") - engine = engine_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) save_engine(engine, path=self.engine_path) - def build_and_save(self, input_example, export_args=None, input_profiles=None, **build_args): + def _build_and_save(self, input_example): """ - If serialized engine is not found, exports self.model to ONNX, + If TRT engine is not ready, exports self.model to ONNX, builds TRT engine and saves serialized TRT engine to the disk. Args: - input_example, export_args: passed to self.onnx_export() - input_profiles: used to get dynamic axes for onnx_export(), - passed to self.build_engine() - build_args : passed to onnx_to_trt() - enable_all_tactics=True, + input_example: passed to onnx.export() """ - if not export_args: - export_args = {} - if input_profiles: - export_args.update({"dynamic_axes": get_dynamic_axes(input_profiles)}) + if self.engine is not None: + return - if not self.has_engine(): - try: - if not self.has_onnx(): - LOGGER.info(f"Exporting to {self.onnx_path}, export args: {export_args}") - add_casts_around_norms(self.model) - convert_to_onnx( - self.model, - input_example, - filename=self.onnx_path, - input_names=self.input_names, - output_names=self.output_names, - **export_args, - ) - LOGGER.info("Export to ONNX successful.") - self.onnx_to_trt(input_profiles=input_profiles, **build_args) - self.load_engine() - os.remove(self.onnx_path) - except Exception as e: - if self.fallback: - LOGGER.info(f"Failed to build engine: {e}") - else: - raise e + if self.model is None: + raise ValueError("ERROR: self.model is None!") + + try: + export_args = self.export_args + dbs = self.dynamic_batchsize + if dbs: + if len(self.profiles) > 0: + raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TRTWrapper!") + if len(dbs) != 3: + raise ValueError("dynamic_batchsize has to have len ==3 ") + profiles = {} + for id, val in input_example.items(): + sh = val.shape[1:] + profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] + self.profiles = [profiles] + + if len(self.profiles) > 0: + export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) + + LOGGER.info(f"Exporting to {self.onnx_path}, export args: {export_args}") + add_casts_around_norms(self.model) + convert_to_onnx( + self.model, + (input_example,), + filename=self.onnx_path, + input_names=self.input_names, + output_names=self.output_names, + **export_args, + ) + LOGGER.info("Export to ONNX successful.") + self._onnx_to_trt() + self._load_engine() + os.remove(self.onnx_path) + except Exception as e: + if self.fallback: + LOGGER.info(f"Failed to build engine: {e}") + self.disabled = True + else: + raise e diff --git a/tests/test_trt_wrapper.py b/tests/test_trt_wrapper.py index 6f9ff772b8..a45ea8f55e 100644 --- a/tests/test_trt_wrapper.py +++ b/tests/test_trt_wrapper.py @@ -12,6 +12,7 @@ from __future__ import annotations import tempfile +import time import unittest import torch @@ -60,14 +61,20 @@ def test_value(self, precision): model.eval() input_example = torch.randn(1, 1, 96, 96, 96).cuda() output_example = model(input_example) - args: dict = {"tf32": True} - if precision == "fp16": - args["fp16"] = True - args["precision_constraints"] = "obey" + args: dict = {"builder_optimization_level": 1} - trt_wrapper = TRTWrapper("test_wrapper", model, input_names=["x"]) - trt_wrapper.build_and_save(input_example, **args, builder_optimization_level=1) - trt_output = trt_wrapper(x=input_example) + trt_wrapper = TRTWrapper( + "test_wrapper", + model, + precision=precision, + build_args=args, + dynamic_batchsize=[1, 4, 8], + timestamp=time.time(), + ) + assert trt_wrapper.engine is None + trt_output = trt_wrapper(input_example) + # Check that lazy TRT build succeeded + assert trt_wrapper.engine is not None torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) From 6846fd421568d817fcf1f6785115454ee65eed8f Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 9 Aug 2024 22:31:41 -0700 Subject: [PATCH 25/85] Added docstring for precision Signed-off-by: Boris Fomitchev --- monai/networks/trt_wrapper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 310cb47cd3..1482861ab3 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -245,6 +245,7 @@ def __init__( Args: path : Path where to save persistent serialized TRT engine, model: Model to "wrap". Can be None if TRT engine is supposed to exist. + precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. input_names: Optional list of output names to pass to onnx.export() output_names: Optional list of output names to pass to onnx.export() export_args: Optional args to pass to onnx.export(). See onnx.export() for details. From d59859094abb20c022a407df3f4d014d0c7319a3 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sat, 10 Aug 2024 23:08:00 -0700 Subject: [PATCH 26/85] Fixed comments, reordered args Signed-off-by: Boris Fomitchev --- monai/networks/trt_wrapper.py | 51 +++++++++++++++++------------------ monai/networks/utils.py | 4 ++- tests/test_trt_wrapper.py | 12 +++------ 3 files changed, 31 insertions(+), 36 deletions(-) diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 1482861ab3..f4c5174f73 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -15,7 +15,6 @@ import os import threading from collections import OrderedDict -from collections.abc import Mapping, Sequence import torch @@ -225,44 +224,44 @@ class TRTWrapper(torch.nn.Module): def __init__( self, - path: str, - model: torch.nn.Module | None = None, - precision: str = "tf32", - input_names: Sequence[str] | None = None, - output_names: Sequence[str] | None = None, - export_args: Mapping | None = None, - build_args: Mapping | None = None, - input_profiles: Mapping | None = None, - dynamic_batchsize: Sequence[int] | None = None, - use_cuda_graph: bool = False, - timestamp: int | None = None, - fallback: bool = False, + model, + path, + precision="tf32", + input_names=None, + output_names=None, + export_args=None, + build_args=None, + input_profiles=None, + dynamic_batchsize=None, + use_cuda_graph=False, + timestamp=None, + fallback=False, ): """ Initialization method: Tries to load persistent serialized TRT engine - Saves arguments for lazy TRT build on first forward() call + Saves its arguments for lazy TRT build on first forward() call Args: - path : Path where to save persistent serialized TRT engine, - model: Model to "wrap". Can be None if TRT engine is supposed to exist. + model: Model to "wrap". If None, TRT engine is supposed to already exist. + path : Path where to save persistent serialized TRT engine. precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. - input_names: Optional list of output names to pass to onnx.export() - output_names: Optional list of output names to pass to onnx.export() + input_names: Optional list of output names to pass to onnx.export(). + output_names: Optional list of output names to pass to onnx.export(). export_args: Optional args to pass to onnx.export(). See onnx.export() for details. - build_args: Optional args to pass to TRT builder. See polygraphy.Config for details + build_args: Optional args to pass to TRT builder. See polygraphy.Config for details. input_profiles: Optional list of profiles for TRT builder and ONNX export. - Each profile is a map of "input name" -> [min,opt,max] values. + Each profile is a map of the form : {"input id" : [min_shape, opt_shape, max_shape], ...}. dynamic_batchsize: A sequence with three elements to define the batch size range of the input for the model to be converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. - If both input_profiles and dynamic_batchsize are omitted, static shapes will be used to build TRT engine. - use_cuda_graph: Use CUDA Graph for inference(all inputs have to be the same GPU memory between calls!) - timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes) - fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile) + [note]: If neither input_profiles nor dynamic_batchsize specified, static shapes will be used to build TRT engine. + use_cuda_graph: Use CUDA Graph for inference. Note: all inputs have to be the same GPU memory between calls! + timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes). + fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile). """ super().__init__() - self.path = path self.model = model + self.path = path self.precision = precision self.output_names = output_names or [] self.profiles = input_profiles or [] @@ -336,7 +335,7 @@ def forward(self, *argv, **kwargs): try: if len(argv) > 0: kwargs.update(self._inputs_to_dict(argv)) - argv = [] + argv = () if self.engine is None and not self.disabled: self._build_and_save(kwargs) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 279900f56f..6c81d80ad9 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -634,7 +634,9 @@ def convert_to_onnx( rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model. atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. use_trace: whether to use `torch.jit.trace` to export the torchscript model. - kwargs: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: + do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done. + kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export() + else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: https://pytorch.org/docs/master/generated/torch.jit.script.html. """ diff --git a/tests/test_trt_wrapper.py b/tests/test_trt_wrapper.py index a45ea8f55e..2292d2325b 100644 --- a/tests/test_trt_wrapper.py +++ b/tests/test_trt_wrapper.py @@ -12,7 +12,6 @@ from __future__ import annotations import tempfile -import time import unittest import torch @@ -35,7 +34,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick -class TestConvertToTRT(unittest.TestCase): +class TestTRTWrapper(unittest.TestCase): def setUp(self): self.gpu_device = torch.cuda.current_device() @@ -57,19 +56,14 @@ def test_value(self, precision): num_res_units=2, norm="batch", ).cuda() - with torch.no_grad(), tempfile.TemporaryDirectory() as _: + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: model.eval() input_example = torch.randn(1, 1, 96, 96, 96).cuda() output_example = model(input_example) args: dict = {"builder_optimization_level": 1} trt_wrapper = TRTWrapper( - "test_wrapper", - model, - precision=precision, - build_args=args, - dynamic_batchsize=[1, 4, 8], - timestamp=time.time(), + model, f"{tmpdir}/test_wrapper", precision=precision, build_args=args, dynamic_batchsize=[1, 4, 8] ) assert trt_wrapper.engine is None trt_output = trt_wrapper(input_example) From 517c111ea1022fb745a881aa71b91f549e2efd9a Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 12 Aug 2024 09:55:01 -0700 Subject: [PATCH 27/85] Reduced test assert accuracy Signed-off-by: Boris Fomitchev --- tests/test_sure_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py index 903f9bd2ca..fb8f5dda72 100644 --- a/tests/test_sure_loss.py +++ b/tests/test_sure_loss.py @@ -65,7 +65,7 @@ def operator(x): loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) - self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6) + self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=5) if __name__ == "__main__": From ed0d93d3816b900e80b38e8ec1195b5dbefe8268 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 14 Aug 2024 01:10:32 -0700 Subject: [PATCH 28/85] Addressing code review comments Signed-off-by: Boris Fomitchev --- monai/networks/trt_wrapper.py | 11 ++++++----- tests/test_trt_wrapper.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index f4c5174f73..3a0c0c3430 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -37,7 +37,7 @@ trt, trt_imported = optional_import("tensorrt") cudart, _ = optional_import("cuda.cudart") -LOGGER = get_logger("run_cmd") +LOGGER = get_logger("trt_wrapper") lock_sm = threading.Lock() @@ -132,10 +132,11 @@ def allocate_buffers(self, device): ctx = self.context for i, binding in enumerate(self.output_names): - shape = ctx.get_tensor_shape(binding) - t = torch.empty(list(shape), dtype=self.dtypes[i], device=device).contiguous() - self.tensors[binding] = t - ctx.set_tensor_address(binding, t.data_ptr()) + shape = list(ctx.get_tensor_shape(binding)) + if binding not in self.tensors or list(self.tensors[binding].shape) != shape: + t = torch.empty(shape, dtype=self.dtypes[i], device=device).contiguous() + self.tensors[binding] = t + ctx.set_tensor_address(binding, t.data_ptr()) @staticmethod def check_shape(shape, profile): diff --git a/tests/test_trt_wrapper.py b/tests/test_trt_wrapper.py index 2292d2325b..321a26af87 100644 --- a/tests/test_trt_wrapper.py +++ b/tests/test_trt_wrapper.py @@ -65,10 +65,10 @@ def test_value(self, precision): trt_wrapper = TRTWrapper( model, f"{tmpdir}/test_wrapper", precision=precision, build_args=args, dynamic_batchsize=[1, 4, 8] ) - assert trt_wrapper.engine is None + self.assertIsNone(trt_wrapper.engine) trt_output = trt_wrapper(input_example) # Check that lazy TRT build succeeded - assert trt_wrapper.engine is not None + self.assertIsNotNone(trt_wrapper.engine) torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) From fdcf11816f0418b80417d521c8373e38723b520d Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 15 Aug 2024 01:39:43 -0700 Subject: [PATCH 29/85] Added Torch-TRT option, cleaned up engine save method Signed-off-by: Boris Fomitchev --- monai/networks/trt_wrapper.py | 93 ++++++++++++++++++++--------------- monai/networks/utils.py | 9 ++-- 2 files changed, 60 insertions(+), 42 deletions(-) diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 3a0c0c3430..0d72d360dd 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -29,12 +29,12 @@ CreateConfig, Profile, engine_from_bytes, - engine_from_network, + engine_bytes_from_network, network_from_onnx_path, - save_engine, ) trt, trt_imported = optional_import("tensorrt") +torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0") cudart, _ = optional_import("cuda.cudart") LOGGER = get_logger("trt_wrapper") @@ -228,6 +228,7 @@ def __init__( model, path, precision="tf32", + export_method="onnx", input_names=None, output_names=None, export_args=None, @@ -246,9 +247,10 @@ def __init__( model: Model to "wrap". If None, TRT engine is supposed to already exist. path : Path where to save persistent serialized TRT engine. precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. - input_names: Optional list of output names to pass to onnx.export(). - output_names: Optional list of output names to pass to onnx.export(). - export_args: Optional args to pass to onnx.export(). See onnx.export() for details. + export_method: One of 'onnx'|'onnx_dynamo'|'torch_trt'. + input_names: Optional list of input names to use for export. + output_names: Optional list of output names to use for export. + export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. build_args: Optional args to pass to TRT builder. See polygraphy.Config for details. input_profiles: Optional list of profiles for TRT builder and ONNX export. Each profile is a map of the form : {"input id" : [min_shape, opt_shape, max_shape], ...}. @@ -264,6 +266,7 @@ def __init__( self.model = model self.path = path self.precision = precision + self.export_method = export_method self.output_names = output_names or [] self.profiles = input_profiles or [] self.dynamic_batchsize = dynamic_batchsize @@ -333,13 +336,21 @@ def forward(self, *argv, **kwargs): Returns: Passing through wrapped module's forward() return value(s) """ - try: - if len(argv) > 0: - kwargs.update(self._inputs_to_dict(argv)) - argv = () + if len(argv) > 0: + kwargs.update(self._inputs_to_dict(argv)) + argv = () - if self.engine is None and not self.disabled: + if self.engine is None and not self.disabled: + try: self._build_and_save(kwargs) + self._load_engine() + except Exception as e: + if self.fallback: + LOGGER.info(f"Failed to build engine: {e}") + self.disabled = True + else: + raise e + try: if self.engine is not None: # forward_trt is not thread safe as we do not use per-thread execution contexts with lock_sm: @@ -386,8 +397,7 @@ def _onnx_to_trt(self): LOGGER.info(f"Building TensorRT engine for {self.onnx_path}: {self.engine_path}") network = network_from_onnx_path(self.onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) - engine = engine_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) - save_engine(engine, path=self.engine_path) + return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) def _build_and_save(self, input_example): """ @@ -402,40 +412,45 @@ def _build_and_save(self, input_example): if self.model is None: raise ValueError("ERROR: self.model is None!") - try: - export_args = self.export_args - dbs = self.dynamic_batchsize - if dbs: - if len(self.profiles) > 0: - raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TRTWrapper!") - if len(dbs) != 3: - raise ValueError("dynamic_batchsize has to have len ==3 ") - profiles = {} - for id, val in input_example.items(): - sh = val.shape[1:] - profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] - self.profiles = [profiles] - + export_args = self.export_args + dbs = self.dynamic_batchsize + if dbs: if len(self.profiles) > 0: - export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) - - LOGGER.info(f"Exporting to {self.onnx_path}, export args: {export_args}") - add_casts_around_norms(self.model) + raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TRTWrapper!") + if len(dbs) != 3: + raise ValueError("dynamic_batchsize has to have len ==3 ") + profiles = {} + for id, val in input_example.items(): + sh = val.shape[1:] + profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] + self.profiles = [profiles] + + if len(self.profiles) > 0: + export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) + + LOGGER.info(f"Exporting to {self.onnx_path}, export args: {export_args}") + add_casts_around_norms(self.model) + if self.export_method == 'torch_trt': + enabled_precisions = [torch.float32] + if self.precision == "fp16": + enabled_precisions.append(torch.float16) + elif self.precision == "bf16": + enabled_precisions.append(torch.bfloat16) + engine_bytes = torch_tensorrt.convert_method_to_trt_engine(self.model, + input_example, + enabled_precisions=enabled_precisions, + **export_args) + else: convert_to_onnx( self.model, - (input_example,), + input_example, filename=self.onnx_path, input_names=self.input_names, output_names=self.output_names, + dynamo=self.export_method == 'onnx_dynamo', **export_args, ) LOGGER.info("Export to ONNX successful.") - self._onnx_to_trt() - self._load_engine() + engine_bytes = self._onnx_to_trt() os.remove(self.onnx_path) - except Exception as e: - if self.fallback: - LOGGER.info(f"Failed to build engine: {e}") - self.disabled = True - else: - raise e + open(self.engine_path, 'wb').write(engine_bytes) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index c6fd46c4f1..61f7603dce 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -608,6 +608,7 @@ def convert_to_onnx( atol: float = 0.0, use_trace: bool = True, do_constant_folding: bool = True, + dynamo=False, **kwargs, ): """ @@ -635,6 +636,7 @@ def convert_to_onnx( atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. use_trace: whether to use `torch.jit.trace` to export the torchscript model. do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done. + dynamo: passed to onnx.export(). [When dynamo export API is finalized] kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export() else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: https://pytorch.org/docs/master/generated/torch.jit.script.html. @@ -669,13 +671,14 @@ def convert_to_onnx( torch.onnx.export( mode_to_export, - tuple(inputs), + inputs, f=f, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=opset_version, do_constant_folding=do_constant_folding, + # dynamo=dynamo, **torch_versioned_kwargs, ) if filename is None: @@ -958,8 +961,6 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int): # convert the torch model to a TorchScript model on target device model = model.eval().to(target_device) - ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace) - ir_model.eval() if use_onnx: # set the batch dim as dynamic @@ -981,6 +982,8 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int): output_names=onnx_output_names, ) else: + ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace) + ir_model.eval() # convert the model through the Torch-TensorRT way ir_model.to(target_device) with torch.no_grad(): From 1009dc59c05b02c1d83b519fe986a55967c8a9d2 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 15 Aug 2024 18:03:29 -0700 Subject: [PATCH 30/85] Added trt_wrap adapter Signed-off-by: Boris Fomitchev --- monai/networks/__init__.py | 2 ++ monai/networks/trt_wrapper.py | 28 +++++++++++++++++++++++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 3cf3aa0cec..f2dea8d491 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -33,3 +33,5 @@ to_norm_affine, train_mode, ) + +from .trt_wrapper import trt_wrap diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 0d72d360dd..6a3f1dd3ff 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -22,8 +22,8 @@ from monai.networks.utils import add_casts_around_norms, convert_to_onnx from monai.utils.module import optional_import -P, P_imported = optional_import("polygraphy") -if P_imported: +polygraphy, polygraphy_imported = optional_import("polygraphy") +if polygraphy_imported: from polygraphy.backend.common import bytes_from_path from polygraphy.backend.trt import ( CreateConfig, @@ -277,7 +277,16 @@ def __init__( self.fallback = fallback self.disabled = False - # Force engine rebuild if older than the timestamp passed + # if "path" filename point to existing file (e.g. checkpoint) + # it's also treated as dependency + if os.path.exists(path): + path_timestamp = os.path.getmtime(path) + if timestamp is None: + timestamp = path_timestamp + else: + timestamp = max(timestamp, path_timestamp) + + # Force engine rebuild if older than the timestamp if ( timestamp is not None and os.path.exists(self.engine_path) @@ -454,3 +463,16 @@ def _build_and_save(self, input_example): engine_bytes = self._onnx_to_trt() os.remove(self.onnx_path) open(self.engine_path, 'wb').write(engine_bytes) + + +def trt_wrap(model, path, trt_wrapper_args): + """ + TRTWrapper factory function and argument adapter + Args: + model, path: passed to TRTWrapper(). + trt_wrapper_args: dict : unpacked and passed to TRTWrapper(). + """ + if trt_imported and polygraphy_imported and torch.cuda.is_available(): + return TRTWrapper(model, path, **trt_wrapper_args) + else: + return model From fd679c040739583768333a3386b8c892d7ab48dd Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 15 Aug 2024 19:56:42 -0700 Subject: [PATCH 31/85] Refined trt_wrap Signed-off-by: Boris Fomitchev --- monai/networks/trt_wrapper.py | 48 +++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 6a3f1dd3ff..97fe8af920 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -277,15 +277,6 @@ def __init__( self.fallback = fallback self.disabled = False - # if "path" filename point to existing file (e.g. checkpoint) - # it's also treated as dependency - if os.path.exists(path): - path_timestamp = os.path.getmtime(path) - if timestamp is None: - timestamp = path_timestamp - else: - timestamp = max(timestamp, path_timestamp) - # Force engine rebuild if older than the timestamp if ( timestamp is not None @@ -465,14 +456,43 @@ def _build_and_save(self, input_example): open(self.engine_path, 'wb').write(engine_bytes) -def trt_wrap(model, path, trt_wrapper_args): +def trt_wrap(model, path, args=None, submodule=None): """ TRTWrapper factory function and argument adapter Args: model, path: passed to TRTWrapper(). - trt_wrapper_args: dict : unpacked and passed to TRTWrapper(). + args: dict : unpacked and passed to TRTWrapper(). + submodule : Hierarchical id of submodule to convert, e.g. 'image_decoder.decoder' + If None, TRTWrapper is applied to the whole model and returned. + Otherwise, submodule is replaced in-place with TRTWrapper. """ + if args is None: + args = {} if trt_imported and polygraphy_imported and torch.cuda.is_available(): - return TRTWrapper(model, path, **trt_wrapper_args) - else: - return model + # if "path" filename point to existing file (e.g. checkpoint) + # it's also treated as dependency + if os.path.exists(path): + timestamp = os.path.getmtime(path) + if 'timestamp' in args: + timestamp = max(args['timestamp'], timestamp) + args['timestamp'] = timestamp + + if submodule is None: + return TRTWrapper(model, path, **args) + else: + def find_sub(parent, submodule): + idx = submodule.find(".") + # if there is "." in name, call recursively + if idx != -1: + parent_name = submodule[:idx] + parent = getattr(parent, parent_name) + submodule = submodule[idx + 1 :] + return find_sub(parent, submodule) + return parent, submodule + + path = path + '.' + submodule + parent, submodule = find_sub(model, submodule) + submodel = getattr(parent, submodule) + wrapper = TRTWrapper(submodel, path, **args) + setattr(parent, submodule, wrapper) + return model From dc13b52139ed4bd7e1081072a149a34990037d8f Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 16 Aug 2024 17:50:36 -0700 Subject: [PATCH 32/85] Used tempdir for ONNX Signed-off-by: Boris Fomitchev --- monai/networks/trt_wrapper.py | 43 ++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 97fe8af920..6637b50dd2 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -13,6 +13,8 @@ import inspect import os +from pathlib import Path +import tempfile import threading from collections import OrderedDict @@ -301,10 +303,6 @@ def __init__( def engine_path(self): return self.path + ".plan" - @property - def onnx_path(self): - return self.path + ".onnx" - def _inputs_to_dict(self, input_example): trt_inputs = {} for i, inp in enumerate(input_example): @@ -373,9 +371,9 @@ def forward(self, *argv, **kwargs): raise e return self.model.forward(*argv, **kwargs) - def _onnx_to_trt(self): + def _onnx_to_trt(self, onnx_path): """ - Builds TRT engine from ONNX file at self.onnx_path and saves to self.trt_path + Builds TRT engine from ONNX file at onnx_path and saves to self.trt_path """ profiles = [] @@ -395,8 +393,8 @@ def _onnx_to_trt(self): build_args["fp16"] = self.precision == "fp16" build_args["bf16"] = self.precision == "bf16" - LOGGER.info(f"Building TensorRT engine for {self.onnx_path}: {self.engine_path}") - network = network_from_onnx_path(self.onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + LOGGER.info(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) def _build_and_save(self, input_example): @@ -428,7 +426,6 @@ def _build_and_save(self, input_example): if len(self.profiles) > 0: export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) - LOGGER.info(f"Exporting to {self.onnx_path}, export args: {export_args}") add_casts_around_norms(self.model) if self.export_method == 'torch_trt': enabled_precisions = [torch.float32] @@ -441,18 +438,22 @@ def _build_and_save(self, input_example): enabled_precisions=enabled_precisions, **export_args) else: - convert_to_onnx( - self.model, - input_example, - filename=self.onnx_path, - input_names=self.input_names, - output_names=self.output_names, - dynamo=self.export_method == 'onnx_dynamo', - **export_args, - ) - LOGGER.info("Export to ONNX successful.") - engine_bytes = self._onnx_to_trt() - os.remove(self.onnx_path) + # Use temporary directory for easy cleanup in case of external weights + with tempfile.TemporaryDirectory() as tmpdir: + onnx_path = Path(tmpdir) / 'model.onnx' + LOGGER.info(f"Exporting to {onnx_path}, export args: {export_args}") + convert_to_onnx( + self.model, + input_example, + filename=onnx_path, + input_names=self.input_names, + output_names=self.output_names, + dynamo=self.export_method == 'onnx_dynamo', + **export_args, + ) + LOGGER.info("Export to ONNX successful.") + engine_bytes = self._onnx_to_trt(str(onnx_path)) + open(self.engine_path, 'wb').write(engine_bytes) From 779de928426b2dbe72648e4acb5778ce3261b884 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sun, 18 Aug 2024 13:51:07 -0700 Subject: [PATCH 33/85] Refactored trt wrapper, added trt handler Signed-off-by: Boris Fomitchev --- monai/handlers/__init__.py | 1 + monai/handlers/trt_handler.py | 72 ++++++++++++++++++++++++++ monai/networks/trt_wrapper.py | 96 ++++++++++++++++++++--------------- tests/test_trt_wrapper.py | 23 +++++---- 4 files changed, 140 insertions(+), 52 deletions(-) create mode 100644 monai/handlers/trt_handler.py diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 641f9aae7d..fa6e158be8 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -40,5 +40,6 @@ from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler +from .trt_handler import TrtHandler from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports from .validation_handler import ValidationHandler diff --git a/monai/handlers/trt_handler.py b/monai/handlers/trt_handler.py new file mode 100644 index 0000000000..4c42c38bdb --- /dev/null +++ b/monai/handlers/trt_handler.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from monai.config import IgniteInfo +from monai.networks import trt_wrap +from monai.utils import min_version, optional_import + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + + +class TrtHandler: + """ + TrtHandler acts as an Ignite handler to apply TRT acceleration to the model. + Usage example:: + handler = TrtHandler(model=model, path="/test/checkpoint.pt", args={"precision": "fp16"}) + handler(trainer) + + Args: + path: the file path of checkpoint, it should be a PyTorch `pth` file. + args: dict : unpacked and passed to TrtWrapper(). + submodules : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder' + If None, TrtWrapper is applied to the whole model and returned. + Otherwise, submodules are replaced in-place with TrtWrappers. + """ + + def __init__( + self, + model, + path, + args=None, + submodules=None, + enabled=True + ): + self.model = model + self.path = path + self.args = args + self.enabled = enabled + self.submodules = submodules or [""] + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + self.logger = engine.logger + engine.add_event_handler(Events.STARTED, self) + + def __call__(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + if self.enabled: + for submodule in self.submodules: + trt_wrap(self.model, self.path, args=self.args, submodule=submodule) + self.logger.info(f"Created TRT wrapper for {self.path}.{submodule}") diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 6637b50dd2..58879ac21c 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -19,6 +19,7 @@ from collections import OrderedDict import torch +from types import MethodType from monai.apps.utils import get_logger from monai.networks.utils import add_casts_around_norms, convert_to_onnx @@ -217,7 +218,7 @@ def infer(self, stream, use_cuda_graph=False): return self.tensors -class TRTWrapper(torch.nn.Module): +class TrtWrappper: """ This wrapper implements: - TRT lazy persistent export @@ -246,7 +247,7 @@ def __init__( Tries to load persistent serialized TRT engine Saves its arguments for lazy TRT build on first forward() call Args: - model: Model to "wrap". If None, TRT engine is supposed to already exist. + model: Model to "wrap". path : Path where to save persistent serialized TRT engine. precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. export_method: One of 'onnx'|'onnx_dynamo'|'torch_trt'. @@ -263,9 +264,6 @@ def __init__( timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes). fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile). """ - - super().__init__() - self.model = model self.path = path self.precision = precision self.export_method = export_method @@ -279,6 +277,13 @@ def __init__( self.fallback = fallback self.disabled = False + # Normally we read input_names from forward() but can be overridden + if input_names is None: + argspec = inspect.getfullargspec(model.forward) + input_names = argspec.args[1:] + self.input_names = input_names + self.old_forward = model.forward + # Force engine rebuild if older than the timestamp if ( timestamp is not None @@ -287,14 +292,6 @@ def __init__( ): os.remove(self.engine_path) - # Normally we read input_names from forward() but can be overridden - if input_names is None and self.model is not None: - argspec = inspect.getfullargspec(self.model.forward) - input_names = argspec.args[1:] - self.input_names = input_names - - self._load_engine() - """ Auxiliary getters/setters """ @@ -317,13 +314,10 @@ def _load_engine(self): try: self.engine = TRTEngine(self.engine_path) self.input_names = self.engine.input_names - if self.fallback and self.model is not None: - del self.model - self.model = None except Exception as e: LOGGER.debug(f"Exception while loading the engine:\n{e}") - def forward(self, *argv, **kwargs): + def forward(self, model, argv, kwargs): """ Main forward method: Builds TRT engine if not available yet. @@ -339,15 +333,28 @@ def forward(self, *argv, **kwargs): argv = () if self.engine is None and not self.disabled: + # Restore original forward for export + new_forward = model.forward + model.forward = self.old_forward try: - self._build_and_save(kwargs) self._load_engine() + if self.engine is None: + self._build_and_save(model, kwargs) + self._load_engine() except Exception as e: if self.fallback: LOGGER.info(f"Failed to build engine: {e}") self.disabled = True else: raise e + if not self.disabled and not self.fallback: + # Delete all parameters + for param in model.parameters(): + del param + # Call empty_cache to release GPU memory + torch.cuda.empty_cache() + model.forward = new_forward + # Run the engine try: if self.engine is not None: # forward_trt is not thread safe as we do not use per-thread execution contexts @@ -365,11 +372,11 @@ def forward(self, *argv, **kwargs): ret = ret[0] return ret except Exception as e: - if self.model: + if model is not None: LOGGER.info(f"Exception: {e}\nFalling back to Pytorch ...") else: raise e - return self.model.forward(*argv, **kwargs) + return self.old_forward(model, *argv, **kwargs) def _onnx_to_trt(self, onnx_path): """ @@ -397,24 +404,22 @@ def _onnx_to_trt(self, onnx_path): network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) - def _build_and_save(self, input_example): + def _build_and_save(self, model, input_example): """ - If TRT engine is not ready, exports self.model to ONNX, + If TRT engine is not ready, exports model to ONNX, builds TRT engine and saves serialized TRT engine to the disk. Args: input_example: passed to onnx.export() """ + if self.engine is not None: return - if self.model is None: - raise ValueError("ERROR: self.model is None!") - export_args = self.export_args dbs = self.dynamic_batchsize if dbs: if len(self.profiles) > 0: - raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TRTWrapper!") + raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtWrappper!") if len(dbs) != 3: raise ValueError("dynamic_batchsize has to have len ==3 ") profiles = {} @@ -426,14 +431,14 @@ def _build_and_save(self, input_example): if len(self.profiles) > 0: export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) - add_casts_around_norms(self.model) + add_casts_around_norms(model) if self.export_method == 'torch_trt': enabled_precisions = [torch.float32] if self.precision == "fp16": enabled_precisions.append(torch.float16) elif self.precision == "bf16": enabled_precisions.append(torch.bfloat16) - engine_bytes = torch_tensorrt.convert_method_to_trt_engine(self.model, + engine_bytes = torch_tensorrt.convert_method_to_trt_engine(model, input_example, enabled_precisions=enabled_precisions, **export_args) @@ -443,9 +448,9 @@ def _build_and_save(self, input_example): onnx_path = Path(tmpdir) / 'model.onnx' LOGGER.info(f"Exporting to {onnx_path}, export args: {export_args}") convert_to_onnx( - self.model, + model, input_example, - filename=onnx_path, + filename=str(onnx_path), input_names=self.input_names, output_names=self.output_names, dynamo=self.export_method == 'onnx_dynamo', @@ -457,15 +462,23 @@ def _build_and_save(self, input_example): open(self.engine_path, 'wb').write(engine_bytes) +def trt_forward(self, *argv, **kwargs): + """ + Patch function to replace original model's forward() with. + Redirects to TrtWrappper.forward() + """ + return self._trt_wrapper.forward(self, argv, kwargs) + + def trt_wrap(model, path, args=None, submodule=None): """ - TRTWrapper factory function and argument adapter + TrtWrappper factory function and argument adapter Args: - model, path: passed to TRTWrapper(). - args: dict : unpacked and passed to TRTWrapper(). + model, path: passed to TrtWrappper(). + args: dict : unpacked and passed to TrtWrappper(). submodule : Hierarchical id of submodule to convert, e.g. 'image_decoder.decoder' - If None, TRTWrapper is applied to the whole model and returned. - Otherwise, submodule is replaced in-place with TRTWrapper. + If None, TrtWrappper is applied to the whole model and returned. + Otherwise, submodule is replaced in-place with TrtWrappper. """ if args is None: args = {} @@ -478,9 +491,7 @@ def trt_wrap(model, path, args=None, submodule=None): timestamp = max(args['timestamp'], timestamp) args['timestamp'] = timestamp - if submodule is None: - return TRTWrapper(model, path, **args) - else: + if submodule: def find_sub(parent, submodule): idx = submodule.find(".") # if there is "." in name, call recursively @@ -493,7 +504,8 @@ def find_sub(parent, submodule): path = path + '.' + submodule parent, submodule = find_sub(model, submodule) - submodel = getattr(parent, submodule) - wrapper = TRTWrapper(submodel, path, **args) - setattr(parent, submodule, wrapper) - return model + model = getattr(parent, submodule) + + wrapper = TrtWrappper(model, path, **args) + model._trt_wrapper = wrapper + model.forward = MethodType(trt_forward, model) diff --git a/tests/test_trt_wrapper.py b/tests/test_trt_wrapper.py index 321a26af87..a5ae5bfc09 100644 --- a/tests/test_trt_wrapper.py +++ b/tests/test_trt_wrapper.py @@ -21,9 +21,9 @@ from monai.utils import optional_import from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows -TRTWrapper, has_trtwrapper = optional_import( - "monai.networks.trt_wrapper", - name="TRTWrapper", +trt_wrap, has_trtwrapper = optional_import( + "monai.networks", + name="trt_wrap", descriptor="TRT wrapper is not available - check your installation!", ) @@ -61,14 +61,17 @@ def test_value(self, precision): input_example = torch.randn(1, 1, 96, 96, 96).cuda() output_example = model(input_example) args: dict = {"builder_optimization_level": 1} - - trt_wrapper = TRTWrapper( - model, f"{tmpdir}/test_wrapper", precision=precision, build_args=args, dynamic_batchsize=[1, 4, 8] - ) - self.assertIsNone(trt_wrapper.engine) - trt_output = trt_wrapper(input_example) + trt_wrap(model, + f"{tmpdir}/test_wrapper", + args={"precision": precision, + "build_args": args, + "dynamic_batchsize": [1, 4, 8] + } + ) + self.assertIsNone(model._trt_wrapper.engine) + trt_output = model(input_example) # Check that lazy TRT build succeeded - self.assertIsNotNone(trt_wrapper.engine) + self.assertIsNotNone(model._trt_wrapper.engine) torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) From 6504dc9deeb18c03ad5dd881cdbebbe8b064b9d8 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sun, 18 Aug 2024 14:51:21 -0700 Subject: [PATCH 34/85] Adjusted refactor for use in config Signed-off-by: Boris Fomitchev --- monai/networks/trt_wrapper.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 58879ac21c..bd3b94b9de 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -472,16 +472,19 @@ def trt_forward(self, *argv, **kwargs): def trt_wrap(model, path, args=None, submodule=None): """ - TrtWrappper factory function and argument adapter + Instruments model or submodule with TrtWrappper and reppaces forward() with a hook. Args: model, path: passed to TrtWrappper(). args: dict : unpacked and passed to TrtWrappper(). submodule : Hierarchical id of submodule to convert, e.g. 'image_decoder.decoder' If None, TrtWrappper is applied to the whole model and returned. Otherwise, submodule is replaced in-place with TrtWrappper. + Returns: + Always returns same model passed in as argument. This is for ease of use in configs. """ if args is None: args = {} + orig_model = model if trt_imported and polygraphy_imported and torch.cuda.is_available(): # if "path" filename point to existing file (e.g. checkpoint) # it's also treated as dependency @@ -509,3 +512,4 @@ def find_sub(parent, submodule): wrapper = TrtWrappper(model, path, **args) model._trt_wrapper = wrapper model.forward = MethodType(trt_forward, model) + return orig_model From c1be72ce49c06e9f6ef0d7a6ab7e70ad2e53838b Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 20 Aug 2024 00:59:38 -0700 Subject: [PATCH 35/85] Added fold constant threshold param Signed-off-by: Boris Fomitchev --- monai/networks/nets/swin_unetr.py | 10 +++++----- monai/networks/trt_wrapper.py | 8 +++++++- monai/networks/utils.py | 9 +++++---- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 3900c866b3..ed0c6cd9f1 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -320,8 +320,8 @@ def _check_input_size(self, spatial_shape): ) def forward(self, x_in): - if not torch.jit.is_scripting(): - self._check_input_size(x_in.shape[2:]) + # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + # self._check_input_size(x_in.shape[2:]) hidden_states_out = self.swinViT(x_in, self.normalize) enc0 = self.encoder1(x_in) enc1 = self.encoder2(hidden_states_out[0]) @@ -1046,14 +1046,14 @@ def __init__( def proj_out(self, x, normalize=False): if normalize: - x_shape = x.size() + x_shape = x.shape + # Force trace() to generate a constant by casting to int + ch = int(x_shape[1]) if len(x_shape) == 5: - n, ch, d, h, w = x_shape x = rearrange(x, "n c d h w -> n d h w c") x = F.layer_norm(x, [ch]) x = rearrange(x, "n d h w c -> n c d h w") elif len(x_shape) == 4: - n, ch, h, w = x_shape x = rearrange(x, "n c h w -> n h w c") x = F.layer_norm(x, [ch]) x = rearrange(x, "n h w c -> n c h w") diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index bd3b94b9de..a0a1fd46b1 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -443,6 +443,12 @@ def _build_and_save(self, model, input_example): enabled_precisions=enabled_precisions, **export_args) else: + if self.export_method == 'onnx_dynamo': + dynamo = True + import torch_onnx + torch_onnx.patch_torch() + else: + dynamo = False # Use temporary directory for easy cleanup in case of external weights with tempfile.TemporaryDirectory() as tmpdir: onnx_path = Path(tmpdir) / 'model.onnx' @@ -453,7 +459,7 @@ def _build_and_save(self, model, input_example): filename=str(onnx_path), input_names=self.input_names, output_names=self.output_names, - dynamo=self.export_method == 'onnx_dynamo', + dynamo=dynamo, **export_args, ) LOGGER.info("Export to ONNX successful.") diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 61f7603dce..7edcf300cb 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -608,6 +608,7 @@ def convert_to_onnx( atol: float = 0.0, use_trace: bool = True, do_constant_folding: bool = True, + constant_size_threshold: int = 16 * 1024 * 1024 * 1024, dynamo=False, **kwargs, ): @@ -636,6 +637,7 @@ def convert_to_onnx( atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. use_trace: whether to use `torch.jit.trace` to export the torchscript model. do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done. + constant_size_threshold: passed to polygrapy conatant forling, default = 16M dynamo: passed to onnx.export(). [When dynamo export API is finalized] kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export() else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: @@ -661,7 +663,7 @@ def convert_to_onnx( del kwargs["example_outputs"] mode_to_export = torch.jit.script(model, **kwargs) - if torch.is_tensor(inputs): + if not isinstance(inputs, tuple): inputs = (inputs,) if filename is None: @@ -687,9 +689,8 @@ def convert_to_onnx( onnx_model = onnx.load(filename) if do_constant_folding and polygraphy_imported: - from polygraphy.backend.onnx import fold_constants - - fold_constants(onnx_model) + from polygraphy.backend.onnx.loader import fold_constants + fold_constants(onnx_model, size_threshold=constant_size_threshold) if verify: if device is None: From 5c495b60bc045f4002413d67e9adca95dd193c9f Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 20 Aug 2024 13:40:09 -0700 Subject: [PATCH 36/85] Logger refactoring Signed-off-by: Boris Fomitchev --- monai/handlers/trt_handler.py | 5 +++-- monai/networks/trt_wrapper.py | 29 ++++++++++++++++------------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/monai/handlers/trt_handler.py b/monai/handlers/trt_handler.py index 4c42c38bdb..d924b8582e 100644 --- a/monai/handlers/trt_handler.py +++ b/monai/handlers/trt_handler.py @@ -29,7 +29,8 @@ class TrtHandler: TrtHandler acts as an Ignite handler to apply TRT acceleration to the model. Usage example:: handler = TrtHandler(model=model, path="/test/checkpoint.pt", args={"precision": "fp16"}) - handler(trainer) + handler.attach(engine) + engine.run() Args: path: the file path of checkpoint, it should be a PyTorch `pth` file. @@ -68,5 +69,5 @@ def __call__(self, engine: Engine) -> None: """ if self.enabled: for submodule in self.submodules: - trt_wrap(self.model, self.path, args=self.args, submodule=submodule) + trt_wrap(self.model, self.path, args=self.args, submodule=submodule, logger=self.logger) self.logger.info(f"Created TRT wrapper for {self.path}.{submodule}") diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index a0a1fd46b1..7cbb89c6d7 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -40,7 +40,6 @@ torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0") cudart, _ = optional_import("cuda.cudart") -LOGGER = get_logger("trt_wrapper") lock_sm = threading.Lock() @@ -104,12 +103,13 @@ class TRTEngine: """ - def __init__(self, engine_path): + def __init__(self, engine_path, logger=None): """ Loads serialized engine, creates execution context and activates it """ self.engine_path = engine_path - LOGGER.info(f"Loading TensorRT engine: {self.engine_path}") + self.logger = logger or get_logger("trt_wrapper") + self.logger.info(f"Loading TensorRT engine: {self.engine_path}") self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) self.tensors = OrderedDict() self.cuda_graph_instance = None # cuda graph @@ -208,7 +208,7 @@ def infer(self, stream, use_cuda_graph=False): self.context.execute_async_v3(stream) graph = cuassert(cudart.cudaStreamEndCapture(stream)) self.cuda_graph_instance = cuassert(cudart.cudaGraphInstantiate(graph, 0)) - LOGGER.info("CUDA Graph captured!") + self.logger.info("CUDA Graph captured!") else: noerror = self.context.execute_async_v3(stream) cuassert(cudart.cudaStreamSynchronize(stream)) @@ -241,6 +241,7 @@ def __init__( use_cuda_graph=False, timestamp=None, fallback=False, + logger=None, ): """ Initialization method: @@ -277,6 +278,8 @@ def __init__( self.fallback = fallback self.disabled = False + self.logger = logger or get_logger("trt_wrapper") + # Normally we read input_names from forward() but can be overridden if input_names is None: argspec = inspect.getfullargspec(model.forward) @@ -312,10 +315,10 @@ def _load_engine(self): Loads TRT plan from disk and activates its execution context. """ try: - self.engine = TRTEngine(self.engine_path) + self.engine = TRTEngine(self.engine_path, self.logger) self.input_names = self.engine.input_names except Exception as e: - LOGGER.debug(f"Exception while loading the engine:\n{e}") + self.logger.debug(f"Exception while loading the engine:\n{e}") def forward(self, model, argv, kwargs): """ @@ -343,7 +346,7 @@ def forward(self, model, argv, kwargs): self._load_engine() except Exception as e: if self.fallback: - LOGGER.info(f"Failed to build engine: {e}") + self.logger.info(f"Failed to build engine: {e}") self.disabled = True else: raise e @@ -373,7 +376,7 @@ def forward(self, model, argv, kwargs): return ret except Exception as e: if model is not None: - LOGGER.info(f"Exception: {e}\nFalling back to Pytorch ...") + self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...") else: raise e return self.old_forward(model, *argv, **kwargs) @@ -400,7 +403,7 @@ def _onnx_to_trt(self, onnx_path): build_args["fp16"] = self.precision == "fp16" build_args["bf16"] = self.precision == "bf16" - LOGGER.info(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) @@ -452,7 +455,7 @@ def _build_and_save(self, model, input_example): # Use temporary directory for easy cleanup in case of external weights with tempfile.TemporaryDirectory() as tmpdir: onnx_path = Path(tmpdir) / 'model.onnx' - LOGGER.info(f"Exporting to {onnx_path}, export args: {export_args}") + self.logger.info(f"Exporting to {onnx_path}, export args: {export_args}") convert_to_onnx( model, input_example, @@ -462,7 +465,7 @@ def _build_and_save(self, model, input_example): dynamo=dynamo, **export_args, ) - LOGGER.info("Export to ONNX successful.") + self.logger.info("Export to ONNX successful.") engine_bytes = self._onnx_to_trt(str(onnx_path)) open(self.engine_path, 'wb').write(engine_bytes) @@ -476,7 +479,7 @@ def trt_forward(self, *argv, **kwargs): return self._trt_wrapper.forward(self, argv, kwargs) -def trt_wrap(model, path, args=None, submodule=None): +def trt_wrap(model, path, args=None, submodule=None, logger=None): """ Instruments model or submodule with TrtWrappper and reppaces forward() with a hook. Args: @@ -515,7 +518,7 @@ def find_sub(parent, submodule): parent, submodule = find_sub(model, submodule) model = getattr(parent, submodule) - wrapper = TrtWrappper(model, path, **args) + wrapper = TrtWrappper(model, path, logger=logger, **args) model._trt_wrapper = wrapper model.forward = MethodType(trt_forward, model) return orig_model From 48b85ce392b3239027c3b0a2a7d3e0d136fed492 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 21 Aug 2024 18:49:24 -0700 Subject: [PATCH 37/85] Addressing code review comments Signed-off-by: Boris Fomitchev --- monai/networks/nets/swin_unetr.py | 4 +- monai/networks/trt_wrapper.py | 66 ++++++++++++++++--------------- 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index ed0c6cd9f1..714d986f4b 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -320,8 +320,8 @@ def _check_input_size(self, spatial_shape): ) def forward(self, x_in): - # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - # self._check_input_size(x_in.shape[2:]) + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + self._check_input_size(x_in.shape[2:]) hidden_states_out = self.swinViT(x_in, self.normalize) enc0 = self.encoder1(x_in) enc1 = self.encoder2(hidden_states_out[0]) diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 7cbb89c6d7..632fe643c9 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -103,14 +103,14 @@ class TRTEngine: """ - def __init__(self, engine_path, logger=None): + def __init__(self, plan_path, logger=None): """ Loads serialized engine, creates execution context and activates it """ - self.engine_path = engine_path + self.plan_path = plan_path self.logger = logger or get_logger("trt_wrapper") - self.logger.info(f"Loading TensorRT engine: {self.engine_path}") - self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + self.logger.info(f"Loading TensorRT engine: {self.plan_path}") + self.engine = engine_from_bytes(bytes_from_path(self.plan_path)) self.tensors = OrderedDict() self.cuda_graph_instance = None # cuda graph self.context = self.engine.create_execution_context() @@ -229,8 +229,8 @@ class TrtWrappper: def __init__( self, model, - path, - precision="tf32", + plan_path, + precision="fp16", export_method="onnx", input_names=None, output_names=None, @@ -249,7 +249,7 @@ def __init__( Saves its arguments for lazy TRT build on first forward() call Args: model: Model to "wrap". - path : Path where to save persistent serialized TRT engine. + plan_path : Path where to save persistent serialized TRT engine. precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. export_method: One of 'onnx'|'onnx_dynamo'|'torch_trt'. input_names: Optional list of input names to use for export. @@ -265,7 +265,7 @@ def __init__( timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes). fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile). """ - self.path = path + self.plan_path = plan_path self.precision = precision self.export_method = export_method self.output_names = output_names or [] @@ -290,18 +290,10 @@ def __init__( # Force engine rebuild if older than the timestamp if ( timestamp is not None - and os.path.exists(self.engine_path) - and os.path.getmtime(self.engine_path) < timestamp + and os.path.exists(self.plan_path) + and os.path.getmtime(self.plan_path) < timestamp ): - os.remove(self.engine_path) - - """ - Auxiliary getters/setters - """ - - @property - def engine_path(self): - return self.path + ".plan" + os.remove(self.plan_path) def _inputs_to_dict(self, input_example): trt_inputs = {} @@ -315,7 +307,7 @@ def _load_engine(self): Loads TRT plan from disk and activates its execution context. """ try: - self.engine = TRTEngine(self.engine_path, self.logger) + self.engine = TRTEngine(self.plan_path, self.logger) self.input_names = self.engine.input_names except Exception as e: self.logger.debug(f"Exception while loading the engine:\n{e}") @@ -383,7 +375,7 @@ def forward(self, model, argv, kwargs): def _onnx_to_trt(self, onnx_path): """ - Builds TRT engine from ONNX file at onnx_path and saves to self.trt_path + Builds TRT engine from ONNX file at onnx_path and saves to self.plan_path """ profiles = [] @@ -403,7 +395,7 @@ def _onnx_to_trt(self, onnx_path): build_args["fp16"] = self.precision == "fp16" build_args["bf16"] = self.precision == "bf16" - self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}") network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) @@ -468,7 +460,7 @@ def _build_and_save(self, model, input_example): self.logger.info("Export to ONNX successful.") engine_bytes = self._onnx_to_trt(str(onnx_path)) - open(self.engine_path, 'wb').write(engine_bytes) + open(self.plan_path, 'wb').write(engine_bytes) def trt_forward(self, *argv, **kwargs): @@ -479,11 +471,12 @@ def trt_forward(self, *argv, **kwargs): return self._trt_wrapper.forward(self, argv, kwargs) -def trt_wrap(model, path, args=None, submodule=None, logger=None): +def trt_wrap(model, ckpt_path, args=None, submodule=None, logger=None): """ Instruments model or submodule with TrtWrappper and reppaces forward() with a hook. Args: - model, path: passed to TrtWrappper(). + model: passed to TrtWrappper(). + ckpt_path: path to associated checkpoint, or just basename for .plan args: dict : unpacked and passed to TrtWrappper(). submodule : Hierarchical id of submodule to convert, e.g. 'image_decoder.decoder' If None, TrtWrappper is applied to the whole model and returned. @@ -491,14 +484,25 @@ def trt_wrap(model, path, args=None, submodule=None, logger=None): Returns: Always returns same model passed in as argument. This is for ease of use in configs. """ - if args is None: - args = {} + + default_args = { + "export_method": "onnx", + "precision": "fp16", + "build_args": { + "builder_optimization_level": 5, + "precision_constraints": "obey" + } + } + + default_args.update(args or {}) + args = default_args + orig_model = model if trt_imported and polygraphy_imported and torch.cuda.is_available(): # if "path" filename point to existing file (e.g. checkpoint) # it's also treated as dependency - if os.path.exists(path): - timestamp = os.path.getmtime(path) + if os.path.exists(ckpt_path): + timestamp = os.path.getmtime(ckpt_path) if 'timestamp' in args: timestamp = max(args['timestamp'], timestamp) args['timestamp'] = timestamp @@ -514,11 +518,11 @@ def find_sub(parent, submodule): return find_sub(parent, submodule) return parent, submodule - path = path + '.' + submodule + ckpt_path = ckpt_path + '.' + submodule parent, submodule = find_sub(model, submodule) model = getattr(parent, submodule) - wrapper = TrtWrappper(model, path, logger=logger, **args) + wrapper = TrtWrappper(model, ckpt_path + ".plan", logger=logger, **args) model._trt_wrapper = wrapper model.forward = MethodType(trt_forward, model) return orig_model From 1244c4965a80d199bbfe5bab4a0e10ceb1776a9b Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 21 Aug 2024 21:40:45 -0700 Subject: [PATCH 38/85] Added multiple submodules option to trt_wrap Signed-off-by: Boris Fomitchev --- monai/networks/trt_wrapper.py | 52 +++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 632fe643c9..b8de3d4184 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -475,12 +475,12 @@ def trt_wrap(model, ckpt_path, args=None, submodule=None, logger=None): """ Instruments model or submodule with TrtWrappper and reppaces forward() with a hook. Args: - model: passed to TrtWrappper(). + model: module to patch with TrtWrappper(). ckpt_path: path to associated checkpoint, or just basename for .plan args: dict : unpacked and passed to TrtWrappper(). - submodule : Hierarchical id of submodule to convert, e.g. 'image_decoder.decoder' - If None, TrtWrappper is applied to the whole model and returned. - Otherwise, submodule is replaced in-place with TrtWrappper. + submodule : Hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder'] + If None, TrtWrappper patch is applied to the whole model. + Otherwise, submodule (or list of) is being patched. Returns: Always returns same model passed in as argument. This is for ease of use in configs. """ @@ -497,7 +497,6 @@ def trt_wrap(model, ckpt_path, args=None, submodule=None, logger=None): default_args.update(args or {}) args = default_args - orig_model = model if trt_imported and polygraphy_imported and torch.cuda.is_available(): # if "path" filename point to existing file (e.g. checkpoint) # it's also treated as dependency @@ -507,22 +506,27 @@ def trt_wrap(model, ckpt_path, args=None, submodule=None, logger=None): timestamp = max(args['timestamp'], timestamp) args['timestamp'] = timestamp - if submodule: - def find_sub(parent, submodule): - idx = submodule.find(".") - # if there is "." in name, call recursively - if idx != -1: - parent_name = submodule[:idx] - parent = getattr(parent, parent_name) - submodule = submodule[idx + 1 :] - return find_sub(parent, submodule) - return parent, submodule - - ckpt_path = ckpt_path + '.' + submodule - parent, submodule = find_sub(model, submodule) - model = getattr(parent, submodule) - - wrapper = TrtWrappper(model, ckpt_path + ".plan", logger=logger, **args) - model._trt_wrapper = wrapper - model.forward = MethodType(trt_forward, model) - return orig_model + def wrap(model, path): + wrapper = TrtWrappper(model, path + ".plan", logger=logger, **args) + model._trt_wrapper = wrapper + model.forward = MethodType(trt_forward, model) + + def find_sub(parent, submodule): + idx = submodule.find(".") + # if there is "." in name, call recursively + if idx != -1: + parent_name = submodule[:idx] + parent = getattr(parent, parent_name) + submodule = submodule[idx + 1 :] + return find_sub(parent, submodule) + return parent, submodule + + if submodule is not None: + if isinstance(submodule, str): + submodule = [submodule] + for s in submodule: + parent, submodule = find_sub(model, s) + wrap(getattr(parent, submodule), ckpt_path + '.' + s) + else: + wrap(model, ckpt_path) + return model From a603f13640d95d49bbdf25e2f04076d90df77ced Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 22 Aug 2024 21:44:56 -0700 Subject: [PATCH 39/85] Added polygraphy to more places, torch-tensorrt option debugging Signed-off-by: Boris Fomitchev --- Dockerfile | 13 ++----------- docs/requirements.txt | 1 + monai/networks/trt_wrapper.py | 8 ++++++-- setup.cfg | 3 +++ 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/Dockerfile b/Dockerfile index 2580d80681..7d048ea628 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,8 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.03-py3 +# ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.03-py3 +ARG PYTORCH_IMAGE=gitlab-master.nvidia.com:5005/dl/dgx/pytorch:24.08-py3-stage FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" @@ -24,16 +25,6 @@ RUN if [[ $(uname -m) =~ "aarch64" ]]; then \ pip install numcodecs; \ fi -ARG TRT_URL=http://cuda-repo.nvidia.com/release-candidates/Libraries/TensorRT/v10.3/10.3.0.25-4abf3f29/12.5-r555/Ubuntu22_04-x64-manylinux_2_17/deb/ - -RUN rm -fr /tmp/trt && mkdir -p /tmp/trt && cd /tmp/trt && \ - curl ${TRT_URL} -o index.html && \ - for package in $(grep -o '[^ >"]*\.deb' index.html | uniq); do wget -nv ${TRT_URL}${package} & done && wait \ - && rm -f *-dev_* tensorrt_10* *-samples* \ - && dpkg -i *.deb \ - && apt-get --fix-broken install -y \ - && rm -rf index.html *.deb - WORKDIR /opt/monai # install full deps diff --git a/docs/requirements.txt b/docs/requirements.txt index ff94f7b6de..7307d8e5f9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -42,3 +42,4 @@ zarr huggingface_hub pyamg>=5.0.0 packaging +polygraphy diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index b8de3d4184..e11905e5d2 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -22,7 +22,7 @@ from types import MethodType from monai.apps.utils import get_logger -from monai.networks.utils import add_casts_around_norms, convert_to_onnx +from monai.networks.utils import add_casts_around_norms, convert_to_onnx, convert_to_torchscript from monai.utils.module import optional_import polygraphy, polygraphy_imported = optional_import("polygraphy") @@ -433,8 +433,12 @@ def _build_and_save(self, model, input_example): enabled_precisions.append(torch.float16) elif self.precision == "bf16": enabled_precisions.append(torch.bfloat16) + inputs=list(input_example.values()) + ir_model = convert_to_torchscript(model, inputs=inputs, use_trace=True) engine_bytes = torch_tensorrt.convert_method_to_trt_engine(model, - input_example, + 'forward', + inputs=inputs, + ir='torchscript', enabled_precisions=enabled_precisions, **export_args) else: diff --git a/setup.cfg b/setup.cfg index 1ce4a3f34c..c97118d43a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -160,6 +160,9 @@ lpips = lpips==0.1.4 pynvml = nvidia-ml-py +polygraphy = + polygraphy + # # workaround https://github.com/Project-MONAI/MONAI/issues/5882 # MetricsReloaded = # MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded From f5be0ccc14f7b9a0214fbd6f0a450453435bf679 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 22 Aug 2024 22:30:42 -0700 Subject: [PATCH 40/85] Renamed trt_wrap -> trt_compile Signed-off-by: Boris Fomitchev --- monai/handlers/trt_handler.py | 16 +++++-------- monai/networks/__init__.py | 2 +- .../{trt_wrapper.py => trt_compile.py} | 22 ++++++++--------- ...est_trt_wrapper.py => test_trt_compile.py} | 24 +++++++++---------- 4 files changed, 30 insertions(+), 34 deletions(-) rename monai/networks/{trt_wrapper.py => trt_compile.py} (97%) rename tests/{test_trt_wrapper.py => test_trt_compile.py} (79%) diff --git a/monai/handlers/trt_handler.py b/monai/handlers/trt_handler.py index d924b8582e..70d2379f1a 100644 --- a/monai/handlers/trt_handler.py +++ b/monai/handlers/trt_handler.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING from monai.config import IgniteInfo -from monai.networks import trt_wrap +from monai.networks import trt_compile from monai.utils import min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") @@ -34,10 +34,8 @@ class TrtHandler: Args: path: the file path of checkpoint, it should be a PyTorch `pth` file. - args: dict : unpacked and passed to TrtWrapper(). - submodules : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder' - If None, TrtWrapper is applied to the whole model and returned. - Otherwise, submodules are replaced in-place with TrtWrappers. + args: passed to trt_compile(). + submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder' """ def __init__( @@ -45,14 +43,14 @@ def __init__( model, path, args=None, - submodules=None, + submodule=None, enabled=True ): self.model = model self.path = path self.args = args self.enabled = enabled - self.submodules = submodules or [""] + self.submodule = submodule def attach(self, engine: Engine) -> None: """ @@ -68,6 +66,4 @@ def __call__(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ if self.enabled: - for submodule in self.submodules: - trt_wrap(self.model, self.path, args=self.args, submodule=submodule, logger=self.logger) - self.logger.info(f"Created TRT wrapper for {self.path}.{submodule}") + trt_compile(self.model, self.path, args=self.args, submodule=self.submodule, logger=self.logger) diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index f2dea8d491..3d09cbb33c 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -34,4 +34,4 @@ train_mode, ) -from .trt_wrapper import trt_wrap +from .trt_compile import trt_compile diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_compile.py similarity index 97% rename from monai/networks/trt_wrapper.py rename to monai/networks/trt_compile.py index e11905e5d2..9ec76368d1 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_compile.py @@ -108,7 +108,7 @@ def __init__(self, plan_path, logger=None): Loads serialized engine, creates execution context and activates it """ self.plan_path = plan_path - self.logger = logger or get_logger("trt_wrapper") + self.logger = logger or get_logger("trt_compile") self.logger.info(f"Loading TensorRT engine: {self.plan_path}") self.engine = engine_from_bytes(bytes_from_path(self.plan_path)) self.tensors = OrderedDict() @@ -231,7 +231,7 @@ def __init__( model, plan_path, precision="fp16", - export_method="onnx", + method="onnx", input_names=None, output_names=None, export_args=None, @@ -251,7 +251,7 @@ def __init__( model: Model to "wrap". plan_path : Path where to save persistent serialized TRT engine. precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. - export_method: One of 'onnx'|'onnx_dynamo'|'torch_trt'. + method: One of 'onnx'|'onnx_dynamo'|'torch_trt'. input_names: Optional list of input names to use for export. output_names: Optional list of output names to use for export. export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. @@ -267,7 +267,7 @@ def __init__( """ self.plan_path = plan_path self.precision = precision - self.export_method = export_method + self.method = method self.output_names = output_names or [] self.profiles = input_profiles or [] self.dynamic_batchsize = dynamic_batchsize @@ -278,7 +278,7 @@ def __init__( self.fallback = fallback self.disabled = False - self.logger = logger or get_logger("trt_wrapper") + self.logger = logger or get_logger("trt_compile") # Normally we read input_names from forward() but can be overridden if input_names is None: @@ -427,22 +427,22 @@ def _build_and_save(self, model, input_example): export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) add_casts_around_norms(model) - if self.export_method == 'torch_trt': + if self.method == 'torch_trt': enabled_precisions = [torch.float32] if self.precision == "fp16": enabled_precisions.append(torch.float16) elif self.precision == "bf16": enabled_precisions.append(torch.bfloat16) - inputs=list(input_example.values()) + inputs = list(input_example.values()) ir_model = convert_to_torchscript(model, inputs=inputs, use_trace=True) - engine_bytes = torch_tensorrt.convert_method_to_trt_engine(model, + engine_bytes = torch_tensorrt.convert_method_to_trt_engine(ir_model, 'forward', inputs=inputs, ir='torchscript', enabled_precisions=enabled_precisions, **export_args) else: - if self.export_method == 'onnx_dynamo': + if self.method == 'onnx_dynamo': dynamo = True import torch_onnx torch_onnx.patch_torch() @@ -475,7 +475,7 @@ def trt_forward(self, *argv, **kwargs): return self._trt_wrapper.forward(self, argv, kwargs) -def trt_wrap(model, ckpt_path, args=None, submodule=None, logger=None): +def trt_compile(model, ckpt_path, args=None, submodule=None, logger=None): """ Instruments model or submodule with TrtWrappper and reppaces forward() with a hook. Args: @@ -490,7 +490,7 @@ def trt_wrap(model, ckpt_path, args=None, submodule=None, logger=None): """ default_args = { - "export_method": "onnx", + "method": "onnx", "precision": "fp16", "build_args": { "builder_optimization_level": 5, diff --git a/tests/test_trt_wrapper.py b/tests/test_trt_compile.py similarity index 79% rename from tests/test_trt_wrapper.py rename to tests/test_trt_compile.py index a5ae5bfc09..3df30da9b9 100644 --- a/tests/test_trt_wrapper.py +++ b/tests/test_trt_compile.py @@ -21,10 +21,10 @@ from monai.utils import optional_import from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows -trt_wrap, has_trtwrapper = optional_import( +trt_compile, has_trt = optional_import( "monai.networks", - name="trt_wrap", - descriptor="TRT wrapper is not available - check your installation!", + name="trt_compile", + descriptor="TRT compile is not available - check your installation!", ) TEST_CASE_1 = ["fp32"] @@ -34,7 +34,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick -class TestTRTWrapper(unittest.TestCase): +class TestTRTCompile(unittest.TestCase): def setUp(self): self.gpu_device = torch.cuda.current_device() @@ -45,7 +45,7 @@ def tearDown(self): torch.cuda.set_device(self.gpu_device) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - @unittest.skipUnless(has_trtwrapper, "TensorRT wrapper is required for convert!") + @unittest.skipUnless(has_trt, "TensorRT compile wrapper is required for convert!") def test_value(self, precision): model = UNet( spatial_dims=3, @@ -61,13 +61,13 @@ def test_value(self, precision): input_example = torch.randn(1, 1, 96, 96, 96).cuda() output_example = model(input_example) args: dict = {"builder_optimization_level": 1} - trt_wrap(model, - f"{tmpdir}/test_wrapper", - args={"precision": precision, - "build_args": args, - "dynamic_batchsize": [1, 4, 8] - } - ) + trt_compile(model, + f"{tmpdir}/test_trt_compile", + args={"precision": precision, + "build_args": args, + "dynamic_batchsize": [1, 4, 8] + } + ) self.assertIsNone(model._trt_wrapper.engine) trt_output = model(input_example) # Check that lazy TRT build succeeded From b96ebb475cf7a0d24ce72b017326d6098ffa66f8 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 22 Aug 2024 23:34:14 -0700 Subject: [PATCH 41/85] Reformatted for CI Signed-off-by: Boris Fomitchev --- monai/handlers/trt_handler.py | 9 +------ monai/networks/__init__.py | 3 +-- monai/networks/trt_compile.py | 48 ++++++++++++++++------------------- monai/networks/utils.py | 1 + tests/test_trt_compile.py | 16 +++++------- 5 files changed, 31 insertions(+), 46 deletions(-) diff --git a/monai/handlers/trt_handler.py b/monai/handlers/trt_handler.py index 70d2379f1a..9085f0bff1 100644 --- a/monai/handlers/trt_handler.py +++ b/monai/handlers/trt_handler.py @@ -38,14 +38,7 @@ class TrtHandler: submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder' """ - def __init__( - self, - model, - path, - args=None, - submodule=None, - enabled=True - ): + def __init__(self, model, path, args=None, submodule=None, enabled=True): self.model = model self.path = path self.args = args diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 3d09cbb33c..bf229f4cf5 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .trt_compile import trt_compile from .utils import ( add_casts_around_norms, convert_to_onnx, @@ -33,5 +34,3 @@ to_norm_affine, train_mode, ) - -from .trt_compile import trt_compile diff --git a/monai/networks/trt_compile.py b/monai/networks/trt_compile.py index 9ec76368d1..47391d4ddb 100644 --- a/monai/networks/trt_compile.py +++ b/monai/networks/trt_compile.py @@ -13,13 +13,13 @@ import inspect import os -from pathlib import Path import tempfile import threading from collections import OrderedDict +from pathlib import Path +from types import MethodType import torch -from types import MethodType from monai.apps.utils import get_logger from monai.networks.utils import add_casts_around_norms, convert_to_onnx, convert_to_torchscript @@ -31,8 +31,8 @@ from polygraphy.backend.trt import ( CreateConfig, Profile, - engine_from_bytes, engine_bytes_from_network, + engine_from_bytes, network_from_onnx_path, ) @@ -288,11 +288,7 @@ def __init__( self.old_forward = model.forward # Force engine rebuild if older than the timestamp - if ( - timestamp is not None - and os.path.exists(self.plan_path) - and os.path.getmtime(self.plan_path) < timestamp - ): + if timestamp is not None and os.path.exists(self.plan_path) and os.path.getmtime(self.plan_path) < timestamp: os.remove(self.plan_path) def _inputs_to_dict(self, input_example): @@ -427,7 +423,7 @@ def _build_and_save(self, model, input_example): export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) add_casts_around_norms(model) - if self.method == 'torch_trt': + if self.method == "torch_trt": enabled_precisions = [torch.float32] if self.precision == "fp16": enabled_precisions.append(torch.float16) @@ -435,22 +431,25 @@ def _build_and_save(self, model, input_example): enabled_precisions.append(torch.bfloat16) inputs = list(input_example.values()) ir_model = convert_to_torchscript(model, inputs=inputs, use_trace=True) - engine_bytes = torch_tensorrt.convert_method_to_trt_engine(ir_model, - 'forward', - inputs=inputs, - ir='torchscript', - enabled_precisions=enabled_precisions, - **export_args) + engine_bytes = torch_tensorrt.convert_method_to_trt_engine( + ir_model, + "forward", + inputs=inputs, + ir="torchscript", + enabled_precisions=enabled_precisions, + **export_args, + ) else: - if self.method == 'onnx_dynamo': + if self.method == "onnx_dynamo": dynamo = True import torch_onnx + torch_onnx.patch_torch() else: dynamo = False # Use temporary directory for easy cleanup in case of external weights with tempfile.TemporaryDirectory() as tmpdir: - onnx_path = Path(tmpdir) / 'model.onnx' + onnx_path = Path(tmpdir) / "model.onnx" self.logger.info(f"Exporting to {onnx_path}, export args: {export_args}") convert_to_onnx( model, @@ -464,7 +463,7 @@ def _build_and_save(self, model, input_example): self.logger.info("Export to ONNX successful.") engine_bytes = self._onnx_to_trt(str(onnx_path)) - open(self.plan_path, 'wb').write(engine_bytes) + open(self.plan_path, "wb").write(engine_bytes) def trt_forward(self, *argv, **kwargs): @@ -492,10 +491,7 @@ def trt_compile(model, ckpt_path, args=None, submodule=None, logger=None): default_args = { "method": "onnx", "precision": "fp16", - "build_args": { - "builder_optimization_level": 5, - "precision_constraints": "obey" - } + "build_args": {"builder_optimization_level": 5, "precision_constraints": "obey"}, } default_args.update(args or {}) @@ -506,9 +502,9 @@ def trt_compile(model, ckpt_path, args=None, submodule=None, logger=None): # it's also treated as dependency if os.path.exists(ckpt_path): timestamp = os.path.getmtime(ckpt_path) - if 'timestamp' in args: - timestamp = max(args['timestamp'], timestamp) - args['timestamp'] = timestamp + if "timestamp" in args: + timestamp = max(args["timestamp"], timestamp) + args["timestamp"] = timestamp def wrap(model, path): wrapper = TrtWrappper(model, path + ".plan", logger=logger, **args) @@ -530,7 +526,7 @@ def find_sub(parent, submodule): submodule = [submodule] for s in submodule: parent, submodule = find_sub(model, s) - wrap(getattr(parent, submodule), ckpt_path + '.' + s) + wrap(getattr(parent, submodule), ckpt_path + "." + s) else: wrap(model, ckpt_path) return model diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 7edcf300cb..e439d53884 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -690,6 +690,7 @@ def convert_to_onnx( if do_constant_folding and polygraphy_imported: from polygraphy.backend.onnx.loader import fold_constants + fold_constants(onnx_model, size_threshold=constant_size_threshold) if verify: diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 3df30da9b9..1fdaa18c20 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -22,9 +22,7 @@ from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows trt_compile, has_trt = optional_import( - "monai.networks", - name="trt_compile", - descriptor="TRT compile is not available - check your installation!", + "monai.networks", name="trt_compile", descriptor="TRT compile is not available - check your installation!" ) TEST_CASE_1 = ["fp32"] @@ -61,13 +59,11 @@ def test_value(self, precision): input_example = torch.randn(1, 1, 96, 96, 96).cuda() output_example = model(input_example) args: dict = {"builder_optimization_level": 1} - trt_compile(model, - f"{tmpdir}/test_trt_compile", - args={"precision": precision, - "build_args": args, - "dynamic_batchsize": [1, 4, 8] - } - ) + trt_compile( + model, + f"{tmpdir}/test_trt_compile", + args={"precision": precision, "build_args": args, "dynamic_batchsize": [1, 4, 8]}, + ) self.assertIsNone(model._trt_wrapper.engine) trt_output = model(input_example) # Check that lazy TRT build succeeded From 85140e265299bd9059909bb1fd8faa18d28fbabb Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 23 Aug 2024 00:05:21 -0700 Subject: [PATCH 42/85] Fixed alias issue Signed-off-by: Boris Fomitchev --- monai/networks/__init__.py | 2 +- monai/networks/{trt_compile.py => trt_wrapper.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename monai/networks/{trt_compile.py => trt_wrapper.py} (100%) diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index bf229f4cf5..6c424ddfb9 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -11,7 +11,7 @@ from __future__ import annotations -from .trt_compile import trt_compile +from .trt_wrapper import trt_compile from .utils import ( add_casts_around_norms, convert_to_onnx, diff --git a/monai/networks/trt_compile.py b/monai/networks/trt_wrapper.py similarity index 100% rename from monai/networks/trt_compile.py rename to monai/networks/trt_wrapper.py From fa4c182b8e8f73fb8d7740314c5e35a18218ea14 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 23 Aug 2024 00:19:26 -0700 Subject: [PATCH 43/85] Fixed base in Dockerfile Signed-off-by: Boris Fomitchev --- Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 7d048ea628..7ad3527ad1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,8 +11,7 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -# ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.03-py3 -ARG PYTORCH_IMAGE=gitlab-master.nvidia.com:5005/dl/dgx/pytorch:24.08-py3-stage +ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.07-py3 FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" From 78a3ef308da5a84db02d721cd9c67bcc582fc1c8 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 23 Aug 2024 01:43:38 -0700 Subject: [PATCH 44/85] Fixed CI test failures Signed-off-by: Boris Fomitchev --- monai/networks/trt_wrapper.py | 2 +- monai/networks/utils.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 47391d4ddb..1558b7cbef 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -367,7 +367,7 @@ def forward(self, model, argv, kwargs): self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...") else: raise e - return self.old_forward(model, *argv, **kwargs) + return self.old_forward(*argv, **kwargs) def _onnx_to_trt(self, onnx_path): """ diff --git a/monai/networks/utils.py b/monai/networks/utils.py index e439d53884..7fd07064c1 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -663,8 +663,10 @@ def convert_to_onnx( del kwargs["example_outputs"] mode_to_export = torch.jit.script(model, **kwargs) - if not isinstance(inputs, tuple): - inputs = (inputs,) + if torch.is_tensor(inputs) or isinstance(inputs, dict): + onnx_inputs = (inputs,) + else: + onnx_inputs = tuple(inputs) if filename is None: f = io.BytesIO() @@ -673,7 +675,7 @@ def convert_to_onnx( torch.onnx.export( mode_to_export, - inputs, + onnx_inputs, f=f, input_names=input_names, output_names=output_names, From 267c12546fc3fa23b490eb97690622bb8b2081d4 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 23 Aug 2024 12:34:15 -0700 Subject: [PATCH 45/85] Addressed code review comments Signed-off-by: Boris Fomitchev --- Dockerfile | 1 + monai/handlers/trt_handler.py | 21 ++++++------ monai/networks/trt_wrapper.py | 60 +++++++++++++++++++---------------- 3 files changed, 44 insertions(+), 38 deletions(-) diff --git a/Dockerfile b/Dockerfile index 7ad3527ad1..62abc268e8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -56,4 +56,5 @@ RUN apt-get update \ && rm -rf /var/lib/apt/lists/* # append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations ENV PATH=${PATH}:/opt/tools +ENV POLYGRAPHY_AUTOINSTALL_DEPS=1 WORKDIR /opt/monai diff --git a/monai/handlers/trt_handler.py b/monai/handlers/trt_handler.py index 9085f0bff1..0e36b59d8c 100644 --- a/monai/handlers/trt_handler.py +++ b/monai/handlers/trt_handler.py @@ -28,21 +28,21 @@ class TrtHandler: """ TrtHandler acts as an Ignite handler to apply TRT acceleration to the model. Usage example:: - handler = TrtHandler(model=model, path="/test/checkpoint.pt", args={"precision": "fp16"}) + handler = TrtHandler(model=model, base_path="/test/checkpoint.pt", args={"precision": "fp16"}) handler.attach(engine) engine.run() - - Args: - path: the file path of checkpoint, it should be a PyTorch `pth` file. - args: passed to trt_compile(). - submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder' """ - def __init__(self, model, path, args=None, submodule=None, enabled=True): + def __init__(self, model, base_path, args=None, submodule=None): + """ + Args: + base_path: TRT path basename. TRT plan(s) saved to "base_path[.submodule].plan" + args: passed to trt_compile(). See trt_compile() for details. + submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder' + """ self.model = model - self.path = path + self.base_path = base_path self.args = args - self.enabled = enabled self.submodule = submodule def attach(self, engine: Engine) -> None: @@ -58,5 +58,4 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - if self.enabled: - trt_compile(self.model, self.path, args=self.args, submodule=self.submodule, logger=self.logger) + trt_compile(self.model, self.base_path, args=self.args, submodule=self.submodule, logger=self.logger) diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 1558b7cbef..9551ee03e2 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -18,6 +18,7 @@ from collections import OrderedDict from pathlib import Path from types import MethodType +from typing import Union, List, Dict, Any import torch @@ -59,8 +60,9 @@ def trt_to_torch_dtype_dict(): def get_dynamic_axes(profiles): """ - Given [[min,opt,max],...] list of profile dimensions, - this method calculates dynamic_axes to use in onnx.export() + This method calculates dynamic_axes to use in onnx.export(). + Args: + profiles: [[min,opt,max],...] list of profile dimensions """ dynamic_axes: dict[str, list[int]] = {} if not profiles: @@ -79,7 +81,9 @@ def get_dynamic_axes(profiles): def cuassert(cuda_ret): """ - Error reporting method for CUDA calls + Error reporting method for CUDA calls. + Args: + cuda_ret: CUDA return code. """ err = cuda_ret[0] if err != 0: @@ -106,6 +110,9 @@ class TRTEngine: def __init__(self, plan_path, logger=None): """ Loads serialized engine, creates execution context and activates it + Args: + plan_path: path to serialized TRT engine. + logger: optional logger object """ self.plan_path = plan_path self.logger = logger or get_logger("trt_compile") @@ -131,6 +138,8 @@ def __init__(self, plan_path, logger=None): def allocate_buffers(self, device): """ Allocates outputs to run TRT engine + Args: + device: GPU device to allocate memory on """ ctx = self.context @@ -141,20 +150,12 @@ def allocate_buffers(self, device): self.tensors[binding] = t ctx.set_tensor_address(binding, t.data_ptr()) - @staticmethod - def check_shape(shape, profile): - shape = list(shape) - minlist = profile[0] - maxlist = profile[2] - good = True - for i, s in enumerate(shape): - if s < minlist[i] or s > maxlist[i]: - good = False - return good - def set_inputs(self, feed_dict, stream): """ Sets input bindings for TRT engine according to feed_dict + Args: + feed_dict: a dictionary [str->Tensor] + stream: CUDA stream to use """ e = self.engine ctx = self.context @@ -166,10 +167,6 @@ def try_set_inputs(): if t is not None: t = t.contiguous() shape = t.shape - # TODO: port to new TRT10 API - # mincurmax = list(e.get_profile_shape(self.cur_profile, binding)) - # if not self.check_shape(shape, mincurmax): - # raise ShapeError(f"Input shape to be set is outside the bounds: {binding} -> {shape}") ctx.set_input_shape(binding, shape) ctx.set_tensor_address(binding, t.data_ptr()) @@ -190,7 +187,9 @@ def try_set_inputs(): def infer(self, stream, use_cuda_graph=False): """ Runs TRT engine. - Note use_cuda_graph requires all inputs to be the same GPU memory between calls. + Args: + stream: CUDA stream to run on + use_cuda_graph: use CUDA graph. Note: requires all inputs to be the same GPU memory between calls. """ if use_cuda_graph: if self.cuda_graph_instance is not None: @@ -474,12 +473,19 @@ def trt_forward(self, *argv, **kwargs): return self._trt_wrapper.forward(self, argv, kwargs) -def trt_compile(model, ckpt_path, args=None, submodule=None, logger=None): +def trt_compile( + model: torch.nn.Module, + base_path: str, + args: Dict[str, Any] | None = None, + submodule: Union[str, List[str]] | None = None, + logger: Any | None = None) -> torch.nn.Module: """ Instruments model or submodule with TrtWrappper and reppaces forward() with a hook. Args: model: module to patch with TrtWrappper(). - ckpt_path: path to associated checkpoint, or just basename for .plan + base_path: TRT plan(s) saved to "base_path[.submodule].plan" path. + If base_path points to existing file (e.g. associated checkpoint), + that file also becomes dependency - its mtime is added to args["timestamp"]. args: dict : unpacked and passed to TrtWrappper(). submodule : Hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder'] If None, TrtWrappper patch is applied to the whole model. @@ -488,7 +494,7 @@ def trt_compile(model, ckpt_path, args=None, submodule=None, logger=None): Always returns same model passed in as argument. This is for ease of use in configs. """ - default_args = { + default_args: Dict[str, Any] = { "method": "onnx", "precision": "fp16", "build_args": {"builder_optimization_level": 5, "precision_constraints": "obey"}, @@ -500,10 +506,10 @@ def trt_compile(model, ckpt_path, args=None, submodule=None, logger=None): if trt_imported and polygraphy_imported and torch.cuda.is_available(): # if "path" filename point to existing file (e.g. checkpoint) # it's also treated as dependency - if os.path.exists(ckpt_path): - timestamp = os.path.getmtime(ckpt_path) + if os.path.exists(base_path): + timestamp = int(os.path.getmtime(base_path)) if "timestamp" in args: - timestamp = max(args["timestamp"], timestamp) + timestamp = max(int(args["timestamp"]), timestamp) args["timestamp"] = timestamp def wrap(model, path): @@ -526,7 +532,7 @@ def find_sub(parent, submodule): submodule = [submodule] for s in submodule: parent, submodule = find_sub(model, s) - wrap(getattr(parent, submodule), ckpt_path + "." + s) + wrap(getattr(parent, s), base_path + "." + s) else: - wrap(model, ckpt_path) + wrap(model, base_path) return model From 9adc035cac7df697cdf2b01a8bf8c08f0cb83fee Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sun, 25 Aug 2024 22:47:05 -0700 Subject: [PATCH 46/85] Added dictionary return option Signed-off-by: Boris Fomitchev --- monai/networks/trt_wrapper.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 9551ee03e2..36d90e4fb6 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -251,8 +251,8 @@ def __init__( plan_path : Path where to save persistent serialized TRT engine. precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. method: One of 'onnx'|'onnx_dynamo'|'torch_trt'. - input_names: Optional list of input names to use for export. - output_names: Optional list of output names to use for export. + input_names: Optional list of input names. If None, will be read from the function signature. + output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. build_args: Optional args to pass to TRT builder. See polygraphy.Config for details. input_profiles: Optional list of profiles for TRT builder and ONNX export. @@ -356,10 +356,12 @@ def forward(self, model, argv, kwargs): # Need this to synchronize with Torch stream stream.wait_stream(torch.cuda.current_stream()) ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) - ret = list(ret.values()) + # if output_names is not None, return dictionary + if self.output_names is None: + ret = list(ret.values()) + if len(ret) == 1: + ret = ret[0] - if len(ret) == 1: - ret = ret[0] return ret except Exception as e: if model is not None: @@ -422,6 +424,7 @@ def _build_and_save(self, model, input_example): export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) add_casts_around_norms(model) + if self.method == "torch_trt": enabled_precisions = [torch.float32] if self.precision == "fp16": @@ -449,7 +452,7 @@ def _build_and_save(self, model, input_example): # Use temporary directory for easy cleanup in case of external weights with tempfile.TemporaryDirectory() as tmpdir: onnx_path = Path(tmpdir) / "model.onnx" - self.logger.info(f"Exporting to {onnx_path}, export args: {export_args}") + self.logger.info(f"Exporting to {onnx_path}:\n\toutput_names={self.output_names}\n\texport args: {export_args}") convert_to_onnx( model, input_example, @@ -486,7 +489,7 @@ def trt_compile( base_path: TRT plan(s) saved to "base_path[.submodule].plan" path. If base_path points to existing file (e.g. associated checkpoint), that file also becomes dependency - its mtime is added to args["timestamp"]. - args: dict : unpacked and passed to TrtWrappper(). + args: dict : unpacked and passed to TrtWrappper() - see TrtWrapper above for details. submodule : Hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder'] If None, TrtWrappper patch is applied to the whole model. Otherwise, submodule (or list of) is being patched. From 7f1c0c194c394ec473c78079617616759d0868a1 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 26 Aug 2024 13:58:01 -0700 Subject: [PATCH 47/85] Fixed return_dict issue Signed-off-by: Boris Fomitchev --- monai/networks/trt_wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 36d90e4fb6..07e29ea559 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -267,6 +267,7 @@ def __init__( self.plan_path = plan_path self.precision = precision self.method = method + self.return_dict = (output_names is not None) self.output_names = output_names or [] self.profiles = input_profiles or [] self.dynamic_batchsize = dynamic_batchsize @@ -357,11 +358,10 @@ def forward(self, model, argv, kwargs): stream.wait_stream(torch.cuda.current_stream()) ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) # if output_names is not None, return dictionary - if self.output_names is None: + if not self.return_dict: ret = list(ret.values()) if len(ret) == 1: ret = ret[0] - return ret except Exception as e: if model is not None: From a242a64008e1bce5ee9382ec5ac9b585d784e266 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 26 Aug 2024 16:41:41 -0700 Subject: [PATCH 48/85] Implemented https://github.com/Project-MONAI/MONAI/issues/8044 Signed-off-by: Boris Fomitchev --- monai/bundle/config_parser.py | 15 ++++++++++++++- monai/networks/trt_wrapper.py | 23 +++++++++++++---------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index a2ffeedc92..d23017e568 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -424,7 +424,20 @@ def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs files = files.split(",") for i in ensure_tuple(files): for k, v in (cls.load_config_file(i, **kwargs)).items(): - parser[k] = v + if k.startswith("+"): + id = k[1:] + if id in parser and v is not None: + if isinstance(v, dict) and isinstance(parser[id], dict): + parser[id].update(v) + elif isinstance(v, list) and isinstance(parser[id], list): + parser[id].extend(v) + else: + raise ValueError( + ValueError(f"config must be dict or list for key `{k}`, but got {type(v)}: {v}.") + ) + else: + parser[k] = v + return parser.get() # type: ignore @classmethod diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index 07e29ea559..d62cde6e52 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -18,7 +18,7 @@ from collections import OrderedDict from pathlib import Path from types import MethodType -from typing import Union, List, Dict, Any +from typing import Any, Dict, List, Union import torch @@ -267,7 +267,7 @@ def __init__( self.plan_path = plan_path self.precision = precision self.method = method - self.return_dict = (output_names is not None) + self.return_dict = output_names is not None self.output_names = output_names or [] self.profiles = input_profiles or [] self.dynamic_batchsize = dynamic_batchsize @@ -452,7 +452,9 @@ def _build_and_save(self, model, input_example): # Use temporary directory for easy cleanup in case of external weights with tempfile.TemporaryDirectory() as tmpdir: onnx_path = Path(tmpdir) / "model.onnx" - self.logger.info(f"Exporting to {onnx_path}:\n\toutput_names={self.output_names}\n\texport args: {export_args}") + self.logger.info( + f"Exporting to {onnx_path}:\n\toutput_names={self.output_names}\n\texport args: {export_args}" + ) convert_to_onnx( model, input_example, @@ -477,11 +479,12 @@ def trt_forward(self, *argv, **kwargs): def trt_compile( - model: torch.nn.Module, - base_path: str, - args: Dict[str, Any] | None = None, - submodule: Union[str, List[str]] | None = None, - logger: Any | None = None) -> torch.nn.Module: + model: torch.nn.Module, + base_path: str, + args: Dict[str, Any] | None = None, + submodule: Union[str, List[str]] | None = None, + logger: Any | None = None, +) -> torch.nn.Module: """ Instruments model or submodule with TrtWrappper and reppaces forward() with a hook. Args: @@ -534,8 +537,8 @@ def find_sub(parent, submodule): if isinstance(submodule, str): submodule = [submodule] for s in submodule: - parent, submodule = find_sub(model, s) - wrap(getattr(parent, s), base_path + "." + s) + parent, sub = find_sub(model, s) + wrap(getattr(parent, sub), base_path + "." + s) else: wrap(model, base_path) return model From 5afc91202a781f7a58fb4b0c5bfc500c6fb74b96 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 26 Aug 2024 23:10:04 -0700 Subject: [PATCH 49/85] Generalizing merge logic, adding test case and doc Signed-off-by: Boris Fomitchev --- docs/source/config_syntax.md | 26 ++++++++++++++++++++++++++ monai/bundle/config_parser.py | 19 ++++--------------- monai/bundle/scripts.py | 4 ++-- monai/bundle/utils.py | 30 +++++++++++++++++++++++++++++- tests/test_config_parser.py | 32 ++++++++++++++++++++++++++++++++ 5 files changed, 93 insertions(+), 18 deletions(-) diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index c932879b5a..bb29c453f7 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -16,6 +16,7 @@ Content: - [`$` to evaluate as Python expressions](#to-evaluate-as-python-expressions) - [`%` to textually replace configuration elements](#to-textually-replace-configuration-elements) - [`_target_` (`_disabled_`, `_desc_`, `_requires_`, `_mode_`) to instantiate a Python object](#instantiate-a-python-object) + - [`+` to alter semantics of merging config keys from multiple configuration files](#multiple-config-files) - [The command line interface](#the-command-line-interface) - [Recommendations](#recommendations) @@ -175,6 +176,31 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k - `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``, see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall). +## Multiple config files + +_Description:_ Multiple config files may be specified on the command line. +The content of those config files is being merged. When same keys are specifiled in more than one config file, +the value associated with the key is being overridden, in the order config files are specified. +If the desired behaviour is to merge values from both files, the key in second config file should be prefixed with '-'. +Here's an example. In this case, "amp" value will be overridden by json2 config, and "imports" list will me merged: +```json1 +{ + "amp": True + "imports": [ + "$import torch" + ], +} +``` + +```json2 +{ + "amp": False + "+imports": [ + "$from monai.networks import trt_compile" + ], +} +``` + ## The command line interface In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle. diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index d23017e568..1d9920a230 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -20,7 +20,7 @@ from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.bundle.reference_resolver import ReferenceResolver -from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY +from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY, merge_kv from monai.config import PathLike from monai.utils import ensure_tuple, look_up_option, optional_import from monai.utils.misc import CheckKeyDuplicatesYamlLoader, check_key_duplicates @@ -423,20 +423,9 @@ def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs if isinstance(files, str) and not Path(files).is_file() and "," in files: files = files.split(",") for i in ensure_tuple(files): - for k, v in (cls.load_config_file(i, **kwargs)).items(): - if k.startswith("+"): - id = k[1:] - if id in parser and v is not None: - if isinstance(v, dict) and isinstance(parser[id], dict): - parser[id].update(v) - elif isinstance(v, list) and isinstance(parser[id], list): - parser[id].extend(v) - else: - raise ValueError( - ValueError(f"config must be dict or list for key `{k}`, but got {type(v)}: {v}.") - ) - else: - parser[k] = v + config_dict = cls.load_config_file(i, **kwargs) + for k, v in config_dict.items(): + merge_kv(parser, k, v) return parser.get() # type: ignore diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 142a366669..f1d1286e4b 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -32,7 +32,7 @@ from monai.apps.utils import _basename, download_url, extractall, get_logger from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser -from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA +from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow from monai.config import IgniteInfo, PathLike from monai.data import load_net_with_metadata, save_net_with_metadata @@ -105,7 +105,7 @@ def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kw if isinstance(v, dict) and isinstance(args_.get(k), dict): args_[k] = update_kwargs(args_[k], ignore_none, **v) else: - args_[k] = v + merge_kv(args_, k, v) return args_ diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py index 50d2608f4c..eb6445116b 100644 --- a/monai/bundle/utils.py +++ b/monai/bundle/utils.py @@ -21,12 +21,21 @@ yaml, _ = optional_import("yaml") -__all__ = ["ID_REF_KEY", "ID_SEP_KEY", "EXPR_KEY", "MACRO_KEY", "DEFAULT_MLFLOW_SETTINGS", "DEFAULT_EXP_MGMT_SETTINGS"] +__all__ = [ + "ID_REF_KEY", + "ID_SEP_KEY", + "EXPR_KEY", + "MACRO_KEY", + "MERGE_KEY", + "DEFAULT_MLFLOW_SETTINGS", + "DEFAULT_EXP_MGMT_SETTINGS", +] ID_REF_KEY = "@" # start of a reference to a ConfigItem ID_SEP_KEY = "::" # separator for the ID of a ConfigItem EXPR_KEY = "$" # start of a ConfigExpression MACRO_KEY = "%" # start of a macro of a config +MERGE_KEY = "+" # start of a macro of a config _conf_values = get_config_values() @@ -233,3 +242,22 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any parser.read_config(f=cdata) return parser + + +def merge_kv(args: dict | Any, k: str, v: Any) -> None: + """ + Update the `args` dict-like object with the key/value pair `k` and `v`. + """ + if k.startswith(MERGE_KEY): + id = k[1:] + if id in args and v is not None: + if isinstance(v, dict) and isinstance(args[id], dict): + args[id].update(v) + elif isinstance(v, list) and isinstance(args[id], list): + args[id].extend(v) + else: + raise ValueError(ValueError(f"config must be dict or list for key `{k}`, but got {type(v)}: {v}.")) + else: + args[id] = v + else: + args[k] = v diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index cf1edc8f08..2b00c9f9d1 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -125,6 +125,22 @@ def __call__(self, a, b): [0, 4], ] +TEST_CASE_MERGE_JSON = ["""{"key1": [0], "key2": [0] }""", """{"key1": [1], "+key2": [4] }""", "json", [1], [0, 4]] + +TEST_CASE_MERGE_YAML = [ + """ + key1: 0 + key2: [0] + """, + """ + key1: 1 + +key2: [4] + """, + "yaml", + 1, + [0, 4], +] + class TestConfigParser(unittest.TestCase): @@ -357,6 +373,22 @@ def test_parse_json_warn(self, config_string, extension, expected_unique_val, ex self.assertEqual(parser.get_parsed_content("key#unique"), expected_unique_val) self.assertIn(parser.get_parsed_content("key#duplicate"), expected_duplicate_vals) + @parameterized.expand([TEST_CASE_MERGE_JSON, TEST_CASE_MERGE_YAML]) + @skipUnless(has_yaml, "Requires pyyaml") + def test_load_configs( + self, config_string, config_string2, extension, expected_overridden_val, expected_merged_vals + ): + with tempfile.TemporaryDirectory() as tempdir: + config_path1 = Path(tempdir) / f"config1.{extension}" + config_path2 = Path(tempdir) / f"config2.{extension}" + config_path1.write_text(config_string) + config_path2.write_text(config_string2) + + parser = ConfigParser.load_config_files([config_path1, config_path2]) + + self.assertEqual(parser["key1"], expected_overridden_val) + self.assertEqual(parser["key2"], expected_merged_vals) + if __name__ == "__main__": unittest.main() From ceff01843ea3f621ee10fc3359fa81fb9bd57d50 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 06:12:48 +0000 Subject: [PATCH 50/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/config_syntax.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index bb29c453f7..488ff825a3 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -16,7 +16,7 @@ Content: - [`$` to evaluate as Python expressions](#to-evaluate-as-python-expressions) - [`%` to textually replace configuration elements](#to-textually-replace-configuration-elements) - [`_target_` (`_disabled_`, `_desc_`, `_requires_`, `_mode_`) to instantiate a Python object](#instantiate-a-python-object) - - [`+` to alter semantics of merging config keys from multiple configuration files](#multiple-config-files) + - [`+` to alter semantics of merging config keys from multiple configuration files](#multiple-config-files) - [The command line interface](#the-command-line-interface) - [Recommendations](#recommendations) @@ -182,21 +182,21 @@ _Description:_ Multiple config files may be specified on the command line. The content of those config files is being merged. When same keys are specifiled in more than one config file, the value associated with the key is being overridden, in the order config files are specified. If the desired behaviour is to merge values from both files, the key in second config file should be prefixed with '-'. -Here's an example. In this case, "amp" value will be overridden by json2 config, and "imports" list will me merged: +Here's an example. In this case, "amp" value will be overridden by json2 config, and "imports" list will me merged: ```json1 { - "amp": True + "amp": True "imports": [ - "$import torch" + "$import torch" ], } ``` ```json2 { - "amp": False + "amp": False "+imports": [ - "$from monai.networks import trt_compile" + "$from monai.networks import trt_compile" ], } ``` From e294968f9d4e73a5ef8fa8400029214ea0e0df2c Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 27 Aug 2024 15:59:35 -0700 Subject: [PATCH 51/85] Addressing code review comments Signed-off-by: Boris Fomitchev --- docs/source/config_syntax.md | 8 +++++--- monai/bundle/utils.py | 10 ++++++++-- monai/networks/trt_wrapper.py | 2 +- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index bb29c453f7..10f445aa99 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -181,8 +181,10 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k _Description:_ Multiple config files may be specified on the command line. The content of those config files is being merged. When same keys are specifiled in more than one config file, the value associated with the key is being overridden, in the order config files are specified. -If the desired behaviour is to merge values from both files, the key in second config file should be prefixed with '-'. -Here's an example. In this case, "amp" value will be overridden by json2 config, and "imports" list will me merged: +If the desired behaviour is to merge values from both files, the key in second config file should be prefixed with `+`. +Both values associated with `+`-prefixed key pair must be of `dict` or `list` type. +`dict` values will be merged via update(), `list` values - concatenated. +Here's an example. In this case, "amp" value will be overridden by json2 config, and "imports" list will be merged: ```json1 { "amp": True @@ -196,7 +198,7 @@ Here's an example. In this case, "amp" value will be overridden by json2 config, { "amp": False "+imports": [ - "$from monai.networks import trt_compile" + "$from monai.networks import trt_compile" ], } ``` diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py index eb6445116b..53d619f234 100644 --- a/monai/bundle/utils.py +++ b/monai/bundle/utils.py @@ -13,6 +13,7 @@ import json import os +import warnings import zipfile from typing import Any @@ -35,7 +36,7 @@ ID_SEP_KEY = "::" # separator for the ID of a ConfigItem EXPR_KEY = "$" # start of a ConfigExpression MACRO_KEY = "%" # start of a macro of a config -MERGE_KEY = "+" # start of a macro of a config +MERGE_KEY = "+" # prefix indicating merge instead of override in case of multiple configs. _conf_values = get_config_values() @@ -249,8 +250,12 @@ def merge_kv(args: dict | Any, k: str, v: Any) -> None: Update the `args` dict-like object with the key/value pair `k` and `v`. """ if k.startswith(MERGE_KEY): + """ + Both values associated with `+`-prefixed key pair must be of `dict` or `list` type. + `dict` values will be merged, `list` values - concatenated. + """ id = k[1:] - if id in args and v is not None: + if id in args: if isinstance(v, dict) and isinstance(args[id], dict): args[id].update(v) elif isinstance(v, list) and isinstance(args[id], list): @@ -258,6 +263,7 @@ def merge_kv(args: dict | Any, k: str, v: Any) -> None: else: raise ValueError(ValueError(f"config must be dict or list for key `{k}`, but got {type(v)}: {v}.")) else: + warnings.warn(f"Can't merge entry ['{k}'], '{id}' is not in target dict - copying instead.") args[id] = v else: args[k] = v diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_wrapper.py index d62cde6e52..8d295d31d5 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_wrapper.py @@ -486,7 +486,7 @@ def trt_compile( logger: Any | None = None, ) -> torch.nn.Module: """ - Instruments model or submodule with TrtWrappper and reppaces forward() with a hook. + Instruments model or submodule with TrtWrappper and replaces its forward() with TRT hook. Args: model: module to patch with TrtWrappper(). base_path: TRT plan(s) saved to "base_path[.submodule].plan" path. From b6d9179600c5d0bbe30dbdbcfe61868b91d4282d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 23:02:04 +0000 Subject: [PATCH 52/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/config_syntax.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index 50f637dc33..567ed87915 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -183,8 +183,8 @@ The content of those config files is being merged. When same keys are specifiled the value associated with the key is being overridden, in the order config files are specified. If the desired behaviour is to merge values from both files, the key in second config file should be prefixed with `+`. Both values associated with `+`-prefixed key pair must be of `dict` or `list` type. -`dict` values will be merged via update(), `list` values - concatenated. -Here's an example. In this case, "amp" value will be overridden by json2 config, and "imports" list will be merged: +`dict` values will be merged via update(), `list` values - concatenated. +Here's an example. In this case, "amp" value will be overridden by json2 config, and "imports" list will be merged: ```json1 { "amp": True From 652448a17716dc9b37c0bf17c607cdc83e72edeb Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 27 Aug 2024 18:39:43 -0700 Subject: [PATCH 53/85] doc build fixed Signed-off-by: Boris Fomitchev --- docs/source/config_syntax.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index 50f637dc33..0f8f5f1bf4 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -185,7 +185,7 @@ If the desired behaviour is to merge values from both files, the key in second c Both values associated with `+`-prefixed key pair must be of `dict` or `list` type. `dict` values will be merged via update(), `list` values - concatenated. Here's an example. In this case, "amp" value will be overridden by json2 config, and "imports" list will be merged: -```json1 +```json { "amp": True "imports": [ @@ -194,7 +194,7 @@ Here's an example. In this case, "amp" value will be overridden by json2 config, } ``` -```json2 +```json { "amp": False "+imports": [ From 5c4f63a03f82ae60b3122ba68a256fe034bbf7ac Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 27 Aug 2024 19:53:19 -0700 Subject: [PATCH 54/85] Fixed formatting Signed-off-by: Boris Fomitchev --- docs/source/config_syntax.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index f2e874d016..7b30fe31e5 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -185,7 +185,7 @@ If the desired behaviour is to merge values from both files, the key in second c Both values associated with `+`-prefixed key pair must be of `dict` or `list` type. `dict` values will be merged via update(), `list` values - concatenated. Here's an example. In this case, "amp" value will be overridden by json2 config, and "imports" list will be merged: -```json1 +```json { "amp": True "imports": [ From dd91183a918a73531ce52ac41dd26a8c2a891f3c Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 27 Aug 2024 21:09:16 -0700 Subject: [PATCH 55/85] Fixed formatting Signed-off-by: Boris Fomitchev --- docs/source/config_syntax.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index 7b30fe31e5..d871efe19b 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -187,7 +187,7 @@ Both values associated with `+`-prefixed key pair must be of `dict` or `list` ty Here's an example. In this case, "amp" value will be overridden by json2 config, and "imports" list will be merged: ```json { - "amp": True + "amp": "$True" "imports": [ "$import torch" ], @@ -196,7 +196,7 @@ Here's an example. In this case, "amp" value will be overridden by json2 config, ```json { - "amp": False + "amp": "$False" "+imports": [ "$from monai.networks import trt_compile" ], From c41cb5ac9c21dc187b0f5c1d717f3f957fe1459d Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 27 Aug 2024 21:20:19 -0700 Subject: [PATCH 56/85] Updated base container to 24.08 Signed-off-by: Boris Fomitchev --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 62abc268e8..e45932c6bb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.07-py3 +ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.08-py3 FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" From 7e440fc3313de29f396de8f7f2df59adfd712645 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 27 Aug 2024 23:03:28 -0700 Subject: [PATCH 57/85] Renaming trt_wrapper -> trt_compiler, adding TRT handler test Signed-off-by: Boris Fomitchev --- monai/networks/__init__.py | 2 +- .../{trt_wrapper.py => trt_compiler.py} | 22 +++++++++---------- tests/test_trt_compile.py | 21 ++++++++++++++++-- 3 files changed, 31 insertions(+), 14 deletions(-) rename monai/networks/{trt_wrapper.py => trt_compiler.py} (97%) diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 6c424ddfb9..5a240021d6 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -11,7 +11,7 @@ from __future__ import annotations -from .trt_wrapper import trt_compile +from .trt_compiler import trt_compile from .utils import ( add_casts_around_norms, convert_to_onnx, diff --git a/monai/networks/trt_wrapper.py b/monai/networks/trt_compiler.py similarity index 97% rename from monai/networks/trt_wrapper.py rename to monai/networks/trt_compiler.py index 8d295d31d5..66c6a8365e 100644 --- a/monai/networks/trt_wrapper.py +++ b/monai/networks/trt_compiler.py @@ -217,9 +217,9 @@ def infer(self, stream, use_cuda_graph=False): return self.tensors -class TrtWrappper: +class TrtCompiler: """ - This wrapper implements: + This class implements: - TRT lazy persistent export - Running TRT with optional fallback to Torch (for TRT engines with limited profiles) @@ -411,7 +411,7 @@ def _build_and_save(self, model, input_example): dbs = self.dynamic_batchsize if dbs: if len(self.profiles) > 0: - raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtWrappper!") + raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!") if len(dbs) != 3: raise ValueError("dynamic_batchsize has to have len ==3 ") profiles = {} @@ -473,9 +473,9 @@ def _build_and_save(self, model, input_example): def trt_forward(self, *argv, **kwargs): """ Patch function to replace original model's forward() with. - Redirects to TrtWrappper.forward() + Redirects to TrtCompiler.forward() """ - return self._trt_wrapper.forward(self, argv, kwargs) + return self._trt_compiler.forward(self, argv, kwargs) def trt_compile( @@ -486,15 +486,15 @@ def trt_compile( logger: Any | None = None, ) -> torch.nn.Module: """ - Instruments model or submodule with TrtWrappper and replaces its forward() with TRT hook. + Instruments model or submodule with TrtCompiler and replaces its forward() with TRT hook. Args: - model: module to patch with TrtWrappper(). + model: module to patch with TrtCompiler(). base_path: TRT plan(s) saved to "base_path[.submodule].plan" path. If base_path points to existing file (e.g. associated checkpoint), that file also becomes dependency - its mtime is added to args["timestamp"]. - args: dict : unpacked and passed to TrtWrappper() - see TrtWrapper above for details. + args: dict : unpacked and passed to TrtCompiler() - see TrtCompiler above for details. submodule : Hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder'] - If None, TrtWrappper patch is applied to the whole model. + If None, TrtCompiler patch is applied to the whole model. Otherwise, submodule (or list of) is being patched. Returns: Always returns same model passed in as argument. This is for ease of use in configs. @@ -519,8 +519,8 @@ def trt_compile( args["timestamp"] = timestamp def wrap(model, path): - wrapper = TrtWrappper(model, path + ".plan", logger=logger, **args) - model._trt_wrapper = wrapper + wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) + model._trt_compiler = wrapper model.forward = MethodType(trt_forward, model) def find_sub(parent, submodule): diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 1fdaa18c20..0b581494a6 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -15,8 +15,10 @@ import unittest import torch +from ignite.engine import Engine, Events from parameterized import parameterized +from monai.handlers import TrtHandler from monai.networks.nets import UNet from monai.utils import optional_import from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows @@ -42,6 +44,21 @@ def tearDown(self): if current_device != self.gpu_device: torch.cuda.set_device(self.gpu_device) + @unittest.skipUnless(has_trt, "TensorRT compile wrapper is required for convert!") + def test_handler(self): + net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) + data1 = net1.state_dict() + data1["0.weight"] = torch.tensor([0.1]) + data1["1.weight"] = torch.tensor([0.2]) + net1.load_state_dict(data1) + net1.cuda() + + with tempfile.TemporaryDirectory() as tempdir: + engine = Engine(lambda e, b: None) + TrtHandler(net1, tempdir + "/trt_handler").attach(engine) + engine.run([0] * 8, max_epochs=1) + self.assertIsNotNone(net1._trt_compiler) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @unittest.skipUnless(has_trt, "TensorRT compile wrapper is required for convert!") def test_value(self, precision): @@ -64,10 +81,10 @@ def test_value(self, precision): f"{tmpdir}/test_trt_compile", args={"precision": precision, "build_args": args, "dynamic_batchsize": [1, 4, 8]}, ) - self.assertIsNone(model._trt_wrapper.engine) + self.assertIsNone(model._trt_compiler.engine) trt_output = model(input_example) # Check that lazy TRT build succeeded - self.assertIsNotNone(model._trt_wrapper.engine) + self.assertIsNotNone(model._trt_compiler.engine) torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) From 329d0243bd16edef84ba2f6e20839ac01029b061 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Aug 2024 06:04:09 +0000 Subject: [PATCH 58/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_trt_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 0b581494a6..2755aecc22 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -15,7 +15,7 @@ import unittest import torch -from ignite.engine import Engine, Events +from ignite.engine import Engine from parameterized import parameterized from monai.handlers import TrtHandler From 84de860fed2871305da344b32e53f3ebab402221 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 27 Aug 2024 23:20:46 -0700 Subject: [PATCH 59/85] fixing CI error Signed-off-by: Boris Fomitchev --- tests/test_trt_compile.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 0b581494a6..ad13c5a7e7 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -15,7 +15,6 @@ import unittest import torch -from ignite.engine import Engine, Events from parameterized import parameterized from monai.handlers import TrtHandler @@ -46,6 +45,8 @@ def tearDown(self): @unittest.skipUnless(has_trt, "TensorRT compile wrapper is required for convert!") def test_handler(self): + from ignite.engine import Engine + net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) data1 = net1.state_dict() data1["0.weight"] = torch.tensor([0.1]) From b84cec4f234ebfaa448d83fcb53e8d4a6d39829f Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 27 Aug 2024 23:56:37 -0700 Subject: [PATCH 60/85] Fixing min test error, addressing comments Signed-off-by: Boris Fomitchev --- docs/source/config_syntax.md | 7 ++++--- tests/test_trt_compile.py | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index d871efe19b..ca3c36e18b 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -182,9 +182,10 @@ _Description:_ Multiple config files may be specified on the command line. The content of those config files is being merged. When same keys are specifiled in more than one config file, the value associated with the key is being overridden, in the order config files are specified. If the desired behaviour is to merge values from both files, the key in second config file should be prefixed with `+`. -Both values associated with `+`-prefixed key pair must be of `dict` or `list` type. -`dict` values will be merged via update(), `list` values - concatenated. -Here's an example. In this case, "amp" value will be overridden by json2 config, and "imports" list will be merged: +The value types for the merged contents must match and be both of `dict` or both of `list` type. +`dict` values will be merged via update(), `list` values - concatenated via extend(). +Here's an example. In this case, "amp" value will be overridden by json2 config, and "imports" lists will be merged. +An error would be thrown if the value type in `"+imports"` is not `list`: ```json { "amp": "$True" diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 2755aecc22..ad13c5a7e7 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -15,7 +15,6 @@ import unittest import torch -from ignite.engine import Engine from parameterized import parameterized from monai.handlers import TrtHandler @@ -46,6 +45,8 @@ def tearDown(self): @unittest.skipUnless(has_trt, "TensorRT compile wrapper is required for convert!") def test_handler(self): + from ignite.engine import Engine + net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) data1 = net1.state_dict() data1["0.weight"] = torch.tensor([0.1]) From 9481d9f6a368c5cac836cca39e493bcc20a6532a Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 28 Aug 2024 00:59:38 -0700 Subject: [PATCH 61/85] optional propagation of dynamo arg fixed, onnx_graphsurgeon package added Signed-off-by: Boris Fomitchev --- monai/networks/trt_compiler.py | 9 +++++---- monai/networks/utils.py | 2 -- requirements-dev.txt | 1 + tests/min_tests.py | 1 + 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index 66c6a8365e..b6f50a52b0 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -251,6 +251,9 @@ def __init__( plan_path : Path where to save persistent serialized TRT engine. precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. method: One of 'onnx'|'onnx_dynamo'|'torch_trt'. + Default is 'onnx' (torch.onnx.export()->TRT). This is the most stable and efficient option. + 'onnx_dynamo' is using experimental 'dynamo' export() option, it may not work with AMP. + 'torch_trt' may not work correctly for nets with multiple inputs or dynamic batch size. input_names: Optional list of input names. If None, will be read from the function signature. output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. @@ -443,12 +446,11 @@ def _build_and_save(self, model, input_example): ) else: if self.method == "onnx_dynamo": - dynamo = True import torch_onnx torch_onnx.patch_torch() - else: - dynamo = False + export_args["dynamo"] = True + # Use temporary directory for easy cleanup in case of external weights with tempfile.TemporaryDirectory() as tmpdir: onnx_path = Path(tmpdir) / "model.onnx" @@ -461,7 +463,6 @@ def _build_and_save(self, model, input_example): filename=str(onnx_path), input_names=self.input_names, output_names=self.output_names, - dynamo=dynamo, **export_args, ) self.logger.info("Export to ONNX successful.") diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 7fd07064c1..656096d384 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -638,7 +638,6 @@ def convert_to_onnx( use_trace: whether to use `torch.jit.trace` to export the torchscript model. do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done. constant_size_threshold: passed to polygrapy conatant forling, default = 16M - dynamo: passed to onnx.export(). [When dynamo export API is finalized] kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export() else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: https://pytorch.org/docs/master/generated/torch.jit.script.html. @@ -682,7 +681,6 @@ def convert_to_onnx( dynamic_axes=dynamic_axes, opset_version=opset_version, do_constant_folding=do_constant_folding, - # dynamo=dynamo, **torch_versioned_kwargs, ) if filename is None: diff --git a/requirements-dev.txt b/requirements-dev.txt index 97e4dff1fc..6d0ccd378a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -59,4 +59,5 @@ nvidia-ml-py huggingface_hub pyamg>=5.0.0 git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 +onnx_graphsurgeon polygraphy diff --git a/tests/min_tests.py b/tests/min_tests.py index f80d06f5d3..632355b5c6 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -186,6 +186,7 @@ def run_testsuit(): "test_torchvisiond", "test_transchex", "test_transformerblock", + "test_trt_compile", "test_unetr", "test_unetr_block", "test_vit", From 6a1158196eac16b4a44b8563e955c02d646ba277 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 28 Aug 2024 17:46:59 +0800 Subject: [PATCH 62/85] add vista test cases Signed-off-by: Yiheng Wang --- tests/test_trt_compile.py | 61 ++++++++++++++++++++++++++++++++++----- 1 file changed, 53 insertions(+), 8 deletions(-) diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index ad13c5a7e7..dcaae9bad4 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -18,13 +18,14 @@ from parameterized import parameterized from monai.handlers import TrtHandler -from monai.networks.nets import UNet +from monai.networks import trt_compile +from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 from monai.utils import optional_import from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows -trt_compile, has_trt = optional_import( - "monai.networks", name="trt_compile", descriptor="TRT compile is not available - check your installation!" -) +trt, trt_imported = optional_import("tensorrt") +polygraphy, polygraphy_imported = optional_import("polygraphy") +build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") TEST_CASE_1 = ["fp32"] TEST_CASE_2 = ["fp16"] @@ -43,7 +44,8 @@ def tearDown(self): if current_device != self.gpu_device: torch.cuda.set_device(self.gpu_device) - @unittest.skipUnless(has_trt, "TensorRT compile wrapper is required for convert!") + @unittest.skipUnless(trt_imported, "tensorrt is required") + @unittest.skipUnless(polygraphy_imported, "polygraphy is required") def test_handler(self): from ignite.engine import Engine @@ -61,8 +63,9 @@ def test_handler(self): self.assertIsNotNone(net1._trt_compiler) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - @unittest.skipUnless(has_trt, "TensorRT compile wrapper is required for convert!") - def test_value(self, precision): + @unittest.skipUnless(trt_imported, "Requires tensorrt") + @unittest.skipUnless(polygraphy_imported, "polygraphy is required") + def test_unet_value(self, precision): model = UNet( spatial_dims=3, in_channels=1, @@ -79,7 +82,7 @@ def test_value(self, precision): args: dict = {"builder_optimization_level": 1} trt_compile( model, - f"{tmpdir}/test_trt_compile", + f"{tmpdir}/test_unet_trt_compile", args={"precision": precision, "build_args": args, "dynamic_batchsize": [1, 4, 8]}, ) self.assertIsNone(model._trt_compiler.engine) @@ -88,6 +91,48 @@ def test_value(self, precision): self.assertIsNotNone(model._trt_compiler.engine) torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @unittest.skipUnless(trt_imported, "Requires tensorrt") + @unittest.skipUnless(polygraphy_imported, "polygraphy is required") + @unittest.skipUnless(has_sam, "Requires SAM installation") + def test_cell_sam_wrapper_value(self, precision): + model = cell_sam_wrapper.CellSamWrapper(checkpoint=None).to("cuda") + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + model.eval() + input_example = torch.randn(1, 3, 128, 128).to("cuda") + output_example = model(input_example) + trt_compile( + model, + f"{tmpdir}/test_cell_sam_wrapper_trt_compile", + args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @unittest.skipUnless(trt_imported, "Requires tensorrt") + @unittest.skipUnless(polygraphy_imported, "polygraphy is required") + def test_vista3d(self, precision): + model = vista3d132(in_channels=1).to("cuda") + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + model.eval() + input_example = torch.randn(1, 1, 64, 64, 64).to("cuda") + output_example = model(input_example) + trt_compile( + model, + f"{tmpdir}/test_vista3d_trt_compile", + args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, + submodule=["image_encoder.encoder", "class_head"], + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + if __name__ == "__main__": unittest.main() From 792a721d1d44f9844ab2e540540d14122ae817e6 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 28 Aug 2024 17:26:16 -0700 Subject: [PATCH 63/85] Code review input addressed Signed-off-by: Boris Fomitchev --- docs/source/config_syntax.md | 17 +++++++++++++++-- monai/networks/trt_compiler.py | 31 +++++++++++++++---------------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index ca3c36e18b..cc38903ced 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -184,23 +184,36 @@ the value associated with the key is being overridden, in the order config files If the desired behaviour is to merge values from both files, the key in second config file should be prefixed with `+`. The value types for the merged contents must match and be both of `dict` or both of `list` type. `dict` values will be merged via update(), `list` values - concatenated via extend(). -Here's an example. In this case, "amp" value will be overridden by json2 config, and "imports" lists will be merged. -An error would be thrown if the value type in `"+imports"` is not `list`: +Here's an example. In this case, "amp" value will be overridden by extra_config.json. +`imports` and `preprocessing#transforms` lists will be merged. An error would be thrown if the value type in `"+imports"` is not `list`: + +config.json: ```json { "amp": "$True" "imports": [ "$import torch" ], + "preprocessing": { + "_target_": "Compose", + "transforms": [ + "$@t1", + "$@t2" + ] + }, } ``` +extra_config.json: ```json { "amp": "$False" "+imports": [ "$from monai.networks import trt_compile" ], + "+preprocessing#transforms": [ + "$@t3" + ] } ``` diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index b6f50a52b0..d1a6c8ed5e 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -250,10 +250,9 @@ def __init__( model: Model to "wrap". plan_path : Path where to save persistent serialized TRT engine. precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. - method: One of 'onnx'|'onnx_dynamo'|'torch_trt'. + method: One of 'onnx'|'torch_trt'. Default is 'onnx' (torch.onnx.export()->TRT). This is the most stable and efficient option. - 'onnx_dynamo' is using experimental 'dynamo' export() option, it may not work with AMP. - 'torch_trt' may not work correctly for nets with multiple inputs or dynamic batch size. + 'torch_trt' may not work for nets with multiple inputs or dynamic batch size. AMP must be off for it. input_names: Optional list of input names. If None, will be read from the function signature. output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. @@ -445,12 +444,6 @@ def _build_and_save(self, model, input_example): **export_args, ) else: - if self.method == "onnx_dynamo": - import torch_onnx - - torch_onnx.patch_torch() - export_args["dynamo"] = True - # Use temporary directory for easy cleanup in case of external weights with tempfile.TemporaryDirectory() as tmpdir: onnx_path = Path(tmpdir) / "model.onnx" @@ -487,16 +480,18 @@ def trt_compile( logger: Any | None = None, ) -> torch.nn.Module: """ - Instruments model or submodule with TrtCompiler and replaces its forward() with TRT hook. + Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook. Args: - model: module to patch with TrtCompiler(). - base_path: TRT plan(s) saved to "base_path[.submodule].plan" path. - If base_path points to existing file (e.g. associated checkpoint), - that file also becomes dependency - its mtime is added to args["timestamp"]. - args: dict : unpacked and passed to TrtCompiler() - see TrtCompiler above for details. - submodule : Hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder'] + model: module to patch with TrtCompiler object. + base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path. + dirname(base_path) must exist, base_path does not have to. + If base_path does point to existing file (e.g. associated checkpoint), + that file becomes a dependency - its mtime is added to args["timestamp"]. + args: Optional dict : unpacked and passed to TrtCompiler() - see TrtCompiler above for details. + submodule: Optional hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder'] If None, TrtCompiler patch is applied to the whole model. Otherwise, submodule (or list of) is being patched. + logger: Optional logger for diagnostics. Returns: Always returns same model passed in as argument. This is for ease of use in configs. """ @@ -542,4 +537,8 @@ def find_sub(parent, submodule): wrap(getattr(parent, sub), base_path + "." + s) else: wrap(model, base_path) + else: + logger = logger or get_logger("trt_compile") + logger.warning("TensorRT and/or polygraphy packages are not available! trt_compile() has no effect.") + return model From 73ac717596b1d9561919c416f39a5733aba43e65 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 28 Aug 2024 20:03:27 -0700 Subject: [PATCH 64/85] Fixed torch-tensorrt path of trt_compile, added test Signed-off-by: Boris Fomitchev --- monai/networks/trt_compiler.py | 65 ++++++++++++++++++++++------------ monai/networks/utils.py | 40 ++++++++++++--------- tests/test_trt_compile.py | 11 +++--- 3 files changed, 73 insertions(+), 43 deletions(-) diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index d1a6c8ed5e..a9dd0d9e9b 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -23,7 +23,7 @@ import torch from monai.apps.utils import get_logger -from monai.networks.utils import add_casts_around_norms, convert_to_onnx, convert_to_torchscript +from monai.networks.utils import add_casts_around_norms, convert_to_onnx, convert_to_torchscript, get_profile_shapes from monai.utils.module import optional_import polygraphy, polygraphy_imported = optional_import("polygraphy") @@ -252,7 +252,7 @@ def __init__( precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. method: One of 'onnx'|'torch_trt'. Default is 'onnx' (torch.onnx.export()->TRT). This is the most stable and efficient option. - 'torch_trt' may not work for nets with multiple inputs or dynamic batch size. AMP must be off for it. + 'torch_trt' may not work for some nets. Also AMP must be turned off for it to work. input_names: Optional list of input names. If None, will be read from the function signature. output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. @@ -266,6 +266,14 @@ def __init__( timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes). fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile). """ + + method_vals = ["onnx", "torch_trt"] + if method not in method_vals: + raise ValueError(f"trt_compile(): 'method' should be one of {method_vals}, got: {method}.") + precision_vals = ["fp32", "tf32", "fp16", "bf16"] + if precision not in precision_vals: + raise ValueError(f"trt_compile(): 'precision' should be one of {precision_vals}, got: {precision}.") + self.plan_path = plan_path self.precision = precision self.method = method @@ -321,10 +329,6 @@ def forward(self, model, argv, kwargs): Returns: Passing through wrapped module's forward() return value(s) """ - if len(argv) > 0: - kwargs.update(self._inputs_to_dict(argv)) - argv = () - if self.engine is None and not self.disabled: # Restore original forward for export new_forward = model.forward @@ -332,7 +336,11 @@ def forward(self, model, argv, kwargs): try: self._load_engine() if self.engine is None: - self._build_and_save(model, kwargs) + build_args = kwargs.copy() + if len(argv) > 0: + build_args.update(self._inputs_to_dict(argv)) + self._build_and_save(model, build_args) + # This will reassign input_names from the engine self._load_engine() except Exception as e: if self.fallback: @@ -349,6 +357,10 @@ def forward(self, model, argv, kwargs): model.forward = new_forward # Run the engine try: + if len(argv) > 0: + kwargs.update(self._inputs_to_dict(argv)) + argv = () + if self.engine is not None: # forward_trt is not thread safe as we do not use per-thread execution contexts with lock_sm: @@ -410,20 +422,6 @@ def _build_and_save(self, model, input_example): return export_args = self.export_args - dbs = self.dynamic_batchsize - if dbs: - if len(self.profiles) > 0: - raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!") - if len(dbs) != 3: - raise ValueError("dynamic_batchsize has to have len ==3 ") - profiles = {} - for id, val in input_example.items(): - sh = val.shape[1:] - profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] - self.profiles = [profiles] - - if len(self.profiles) > 0: - export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) add_casts_around_norms(model) @@ -435,15 +433,38 @@ def _build_and_save(self, model, input_example): enabled_precisions.append(torch.bfloat16) inputs = list(input_example.values()) ir_model = convert_to_torchscript(model, inputs=inputs, use_trace=True) + + def get_torch_trt_input(input_shape, dynamic_batchsize): + min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) + return torch_tensorrt.Input( + min_shape=min_input_shape, opt_shape=opt_input_shape, max_shape=max_input_shape + ) + + tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs] engine_bytes = torch_tensorrt.convert_method_to_trt_engine( ir_model, "forward", - inputs=inputs, + inputs=tt_inputs, ir="torchscript", enabled_precisions=enabled_precisions, **export_args, ) else: + dbs = self.dynamic_batchsize + if dbs: + if len(self.profiles) > 0: + raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!") + if len(dbs) != 3: + raise ValueError("dynamic_batchsize has to have len ==3 ") + profiles = {} + for id, val in input_example.items(): + sh = val.shape[1:] + profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] + self.profiles = [profiles] + + if len(self.profiles) > 0: + export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) + # Use temporary directory for easy cleanup in case of external weights with tempfile.TemporaryDirectory() as tmpdir: onnx_path = Path(tmpdir) / "model.onnx" diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 656096d384..c72a16d4c7 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -37,6 +37,7 @@ onnxreference, _ = optional_import("onnx.reference") onnxruntime, _ = optional_import("onnxruntime") polygraphy, polygraphy_imported = optional_import("polygraphy") +torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0") __all__ = [ "one_hot", @@ -62,6 +63,7 @@ "look_up_named_module", "set_named_module", "has_nvfuser_instance_norm", + "get_profile_shapes", ] logger = get_logger(module_name=__name__) @@ -69,6 +71,26 @@ _has_nvfuser = None +def get_profile_shapes(input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None): + """ + Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize. + """ + + def scale_batch_size(input_shape: Sequence[int], scale_num: int): + scale_shape = [*input_shape] + scale_shape[0] = scale_num + return scale_shape + + # Use the dynamic batchsize range to generate the min, opt and max model input shape + if dynamic_batchsize: + min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0]) + opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1]) + max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2]) + else: + min_input_shape = opt_input_shape = max_input_shape = input_shape + return min_input_shape, opt_input_shape, max_input_shape + + def has_nvfuser_instance_norm(): """whether the current environment has InstanceNorm3dNVFuser https://github.com/NVIDIA/apex/blob/23.05-devel/apex/normalization/instance_norm.py#L15-L16 @@ -827,7 +849,6 @@ def _onnx_trt_compile( """ trt, _ = optional_import("tensorrt", "8.5.3") - torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0") input_shapes = (min_shape, opt_shape, max_shape) # default to an empty list to fit the `torch_tensorrt.ts.embed_engine_in_new_module` function. @@ -929,8 +950,6 @@ def convert_to_trt( to compile model, for more details: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html#torch-tensorrt-py. """ - torch_tensorrt, _ = optional_import("torch_tensorrt", version="1.4.0") - if not torch.cuda.is_available(): raise Exception("Cannot find any GPU devices.") @@ -948,21 +967,9 @@ def convert_to_trt( convert_precision = torch.float32 if precision == "fp32" else torch.half inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)] - def scale_batch_size(input_shape: Sequence[int], scale_num: int): - scale_shape = [*input_shape] - scale_shape[0] *= scale_num - return scale_shape - - # Use the dynamic batchsize range to generate the min, opt and max model input shape - if dynamic_batchsize: - min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0]) - opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1]) - max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2]) - else: - min_input_shape = opt_input_shape = max_input_shape = input_shape - # convert the torch model to a TorchScript model on target device model = model.eval().to(target_device) + min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) if use_onnx: # set the batch dim as dynamic @@ -971,7 +978,6 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int): ir_model = convert_to_onnx( model, inputs, onnx_input_names, onnx_output_names, use_trace=use_trace, dynamic_axes=dynamic_axes ) - # convert the model through the ONNX-TensorRT way trt_model = _onnx_trt_compile( ir_model, diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index ad13c5a7e7..20b84562c8 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -33,6 +33,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick +@unittest.skipUnless(has_trt, "TensorRT compile wrapper is required for convert!") class TestTRTCompile(unittest.TestCase): def setUp(self): @@ -43,7 +44,6 @@ def tearDown(self): if current_device != self.gpu_device: torch.cuda.set_device(self.gpu_device) - @unittest.skipUnless(has_trt, "TensorRT compile wrapper is required for convert!") def test_handler(self): from ignite.engine import Engine @@ -56,12 +56,15 @@ def test_handler(self): with tempfile.TemporaryDirectory() as tempdir: engine = Engine(lambda e, b: None) - TrtHandler(net1, tempdir + "/trt_handler").attach(engine) + args = {"method": "torch_trt"} + TrtHandler(net1, tempdir + "/trt_handler", args=args).attach(engine) engine.run([0] * 8, max_epochs=1) self.assertIsNotNone(net1._trt_compiler) + self.assertIsNone(net1._trt_compiler.engine) + net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device="cuda")) + self.assertIsNotNone(net1._trt_compiler.engine) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - @unittest.skipUnless(has_trt, "TensorRT compile wrapper is required for convert!") def test_value(self, precision): model = UNet( spatial_dims=3, @@ -74,7 +77,7 @@ def test_value(self, precision): ).cuda() with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: model.eval() - input_example = torch.randn(1, 1, 96, 96, 96).cuda() + input_example = torch.randn(2, 1, 96, 96, 96).cuda() output_example = model(input_example) args: dict = {"builder_optimization_level": 1} trt_compile( From 1e7e76d17dc93cafd4ed303c1c3038bec3c9c49a Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 28 Aug 2024 20:44:33 -0700 Subject: [PATCH 65/85] Fixing tests Signed-off-by: Boris Fomitchev --- tests/test_trt_compile.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 174af80a69..21125d203f 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -119,16 +119,20 @@ def test_vista3d(self, precision): model.eval() input_example = torch.randn(1, 1, 64, 64, 64).to("cuda") output_example = model(input_example) - trt_compile( + model = trt_compile( model, f"{tmpdir}/test_vista3d_trt_compile", args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, submodule=["image_encoder.encoder", "class_head"], ) - self.assertIsNone(model._trt_compiler.engine) - trt_output = model(input_example) + self.assertIsNotNone(model.image_encoder.encoder._trt_compiler) + self.assertIsNotNone(model.class_head._trt_compiler) + trt_output = model.forward(input_example) # Check that lazy TRT build succeeded - self.assertIsNotNone(model._trt_compiler.engine) + # TODO: set up input_example in such a way that image_encoder.encoder and class_head are called + # and uncomment the asserts below + # self.assertIsNotNone(model.image_encoder.encoder._trt_compiler.engine) + # self.assertIsNotNone(model.class_head._trt_compiler.engine) torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) From 47e676e6aade83e01ab1598edb78263f6fade5d5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Aug 2024 03:44:57 +0000 Subject: [PATCH 66/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/config_syntax.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index cc38903ced..742841acca 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -199,7 +199,7 @@ config.json: "transforms": [ "$@t1", "$@t2" - ] + ] }, } ``` From c126e6750c8351e99c3cf36c6cbb810dc6b02bf7 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 2 Sep 2024 20:26:50 -0700 Subject: [PATCH 67/85] Fixing TRT 8.x compatibility Signed-off-by: Boris Fomitchev --- monai/networks/trt_compiler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index a9dd0d9e9b..24d16df842 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -403,8 +403,10 @@ def _onnx_to_trt(self, onnx_path): build_args = self.build_args.copy() build_args["tf32"] = self.precision != "fp32" - build_args["fp16"] = self.precision == "fp16" - build_args["bf16"] = self.precision == "bf16" + if self.precision == "fp16": + build_args["fp16"] = True + elif self.precision == "bf16": + build_args["bf16"] = True self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}") network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) From 5645157568ffaea6035f8bdf37067c7d82013c6a Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 2 Sep 2024 22:06:51 -0700 Subject: [PATCH 68/85] Improved diagnostic, skip trt test if < 10.3 Signed-off-by: Boris Fomitchev --- monai/networks/trt_compiler.py | 2 ++ tests/test_trt_compile.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index 24d16df842..00d2eb61af 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -342,6 +342,7 @@ def forward(self, model, argv, kwargs): self._build_and_save(model, build_args) # This will reassign input_names from the engine self._load_engine() + assert self.engine is not None except Exception as e: if self.fallback: self.logger.info(f"Failed to build engine: {e}") @@ -504,6 +505,7 @@ def trt_compile( ) -> torch.nn.Module: """ Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook. + Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x Args: model: module to patch with TrtCompiler object. base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path. diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 21125d203f..2f9db8f0c2 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -20,10 +20,10 @@ from monai.handlers import TrtHandler from monai.networks import trt_compile from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 -from monai.utils import optional_import +from monai.utils import min_version, optional_import from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows -trt, trt_imported = optional_import("tensorrt") +trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version) polygraphy, polygraphy_imported = optional_import("polygraphy") build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") From 72a4c3dded1834fd4205c834ed18b0e43a9de774 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 2 Oct 2024 17:21:22 -0700 Subject: [PATCH 69/85] trt_compile post-fixes Signed-off-by: Boris Fomitchev --- monai/networks/trt_compiler.py | 60 ++++++++++++++++++++++++---------- monai/networks/utils.py | 9 +++-- tests/test_trt_compile.py | 2 ++ 3 files changed, 51 insertions(+), 20 deletions(-) diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index 00d2eb61af..35bcf3de81 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -134,6 +134,7 @@ def __init__(self, plan_path, logger=None): self.output_names.append(binding) dtype = dtype_dict[self.engine.get_tensor_dtype(binding)] self.dtypes.append(dtype) + self.logger.info(f"Loaded TensorRT engine: {self.plan_path}.\nInputs: {self.input_names}\nOutputs: {self.output_names}") def allocate_buffers(self, device): """ @@ -180,7 +181,8 @@ def try_set_inputs(): raise self.cur_profile = next_profile ctx.set_optimization_profile_async(self.cur_profile, stream) - + except Exception: + raise left = ctx.infer_shapes() assert len(left) == 0 @@ -216,6 +218,17 @@ def infer(self, stream, use_cuda_graph=False): return self.tensors +def remove_non_tensors(input_example): + # + # TODO : see if we can instantiate wrappers to handle non-default non-tensors + # + non_tensors = {} + for k, v in input_example.items(): + if not torch.is_tensor(v): + # print(f"Removing non-tensor input: {k} ({type(v)})") + non_tensors[k] = v + for key in non_tensors.keys(): + input_example.pop(key) class TrtCompiler: """ @@ -240,6 +253,7 @@ def __init__( use_cuda_graph=False, timestamp=None, fallback=False, + forward_override=None, logger=None, ): """ @@ -287,7 +301,7 @@ def __init__( self.use_cuda_graph = use_cuda_graph self.fallback = fallback self.disabled = False - + self.logger = logger or get_logger("trt_compile") # Normally we read input_names from forward() but can be overridden @@ -315,8 +329,9 @@ def _load_engine(self): try: self.engine = TRTEngine(self.plan_path, self.logger) self.input_names = self.engine.input_names + self.logger.info(f"Engine loaded, inputs:{self.input_names}") except Exception as e: - self.logger.debug(f"Exception while loading the engine:\n{e}") + self.logger.info(f"Exception while loading the engine:\n{e}") def forward(self, model, argv, kwargs): """ @@ -339,8 +354,9 @@ def forward(self, model, argv, kwargs): build_args = kwargs.copy() if len(argv) > 0: build_args.update(self._inputs_to_dict(argv)) - self._build_and_save(model, build_args) - # This will reassign input_names from the engine + with torch.no_grad(): + self._build_and_save(model, build_args) + # This will reassign input_names from the engine self._load_engine() assert self.engine is not None except Exception as e: @@ -355,14 +371,15 @@ def forward(self, model, argv, kwargs): del param # Call empty_cache to release GPU memory torch.cuda.empty_cache() + # restore TRT hook model.forward = new_forward # Run the engine try: - if len(argv) > 0: - kwargs.update(self._inputs_to_dict(argv)) - argv = () - if self.engine is not None: + if len(argv) > 0: + kwargs.update(self._inputs_to_dict(argv)) + argv = () + remove_non_tensors(kwargs) # forward_trt is not thread safe as we do not use per-thread execution contexts with lock_sm: device = torch.cuda.current_device() @@ -379,7 +396,7 @@ def forward(self, model, argv, kwargs): ret = ret[0] return ret except Exception as e: - if model is not None: + if self.fallback: self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...") else: raise e @@ -420,12 +437,15 @@ def _build_and_save(self, model, input_example): Args: input_example: passed to onnx.export() """ - + if self.engine is not None: return export_args = self.export_args + remove_non_tensors(input_example) + + engine_bytes = None add_casts_around_norms(model) if self.method == "torch_trt": @@ -465,27 +485,29 @@ def get_torch_trt_input(input_shape, dynamic_batchsize): profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] self.profiles = [profiles] + dynamic_axes = get_dynamic_axes(self.profiles) + if len(self.profiles) > 0: - export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) + export_args.update({"dynamic_axes": dynamic_axes}) # Use temporary directory for easy cleanup in case of external weights with tempfile.TemporaryDirectory() as tmpdir: - onnx_path = Path(tmpdir) / "model.onnx" + onnx_path = str(Path(tmpdir) / "model.onnx") self.logger.info( - f"Exporting to {onnx_path}:\n\toutput_names={self.output_names}\n\texport args: {export_args}" + f"Exporting to {onnx_path}:\ninputs={list(input_example.keys())}\toutput_names={self.output_names}\n\texport args: {export_args}" ) convert_to_onnx( model, input_example, - filename=str(onnx_path), + filename=onnx_path, input_names=self.input_names, output_names=self.output_names, **export_args, ) self.logger.info("Export to ONNX successful.") - engine_bytes = self._onnx_to_trt(str(onnx_path)) - - open(self.plan_path, "wb").write(engine_bytes) + engine_bytes = self._onnx_to_trt(onnx_path) + if engine_bytes: + open(self.plan_path, "wb").write(engine_bytes) def trt_forward(self, *argv, **kwargs): @@ -540,6 +562,8 @@ def trt_compile( args["timestamp"] = timestamp def wrap(model, path): + if not hasattr(model, "_trt_compiler"): + model.orig_forward = model.forward wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) model._trt_compiler = wrapper model.forward = MethodType(trt_forward, model) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index d0150b4e5b..d7c0c0df57 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -631,7 +631,6 @@ def convert_to_onnx( use_trace: bool = True, do_constant_folding: bool = True, constant_size_threshold: int = 16 * 1024 * 1024 * 1024, - dynamo=False, **kwargs, ): """ @@ -672,6 +671,9 @@ def convert_to_onnx( # let torch.onnx.export to trace the model. mode_to_export = model torch_versioned_kwargs = kwargs + if "dynamo" in kwargs: + torch_versioned_kwargs["verify"] = verify + verify = False else: if not pytorch_after(1, 10): if "example_outputs" not in kwargs: @@ -693,7 +695,7 @@ def convert_to_onnx( f = io.BytesIO() else: f = filename - + print(f"torch_versioned_kwargs={torch_versioned_kwargs}") torch.onnx.export( mode_to_export, onnx_inputs, @@ -716,6 +718,9 @@ def convert_to_onnx( fold_constants(onnx_model, size_threshold=constant_size_threshold) if verify: + if isinstance(inputs, dict): + inputs = list(inputs.values()) + if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 2f9db8f0c2..8007148888 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -24,6 +24,7 @@ from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version) +torch_tensorrt, torch_trt_imported = optional_import("torch_tensorrt") polygraphy, polygraphy_imported = optional_import("polygraphy") build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") @@ -46,6 +47,7 @@ def tearDown(self): if current_device != self.gpu_device: torch.cuda.set_device(self.gpu_device) + @unittest.skipUnless(torch_trt_imported, "torch_tensorrt is required") def test_handler(self): from ignite.engine import Engine From bf61b48c5e311e71394fa3ba950f90c8db12263b Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 9 Oct 2024 14:54:46 -0700 Subject: [PATCH 70/85] exporting controlnet Signed-off-by: Boris Fomitchev --- monai/networks/blocks/spatialattention.py | 24 +++++------------------ monai/networks/utils.py | 2 +- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 665442b55e..1c359dba9d 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -19,7 +19,6 @@ from monai.networks.blocks import SABlock from monai.utils import optional_import -Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") class SpatialAttentionBlock(nn.Module): @@ -71,27 +70,14 @@ def __init__( use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) - + def forward(self, x: torch.Tensor): residual = x - - if self.spatial_dims == 1: - h = x.shape[2] - rearrange_input = Rearrange("b c h -> b h c") - rearrange_output = Rearrange("b h c -> b c h", h=h) - if self.spatial_dims == 2: - h, w = x.shape[2], x.shape[3] - rearrange_input = Rearrange("b c h w -> b (h w) c") - rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w) - else: - h, w, d = x.shape[2], x.shape[3], x.shape[4] - rearrange_input = Rearrange("b c h w d -> b (h w d) c") - rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d) - + shape = x.shape x = self.norm(x) - x = rearrange_input(x) # B x (x_dim * y_dim [ * z_dim]) x C - + x = x.reshape(shape[0], shape[1], -1).transpose(1,2) x = self.attn(x) - x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim] + x = x.transpose(1,2) + x = x.reshape(shape) x = x + residual return x diff --git a/monai/networks/utils.py b/monai/networks/utils.py index d7c0c0df57..55deb79434 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -671,7 +671,7 @@ def convert_to_onnx( # let torch.onnx.export to trace the model. mode_to_export = model torch_versioned_kwargs = kwargs - if "dynamo" in kwargs: + if "dynamo" in kwargs and kwargs["dynamo"] and verify: torch_versioned_kwargs["verify"] = verify verify = False else: From 6297b4513311b2949af7bfabd624314d1e0c520f Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 10 Oct 2024 00:32:30 -0700 Subject: [PATCH 71/85] Working controlnet TRT Signed-off-by: Boris Fomitchev --- .../maisi/networks/controlnet_maisi.py | 2 +- monai/networks/blocks/spatialattention.py | 5 ++- monai/networks/trt_compiler.py | 36 +++++++++++++------ 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 269086d971..384f7970e4 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -119,7 +119,7 @@ def forward( down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] mid_block_res_sample *= conditioning_scale - return down_block_res_samples, mid_block_res_sample + return [*down_block_res_samples, mid_block_res_sample] def _prepare_time_and_class_embedding(self, x, timesteps, class_labels): # 1. time diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 1c359dba9d..c6ce8487e0 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -75,9 +75,8 @@ def forward(self, x: torch.Tensor): residual = x shape = x.shape x = self.norm(x) - x = x.reshape(shape[0], shape[1], -1).transpose(1,2) + x = x.reshape(*shape[:2], -1).transpose(1,2) # "b c h w d -> b (h w d) c" x = self.attn(x) - x = x.transpose(1,2) - x = x.reshape(shape) + x = x.transpose(1,2).reshape(shape) # "b (h w d) c -> b c h w d" x = x + residual return x diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index 35bcf3de81..5141a9e90f 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -164,7 +164,8 @@ def set_inputs(self, feed_dict, stream): last_profile = self.cur_profile def try_set_inputs(): - for binding, t in feed_dict.items(): + for binding in self.input_names: + t = feed_dict[binding] if t is not None: t = t.contiguous() shape = t.shape @@ -218,17 +219,24 @@ def infer(self, stream, use_cuda_graph=False): return self.tensors -def remove_non_tensors(input_example): +def remove_non_tensors(input_example, remove_constants=True): # # TODO : see if we can instantiate wrappers to handle non-default non-tensors # non_tensors = {} for k, v in input_example.items(): - if not torch.is_tensor(v): - # print(f"Removing non-tensor input: {k} ({type(v)})") + if v is None: non_tensors[k] = v + elif not torch.is_tensor(v): + if remove_constants: + non_tensors[k] = v + else: + input_example[k] = torch.tensor(v) + for key in non_tensors.keys(): + # print(f"Removing non-tensor input: {key})") input_example.pop(key) + return non_tensors class TrtCompiler: """ @@ -303,11 +311,16 @@ def __init__( self.disabled = False self.logger = logger or get_logger("trt_compile") - + self.argspec = inspect.getfullargspec(model.forward) # Normally we read input_names from forward() but can be overridden if input_names is None: - argspec = inspect.getfullargspec(model.forward) - input_names = argspec.args[1:] + input_names = self.argspec.args[1:] + self.defaults = {} + for i in range(len(self.argspec.defaults)): + d = self.argspec.defaults[-i-1] + if d is not None: + self.defaults[self.argspec.args[-i-1]] = torch.tensor(d).cuda() + self.input_names = input_names self.old_forward = model.forward @@ -376,15 +389,16 @@ def forward(self, model, argv, kwargs): # Run the engine try: if self.engine is not None: + args = self.defaults + args.update(kwargs) if len(argv) > 0: - kwargs.update(self._inputs_to_dict(argv)) - argv = () - remove_non_tensors(kwargs) + args.update(self._inputs_to_dict(argv)) + remove_non_tensors(args, remove_constants=False) # forward_trt is not thread safe as we do not use per-thread execution contexts with lock_sm: device = torch.cuda.current_device() stream = torch.cuda.Stream(device=device) - self.engine.set_inputs(kwargs, stream.cuda_stream) + self.engine.set_inputs(args, stream.cuda_stream) self.engine.allocate_buffers(device=device) # Need this to synchronize with Torch stream stream.wait_stream(torch.cuda.current_stream()) From f00fea417f67ef4062cef59346053ed61006be7d Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 10 Oct 2024 00:36:50 -0700 Subject: [PATCH 72/85] Reformat Signed-off-by: Boris Fomitchev --- monai/networks/blocks/spatialattention.py | 7 +++---- monai/networks/trt_compiler.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index c6ce8487e0..903d11942d 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -20,7 +20,6 @@ from monai.utils import optional_import - class SpatialAttentionBlock(nn.Module): """Perform spatial self-attention on the input tensor. @@ -70,13 +69,13 @@ def __init__( use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) - + def forward(self, x: torch.Tensor): residual = x shape = x.shape x = self.norm(x) - x = x.reshape(*shape[:2], -1).transpose(1,2) # "b c h w d -> b (h w d) c" + x = x.reshape(*shape[:2], -1).transpose(1, 2) # "b c h w d -> b (h w d) c" x = self.attn(x) - x = x.transpose(1,2).reshape(shape) # "b (h w d) c -> b c h w d" + x = x.transpose(1, 2).reshape(shape) # "b (h w d) c -> b c h w d" x = x + residual return x diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index 5141a9e90f..3535ae5c7a 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -134,7 +134,9 @@ def __init__(self, plan_path, logger=None): self.output_names.append(binding) dtype = dtype_dict[self.engine.get_tensor_dtype(binding)] self.dtypes.append(dtype) - self.logger.info(f"Loaded TensorRT engine: {self.plan_path}.\nInputs: {self.input_names}\nOutputs: {self.output_names}") + self.logger.info( + f"Loaded TensorRT engine: {self.plan_path}.\nInputs: {self.input_names}\nOutputs: {self.output_names}" + ) def allocate_buffers(self, device): """ @@ -219,6 +221,7 @@ def infer(self, stream, use_cuda_graph=False): return self.tensors + def remove_non_tensors(input_example, remove_constants=True): # # TODO : see if we can instantiate wrappers to handle non-default non-tensors @@ -238,6 +241,7 @@ def remove_non_tensors(input_example, remove_constants=True): input_example.pop(key) return non_tensors + class TrtCompiler: """ This class implements: @@ -309,7 +313,7 @@ def __init__( self.use_cuda_graph = use_cuda_graph self.fallback = fallback self.disabled = False - + self.logger = logger or get_logger("trt_compile") self.argspec = inspect.getfullargspec(model.forward) # Normally we read input_names from forward() but can be overridden @@ -317,9 +321,9 @@ def __init__( input_names = self.argspec.args[1:] self.defaults = {} for i in range(len(self.argspec.defaults)): - d = self.argspec.defaults[-i-1] + d = self.argspec.defaults[-i - 1] if d is not None: - self.defaults[self.argspec.args[-i-1]] = torch.tensor(d).cuda() + self.defaults[self.argspec.args[-i - 1]] = torch.tensor(d).cuda() self.input_names = input_names self.old_forward = model.forward @@ -451,7 +455,7 @@ def _build_and_save(self, model, input_example): Args: input_example: passed to onnx.export() """ - + if self.engine is not None: return @@ -500,7 +504,7 @@ def get_torch_trt_input(input_shape, dynamic_batchsize): self.profiles = [profiles] dynamic_axes = get_dynamic_axes(self.profiles) - + if len(self.profiles) > 0: export_args.update({"dynamic_axes": dynamic_axes}) From 57fbcf046c27293431791a429b4a81e6f2b59e5c Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 11 Oct 2024 23:27:21 -0700 Subject: [PATCH 73/85] Working TRT for MAISI Signed-off-by: Boris Fomitchev --- monai/networks/nets/vista3d.py | 14 +++++---- monai/networks/schedulers/ddim.py | 2 +- monai/networks/schedulers/ddpm.py | 5 ++-- monai/networks/trt_compiler.py | 48 +++++++++++++++++++------------ 4 files changed, 43 insertions(+), 26 deletions(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 4215a9a594..12f29367da 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -13,7 +13,7 @@ import math from typing import Any, Callable, Optional, Sequence, Tuple - +import time import numpy as np import torch import torch.nn.functional as F @@ -433,9 +433,13 @@ def forward( if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None: out, out_auto = self.image_embeddings, None else: + torch.cuda.synchronize() + t0 = time.time() out, out_auto = self.image_encoder( input_images, with_point=point_coords is not None, with_label=class_vector is not None ) + torch.cuda.synchronize() + print("Encoder time : ", time.time()-t0, input_images.shape) # release memory input_images = None # type: ignore @@ -639,10 +643,10 @@ def forward(self, src: torch.Tensor, class_vector: torch.Tensor): if self.use_mlp: class_embedding = self.mlp(class_embedding) # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension. - masks = [] - for i in range(b): - mask = class_embedding @ src[[i]].view(1, c, h * w * d) - masks.append(mask.view(-1, 1, h, w, d)) + masks_embedding = class_embedding.squeeze() @ src.view(b, c, h * w * d) + masks_embedding = masks_embedding.view(b, -1, h, w, d).transpose(0, 1) + + return masks_embedding, class_embedding return torch.cat(masks, 1), class_embedding diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index 2a0121d063..50a680336d 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -220,7 +220,7 @@ def step( if eta > 0: # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu") - noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator, device=device) variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise pred_prev_sample = pred_prev_sample + variance diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py index 93ad833031..3eba12ee72 100644 --- a/monai/networks/schedulers/ddpm.py +++ b/monai/networks/schedulers/ddpm.py @@ -241,8 +241,9 @@ def step( variance = 0 if timestep > 0: noise = torch.randn( - model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator - ).to(model_output.device) + model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator, + device=model_output.device + ) variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise pred_prev_sample = pred_prev_sample + variance diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index 3535ae5c7a..b0802a5341 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -241,6 +241,18 @@ def remove_non_tensors(input_example, remove_constants=True): input_example.pop(key) return non_tensors +def unroll_input(input_names, input_example): + # Simulate list/tuple unrolling during ONNX export + unrolled_input={} + for name in input_names: + val = input_example[name] + if val is not None: + if isinstance(val, list | tuple): + for i in range(len(val)): + unrolled_input[f"{name}_{i}"] = val[i] + else: + unrolled_input[name] = val + return unrolled_input class TrtCompiler: """ @@ -320,10 +332,12 @@ def __init__( if input_names is None: input_names = self.argspec.args[1:] self.defaults = {} - for i in range(len(self.argspec.defaults)): - d = self.argspec.defaults[-i - 1] - if d is not None: - self.defaults[self.argspec.args[-i - 1]] = torch.tensor(d).cuda() + if self.argspec.defaults is not None: + for i in range(len(self.argspec.defaults)): + d = self.argspec.defaults[-i - 1] + if d is not None: + d = torch.tensor(d).cuda() + self.defaults[self.argspec.args[-i - 1]] = d self.input_names = input_names self.old_forward = model.forward @@ -345,8 +359,7 @@ def _load_engine(self): """ try: self.engine = TRTEngine(self.plan_path, self.logger) - self.input_names = self.engine.input_names - self.logger.info(f"Engine loaded, inputs:{self.input_names}") + self.logger.info(f"Engine loaded, inputs:{self.engine.input_names}") except Exception as e: self.logger.info(f"Exception while loading the engine:\n{e}") @@ -361,6 +374,11 @@ def forward(self, model, argv, kwargs): Returns: Passing through wrapped module's forward() return value(s) """ + args = self.defaults + args.update(kwargs) + if len(argv) > 0: + args.update(self._inputs_to_dict(argv)) + if self.engine is None and not self.disabled: # Restore original forward for export new_forward = model.forward @@ -368,9 +386,7 @@ def forward(self, model, argv, kwargs): try: self._load_engine() if self.engine is None: - build_args = kwargs.copy() - if len(argv) > 0: - build_args.update(self._inputs_to_dict(argv)) + build_args = args.copy() with torch.no_grad(): self._build_and_save(model, build_args) # This will reassign input_names from the engine @@ -393,16 +409,11 @@ def forward(self, model, argv, kwargs): # Run the engine try: if self.engine is not None: - args = self.defaults - args.update(kwargs) - if len(argv) > 0: - args.update(self._inputs_to_dict(argv)) - remove_non_tensors(args, remove_constants=False) # forward_trt is not thread safe as we do not use per-thread execution contexts with lock_sm: device = torch.cuda.current_device() stream = torch.cuda.Stream(device=device) - self.engine.set_inputs(args, stream.cuda_stream) + self.engine.set_inputs(unroll_input(self.input_names, args), stream.cuda_stream) self.engine.allocate_buffers(device=device) # Need this to synchronize with Torch stream stream.wait_stream(torch.cuda.current_stream()) @@ -448,7 +459,7 @@ def _onnx_to_trt(self, onnx_path): network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) - def _build_and_save(self, model, input_example): + def _build_and_save(self, model, input_example): """ If TRT engine is not ready, exports model to ONNX, builds TRT engine and saves serialized TRT engine to the disk. @@ -461,7 +472,7 @@ def _build_and_save(self, model, input_example): export_args = self.export_args - remove_non_tensors(input_example) + # remove_non_tensors(input_example) engine_bytes = None add_casts_around_norms(model) @@ -514,11 +525,12 @@ def get_torch_trt_input(input_shape, dynamic_batchsize): self.logger.info( f"Exporting to {onnx_path}:\ninputs={list(input_example.keys())}\toutput_names={self.output_names}\n\texport args: {export_args}" ) + input_names = list(unroll_input(self.input_names, input_example).keys()) convert_to_onnx( model, input_example, filename=onnx_path, - input_names=self.input_names, + input_names=input_names, output_names=self.output_names, **export_args, ) From a004bc5ebfe22a6ef75c91277c9bc298bc8c14d3 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 15 Oct 2024 17:28:53 -0700 Subject: [PATCH 74/85] Working dynamic batch with sequences Signed-off-by: Boris Fomitchev --- monai/networks/trt_compiler.py | 47 ++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index b0802a5341..51fc02d812 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -437,16 +437,11 @@ def _onnx_to_trt(self, onnx_path): """ profiles = [] - if self.profiles: - for input_profile in self.profiles: - if isinstance(input_profile, Profile): - profiles.append(input_profile) - else: - p = Profile() - for name, dims in input_profile.items(): - assert len(dims) == 3 - p.add(name, min=dims[0], opt=dims[1], max=dims[2]) - profiles.append(p) + for profile in self.profiles: + p = Profile() + for id, val in profile.items(): + p.add(id, min=val[0], opt=val[1], max=val[2]) + profiles.append(p) build_args = self.build_args.copy() build_args["tf32"] = self.precision != "fp32" @@ -508,29 +503,37 @@ def get_torch_trt_input(input_shape, dynamic_batchsize): raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!") if len(dbs) != 3: raise ValueError("dynamic_batchsize has to have len ==3 ") - profiles = {} + profile = {} for id, val in input_example.items(): - sh = val.shape[1:] - profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] - self.profiles = [profiles] - - dynamic_axes = get_dynamic_axes(self.profiles) - - if len(self.profiles) > 0: - export_args.update({"dynamic_axes": dynamic_axes}) + def add_profile(id, val): + sh = val.shape + if len(sh) > 0: + sh = sh[1:] + profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] + if isinstance(val, list | tuple): + for i in range(len(val)): + add_profile(f"{id}_{i}", val[i]) + elif isinstance(val, torch.Tensor): + add_profile(id, val) + self.profiles = [profile] + + self.dynamic_axes = get_dynamic_axes(self.profiles) + + if len(self.dynamic_axes) > 0: + export_args.update({"dynamic_axes": self.dynamic_axes}) # Use temporary directory for easy cleanup in case of external weights with tempfile.TemporaryDirectory() as tmpdir: + unrolled_input = unroll_input(self.input_names, input_example) onnx_path = str(Path(tmpdir) / "model.onnx") self.logger.info( - f"Exporting to {onnx_path}:\ninputs={list(input_example.keys())}\toutput_names={self.output_names}\n\texport args: {export_args}" + f"Exporting to {onnx_path}:\nunrolled_inputs={list(unrolled_input.keys())}\noutput_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}" ) - input_names = list(unroll_input(self.input_names, input_example).keys()) convert_to_onnx( model, input_example, filename=onnx_path, - input_names=input_names, + input_names=list(unrolled_input.keys()), output_names=self.output_names, **export_args, ) From adf9bc9bf20ced8b0b2121bb8e88597d0c52becd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Oct 2024 01:22:03 +0000 Subject: [PATCH 75/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/spatialattention.py | 1 - monai/networks/nets/vista3d.py | 4 ++-- monai/networks/trt_compiler.py | 8 ++++---- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 903d11942d..60a89a7840 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -17,7 +17,6 @@ import torch.nn as nn from monai.networks.blocks import SABlock -from monai.utils import optional_import class SpatialAttentionBlock(nn.Module): diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 12f29367da..7778797af6 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -439,7 +439,7 @@ def forward( input_images, with_point=point_coords is not None, with_label=class_vector is not None ) torch.cuda.synchronize() - print("Encoder time : ", time.time()-t0, input_images.shape) + print("Encoder time : ", time.time()-t0, input_images.shape) # release memory input_images = None # type: ignore @@ -645,7 +645,7 @@ def forward(self, src: torch.Tensor, class_vector: torch.Tensor): # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension. masks_embedding = class_embedding.squeeze() @ src.view(b, c, h * w * d) masks_embedding = masks_embedding.view(b, -1, h, w, d).transpose(0, 1) - + return masks_embedding, class_embedding return torch.cat(masks, 1), class_embedding diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index 51fc02d812..9cda55edde 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -454,7 +454,7 @@ def _onnx_to_trt(self, onnx_path): network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) - def _build_and_save(self, model, input_example): + def _build_and_save(self, model, input_example): """ If TRT engine is not ready, exports model to ONNX, builds TRT engine and saves serialized TRT engine to the disk. @@ -506,7 +506,7 @@ def get_torch_trt_input(input_shape, dynamic_batchsize): profile = {} for id, val in input_example.items(): def add_profile(id, val): - sh = val.shape + sh = val.shape if len(sh) > 0: sh = sh[1:] profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] @@ -516,9 +516,9 @@ def add_profile(id, val): elif isinstance(val, torch.Tensor): add_profile(id, val) self.profiles = [profile] - + self.dynamic_axes = get_dynamic_axes(self.profiles) - + if len(self.dynamic_axes) > 0: export_args.update({"dynamic_axes": self.dynamic_axes}) From 4002d9dc5e179aadd3a4e454885203b7c7232013 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 15 Oct 2024 18:49:37 -0700 Subject: [PATCH 76/85] Merge fixed and style Signed-off-by: Boris Fomitchev --- monai/networks/nets/vista3d.py | 9 ++++----- monai/networks/trt_compiler.py | 18 +++++++++++------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 12f29367da..790b3df141 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -12,8 +12,9 @@ from __future__ import annotations import math -from typing import Any, Callable, Optional, Sequence, Tuple import time +from typing import Any, Callable, Optional, Sequence, Tuple + import numpy as np import torch import torch.nn.functional as F @@ -439,7 +440,7 @@ def forward( input_images, with_point=point_coords is not None, with_label=class_vector is not None ) torch.cuda.synchronize() - print("Encoder time : ", time.time()-t0, input_images.shape) + print("Encoder time : ", time.time() - t0, input_images.shape) # release memory input_images = None # type: ignore @@ -645,10 +646,8 @@ def forward(self, src: torch.Tensor, class_vector: torch.Tensor): # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension. masks_embedding = class_embedding.squeeze() @ src.view(b, c, h * w * d) masks_embedding = masks_embedding.view(b, -1, h, w, d).transpose(0, 1) - - return masks_embedding, class_embedding - return torch.cat(masks, 1), class_embedding + return masks_embedding, class_embedding class TwoWayTransformer(nn.Module): diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index 51fc02d812..60b8c41827 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -241,9 +241,10 @@ def remove_non_tensors(input_example, remove_constants=True): input_example.pop(key) return non_tensors + def unroll_input(input_names, input_example): # Simulate list/tuple unrolling during ONNX export - unrolled_input={} + unrolled_input = {} for name in input_names: val = input_example[name] if val is not None: @@ -254,6 +255,7 @@ def unroll_input(input_names, input_example): unrolled_input[name] = val return unrolled_input + class TrtCompiler: """ This class implements: @@ -454,7 +456,7 @@ def _onnx_to_trt(self, onnx_path): network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) - def _build_and_save(self, model, input_example): + def _build_and_save(self, model, input_example): """ If TRT engine is not ready, exports model to ONNX, builds TRT engine and saves serialized TRT engine to the disk. @@ -505,20 +507,22 @@ def get_torch_trt_input(input_shape, dynamic_batchsize): raise ValueError("dynamic_batchsize has to have len ==3 ") profile = {} for id, val in input_example.items(): + def add_profile(id, val): - sh = val.shape + sh = val.shape if len(sh) > 0: - sh = sh[1:] - profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] + sh = sh[1:] + profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] + if isinstance(val, list | tuple): for i in range(len(val)): add_profile(f"{id}_{i}", val[i]) elif isinstance(val, torch.Tensor): add_profile(id, val) self.profiles = [profile] - + self.dynamic_axes = get_dynamic_axes(self.profiles) - + if len(self.dynamic_axes) > 0: export_args.update({"dynamic_axes": self.dynamic_axes}) From 43ea6a0f6246627247b948eba88d527bb7753577 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 21 Oct 2024 15:14:52 -0700 Subject: [PATCH 77/85] Added output_lists option Signed-off-by: Boris Fomitchev --- .../maisi/networks/controlnet_maisi.py | 2 +- monai/networks/trt_compiler.py | 58 ++++++++++++++++++- 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 384f7970e4..269086d971 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -119,7 +119,7 @@ def forward( down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] mid_block_res_sample *= conditioning_scale - return [*down_block_res_samples, mid_block_res_sample] + return down_block_res_samples, mid_block_res_sample def _prepare_time_and_class_embedding(self, x, timesteps, class_labels): # 1. time diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index 60b8c41827..146bf8203d 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -18,7 +18,7 @@ from collections import OrderedDict from pathlib import Path from types import MethodType -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Tuple, Union import torch @@ -256,6 +256,55 @@ def unroll_input(input_names, input_example): return unrolled_input +def parse_groups( + ret: List[torch.Tensor], output_lists: List[int] +) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]: + """ + Implements parsing of 'output_lists' arg of trt_compile(). + Args: + ret: ungrouped list of Tensors + + output_lists=[[group_n] | [], ...] + [] or group_n == 0 : next output from ret is a scalar + group_n > 0 : next output from ret is a list of group_n length + group_n == -1: next output is a dynamic list. This entry can be at any + position in output_lists, but can appear only once. + Returns: + Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists + + """ + groups = [] + cur = 0 + for l in range(len(output_lists)): + gl = output_lists[l] + assert len(gl) == 0 or len(gl) == 1 + if len(gl) == 0 or gl[0] == 0: + groups.append(ret[cur]) + cur = cur + 1 + elif gl[0] > 0: + groups.append(ret[cur : cur + gl[0]]) + cur = cur + gl[0] + elif gl[0] == -1: + rev_groups = [] + rcur = len(ret) + for rl in range(len(output_lists) - 1, l, -1): + rgl = output_lists[rl] + assert len(rgl) == 0 or len(rgl) == 1 + if len(rgl) == 0 or rgl[0] == 0: + rcur = rcur - 1 + rev_groups.append(ret[rcur]) + elif rgl[0] > 0: + rcur = rcur - rgl[0] + rev_groups.append(ret[rcur : rcur + rgl[0]]) + else: + raise ValueError("Two -1 lists in output") + groups.append(ret[cur:rcur]) + rev_groups.reverse() + groups.extend(rev_groups) + break + return tuple(groups) + + class TrtCompiler: """ This class implements: @@ -272,6 +321,7 @@ def __init__( method="onnx", input_names=None, output_names=None, + output_lists=None, export_args=None, build_args=None, input_profiles=None, @@ -295,6 +345,7 @@ def __init__( 'torch_trt' may not work for some nets. Also AMP must be turned off for it to work. input_names: Optional list of input names. If None, will be read from the function signature. output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. + output_lists: Optional list of output lists. export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. build_args: Optional args to pass to TRT builder. See polygraphy.Config for details. input_profiles: Optional list of profiles for TRT builder and ONNX export. @@ -319,6 +370,7 @@ def __init__( self.method = method self.return_dict = output_names is not None self.output_names = output_names or [] + self.output_lists = output_lists or [] self.profiles = input_profiles or [] self.dynamic_batchsize = dynamic_batchsize self.export_args = export_args or {} @@ -423,7 +475,9 @@ def forward(self, model, argv, kwargs): # if output_names is not None, return dictionary if not self.return_dict: ret = list(ret.values()) - if len(ret) == 1: + if self.output_lists: + ret = parse_groups(ret, self.output_lists) + elif len(ret) == 1: ret = ret[0] return ret except Exception as e: From c1791f6118d5427c9c572e1deadb52d518067b73 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 30 Oct 2024 16:11:11 -0700 Subject: [PATCH 78/85] Bugfix for multiple initialization Signed-off-by: Boris Fomitchev --- monai/networks/trt_compiler.py | 55 ++++++++++++---------------------- 1 file changed, 19 insertions(+), 36 deletions(-) diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index 146bf8203d..19c4281036 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -222,26 +222,6 @@ def infer(self, stream, use_cuda_graph=False): return self.tensors -def remove_non_tensors(input_example, remove_constants=True): - # - # TODO : see if we can instantiate wrappers to handle non-default non-tensors - # - non_tensors = {} - for k, v in input_example.items(): - if v is None: - non_tensors[k] = v - elif not torch.is_tensor(v): - if remove_constants: - non_tensors[k] = v - else: - input_example[k] = torch.tensor(v) - - for key in non_tensors.keys(): - # print(f"Removing non-tensor input: {key})") - input_example.pop(key) - return non_tensors - - def unroll_input(input_names, input_example): # Simulate list/tuple unrolling during ONNX export unrolled_input = {} @@ -261,14 +241,17 @@ def parse_groups( ) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]: """ Implements parsing of 'output_lists' arg of trt_compile(). - Args: - ret: ungrouped list of Tensors - output_lists=[[group_n] | [], ...] - [] or group_n == 0 : next output from ret is a scalar - group_n > 0 : next output from ret is a list of group_n length - group_n == -1: next output is a dynamic list. This entry can be at any - position in output_lists, but can appear only once. + Args: + ret: plain list of Tensors + + output_lists: list of output group sizes: to form some Lists/Tuples out of 'ret' List, this will be a list + of group dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list. + Format: [[group_n] | [], ...] + [] or group_n == 0 : next output from ret is a scalar + group_n > 0 : next output from ret is a list of group_n length + group_n == -1: next output is a dynamic list. This entry can be at any + position in output_lists, but can appear only once. Returns: Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists @@ -345,7 +328,8 @@ def __init__( 'torch_trt' may not work for some nets. Also AMP must be turned off for it to work. input_names: Optional list of input names. If None, will be read from the function signature. output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. - output_lists: Optional list of output lists. + output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list + of their dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list. export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. build_args: Optional args to pass to TRT builder. See polygraphy.Config for details. input_profiles: Optional list of profiles for TRT builder and ONNX export. @@ -467,6 +451,7 @@ def forward(self, model, argv, kwargs): with lock_sm: device = torch.cuda.current_device() stream = torch.cuda.Stream(device=device) + breakpoint() self.engine.set_inputs(unroll_input(self.input_names, args), stream.cuda_stream) self.engine.allocate_buffers(device=device) # Need this to synchronize with Torch stream @@ -522,9 +507,6 @@ def _build_and_save(self, model, input_example): return export_args = self.export_args - - # remove_non_tensors(input_example) - engine_bytes = None add_casts_around_norms(model) @@ -547,7 +529,7 @@ def get_torch_trt_input(input_shape, dynamic_batchsize): engine_bytes = torch_tensorrt.convert_method_to_trt_engine( ir_model, "forward", - inputs=tt_inputs, + arg_inputs=tt_inputs, ir="torchscript", enabled_precisions=enabled_precisions, **export_args, @@ -585,7 +567,8 @@ def add_profile(id, val): unrolled_input = unroll_input(self.input_names, input_example) onnx_path = str(Path(tmpdir) / "model.onnx") self.logger.info( - f"Exporting to {onnx_path}:\nunrolled_inputs={list(unrolled_input.keys())}\noutput_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}" + f"Exporting to {onnx_path}:\nunrolled_inputs={list(unrolled_input.keys())}\n" + + f"output_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}" ) convert_to_onnx( model, @@ -655,9 +638,9 @@ def trt_compile( def wrap(model, path): if not hasattr(model, "_trt_compiler"): model.orig_forward = model.forward - wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) - model._trt_compiler = wrapper - model.forward = MethodType(trt_forward, model) + wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) + model._trt_compiler = wrapper + model.forward = MethodType(trt_forward, model) def find_sub(parent, submodule): idx = submodule.find(".") From 14912d9306b2a99c4bedb4559e304bd0710dfc04 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 30 Oct 2024 23:40:21 +0000 Subject: [PATCH 79/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/trt_compiler.py | 4 ++-- tests/test_trt_compile.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index db9c0bf780..fb9b9dcdd1 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -328,8 +328,8 @@ def __init__( 'torch_trt' may not work for some nets. Also AMP must be turned off for it to work. input_names: Optional list of input names. If None, will be read from the function signature. output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. - output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list - of their dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list. + output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list + of their dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list. export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. build_args: Optional args to pass to TRT builder. See polygraphy.Config for details. input_profiles: Optional list of profiles for TRT builder and ONNX export. diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 3cb9af621d..8007148888 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -21,7 +21,7 @@ from monai.networks import trt_compile from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 from monai.utils import min_version, optional_import -from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows +from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version) torch_tensorrt, torch_trt_imported = optional_import("torch_tensorrt") From 30b8bcfee7ec6d0fdfc03ec0651ee02fee0fa62b Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 30 Oct 2024 23:29:58 -0700 Subject: [PATCH 80/85] Adding Torch patch Signed-off-by: Boris Fomitchev --- Dockerfile | 6 ++++++ monai/torch.patch | 10 ++++++++++ 2 files changed, 16 insertions(+) create mode 100644 monai/torch.patch diff --git a/Dockerfile b/Dockerfile index 5fcfcf274d..d538fd3145 100644 --- a/Dockerfile +++ b/Dockerfile @@ -41,6 +41,10 @@ RUN cp /tmp/requirements.txt /tmp/req.bak \ COPY LICENSE CHANGELOG.md CODE_OF_CONDUCT.md CONTRIBUTING.md README.md versioneer.py setup.py setup.cfg runtests.sh MANIFEST.in ./ COPY tests ./tests COPY monai ./monai + +# TODO: remove this line and torch.patch for 24.11 +RUN patch -R -d /usr/local/lib/python3.10/dist-packages/torch/onnx/ < ./monai/torch.patch + RUN BUILD_MONAI=1 FORCE_CUDA=1 python setup.py develop \ && rm -rf build __pycache__ @@ -57,4 +61,6 @@ RUN apt-get update \ # append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations ENV PATH=${PATH}:/opt/tools ENV POLYGRAPHY_AUTOINSTALL_DEPS=1 + + WORKDIR /opt/monai diff --git a/monai/torch.patch b/monai/torch.patch new file mode 100644 index 0000000000..607c585b70 --- /dev/null +++ b/monai/torch.patch @@ -0,0 +1,10 @@ +--- /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.py 2024-10-31 06:09:21.139938791 +0000 ++++ /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.bak 2024-10-31 06:01:50.207462739 +0000 +@@ -150,6 +150,7 @@ + ), "is_causal and attn_mask cannot be set at the same time" + assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + ++ scale = symbolic_helper._maybe_get_const(scale, "f") + if symbolic_helper._is_none(scale): + scale = _attention_scale(g, query) + From 8baaa74e28b0b5c6d0383ff7da5ff7fe9e4b9b85 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 06:32:34 +0000 Subject: [PATCH 81/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/torch.patch | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/torch.patch b/monai/torch.patch index 607c585b70..e53980968b 100644 --- a/monai/torch.patch +++ b/monai/torch.patch @@ -3,8 +3,7 @@ @@ -150,6 +150,7 @@ ), "is_causal and attn_mask cannot be set at the same time" assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" - + + scale = symbolic_helper._maybe_get_const(scale, "f") if symbolic_helper._is_none(scale): scale = _attention_scale(g, query) - From 214def9f7a5d12b8aca15f6380c4175d6a416f68 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 31 Oct 2024 17:25:16 -0700 Subject: [PATCH 82/85] Fixing torch_trt compile and test case Signed-off-by: Boris Fomitchev --- monai/networks/trt_compiler.py | 10 ++-------- tests/test_trt_compile.py | 31 ++----------------------------- 2 files changed, 4 insertions(+), 37 deletions(-) diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index fb9b9dcdd1..c5cc43df58 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -23,7 +23,7 @@ import torch from monai.apps.utils import get_logger -from monai.networks.utils import add_casts_around_norms, convert_to_onnx, convert_to_torchscript, get_profile_shapes +from monai.networks.utils import add_casts_around_norms, convert_to_onnx, get_profile_shapes from monai.utils.module import optional_import polygraphy, polygraphy_imported = optional_import("polygraphy") @@ -517,7 +517,6 @@ def _build_and_save(self, model, input_example): elif self.precision == "bf16": enabled_precisions.append(torch.bfloat16) inputs = list(input_example.values()) - ir_model = convert_to_torchscript(model, inputs=inputs, use_trace=True) def get_torch_trt_input(input_shape, dynamic_batchsize): min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) @@ -527,12 +526,7 @@ def get_torch_trt_input(input_shape, dynamic_batchsize): tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs] engine_bytes = torch_tensorrt.convert_method_to_trt_engine( - ir_model, - "forward", - arg_inputs=tt_inputs, - ir="torchscript", - enabled_precisions=enabled_precisions, - **export_args, + model, "forward", arg_inputs=tt_inputs, enabled_precisions=enabled_precisions, **export_args ) else: dbs = self.dynamic_batchsize diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 8007148888..2a15d5e697 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -68,33 +68,6 @@ def test_handler(self): net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device="cuda")) self.assertIsNotNone(net1._trt_compiler.engine) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_unet_value(self, precision): - model = UNet( - spatial_dims=3, - in_channels=1, - out_channels=2, - channels=(2, 2, 4, 8, 4), - strides=(2, 2, 2, 2), - num_res_units=2, - norm="batch", - ).cuda() - with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: - model.eval() - input_example = torch.randn(2, 1, 96, 96, 96).cuda() - output_example = model(input_example) - args: dict = {"builder_optimization_level": 1} - trt_compile( - model, - f"{tmpdir}/test_unet_trt_compile", - args={"precision": precision, "build_args": args, "dynamic_batchsize": [1, 4, 8]}, - ) - self.assertIsNone(model._trt_compiler.engine) - trt_output = model(input_example) - # Check that lazy TRT build succeeded - self.assertIsNotNone(model._trt_compiler.engine) - torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @unittest.skipUnless(has_sam, "Requires SAM installation") def test_cell_sam_wrapper_value(self, precision): @@ -106,7 +79,7 @@ def test_cell_sam_wrapper_value(self, precision): trt_compile( model, f"{tmpdir}/test_cell_sam_wrapper_trt_compile", - args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, + args={"precision": precision}, ) self.assertIsNone(model._trt_compiler.engine) trt_output = model(input_example) @@ -124,7 +97,7 @@ def test_vista3d(self, precision): model = trt_compile( model, f"{tmpdir}/test_vista3d_trt_compile", - args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, + args={"precision": precision, "dynamic_batchsize": [1, 2, 4]}, submodule=["image_encoder.encoder", "class_head"], ) self.assertIsNotNone(model.image_encoder.encoder._trt_compiler) From 2452c975e4e4f3e7d9177e57c1009cf155daf04a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Nov 2024 00:25:55 +0000 Subject: [PATCH 83/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_trt_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 2a15d5e697..2bd2a5cd89 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -19,7 +19,7 @@ from monai.handlers import TrtHandler from monai.networks import trt_compile -from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 +from monai.networks.nets import cell_sam_wrapper, vista3d132 from monai.utils import min_version, optional_import from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows From d07a57f5ed8df4de9d26c24726dbf7c8bdfe88c2 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 1 Nov 2024 18:12:24 -0700 Subject: [PATCH 84/85] Added rename table for TRT engine, test for output lists Signed-off-by: Boris Fomitchev --- monai/networks/trt_compiler.py | 25 +++++++++++++++++------ monai/networks/utils.py | 2 +- tests/test_trt_compile.py | 37 +++++++++++++++++++++++++++++----- 3 files changed, 52 insertions(+), 12 deletions(-) diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index c5cc43df58..3eddd85664 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -167,7 +167,7 @@ def set_inputs(self, feed_dict, stream): def try_set_inputs(): for binding in self.input_names: - t = feed_dict[binding] + t = feed_dict.get(self.input_table[binding], None) if t is not None: t = t.contiguous() shape = t.shape @@ -222,6 +222,10 @@ def infer(self, stream, use_cuda_graph=False): return self.tensors +def make_tensor(d): + return d if isinstance(d, torch.Tensor) else torch.tensor(d).cuda() + + def unroll_input(input_names, input_example): # Simulate list/tuple unrolling during ONNX export unrolled_input = {} @@ -230,9 +234,9 @@ def unroll_input(input_names, input_example): if val is not None: if isinstance(val, list | tuple): for i in range(len(val)): - unrolled_input[f"{name}_{i}"] = val[i] + unrolled_input[f"{name}_{i}"] = make_tensor(val[i]) else: - unrolled_input[name] = val + unrolled_input[name] = make_tensor(val) return unrolled_input @@ -375,8 +379,8 @@ def __init__( for i in range(len(self.argspec.defaults)): d = self.argspec.defaults[-i - 1] if d is not None: - d = torch.tensor(d).cuda() - self.defaults[self.argspec.args[-i - 1]] = d + d = make_tensor(d) + self.defaults[self.argspec.args[-i - 1]] = d self.input_names = input_names self.old_forward = model.forward @@ -398,7 +402,16 @@ def _load_engine(self): """ try: self.engine = TRTEngine(self.plan_path, self.logger) - self.logger.info(f"Engine loaded, inputs:{self.engine.input_names}") + # Make sure we have names correct + input_table = {} + for name in self.engine.input_names: + if name.startswith("__") and name not in self.input_names: + orig_name = name[2:] + else: + orig_name = name + input_table[name] = orig_name + self.engine.input_table = input_table + self.logger.info(f"Engine loaded, inputs:{self.engine.input_table}") except Exception as e: self.logger.info(f"Exception while loading the engine:\n{e}") diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 7b1775923a..05627f9c00 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -703,7 +703,7 @@ def convert_to_onnx( onnx_inputs, f=f, input_names=input_names, - output_names=output_names, + output_names=output_names or None, dynamic_axes=dynamic_axes, opset_version=opset_version, do_constant_folding=do_constant_folding, diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 2a15d5e697..9847326388 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -13,6 +13,7 @@ import tempfile import unittest +from typing import List import torch from parameterized import parameterized @@ -32,6 +33,19 @@ TEST_CASE_2 = ["fp16"] +class ListAdd(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: List[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: float = float(0.1)): + y1 = y.clone() + x1 = x.copy() + z1 = z + y + for xi in x: + y1 = y1 + xi + bs + return x1, [y1, z1], y1 + z1 + + @skip_if_windows @skip_if_no_cuda @skip_if_quick @@ -68,6 +82,23 @@ def test_handler(self): net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device="cuda")) self.assertIsNotNone(net1._trt_compiler.engine) + def test_lists(self): + model = ListAdd().cuda() + + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + args = {"output_lists": [[-1], [2], []], "export_args": {"dynamo": False, "verbose": True}} + x = torch.randn(1, 16).to("cuda") + y = torch.randn(1, 16).to("cuda") + z = torch.randn(1, 16).to("cuda") + input_example = ([x, y, z], y.clone(), z.clone()) + output_example = model(*input_example) + trt_compile(model, f"{tmpdir}/test_lists", args=args) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(*input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @unittest.skipUnless(has_sam, "Requires SAM installation") def test_cell_sam_wrapper_value(self, precision): @@ -76,11 +107,7 @@ def test_cell_sam_wrapper_value(self, precision): model.eval() input_example = torch.randn(1, 3, 128, 128).to("cuda") output_example = model(input_example) - trt_compile( - model, - f"{tmpdir}/test_cell_sam_wrapper_trt_compile", - args={"precision": precision}, - ) + trt_compile(model, f"{tmpdir}/test_cell_sam_wrapper_trt_compile", args={"precision": precision}) self.assertIsNone(model._trt_compiler.engine) trt_output = model(input_example) # Check that lazy TRT build succeeded From 6eeed4b25735b8e52fc8fa142aed54155b6948a4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 2 Nov 2024 01:13:10 +0000 Subject: [PATCH 85/85] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_trt_compile.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 98036a3c61..8231387601 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -13,7 +13,6 @@ import tempfile import unittest -from typing import List import torch from parameterized import parameterized @@ -37,7 +36,7 @@ class ListAdd(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, x: List[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: float = float(0.1)): + def forward(self, x: list[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: float = float(0.1)): y1 = y.clone() x1 = x.copy() z1 = z + y