-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TRT support for MAISI #8153
base: dev
Are you sure you want to change the base?
TRT support for MAISI #8153
Changes from 108 commits
054a2b8
3ab9c83
4ec1d3b
fe71030
6a9727f
29d9725
5b8b4f2
f31d6dd
9303c32
c1d0b19
63c4b70
9a3d6a6
8bf0300
c03e49b
39c94c2
35dffcc
9d867a7
6e2733a
848a42d
9ade6af
cf2c3b1
e8b51f4
ddb5bc8
79014d7
511081f
b188237
60cdd74
778a44a
0ab5d26
a948bfb
3a72c76
7d449f5
6846fd4
d598590
9109d3f
517c111
4739756
ed0d93d
2ec8e53
fdcf118
1009dc5
763f769
fd679c0
dc13b52
779de92
6504dc9
c1be72c
0f16b8b
5c495b6
5d1ebc2
48b85ce
1244c49
a603f13
f5be0cc
b96ebb4
73be701
85140e2
fa4c182
78a3ef3
267c125
9adc035
a017fcd
7f1c0c1
a242a64
5afc912
ceff018
e294968
55cf7fa
b6d9179
652448a
b793eb2
5c4f63a
dd91183
c41cb5a
7e440fc
329d024
84de860
875d1a8
b84cec4
9481d9f
6a11581
6e8bd6b
3d221cb
792a721
73ac717
ea879f2
1e7e76d
47e676e
1cb49c3
dd4d2d6
6b47a8b
80d3928
47823d1
c126e67
5645157
9e98f66
72a4c3d
b5f8ff2
bf61b48
3d16f86
6297b45
f00fea4
57fbcf0
a004bc5
cee7299
adf9bc9
4002d9d
d8407c9
a37fd53
43ea6a0
c1791f6
ce73be5
14912d9
30b8bcf
7adf804
8baaa74
214def9
2452c97
d07a57f
c6e11bb
6eeed4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
|
||
# To build with a different base image | ||
# please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. | ||
ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.08-py3 | ||
ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.09-py3 | ||
FROM ${PYTORCH_IMAGE} | ||
|
||
LABEL maintainer="[email protected]" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -134,6 +134,9 @@ def __init__(self, plan_path, logger=None): | |
self.output_names.append(binding) | ||
dtype = dtype_dict[self.engine.get_tensor_dtype(binding)] | ||
self.dtypes.append(dtype) | ||
self.logger.info( | ||
f"Loaded TensorRT engine: {self.plan_path}.\nInputs: {self.input_names}\nOutputs: {self.output_names}" | ||
) | ||
|
||
def allocate_buffers(self, device): | ||
""" | ||
|
@@ -163,7 +166,8 @@ def set_inputs(self, feed_dict, stream): | |
last_profile = self.cur_profile | ||
|
||
def try_set_inputs(): | ||
for binding, t in feed_dict.items(): | ||
for binding in self.input_names: | ||
t = feed_dict[binding] | ||
if t is not None: | ||
t = t.contiguous() | ||
shape = t.shape | ||
|
@@ -180,7 +184,8 @@ def try_set_inputs(): | |
raise | ||
self.cur_profile = next_profile | ||
ctx.set_optimization_profile_async(self.cur_profile, stream) | ||
|
||
except Exception: | ||
raise | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add more info to explain this exception? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would be exception trying to set input shapes for which the engine was not built ; previously I had a logic there that would try rotating trt optimization profile index on such an exception - we do not use multiple profiles with MONAI so I should probably simplify the code. |
||
left = ctx.infer_shapes() | ||
assert len(left) == 0 | ||
|
||
|
@@ -217,6 +222,40 @@ def infer(self, stream, use_cuda_graph=False): | |
return self.tensors | ||
|
||
|
||
def remove_non_tensors(input_example, remove_constants=True): | ||
# | ||
# TODO : see if we can instantiate wrappers to handle non-default non-tensors | ||
# | ||
non_tensors = {} | ||
for k, v in input_example.items(): | ||
if v is None: | ||
non_tensors[k] = v | ||
elif not torch.is_tensor(v): | ||
if remove_constants: | ||
non_tensors[k] = v | ||
else: | ||
input_example[k] = torch.tensor(v) | ||
|
||
for key in non_tensors.keys(): | ||
# print(f"Removing non-tensor input: {key})") | ||
input_example.pop(key) | ||
return non_tensors | ||
|
||
|
||
def unroll_input(input_names, input_example): | ||
# Simulate list/tuple unrolling during ONNX export | ||
unrolled_input = {} | ||
for name in input_names: | ||
val = input_example[name] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Thanks, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes we can look more into making this robust for the odd cases. |
||
if val is not None: | ||
if isinstance(val, list | tuple): | ||
for i in range(len(val)): | ||
unrolled_input[f"{name}_{i}"] = val[i] | ||
else: | ||
unrolled_input[name] = val | ||
return unrolled_input | ||
|
||
|
||
class TrtCompiler: | ||
""" | ||
This class implements: | ||
|
@@ -240,6 +279,7 @@ def __init__( | |
use_cuda_graph=False, | ||
timestamp=None, | ||
fallback=False, | ||
forward_override=None, | ||
logger=None, | ||
): | ||
""" | ||
|
@@ -289,11 +329,18 @@ def __init__( | |
self.disabled = False | ||
|
||
self.logger = logger or get_logger("trt_compile") | ||
|
||
self.argspec = inspect.getfullargspec(model.forward) | ||
# Normally we read input_names from forward() but can be overridden | ||
if input_names is None: | ||
argspec = inspect.getfullargspec(model.forward) | ||
input_names = argspec.args[1:] | ||
input_names = self.argspec.args[1:] | ||
self.defaults = {} | ||
if self.argspec.defaults is not None: | ||
for i in range(len(self.argspec.defaults)): | ||
d = self.argspec.defaults[-i - 1] | ||
if d is not None: | ||
d = torch.tensor(d).cuda() | ||
self.defaults[self.argspec.args[-i - 1]] = d | ||
|
||
self.input_names = input_names | ||
self.old_forward = model.forward | ||
|
||
|
@@ -314,9 +361,9 @@ def _load_engine(self): | |
""" | ||
try: | ||
self.engine = TRTEngine(self.plan_path, self.logger) | ||
self.input_names = self.engine.input_names | ||
self.logger.info(f"Engine loaded, inputs:{self.engine.input_names}") | ||
except Exception as e: | ||
self.logger.debug(f"Exception while loading the engine:\n{e}") | ||
self.logger.info(f"Exception while loading the engine:\n{e}") | ||
|
||
def forward(self, model, argv, kwargs): | ||
""" | ||
|
@@ -329,18 +376,22 @@ def forward(self, model, argv, kwargs): | |
Returns: Passing through wrapped module's forward() return value(s) | ||
|
||
""" | ||
args = self.defaults | ||
args.update(kwargs) | ||
if len(argv) > 0: | ||
args.update(self._inputs_to_dict(argv)) | ||
|
||
if self.engine is None and not self.disabled: | ||
# Restore original forward for export | ||
new_forward = model.forward | ||
model.forward = self.old_forward | ||
try: | ||
self._load_engine() | ||
if self.engine is None: | ||
build_args = kwargs.copy() | ||
if len(argv) > 0: | ||
build_args.update(self._inputs_to_dict(argv)) | ||
self._build_and_save(model, build_args) | ||
# This will reassign input_names from the engine | ||
build_args = args.copy() | ||
with torch.no_grad(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May I ask the reason for adding the Thanks, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes there were some issues with export. As TRT is inference-only, it makes sense to do the whole export with torch.no_grad() - this is the recommended way. |
||
self._build_and_save(model, build_args) | ||
# This will reassign input_names from the engine | ||
self._load_engine() | ||
assert self.engine is not None | ||
except Exception as e: | ||
|
@@ -355,19 +406,16 @@ def forward(self, model, argv, kwargs): | |
del param | ||
# Call empty_cache to release GPU memory | ||
torch.cuda.empty_cache() | ||
# restore TRT hook | ||
model.forward = new_forward | ||
# Run the engine | ||
try: | ||
if len(argv) > 0: | ||
kwargs.update(self._inputs_to_dict(argv)) | ||
argv = () | ||
|
||
if self.engine is not None: | ||
# forward_trt is not thread safe as we do not use per-thread execution contexts | ||
with lock_sm: | ||
device = torch.cuda.current_device() | ||
stream = torch.cuda.Stream(device=device) | ||
self.engine.set_inputs(kwargs, stream.cuda_stream) | ||
self.engine.set_inputs(unroll_input(self.input_names, args), stream.cuda_stream) | ||
self.engine.allocate_buffers(device=device) | ||
# Need this to synchronize with Torch stream | ||
stream.wait_stream(torch.cuda.current_stream()) | ||
|
@@ -379,7 +427,7 @@ def forward(self, model, argv, kwargs): | |
ret = ret[0] | ||
return ret | ||
except Exception as e: | ||
if model is not None: | ||
if self.fallback: | ||
self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...") | ||
else: | ||
raise e | ||
|
@@ -391,16 +439,11 @@ def _onnx_to_trt(self, onnx_path): | |
""" | ||
|
||
profiles = [] | ||
if self.profiles: | ||
for input_profile in self.profiles: | ||
if isinstance(input_profile, Profile): | ||
profiles.append(input_profile) | ||
else: | ||
p = Profile() | ||
for name, dims in input_profile.items(): | ||
assert len(dims) == 3 | ||
p.add(name, min=dims[0], opt=dims[1], max=dims[2]) | ||
profiles.append(p) | ||
for profile in self.profiles: | ||
p = Profile() | ||
for id, val in profile.items(): | ||
p.add(id, min=val[0], opt=val[1], max=val[2]) | ||
profiles.append(p) | ||
|
||
build_args = self.build_args.copy() | ||
build_args["tf32"] = self.precision != "fp32" | ||
|
@@ -426,6 +469,9 @@ def _build_and_save(self, model, input_example): | |
|
||
export_args = self.export_args | ||
|
||
# remove_non_tensors(input_example) | ||
|
||
engine_bytes = None | ||
add_casts_around_norms(model) | ||
|
||
if self.method == "torch_trt": | ||
|
@@ -459,33 +505,46 @@ def get_torch_trt_input(input_shape, dynamic_batchsize): | |
raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!") | ||
if len(dbs) != 3: | ||
raise ValueError("dynamic_batchsize has to have len ==3 ") | ||
profiles = {} | ||
profile = {} | ||
for id, val in input_example.items(): | ||
sh = val.shape[1:] | ||
profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] | ||
self.profiles = [profiles] | ||
|
||
if len(self.profiles) > 0: | ||
export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) | ||
def add_profile(id, val): | ||
sh = val.shape | ||
if len(sh) > 0: | ||
sh = sh[1:] | ||
profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] | ||
|
||
if isinstance(val, list | tuple): | ||
for i in range(len(val)): | ||
add_profile(f"{id}_{i}", val[i]) | ||
elif isinstance(val, torch.Tensor): | ||
add_profile(id, val) | ||
self.profiles = [profile] | ||
|
||
self.dynamic_axes = get_dynamic_axes(self.profiles) | ||
|
||
if len(self.dynamic_axes) > 0: | ||
export_args.update({"dynamic_axes": self.dynamic_axes}) | ||
|
||
# Use temporary directory for easy cleanup in case of external weights | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
onnx_path = Path(tmpdir) / "model.onnx" | ||
unrolled_input = unroll_input(self.input_names, input_example) | ||
onnx_path = str(Path(tmpdir) / "model.onnx") | ||
self.logger.info( | ||
f"Exporting to {onnx_path}:\n\toutput_names={self.output_names}\n\texport args: {export_args}" | ||
f"Exporting to {onnx_path}:\nunrolled_inputs={list(unrolled_input.keys())}\noutput_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}" | ||
) | ||
convert_to_onnx( | ||
model, | ||
input_example, | ||
filename=str(onnx_path), | ||
input_names=self.input_names, | ||
filename=onnx_path, | ||
input_names=list(unrolled_input.keys()), | ||
output_names=self.output_names, | ||
**export_args, | ||
) | ||
self.logger.info("Export to ONNX successful.") | ||
engine_bytes = self._onnx_to_trt(str(onnx_path)) | ||
|
||
open(self.plan_path, "wb").write(engine_bytes) | ||
engine_bytes = self._onnx_to_trt(onnx_path) | ||
if engine_bytes: | ||
open(self.plan_path, "wb").write(engine_bytes) | ||
|
||
|
||
def trt_forward(self, *argv, **kwargs): | ||
|
@@ -540,6 +599,8 @@ def trt_compile( | |
args["timestamp"] = timestamp | ||
|
||
def wrap(model, path): | ||
if not hasattr(model, "_trt_compiler"): | ||
model.orig_forward = model.forward | ||
wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) | ||
model._trt_compiler = wrapper | ||
model.forward = MethodType(trt_forward, model) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -631,7 +631,6 @@ def convert_to_onnx( | |
use_trace: bool = True, | ||
do_constant_folding: bool = True, | ||
constant_size_threshold: int = 16 * 1024 * 1024 * 1024, | ||
dynamo=False, | ||
**kwargs, | ||
): | ||
""" | ||
|
@@ -672,6 +671,9 @@ def convert_to_onnx( | |
# let torch.onnx.export to trace the model. | ||
mode_to_export = model | ||
torch_versioned_kwargs = kwargs | ||
if "dynamo" in kwargs and kwargs["dynamo"] and verify: | ||
torch_versioned_kwargs["verify"] = verify | ||
verify = False | ||
else: | ||
if not pytorch_after(1, 10): | ||
if "example_outputs" not in kwargs: | ||
|
@@ -693,7 +695,7 @@ def convert_to_onnx( | |
f = io.BytesIO() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please also modify this part based on the latest api from |
||
else: | ||
f = filename | ||
|
||
print(f"torch_versioned_kwargs={torch_versioned_kwargs}") | ||
torch.onnx.export( | ||
mode_to_export, | ||
onnx_inputs, | ||
|
@@ -716,6 +718,9 @@ def convert_to_onnx( | |
fold_constants(onnx_model, size_threshold=constant_size_threshold) | ||
|
||
if verify: | ||
if isinstance(inputs, dict): | ||
inputs = list(inputs.values()) | ||
|
||
if device is None: | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May need more test for this base image update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well it does not make real difference (patch I mentioned in the description is needed for 24.09 anyway), so I may revert this one for now, too. 24.10 (and 2.5.0) won't require exporter patch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean we'll need to update to version 24.10 once it's released, since 24.09 still doesn't meet the requirements, and MAISI still lacks TRT support?
I try to update the base image and trigger more test in this PR #8164, shown an error below:
#8164 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I believe it's better to skip 24.09 as it still requires a patch.