Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: improve Brevitas compatibility with torch.compile #785

Closed
wants to merge 16 commits into from

Conversation

nickfraser
Copy link
Collaborator

Currently crashes occasionally occur when using BREVITAS_JIT=1 & torch.compile, ideally these should interoperate without crashing until TorchScript is deprecated by upstream PyTorch.

@Giuseppe5 Giuseppe5 marked this pull request as ready for review August 15, 2024 13:21
@Giuseppe5
Copy link
Collaborator

Giuseppe5 commented Aug 18, 2024

Current findings after applying compile to quantized models (e.g., ready for inference):

  • Current structure of QuantTensor is not amenable to compile
  • It seems possible to get fullgraph compilation with some modifications to the codebase (e.g., removing global variables, adding zero_zero_point flag to avoid data-dependent conditional flow, allow_in_graph for QuantTensor) for layerwise only quantization. However it seems a bit brittle.
  • Inheriting from another NamedTuple (IntQuantTensorBase or FloatQuantTensorBase) breaks the is_namedtuple check within PyTorch, which could be patched.
  • Defining a __new__ method (as we do for type checking) is also not supported. It seems that at compile time the type instance is passed multiple time within the call to new.
  • The above problems are currently being investigate (hopefully) here : Cannot override __add__ in NamedTuple with __new__ + torch.compile pytorch/pytorch#133762
  • Even if that were solved, it seems that compile does not like torch.bool dtypes for constant values (e.g., zero_zero_point or training or signed) and it rather prefers simple bool. torch.bool dtype causes the NamedTuple to decay to tuple, e.g. when adding 2 QT.
  • Calling .item() to get a bool from a torch.bool is kind of supported with certain compile flags

After this, my suggestions for this PR:

  • Most likely, let's do it post-release
  • It might be worth getting rid of training vs inference behavior for QT to avoid "data dependent" checks
  • It might be worth switching to support compile while simultaneously dropping support to PyTorch less than 2.0 to avoid lots of import checks
  • It might be worth considering the switch to Tensor subclass which might be more amenable with respect to compile. This would also mean dropping support to any PyTorch version less than 2.0
  • Or we could have a compile supports in export_mode, where we never instantiate QuantTensor at proxy level, and we only deal with dequantized values. Doing this in export mode ensures that we don't need to propagate QuantTensors since everything is cached (e.g., metadata for bias quantization)

@nickfraser nickfraser closed this Sep 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
next release PRs which should be merged for the next release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants