Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRT support for MAISI #8153

Open
wants to merge 121 commits into
base: dev
Choose a base branch
from
Open

TRT support for MAISI #8153

wants to merge 121 commits into from

Conversation

borisfom
Copy link
Contributor

@borisfom borisfom commented Oct 16, 2024

Description

Added trt_compile() support for Lists and Tuples in arguments for forward() - needed for MAISI.
Did not add support for grouping return results yet - MAISI worked with explicit workaround unrolling the return results.

Notes

To successfully export MAISI, either latest Torch nightly is needed, or this patch needs to be applied to 24.09-based container:

--- /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.bak     2024-10-09 01:38:04.920316673 +0000                                                   
+++ /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.py      2024-10-09 01:38:25.228053951 +0000                                                   
@@ -148,7 +148,6 @@                                                                                                                                                   
         is_causal and symbolic_helper._is_none(attn_mask)                                                                                                            
     ), "is_causal and attn_mask cannot be set at the same time"                                                                                                      
                                                                                                                                                                      
-    scale = symbolic_helper._maybe_get_const(scale, "f")                                                                                                             
     if symbolic_helper._is_none(scale):                                                                                                                              
         scale = _attention_scale(g, query)                                                                                                                           

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).

borisfom and others added 30 commits August 4, 2024 23:17
Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
@borisfom
Copy link
Contributor Author

Also, I did not do any results verification. If any results depend on Meta tensors operation, that part may be lost. Please check!

monai/apps/generation/maisi/networks/controlnet_maisi.py Outdated Show resolved Hide resolved
Dockerfile Outdated
@@ -11,7 +11,7 @@

# To build with a different base image
# please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag.
ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.08-py3
ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.09-py3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May need more test for this base image update.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it does not make real difference (patch I mentioned in the description is needed for 24.09 anyway), so I may revert this one for now, too. 24.10 (and 2.5.0) won't require exporter patch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we'll need to update to version 24.10 once it's released, since 24.09 still doesn't meet the requirements, and MAISI still lacks TRT support?
I try to update the base image and trigger more test in this PR #8164, shown an error below:
#8164 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I believe it's better to skip 24.09 as it still requires a patch.

@@ -693,7 +695,7 @@ def convert_to_onnx(
f = io.BytesIO()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also modify this part based on the latest api from torch.onnx.export? Thanks!
#8149 (comment)

@KumoLiu
Copy link
Contributor

KumoLiu commented Oct 22, 2024

Hi @binliunls, please also review the trt related parts in this PR, thanks.

@KumoLiu KumoLiu mentioned this pull request Oct 22, 2024
7 tasks
@@ -255,6 +345,7 @@ def __init__(
'torch_trt' may not work for some nets. Also AMP must be turned off for it to work.
input_names: Optional list of input names. If None, will be read from the function signature.
output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary.
output_lists: Optional list of output lists.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add more details about this parameter and the relation between this one and the output_names? Now it's hard to understand the meaning of it.

Thanks,
Bin

@@ -233,13 +321,15 @@ def __init__(
method="onnx",
input_names=None,
output_names=None,
output_lists=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also add a simple test case to the unit test to show case how to use this parameter?

Thanks,
Bin

self._build_and_save(model, build_args)
# This will reassign input_names from the engine
build_args = args.copy()
with torch.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask the reason for adding the torch.no_grad() here? Was it caused some issues in the previous version?

Thanks,
Bin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes there were some issues with export. As TRT is inference-only, it makes sense to do the whole export with torch.no_grad() - this is the recommended way.

@@ -180,7 +184,8 @@ def try_set_inputs():
raise
self.cur_profile = next_profile
ctx.set_optimization_profile_async(self.cur_profile, stream)

except Exception:
raise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add more info to explain this exception?
Thanks,
Bin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be exception trying to set input shapes for which the engine was not built ; previously I had a logic there that would try rotating trt optimization profile index on such an exception - we do not use multiple profiles with MONAI so I should probably simplify the code.

# Simulate list/tuple unrolling during ONNX export
unrolled_input = {}
for name in input_names:
val = input_example[name]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think input_example.get(name, None) is a better choice here, in case there are any illegal keys.

Thanks,
Bin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can look more into making this robust for the odd cases.

@@ -41,6 +41,10 @@ RUN cp /tmp/requirements.txt /tmp/req.bak \
COPY LICENSE CHANGELOG.md CODE_OF_CONDUCT.md CONTRIBUTING.md README.md versioneer.py setup.py setup.cfg runtests.sh MANIFEST.in ./
COPY tests ./tests
COPY monai ./monai

# TODO: remove this line and torch.patch for 24.11
RUN patch -R -d /usr/local/lib/python3.10/dist-packages/torch/onnx/ < ./monai/torch.patch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems the patch not included in 24.10, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proper fix is not included with 24.10, yes, so we have to patch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants