-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity] Split DecomposeOpsForTraining into two steps #15954
Conversation
Prior to this commit, the `DecomposeOpsForTraining` transform directly replaced `relax.nn.batch_norm` into more primitive relax operations. This required the decomposed form of `relax.nn.batch_norm` to be duplicated with `DecomposeOpsForInference`. This commit refactors the pass to occur in two steps, first to apply training-specific mutations, and then to decompose. Having a clear `DecomposeOps` pass also has a clear single location for operator decomposition, which may be migrated into the operator definition in the future, similar to `FLegalize`.
a5659ab
to
46932df
Compare
Rebased onto |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, for the late review, Eric! Overall, LGTM. One question.
src/ir/transform.cc
Outdated
@@ -531,6 +531,23 @@ Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassCont | |||
return ModulePass(pass_func, pass_info); | |||
} | |||
|
|||
Pass ApplyPassToFunction(Pass pass, String func_name, bool error_if_function_missing) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function seems useful. Can we get the List[String]
instead of single function name so that we can generalize a little further?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that would work, but only in the case where error_if_function_missing = false
. If there's a List[String]
, then it wouldn't be clear to a reader whether an error is raised when all of the listed functions are absent, or when any of the listed functions are absent.
I've been picturing this utility as being useful for hand-crafted optimization pipelines, where some optimizations should only be applied to specific functions. In those cases, raising an error when an expected function is missing would allow for earlier detection of an invalid optimization pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a third option, what if the func_name
parameter is replaced with a func_name_regex
parameter? That way, it could handle the cases of a single function and of a list of functions, and could clearly state when an error occurs by renaming the parameter to error_if_no_function_matches_regex
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And the more I think on it, the more I like the regex option. Implemented, and going through CI now.
No problem, as I've been slow on responding as well. |
c95d45f
to
45eeb8c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM, thanks @Lunderberg. @sunggg please let @Lunderberg know if you have any follow ups or if we can merge. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regex sounds cool! Thanks @Lunderberg for the update and thanks @csullivan for the reminder!
Quick note. this PR depends on std::regex and we know that there can be a symbol issue with pytorch, see #16249 we may need to revert the usage or change to prefix matching for now so we can unblock unity main transition. For future usage of regex, let us consider simple pattern and prefix for now. |
This is a reapplication of apache#15954, after resolving the breakages that required reverting in apache#16442. The regex matching is now implemented without the `#include <regex>` from the C++ stdlib, to avoid ABI incompatibility with pytorch. Prior to this commit, the `DecomposeOpsForTraining` transform directly replaced `relax.nn.batch_norm` into more primitive relax operations. This required the decomposed form of `relax.nn.batch_norm` to be duplicated with `DecomposeOpsForInference`. This commit refactors the pass to occur in two steps, first to apply training-specific mutations, and then to decompose. Having a clear `DecomposeOps` pass also has a clear single location for operator decomposition, which may be migrated into the operator definition in the future, similar to `FLegalize`.
This is a reapplication of apache#15954, after resolving the breakages that required reverting in apache#16442. The regex matching is now implemented without the `#include <regex>` from the C++ stdlib, to avoid ABI incompatibility with pytorch. Prior to this commit, the `DecomposeOpsForTraining` transform directly replaced `relax.nn.batch_norm` into more primitive relax operations. This required the decomposed form of `relax.nn.batch_norm` to be duplicated with `DecomposeOpsForInference`. This commit refactors the pass to occur in two steps, first to apply training-specific mutations, and then to decompose. Having a clear `DecomposeOps` pass also has a clear single location for operator decomposition, which may be migrated into the operator definition in the future, similar to `FLegalize`.
* [Support] Add PackedFunc "tvm.support.regex_match" This function should be used instead of `std::regex` within C++ call sites, to avoid ABI incompatibilities with pytorch. Currently, the pytorch wheels available through pip install use the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to user the pre-C++11 ABI, this would cause breakages with dynamically-linked LLVM environments. Use of the `<regex>` header in TVM should be avoided, as its implementation is not supported by gcc's dual ABI. This ABI incompatibility results in runtime errors either when `std::regex` is called from TVM, or when `std::regex` is called from pytorch, depending on which library was loaded first. This restriction can be removed when a version of pytorch compiled using `-DUSE_CXX11_ABI=1` is available from PyPI. [0] pytorch/pytorch#51039 * [Redo][Unity] Split DecomposeOpsForTraining into two steps This is a reapplication of #15954, after resolving the breakages that required reverting in #16442. The regex matching is now implemented without the `#include <regex>` from the C++ stdlib, to avoid ABI incompatibility with pytorch. Prior to this commit, the `DecomposeOpsForTraining` transform directly replaced `relax.nn.batch_norm` into more primitive relax operations. This required the decomposed form of `relax.nn.batch_norm` to be duplicated with `DecomposeOpsForInference`. This commit refactors the pass to occur in two steps, first to apply training-specific mutations, and then to decompose. Having a clear `DecomposeOps` pass also has a clear single location for operator decomposition, which may be migrated into the operator definition in the future, similar to `FLegalize`.
Prior to this commit, the
DecomposeOpsForTraining
transform directly replacedrelax.nn.batch_norm
into more primitive relax operations. This required the decomposed form ofrelax.nn.batch_norm
to be duplicated withDecomposeOpsForInference
. This commit refactors the pass to occur in two steps, first to apply training-specific mutations, and then to decompose.Having a clear
DecomposeOps
pass also has a clear single location for operator decomposition, which may be migrated into the operator definition in the future, similar toFLegalize
.