diff --git a/Dockerfile b/Dockerfile index e45932c6bb..d538fd3145 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.08-py3 +ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.10-py3 FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" @@ -41,6 +41,10 @@ RUN cp /tmp/requirements.txt /tmp/req.bak \ COPY LICENSE CHANGELOG.md CODE_OF_CONDUCT.md CONTRIBUTING.md README.md versioneer.py setup.py setup.cfg runtests.sh MANIFEST.in ./ COPY tests ./tests COPY monai ./monai + +# TODO: remove this line and torch.patch for 24.11 +RUN patch -R -d /usr/local/lib/python3.10/dist-packages/torch/onnx/ < ./monai/torch.patch + RUN BUILD_MONAI=1 FORCE_CUDA=1 python setup.py develop \ && rm -rf build __pycache__ @@ -57,4 +61,6 @@ RUN apt-get update \ # append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations ENV PATH=${PATH}:/opt/tools ENV POLYGRAPHY_AUTOINSTALL_DEPS=1 + + WORKDIR /opt/monai diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 6313b7812d..6ecb664b85 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -641,7 +641,6 @@ def forward(self, src: torch.Tensor, class_vector: torch.Tensor): # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension. masks_embedding = class_embedding.squeeze() @ src.view(b, c, h * w * d) masks_embedding = masks_embedding.view(b, -1, h, w, d).transpose(0, 1) - return masks_embedding, class_embedding diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index a360f63dbd..3eddd85664 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -18,12 +18,12 @@ from collections import OrderedDict from pathlib import Path from types import MethodType -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Tuple, Union import torch from monai.apps.utils import get_logger -from monai.networks.utils import add_casts_around_norms, convert_to_onnx, convert_to_torchscript, get_profile_shapes +from monai.networks.utils import add_casts_around_norms, convert_to_onnx, get_profile_shapes from monai.utils.module import optional_import polygraphy, polygraphy_imported = optional_import("polygraphy") @@ -134,6 +134,9 @@ def __init__(self, plan_path, logger=None): self.output_names.append(binding) dtype = dtype_dict[self.engine.get_tensor_dtype(binding)] self.dtypes.append(dtype) + self.logger.info( + f"Loaded TensorRT engine: {self.plan_path}.\nInputs: {self.input_names}\nOutputs: {self.output_names}" + ) def allocate_buffers(self, device): """ @@ -163,7 +166,8 @@ def set_inputs(self, feed_dict, stream): last_profile = self.cur_profile def try_set_inputs(): - for binding, t in feed_dict.items(): + for binding in self.input_names: + t = feed_dict.get(self.input_table[binding], None) if t is not None: t = t.contiguous() shape = t.shape @@ -180,7 +184,8 @@ def try_set_inputs(): raise self.cur_profile = next_profile ctx.set_optimization_profile_async(self.cur_profile, stream) - + except Exception: + raise left = ctx.infer_shapes() assert len(left) == 0 @@ -217,6 +222,76 @@ def infer(self, stream, use_cuda_graph=False): return self.tensors +def make_tensor(d): + return d if isinstance(d, torch.Tensor) else torch.tensor(d).cuda() + + +def unroll_input(input_names, input_example): + # Simulate list/tuple unrolling during ONNX export + unrolled_input = {} + for name in input_names: + val = input_example[name] + if val is not None: + if isinstance(val, list | tuple): + for i in range(len(val)): + unrolled_input[f"{name}_{i}"] = make_tensor(val[i]) + else: + unrolled_input[name] = make_tensor(val) + return unrolled_input + + +def parse_groups( + ret: List[torch.Tensor], output_lists: List[int] +) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]: + """ + Implements parsing of 'output_lists' arg of trt_compile(). + + Args: + ret: plain list of Tensors + + output_lists: list of output group sizes: to form some Lists/Tuples out of 'ret' List, this will be a list + of group dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list. + Format: [[group_n] | [], ...] + [] or group_n == 0 : next output from ret is a scalar + group_n > 0 : next output from ret is a list of group_n length + group_n == -1: next output is a dynamic list. This entry can be at any + position in output_lists, but can appear only once. + Returns: + Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists + + """ + groups = [] + cur = 0 + for l in range(len(output_lists)): + gl = output_lists[l] + assert len(gl) == 0 or len(gl) == 1 + if len(gl) == 0 or gl[0] == 0: + groups.append(ret[cur]) + cur = cur + 1 + elif gl[0] > 0: + groups.append(ret[cur : cur + gl[0]]) + cur = cur + gl[0] + elif gl[0] == -1: + rev_groups = [] + rcur = len(ret) + for rl in range(len(output_lists) - 1, l, -1): + rgl = output_lists[rl] + assert len(rgl) == 0 or len(rgl) == 1 + if len(rgl) == 0 or rgl[0] == 0: + rcur = rcur - 1 + rev_groups.append(ret[rcur]) + elif rgl[0] > 0: + rcur = rcur - rgl[0] + rev_groups.append(ret[rcur : rcur + rgl[0]]) + else: + raise ValueError("Two -1 lists in output") + groups.append(ret[cur:rcur]) + rev_groups.reverse() + groups.extend(rev_groups) + break + return tuple(groups) + + class TrtCompiler: """ This class implements: @@ -233,6 +308,7 @@ def __init__( method="onnx", input_names=None, output_names=None, + output_lists=None, export_args=None, build_args=None, input_profiles=None, @@ -240,6 +316,7 @@ def __init__( use_cuda_graph=False, timestamp=None, fallback=False, + forward_override=None, logger=None, ): """ @@ -255,6 +332,8 @@ def __init__( 'torch_trt' may not work for some nets. Also AMP must be turned off for it to work. input_names: Optional list of input names. If None, will be read from the function signature. output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. + output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list + of their dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list. export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. build_args: Optional args to pass to TRT builder. See polygraphy.Config for details. input_profiles: Optional list of profiles for TRT builder and ONNX export. @@ -279,6 +358,7 @@ def __init__( self.method = method self.return_dict = output_names is not None self.output_names = output_names or [] + self.output_lists = output_lists or [] self.profiles = input_profiles or [] self.dynamic_batchsize = dynamic_batchsize self.export_args = export_args or {} @@ -289,11 +369,19 @@ def __init__( self.disabled = False self.logger = logger or get_logger("monai.networks.trt_compiler") + self.argspec = inspect.getfullargspec(model.forward) # Normally we read input_names from forward() but can be overridden if input_names is None: - argspec = inspect.getfullargspec(model.forward) - input_names = argspec.args[1:] + input_names = self.argspec.args[1:] + self.defaults = {} + if self.argspec.defaults is not None: + for i in range(len(self.argspec.defaults)): + d = self.argspec.defaults[-i - 1] + if d is not None: + d = make_tensor(d) + self.defaults[self.argspec.args[-i - 1]] = d + self.input_names = input_names self.old_forward = model.forward @@ -314,9 +402,18 @@ def _load_engine(self): """ try: self.engine = TRTEngine(self.plan_path, self.logger) - self.input_names = self.engine.input_names + # Make sure we have names correct + input_table = {} + for name in self.engine.input_names: + if name.startswith("__") and name not in self.input_names: + orig_name = name[2:] + else: + orig_name = name + input_table[name] = orig_name + self.engine.input_table = input_table + self.logger.info(f"Engine loaded, inputs:{self.engine.input_table}") except Exception as e: - self.logger.debug(f"Exception while loading the engine:\n{e}") + self.logger.info(f"Exception while loading the engine:\n{e}") def forward(self, model, argv, kwargs): """ @@ -329,6 +426,11 @@ def forward(self, model, argv, kwargs): Returns: Passing through wrapped module's forward() return value(s) """ + args = self.defaults + args.update(kwargs) + if len(argv) > 0: + args.update(self._inputs_to_dict(argv)) + if self.engine is None and not self.disabled: # Restore original forward for export new_forward = model.forward @@ -336,11 +438,10 @@ def forward(self, model, argv, kwargs): try: self._load_engine() if self.engine is None: - build_args = kwargs.copy() - if len(argv) > 0: - build_args.update(self._inputs_to_dict(argv)) - self._build_and_save(model, build_args) - # This will reassign input_names from the engine + build_args = args.copy() + with torch.no_grad(): + self._build_and_save(model, build_args) + # This will reassign input_names from the engine self._load_engine() assert self.engine is not None except Exception as e: @@ -355,19 +456,16 @@ def forward(self, model, argv, kwargs): del param # Call empty_cache to release GPU memory torch.cuda.empty_cache() + # restore TRT hook model.forward = new_forward # Run the engine try: - if len(argv) > 0: - kwargs.update(self._inputs_to_dict(argv)) - argv = () - if self.engine is not None: # forward_trt is not thread safe as we do not use per-thread execution contexts with lock_sm: device = torch.cuda.current_device() stream = torch.cuda.Stream(device=device) - self.engine.set_inputs(kwargs, stream.cuda_stream) + self.engine.set_inputs(unroll_input(self.input_names, args), stream.cuda_stream) self.engine.allocate_buffers(device=device) # Need this to synchronize with Torch stream stream.wait_stream(torch.cuda.current_stream()) @@ -375,11 +473,13 @@ def forward(self, model, argv, kwargs): # if output_names is not None, return dictionary if not self.return_dict: ret = list(ret.values()) - if len(ret) == 1: + if self.output_lists: + ret = parse_groups(ret, self.output_lists) + elif len(ret) == 1: ret = ret[0] return ret except Exception as e: - if model is not None: + if self.fallback: self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...") else: raise e @@ -391,16 +491,11 @@ def _onnx_to_trt(self, onnx_path): """ profiles = [] - if self.profiles: - for input_profile in self.profiles: - if isinstance(input_profile, Profile): - profiles.append(input_profile) - else: - p = Profile() - for name, dims in input_profile.items(): - assert len(dims) == 3 - p.add(name, min=dims[0], opt=dims[1], max=dims[2]) - profiles.append(p) + for profile in self.profiles: + p = Profile() + for id, val in profile.items(): + p.add(id, min=val[0], opt=val[1], max=val[2]) + profiles.append(p) build_args = self.build_args.copy() build_args["tf32"] = self.precision != "fp32" @@ -425,7 +520,7 @@ def _build_and_save(self, model, input_example): return export_args = self.export_args - + engine_bytes = None add_casts_around_norms(model) if self.method == "torch_trt": @@ -435,7 +530,6 @@ def _build_and_save(self, model, input_example): elif self.precision == "bf16": enabled_precisions.append(torch.bfloat16) inputs = list(input_example.values()) - ir_model = convert_to_torchscript(model, inputs=inputs, use_trace=True) def get_torch_trt_input(input_shape, dynamic_batchsize): min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) @@ -445,12 +539,7 @@ def get_torch_trt_input(input_shape, dynamic_batchsize): tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs] engine_bytes = torch_tensorrt.convert_method_to_trt_engine( - ir_model, - "forward", - inputs=tt_inputs, - ir="torchscript", - enabled_precisions=enabled_precisions, - **export_args, + model, "forward", arg_inputs=tt_inputs, enabled_precisions=enabled_precisions, **export_args ) else: dbs = self.dynamic_batchsize @@ -459,33 +548,47 @@ def get_torch_trt_input(input_shape, dynamic_batchsize): raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!") if len(dbs) != 3: raise ValueError("dynamic_batchsize has to have len ==3 ") - profiles = {} + profile = {} for id, val in input_example.items(): - sh = val.shape[1:] - profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] - self.profiles = [profiles] - if len(self.profiles) > 0: - export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) + def add_profile(id, val): + sh = val.shape + if len(sh) > 0: + sh = sh[1:] + profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] + + if isinstance(val, list | tuple): + for i in range(len(val)): + add_profile(f"{id}_{i}", val[i]) + elif isinstance(val, torch.Tensor): + add_profile(id, val) + self.profiles = [profile] + + self.dynamic_axes = get_dynamic_axes(self.profiles) + + if len(self.dynamic_axes) > 0: + export_args.update({"dynamic_axes": self.dynamic_axes}) # Use temporary directory for easy cleanup in case of external weights with tempfile.TemporaryDirectory() as tmpdir: - onnx_path = Path(tmpdir) / "model.onnx" + unrolled_input = unroll_input(self.input_names, input_example) + onnx_path = str(Path(tmpdir) / "model.onnx") self.logger.info( - f"Exporting to {onnx_path}:\n\toutput_names={self.output_names}\n\texport args: {export_args}" + f"Exporting to {onnx_path}:\nunrolled_inputs={list(unrolled_input.keys())}\n" + + f"output_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}" ) convert_to_onnx( model, input_example, - filename=str(onnx_path), - input_names=self.input_names, + filename=onnx_path, + input_names=list(unrolled_input.keys()), output_names=self.output_names, **export_args, ) self.logger.info("Export to ONNX successful.") - engine_bytes = self._onnx_to_trt(str(onnx_path)) - - open(self.plan_path, "wb").write(engine_bytes) + engine_bytes = self._onnx_to_trt(onnx_path) + if engine_bytes: + open(self.plan_path, "wb").write(engine_bytes) def trt_forward(self, *argv, **kwargs): @@ -540,9 +643,11 @@ def trt_compile( args["timestamp"] = timestamp def wrap(model, path): - wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) - model._trt_compiler = wrapper - model.forward = MethodType(trt_forward, model) + if not hasattr(model, "_trt_compiler"): + model.orig_forward = model.forward + wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) + model._trt_compiler = wrapper + model.forward = MethodType(trt_forward, model) def find_sub(parent, submodule): idx = submodule.find(".") diff --git a/monai/networks/utils.py b/monai/networks/utils.py index cfad0364c3..05627f9c00 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -632,7 +632,6 @@ def convert_to_onnx( use_trace: bool = True, do_constant_folding: bool = True, constant_size_threshold: int = 16 * 1024 * 1024 * 1024, - dynamo=False, **kwargs, ): """ @@ -673,6 +672,9 @@ def convert_to_onnx( # let torch.onnx.export to trace the model. mode_to_export = model torch_versioned_kwargs = kwargs + if "dynamo" in kwargs and kwargs["dynamo"] and verify: + torch_versioned_kwargs["verify"] = verify + verify = False else: if not pytorch_after(1, 10): if "example_outputs" not in kwargs: @@ -695,13 +697,13 @@ def convert_to_onnx( f = temp_file.name else: f = filename - + print(f"torch_versioned_kwargs={torch_versioned_kwargs}") torch.onnx.export( mode_to_export, onnx_inputs, f=f, input_names=input_names, - output_names=output_names, + output_names=output_names or None, dynamic_axes=dynamic_axes, opset_version=opset_version, do_constant_folding=do_constant_folding, @@ -715,6 +717,9 @@ def convert_to_onnx( fold_constants(onnx_model, size_threshold=constant_size_threshold) if verify: + if isinstance(inputs, dict): + inputs = list(inputs.values()) + if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/monai/torch.patch b/monai/torch.patch new file mode 100644 index 0000000000..e53980968b --- /dev/null +++ b/monai/torch.patch @@ -0,0 +1,9 @@ +--- /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.py 2024-10-31 06:09:21.139938791 +0000 ++++ /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.bak 2024-10-31 06:01:50.207462739 +0000 +@@ -150,6 +150,7 @@ + ), "is_causal and attn_mask cannot be set at the same time" + assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + ++ scale = symbolic_helper._maybe_get_const(scale, "f") + if symbolic_helper._is_none(scale): + scale = _attention_scale(g, query) diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 5a56f0e4a2..8231387601 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -19,11 +19,12 @@ from monai.handlers import TrtHandler from monai.networks import trt_compile -from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 +from monai.networks.nets import cell_sam_wrapper, vista3d132 from monai.utils import min_version, optional_import -from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows +from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version) +torch_tensorrt, torch_trt_imported = optional_import("torch_tensorrt") polygraphy, polygraphy_imported = optional_import("polygraphy") build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") @@ -31,6 +32,19 @@ TEST_CASE_2 = ["fp16"] +class ListAdd(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: list[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: float = float(0.1)): + y1 = y.clone() + x1 = x.copy() + z1 = z + y + for xi in x: + y1 = y1 + xi + bs + return x1, [y1, z1], y1 + z1 + + @skip_if_windows @skip_if_no_cuda @skip_if_quick @@ -46,7 +60,7 @@ def tearDown(self): if current_device != self.gpu_device: torch.cuda.set_device(self.gpu_device) - @SkipIfAtLeastPyTorchVersion((2, 5, 0)) + @unittest.skipUnless(torch_trt_imported, "torch_tensorrt is required") def test_handler(self): from ignite.engine import Engine @@ -67,29 +81,19 @@ def test_handler(self): net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device="cuda")) self.assertIsNotNone(net1._trt_compiler.engine) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_unet_value(self, precision): - model = UNet( - spatial_dims=3, - in_channels=1, - out_channels=2, - channels=(2, 2, 4, 8, 4), - strides=(2, 2, 2, 2), - num_res_units=2, - norm="batch", - ).cuda() + def test_lists(self): + model = ListAdd().cuda() + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: - model.eval() - input_example = torch.randn(2, 1, 96, 96, 96).cuda() - output_example = model(input_example) - args: dict = {"builder_optimization_level": 1} - trt_compile( - model, - f"{tmpdir}/test_unet_trt_compile", - args={"precision": precision, "build_args": args, "dynamic_batchsize": [1, 4, 8]}, - ) + args = {"output_lists": [[-1], [2], []], "export_args": {"dynamo": False, "verbose": True}} + x = torch.randn(1, 16).to("cuda") + y = torch.randn(1, 16).to("cuda") + z = torch.randn(1, 16).to("cuda") + input_example = ([x, y, z], y.clone(), z.clone()) + output_example = model(*input_example) + trt_compile(model, f"{tmpdir}/test_lists", args=args) self.assertIsNone(model._trt_compiler.engine) - trt_output = model(input_example) + trt_output = model(*input_example) # Check that lazy TRT build succeeded self.assertIsNotNone(model._trt_compiler.engine) torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) @@ -102,11 +106,7 @@ def test_cell_sam_wrapper_value(self, precision): model.eval() input_example = torch.randn(1, 3, 128, 128).to("cuda") output_example = model(input_example) - trt_compile( - model, - f"{tmpdir}/test_cell_sam_wrapper_trt_compile", - args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, - ) + trt_compile(model, f"{tmpdir}/test_cell_sam_wrapper_trt_compile", args={"precision": precision}) self.assertIsNone(model._trt_compiler.engine) trt_output = model(input_example) # Check that lazy TRT build succeeded @@ -123,7 +123,7 @@ def test_vista3d(self, precision): model = trt_compile( model, f"{tmpdir}/test_vista3d_trt_compile", - args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, + args={"precision": precision, "dynamic_batchsize": [1, 2, 4]}, submodule=["image_encoder.encoder", "class_head"], ) self.assertIsNotNone(model.image_encoder.encoder._trt_compiler)