Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRT support for MAISI #8153

Open
wants to merge 121 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 108 commits
Commits
Show all changes
121 commits
Select commit Hold shift + click to select a range
054a2b8
Added TRTWrapper
borisfom Aug 5, 2024
3ab9c83
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
4ec1d3b
Merge branch 'dev' into trt-wrappers
KumoLiu Aug 5, 2024
fe71030
Addressing code review comments, adding docustrings, cleanup
borisfom Aug 5, 2024
6a9727f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
29d9725
Added TRT 10.3RC to Dockerfile
borisfom Aug 6, 2024
5b8b4f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
f31d6dd
Workaround for format check
borisfom Aug 6, 2024
9303c32
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
c1d0b19
More format check workarounds
borisfom Aug 6, 2024
63c4b70
More format check workarounds
borisfom Aug 6, 2024
9a3d6a6
More format check workarounds
borisfom Aug 6, 2024
8bf0300
Using optional exports for trt_utils
borisfom Aug 6, 2024
c03e49b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
39c94c2
Fixing lint errors
borisfom Aug 6, 2024
35dffcc
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 6, 2024
9d867a7
Format fixed
borisfom Aug 6, 2024
6e2733a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
848a42d
Fixing flake errors
borisfom Aug 6, 2024
9ade6af
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 6, 2024
cf2c3b1
Fixing CI
borisfom Aug 6, 2024
e8b51f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
ddb5bc8
Fixed mypy, Engine refactor
borisfom Aug 7, 2024
79014d7
Merge branch 'dev' into trt-wrappers
yiheng-wang-nv Aug 7, 2024
511081f
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 7, 2024
b188237
Merged cast_utils, copyrights fixed.
borisfom Aug 8, 2024
60cdd74
Added unit test
borisfom Aug 8, 2024
778a44a
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 9, 2024
0ab5d26
TRTWrapper moved to networks
borisfom Aug 9, 2024
a948bfb
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 9, 2024
3a72c76
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 9, 2024
7d449f5
Refactored TRTWrapper args
borisfom Aug 10, 2024
6846fd4
Added docstring for precision
borisfom Aug 10, 2024
d598590
Fixed comments, reordered args
borisfom Aug 11, 2024
9109d3f
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 12, 2024
517c111
Reduced test assert accuracy
borisfom Aug 12, 2024
4739756
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 14, 2024
ed0d93d
Addressing code review comments
borisfom Aug 14, 2024
2ec8e53
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 15, 2024
fdcf118
Added Torch-TRT option, cleaned up engine save method
borisfom Aug 15, 2024
1009dc5
Added trt_wrap adapter
borisfom Aug 16, 2024
763f769
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 16, 2024
fd679c0
Refined trt_wrap
borisfom Aug 16, 2024
dc13b52
Used tempdir for ONNX
borisfom Aug 17, 2024
779de92
Refactored trt wrapper, added trt handler
borisfom Aug 18, 2024
6504dc9
Adjusted refactor for use in config
borisfom Aug 18, 2024
c1be72c
Added fold constant threshold param
borisfom Aug 20, 2024
0f16b8b
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 20, 2024
5c495b6
Logger refactoring
borisfom Aug 20, 2024
5d1ebc2
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 20, 2024
48b85ce
Addressing code review comments
borisfom Aug 22, 2024
1244c49
Added multiple submodules option to trt_wrap
borisfom Aug 22, 2024
a603f13
Added polygraphy to more places, torch-tensorrt option debugging
borisfom Aug 23, 2024
f5be0cc
Renamed trt_wrap -> trt_compile
borisfom Aug 23, 2024
b96ebb4
Reformatted for CI
borisfom Aug 23, 2024
73be701
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 23, 2024
85140e2
Fixed alias issue
borisfom Aug 23, 2024
fa4c182
Fixed base in Dockerfile
borisfom Aug 23, 2024
78a3ef3
Fixed CI test failures
borisfom Aug 23, 2024
267c125
Addressed code review comments
borisfom Aug 23, 2024
9adc035
Added dictionary return option
borisfom Aug 26, 2024
a017fcd
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 26, 2024
7f1c0c1
Fixed return_dict issue
borisfom Aug 26, 2024
a242a64
Implemented https://github.com/Project-MONAI/MONAI/issues/8044
borisfom Aug 26, 2024
5afc912
Generalizing merge logic, adding test case and doc
borisfom Aug 27, 2024
ceff018
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
e294968
Addressing code review comments
borisfom Aug 27, 2024
55cf7fa
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 27, 2024
b6d9179
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
652448a
doc build fixed
borisfom Aug 28, 2024
b793eb2
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 28, 2024
5c4f63a
Fixed formatting
borisfom Aug 28, 2024
dd91183
Fixed formatting
borisfom Aug 28, 2024
c41cb5a
Updated base container to 24.08
borisfom Aug 28, 2024
7e440fc
Renaming trt_wrapper -> trt_compiler, adding TRT handler test
borisfom Aug 28, 2024
329d024
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2024
84de860
fixing CI error
borisfom Aug 28, 2024
875d1a8
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 28, 2024
b84cec4
Fixing min test error, addressing comments
borisfom Aug 28, 2024
9481d9f
optional propagation of dynamo arg fixed, onnx_graphsurgeon package a…
borisfom Aug 28, 2024
6a11581
add vista test cases
yiheng-wang-nv Aug 28, 2024
6e8bd6b
Merge branch 'dev' into trt-wrappers
yiheng-wang-nv Aug 28, 2024
3d221cb
Merge branch 'dev' into trt-wrappers
KumoLiu Aug 28, 2024
792a721
Code review input addressed
borisfom Aug 29, 2024
73ac717
Fixed torch-tensorrt path of trt_compile, added test
borisfom Aug 29, 2024
ea879f2
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 29, 2024
1e7e76d
Fixing tests
borisfom Aug 29, 2024
47e676e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2024
1cb49c3
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 29, 2024
dd4d2d6
Merge branch 'dev' into trt-wrappers
binliunls Aug 29, 2024
6b47a8b
Merge branch 'dev' into trt-wrappers
KumoLiu Aug 31, 2024
80d3928
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Sep 3, 2024
47823d1
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Sep 3, 2024
c126e67
Fixing TRT 8.x compatibility
borisfom Sep 3, 2024
5645157
Improved diagnostic, skip trt test if < 10.3
borisfom Sep 3, 2024
9e98f66
Merge branch 'dev' into trt_compiler_fixes
KumoLiu Sep 4, 2024
72a4c3d
trt_compile post-fixes
borisfom Oct 3, 2024
b5f8ff2
Merge remote-tracking branch 'origin/dev' into trt_compiler_fixes
borisfom Oct 3, 2024
bf61b48
exporting controlnet
borisfom Oct 9, 2024
3d16f86
Merge branch 'trt_compiler_fixes' of github.com:borisfom/MONAI into t…
borisfom Oct 9, 2024
6297b45
Working controlnet TRT
borisfom Oct 10, 2024
f00fea4
Reformat
borisfom Oct 10, 2024
57fbcf0
Working TRT for MAISI
borisfom Oct 12, 2024
a004bc5
Working dynamic batch with sequences
borisfom Oct 16, 2024
cee7299
Merge remote-tracking branch 'origin/dev' into maisi-trt
borisfom Oct 16, 2024
adf9bc9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 16, 2024
4002d9d
Merge fixed and style
borisfom Oct 16, 2024
d8407c9
Merge branch 'maisi-trt' of github.com:borisfom/MONAI into maisi-trt
borisfom Oct 16, 2024
a37fd53
Merge remote-tracking branch 'origin/dev' into maisi-trt
borisfom Oct 19, 2024
43ea6a0
Added output_lists option
borisfom Oct 21, 2024
c1791f6
Bugfix for multiple initialization
borisfom Oct 30, 2024
ce73be5
Merge remote-tracking branch 'origin/dev' into maisi-trt
borisfom Oct 30, 2024
14912d9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2024
30b8bcf
Adding Torch patch
borisfom Oct 31, 2024
7adf804
Merge branch 'maisi-trt' of github.com:borisfom/MONAI into maisi-trt
borisfom Oct 31, 2024
8baaa74
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
214def9
Fixing torch_trt compile and test case
borisfom Nov 1, 2024
2452c97
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2024
d07a57f
Added rename table for TRT engine, test for output lists
borisfom Nov 2, 2024
c6e11bb
Merge branch 'maisi-trt' of github.com:borisfom/MONAI into maisi-trt
borisfom Nov 2, 2024
6eeed4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# To build with a different base image
# please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag.
ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.08-py3
ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.09-py3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May need more test for this base image update.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it does not make real difference (patch I mentioned in the description is needed for 24.09 anyway), so I may revert this one for now, too. 24.10 (and 2.5.0) won't require exporter patch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we'll need to update to version 24.10 once it's released, since 24.09 still doesn't meet the requirements, and MAISI still lacks TRT support?
I try to update the base image and trigger more test in this PR #8164, shown an error below:
#8164 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I believe it's better to skip 24.09 as it still requires a patch.

FROM ${PYTORCH_IMAGE}

LABEL maintainer="[email protected]"
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/generation/maisi/networks/controlnet_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def forward(
down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples]
mid_block_res_sample *= conditioning_scale

return down_block_res_samples, mid_block_res_sample
return [*down_block_res_samples, mid_block_res_sample]
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

def _prepare_time_and_class_embedding(self, x, timesteps, class_labels):
# 1. time
Expand Down
23 changes: 3 additions & 20 deletions monai/networks/blocks/spatialattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
import torch.nn as nn

from monai.networks.blocks import SABlock
from monai.utils import optional_import

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")


class SpatialAttentionBlock(nn.Module):
Expand Down Expand Up @@ -74,24 +71,10 @@ def __init__(

def forward(self, x: torch.Tensor):
residual = x

if self.spatial_dims == 1:
h = x.shape[2]
rearrange_input = Rearrange("b c h -> b h c")
rearrange_output = Rearrange("b h c -> b c h", h=h)
if self.spatial_dims == 2:
h, w = x.shape[2], x.shape[3]
rearrange_input = Rearrange("b c h w -> b (h w) c")
rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w)
else:
h, w, d = x.shape[2], x.shape[3], x.shape[4]
rearrange_input = Rearrange("b c h w d -> b (h w d) c")
rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d)

shape = x.shape
x = self.norm(x)
x = rearrange_input(x) # B x (x_dim * y_dim [ * z_dim]) x C

x = x.reshape(*shape[:2], -1).transpose(1, 2) # "b c h w d -> b (h w d) c"
x = self.attn(x)
x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim]
x = x.transpose(1, 2).reshape(shape) # "b (h w d) c -> b c h w d"
x = x + residual
return x
9 changes: 3 additions & 6 deletions monai/networks/nets/vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,12 +639,9 @@ def forward(self, src: torch.Tensor, class_vector: torch.Tensor):
if self.use_mlp:
class_embedding = self.mlp(class_embedding)
# [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension.
masks = []
for i in range(b):
mask = class_embedding @ src[[i]].view(1, c, h * w * d)
masks.append(mask.view(-1, 1, h, w, d))

return torch.cat(masks, 1), class_embedding
masks_embedding = class_embedding.squeeze() @ src.view(b, c, h * w * d)
masks_embedding = masks_embedding.view(b, -1, h, w, d).transpose(0, 1)
return masks_embedding, class_embedding


class TwoWayTransformer(nn.Module):
Expand Down
143 changes: 102 additions & 41 deletions monai/networks/trt_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def __init__(self, plan_path, logger=None):
self.output_names.append(binding)
dtype = dtype_dict[self.engine.get_tensor_dtype(binding)]
self.dtypes.append(dtype)
self.logger.info(
f"Loaded TensorRT engine: {self.plan_path}.\nInputs: {self.input_names}\nOutputs: {self.output_names}"
)

def allocate_buffers(self, device):
"""
Expand Down Expand Up @@ -163,7 +166,8 @@ def set_inputs(self, feed_dict, stream):
last_profile = self.cur_profile

def try_set_inputs():
for binding, t in feed_dict.items():
for binding in self.input_names:
t = feed_dict[binding]
if t is not None:
t = t.contiguous()
shape = t.shape
Expand All @@ -180,7 +184,8 @@ def try_set_inputs():
raise
self.cur_profile = next_profile
ctx.set_optimization_profile_async(self.cur_profile, stream)

except Exception:
raise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add more info to explain this exception?
Thanks,
Bin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be exception trying to set input shapes for which the engine was not built ; previously I had a logic there that would try rotating trt optimization profile index on such an exception - we do not use multiple profiles with MONAI so I should probably simplify the code.

left = ctx.infer_shapes()
assert len(left) == 0

Expand Down Expand Up @@ -217,6 +222,40 @@ def infer(self, stream, use_cuda_graph=False):
return self.tensors


def remove_non_tensors(input_example, remove_constants=True):
#
# TODO : see if we can instantiate wrappers to handle non-default non-tensors
#
non_tensors = {}
for k, v in input_example.items():
if v is None:
non_tensors[k] = v
elif not torch.is_tensor(v):
if remove_constants:
non_tensors[k] = v
else:
input_example[k] = torch.tensor(v)

for key in non_tensors.keys():
# print(f"Removing non-tensor input: {key})")
input_example.pop(key)
return non_tensors


def unroll_input(input_names, input_example):
# Simulate list/tuple unrolling during ONNX export
unrolled_input = {}
for name in input_names:
val = input_example[name]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think input_example.get(name, None) is a better choice here, in case there are any illegal keys.

Thanks,
Bin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can look more into making this robust for the odd cases.

if val is not None:
if isinstance(val, list | tuple):
for i in range(len(val)):
unrolled_input[f"{name}_{i}"] = val[i]
else:
unrolled_input[name] = val
return unrolled_input


class TrtCompiler:
"""
This class implements:
Expand All @@ -240,6 +279,7 @@ def __init__(
use_cuda_graph=False,
timestamp=None,
fallback=False,
forward_override=None,
logger=None,
):
"""
Expand Down Expand Up @@ -289,11 +329,18 @@ def __init__(
self.disabled = False

self.logger = logger or get_logger("trt_compile")

self.argspec = inspect.getfullargspec(model.forward)
# Normally we read input_names from forward() but can be overridden
if input_names is None:
argspec = inspect.getfullargspec(model.forward)
input_names = argspec.args[1:]
input_names = self.argspec.args[1:]
self.defaults = {}
if self.argspec.defaults is not None:
for i in range(len(self.argspec.defaults)):
d = self.argspec.defaults[-i - 1]
if d is not None:
d = torch.tensor(d).cuda()
self.defaults[self.argspec.args[-i - 1]] = d

self.input_names = input_names
self.old_forward = model.forward

Expand All @@ -314,9 +361,9 @@ def _load_engine(self):
"""
try:
self.engine = TRTEngine(self.plan_path, self.logger)
self.input_names = self.engine.input_names
self.logger.info(f"Engine loaded, inputs:{self.engine.input_names}")
except Exception as e:
self.logger.debug(f"Exception while loading the engine:\n{e}")
self.logger.info(f"Exception while loading the engine:\n{e}")

def forward(self, model, argv, kwargs):
"""
Expand All @@ -329,18 +376,22 @@ def forward(self, model, argv, kwargs):
Returns: Passing through wrapped module's forward() return value(s)

"""
args = self.defaults
args.update(kwargs)
if len(argv) > 0:
args.update(self._inputs_to_dict(argv))

if self.engine is None and not self.disabled:
# Restore original forward for export
new_forward = model.forward
model.forward = self.old_forward
try:
self._load_engine()
if self.engine is None:
build_args = kwargs.copy()
if len(argv) > 0:
build_args.update(self._inputs_to_dict(argv))
self._build_and_save(model, build_args)
# This will reassign input_names from the engine
build_args = args.copy()
with torch.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask the reason for adding the torch.no_grad() here? Was it caused some issues in the previous version?

Thanks,
Bin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes there were some issues with export. As TRT is inference-only, it makes sense to do the whole export with torch.no_grad() - this is the recommended way.

self._build_and_save(model, build_args)
# This will reassign input_names from the engine
self._load_engine()
assert self.engine is not None
except Exception as e:
Expand All @@ -355,19 +406,16 @@ def forward(self, model, argv, kwargs):
del param
# Call empty_cache to release GPU memory
torch.cuda.empty_cache()
# restore TRT hook
model.forward = new_forward
# Run the engine
try:
if len(argv) > 0:
kwargs.update(self._inputs_to_dict(argv))
argv = ()

if self.engine is not None:
# forward_trt is not thread safe as we do not use per-thread execution contexts
with lock_sm:
device = torch.cuda.current_device()
stream = torch.cuda.Stream(device=device)
self.engine.set_inputs(kwargs, stream.cuda_stream)
self.engine.set_inputs(unroll_input(self.input_names, args), stream.cuda_stream)
self.engine.allocate_buffers(device=device)
# Need this to synchronize with Torch stream
stream.wait_stream(torch.cuda.current_stream())
Expand All @@ -379,7 +427,7 @@ def forward(self, model, argv, kwargs):
ret = ret[0]
return ret
except Exception as e:
if model is not None:
if self.fallback:
self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...")
else:
raise e
Expand All @@ -391,16 +439,11 @@ def _onnx_to_trt(self, onnx_path):
"""

profiles = []
if self.profiles:
for input_profile in self.profiles:
if isinstance(input_profile, Profile):
profiles.append(input_profile)
else:
p = Profile()
for name, dims in input_profile.items():
assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
profiles.append(p)
for profile in self.profiles:
p = Profile()
for id, val in profile.items():
p.add(id, min=val[0], opt=val[1], max=val[2])
profiles.append(p)

build_args = self.build_args.copy()
build_args["tf32"] = self.precision != "fp32"
Expand All @@ -426,6 +469,9 @@ def _build_and_save(self, model, input_example):

export_args = self.export_args

# remove_non_tensors(input_example)

engine_bytes = None
add_casts_around_norms(model)

if self.method == "torch_trt":
Expand Down Expand Up @@ -459,33 +505,46 @@ def get_torch_trt_input(input_shape, dynamic_batchsize):
raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!")
if len(dbs) != 3:
raise ValueError("dynamic_batchsize has to have len ==3 ")
profiles = {}
profile = {}
for id, val in input_example.items():
sh = val.shape[1:]
profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]]
self.profiles = [profiles]

if len(self.profiles) > 0:
export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)})
def add_profile(id, val):
sh = val.shape
if len(sh) > 0:
sh = sh[1:]
profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]]

if isinstance(val, list | tuple):
for i in range(len(val)):
add_profile(f"{id}_{i}", val[i])
elif isinstance(val, torch.Tensor):
add_profile(id, val)
self.profiles = [profile]

self.dynamic_axes = get_dynamic_axes(self.profiles)

if len(self.dynamic_axes) > 0:
export_args.update({"dynamic_axes": self.dynamic_axes})

# Use temporary directory for easy cleanup in case of external weights
with tempfile.TemporaryDirectory() as tmpdir:
onnx_path = Path(tmpdir) / "model.onnx"
unrolled_input = unroll_input(self.input_names, input_example)
onnx_path = str(Path(tmpdir) / "model.onnx")
self.logger.info(
f"Exporting to {onnx_path}:\n\toutput_names={self.output_names}\n\texport args: {export_args}"
f"Exporting to {onnx_path}:\nunrolled_inputs={list(unrolled_input.keys())}\noutput_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}"
)
convert_to_onnx(
model,
input_example,
filename=str(onnx_path),
input_names=self.input_names,
filename=onnx_path,
input_names=list(unrolled_input.keys()),
output_names=self.output_names,
**export_args,
)
self.logger.info("Export to ONNX successful.")
engine_bytes = self._onnx_to_trt(str(onnx_path))

open(self.plan_path, "wb").write(engine_bytes)
engine_bytes = self._onnx_to_trt(onnx_path)
if engine_bytes:
open(self.plan_path, "wb").write(engine_bytes)


def trt_forward(self, *argv, **kwargs):
Expand Down Expand Up @@ -540,6 +599,8 @@ def trt_compile(
args["timestamp"] = timestamp

def wrap(model, path):
if not hasattr(model, "_trt_compiler"):
model.orig_forward = model.forward
wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args)
model._trt_compiler = wrapper
model.forward = MethodType(trt_forward, model)
Expand Down
9 changes: 7 additions & 2 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,6 @@ def convert_to_onnx(
use_trace: bool = True,
do_constant_folding: bool = True,
constant_size_threshold: int = 16 * 1024 * 1024 * 1024,
dynamo=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -672,6 +671,9 @@ def convert_to_onnx(
# let torch.onnx.export to trace the model.
mode_to_export = model
torch_versioned_kwargs = kwargs
if "dynamo" in kwargs and kwargs["dynamo"] and verify:
torch_versioned_kwargs["verify"] = verify
verify = False
else:
if not pytorch_after(1, 10):
if "example_outputs" not in kwargs:
Expand All @@ -693,7 +695,7 @@ def convert_to_onnx(
f = io.BytesIO()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also modify this part based on the latest api from torch.onnx.export? Thanks!
#8149 (comment)

else:
f = filename

print(f"torch_versioned_kwargs={torch_versioned_kwargs}")
torch.onnx.export(
mode_to_export,
onnx_inputs,
Expand All @@ -716,6 +718,9 @@ def convert_to_onnx(
fold_constants(onnx_model, size_threshold=constant_size_threshold)

if verify:
if isinstance(inputs, dict):
inputs = list(inputs.values())

if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down
2 changes: 2 additions & 0 deletions tests/test_trt_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows

trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version)
torch_tensorrt, torch_trt_imported = optional_import("torch_tensorrt")
polygraphy, polygraphy_imported = optional_import("polygraphy")
build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")

Expand All @@ -46,6 +47,7 @@ def tearDown(self):
if current_device != self.gpu_device:
torch.cuda.set_device(self.gpu_device)

@unittest.skipUnless(torch_trt_imported, "torch_tensorrt is required")
def test_handler(self):
from ignite.engine import Engine

Expand Down
Loading