Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Split DecomposeOpsForTraining into two steps #15954

Merged
merged 2 commits into from
Jan 16, 2024

Conversation

Lunderberg
Copy link
Contributor

Prior to this commit, the DecomposeOpsForTraining transform directly replaced relax.nn.batch_norm into more primitive relax operations. This required the decomposed form of relax.nn.batch_norm to be duplicated with DecomposeOpsForInference. This commit refactors the pass to occur in two steps, first to apply training-specific mutations, and then to decompose.

Having a clear DecomposeOps pass also has a clear single location for operator decomposition, which may be migrated into the operator definition in the future, similar to FLegalize.

@Lunderberg
Copy link
Contributor Author

@sunggg This is what I was referring to by a different decomposition being applied for training and for inference in #15842. This PR extracts that difference out into a separate pass, such that both training and inference can then share the same underlying definition of batch_norm.

Prior to this commit, the `DecomposeOpsForTraining` transform directly
replaced `relax.nn.batch_norm` into more primitive relax operations.
This required the decomposed form of `relax.nn.batch_norm` to be
duplicated with `DecomposeOpsForInference`.  This commit refactors the
pass to occur in two steps, first to apply training-specific
mutations, and then to decompose.

Having a clear `DecomposeOps` pass also has a clear single location
for operator decomposition, which may be migrated into the operator
definition in the future, similar to `FLegalize`.
@Lunderberg
Copy link
Contributor Author

Rebased onto unity to re-run CI, as it has been long enough that discrepancies could have arisen since the previous run.

Copy link
Contributor

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, for the late review, Eric! Overall, LGTM. One question.

@@ -531,6 +531,23 @@ Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassCont
return ModulePass(pass_func, pass_info);
}

Pass ApplyPassToFunction(Pass pass, String func_name, bool error_if_function_missing) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems useful. Can we get the List[String] instead of single function name so that we can generalize a little further?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that would work, but only in the case where error_if_function_missing = false. If there's a List[String], then it wouldn't be clear to a reader whether an error is raised when all of the listed functions are absent, or when any of the listed functions are absent.

I've been picturing this utility as being useful for hand-crafted optimization pipelines, where some optimizations should only be applied to specific functions. In those cases, raising an error when an expected function is missing would allow for earlier detection of an invalid optimization pipeline.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a third option, what if the func_name parameter is replaced with a func_name_regex parameter? That way, it could handle the cases of a single function and of a list of functions, and could clearly state when an error occurs by renaming the parameter to error_if_no_function_matches_regex.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the more I think on it, the more I like the regex option. Implemented, and going through CI now.

@Lunderberg
Copy link
Contributor Author

Sorry, for the late review

No problem, as I've been slow on responding as well.

Copy link
Contributor

@csullivan csullivan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM, thanks @Lunderberg. @sunggg please let @Lunderberg know if you have any follow ups or if we can merge. Thanks!

Copy link
Contributor

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regex sounds cool! Thanks @Lunderberg for the update and thanks @csullivan for the reminder!

@Lunderberg Lunderberg merged commit a2a1b53 into apache:unity Jan 16, 2024
17 checks passed
@Lunderberg Lunderberg deleted the unity_decompose_ops branch January 16, 2024 14:08
@tqchen
Copy link
Member

tqchen commented Jan 16, 2024

Quick note. this PR depends on std::regex and we know that there can be a symbol issue with pytorch, see #16249 we may need to revert the usage or change to prefix matching for now so we can unblock unity main transition.

For future usage of regex, let us consider simple pattern and prefix for now.

Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Jan 24, 2024
This is a reapplication of apache#15954,
after resolving the breakages that required reverting in
apache#16442.  The regex matching is now
implemented without the `#include <regex>` from the C++ stdlib, to
avoid ABI incompatibility with pytorch.

Prior to this commit, the `DecomposeOpsForTraining` transform directly
replaced `relax.nn.batch_norm` into more primitive relax operations.
This required the decomposed form of `relax.nn.batch_norm` to be
duplicated with `DecomposeOpsForInference`.  This commit refactors the
pass to occur in two steps, first to apply training-specific
mutations, and then to decompose.

Having a clear `DecomposeOps` pass also has a clear single location
for operator decomposition, which may be migrated into the operator
definition in the future, similar to `FLegalize`.
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Jan 24, 2024
This is a reapplication of apache#15954,
after resolving the breakages that required reverting in
apache#16442.  The regex matching is now
implemented without the `#include <regex>` from the C++ stdlib, to
avoid ABI incompatibility with pytorch.

Prior to this commit, the `DecomposeOpsForTraining` transform directly
replaced `relax.nn.batch_norm` into more primitive relax operations.
This required the decomposed form of `relax.nn.batch_norm` to be
duplicated with `DecomposeOpsForInference`.  This commit refactors the
pass to occur in two steps, first to apply training-specific
mutations, and then to decompose.

Having a clear `DecomposeOps` pass also has a clear single location
for operator decomposition, which may be migrated into the operator
definition in the future, similar to `FLegalize`.
Lunderberg added a commit that referenced this pull request Feb 6, 2024
* [Support] Add PackedFunc "tvm.support.regex_match"

This function should be used instead of `std::regex` within C++ call
sites, to avoid ABI incompatibilities with pytorch.

Currently, the pytorch wheels available through pip install use the
pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to user
the pre-C++11 ABI, this would cause breakages with dynamically-linked
LLVM environments.

Use of the `<regex>` header in TVM should be avoided, as its
implementation is not supported by gcc's dual ABI. This ABI
incompatibility results in runtime errors either when `std::regex` is
called from TVM, or when `std::regex` is called from pytorch,
depending on which library was loaded first.  This restriction can be
removed when a version of pytorch compiled using `-DUSE_CXX11_ABI=1`
is available from PyPI.

[0] pytorch/pytorch#51039

* [Redo][Unity] Split DecomposeOpsForTraining into two steps

This is a reapplication of #15954,
after resolving the breakages that required reverting in
#16442.  The regex matching is now
implemented without the `#include <regex>` from the C++ stdlib, to
avoid ABI incompatibility with pytorch.

Prior to this commit, the `DecomposeOpsForTraining` transform directly
replaced `relax.nn.batch_norm` into more primitive relax operations.
This required the decomposed form of `relax.nn.batch_norm` to be
duplicated with `DecomposeOpsForInference`.  This commit refactors the
pass to occur in two steps, first to apply training-specific
mutations, and then to decompose.

Having a clear `DecomposeOps` pass also has a clear single location
for operator decomposition, which may be migrated into the operator
definition in the future, similar to `FLegalize`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants