Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Support for Groupwise (MX) quantization #971

Merged
merged 37 commits into from
Aug 20, 2024

Conversation

Giuseppe5
Copy link
Collaborator

@Giuseppe5 Giuseppe5 commented Jun 17, 2024

This implements:

  • New GroupwiseQuantTensor for Int and Float
  • Relevant Proxy classes
  • MX Float based quantizers
  • One notebook to test instantiation and execution
  • Default MXInt quantizers

Missing:

  • Export

Features

Dynamic expansion of contracted groupwise tensors

Compared to Int/Float QuantTensor, the main difference of their groupwise equivalent is that value, scale, and zero_point are not direct attributes anymore but properties. The new attributes are value_, scale_, and zero_point_.

The reason for this is shaping. When quantizing a tensor with shapes [O, I], where O is output channel and I is input channel, with groupsize k, groupwise quantization is normally represented as follow:

  • Tensor with shapes [O, k, I/k]
  • Scales with shapes [O, k, 1]
  • Zero point same as scale

The alternative to this representation is to have all three tensors with shapes [O,I], with a massive increase in memory utilization, especially with QAT + gradients.

The underscored attributes will have the compressed shapes, while the properties (non-underscored naming) will dynamically compute the expanded version of the property. This means:

quant_tensor.scale_.shape
# This will print [O, k, 1]
quant_tensor.scale.shape
# This will print [O, I]

Internally, the quantization will happen with groupwise shaping (e.g., contracted). For this reason, there is a preliminary view applied to the tensors before everything goes in tensor_quant.

Deprecation of scaling_per_output_channel

Another important change of this PR is the deprecation (i.e., still usable but not recommended anymore) of the flag scaling_per_output_channel, in favor of a ternary flag scaling_per_output that can be TENSOR/CHANNEL/GROUP.

Lots of work has gone into maintaining retro-compatibility with the existing binary flag, so that everything will still work as intended.
One thing that is still up for discussion is how to handle shared quantizers for groupwise quantization.

Automatic OCP definition

When instantiating a OCP FP quantizer, NaN/INF encoding are automatically defined through dep inj based on the bitwidths. This is to avoid to have manually to define all possible minifloat quantizers needed for MX (e.g. MX FP8, FP6, etc.)

Groupwise Default quantizers

Lots of possible quantizers could be defined by default. For integer (not much of a problem) we have groupwise with float scaling (defined in examples if I'm not mistaken) and MX (groupwise with Po2 scale) defined in the brevitas source.

For float it gets a bit more complicated with inf/nan. The solution adopted is the following:

  • Float scale will use non-standard floating point representation, meaning no inf/nan representations for all the bitwidths. Defined in examples
  • Po2 scale will use OCP standard for FP8/FP6/FP4, so that we have MX Float compliant quantizers defined out of the box. Defined in brevitas source

I only created quantizers with bitwidth equal to 8, e4m3. Overriding the bit_width, mantissa_bit_width and exponent_bit_width will produce OCP compliant MX quantizers (thanks to the change in Automatic OCP definition)

Example changes

No longer separated flags for ocp/fnuz standard. Now the user should pass float_e4m3 for general fp8 no standard, float_ocp_e4m3 for OCP, float_fnuz_e4m3 for FNUZ.

@Giuseppe5 Giuseppe5 self-assigned this Aug 14, 2024
@nickfraser nickfraser added the next release PRs which should be merged for the next release label Aug 14, 2024
Copy link
Collaborator

@nickfraser nickfraser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth double-checking all of the <Datatype>QuantTensor instantiations where the order of the arguments really matter. I found a few errors that are likely due to some original issue + copy/paste, so they may also exist outside your specific changes.

I think it's worth checking double-checking as many as you can while you're fixing the ones I found.

src/brevitas/proxy/groupwise_int_parameter_quant.py Outdated Show resolved Hide resolved
src/brevitas/proxy/float_runtime_quant.py Show resolved Hide resolved
src/brevitas/quant_tensor/groupwise_int_quant_tensor.py Outdated Show resolved Hide resolved
src/brevitas/graph/quantize.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@nickfraser nickfraser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@Giuseppe5 Giuseppe5 merged commit f1655b2 into Xilinx:dev Aug 20, 2024
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
next release PRs which should be merged for the next release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants