You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
there is only a single MX quantization step of the output gradient (in NoopFwToMXBw).
However, following the MX microscaling paper, there should be 4 quantizations happening: two for output gradient (on 2 different axes), one for the activation and one for the weights (different from the forward ones).
Why does it matter: even though not officially confirmed by hardware vendors, it is clear that MX matmuls can only be fully optimized if the quantization axis correspond to the reduction axis for both operands. Hence, running MX backward pass on next gen hardware will require the 4 quantization steps presented above. Changing of axis for the MX quantization result in a different quantization error, meaning that the current implementation is potentially not giving a full picture of what will be MX training on real hardware.
Potential fix: I believe we need a full implementation of forward+backward pass of blockwise_quantize_linear function, manually handling the backward pass quantization steps.
The text was updated successfully, but these errors were encountered:
Hi @balancap , #932 is also related. This is a KP and I agree with your issue summary. I also agree that we'll have to have a torch.autograd.Function which specifies how the quantization is done for all the 3 gemms in fwd/bwd.
As soon as the first hardware vendor releases the official specs, we on the torchao core team are going to sprint on updating this code to match the final hardware support. We would welcome contributions if anyone is looking to do this update earlier!
In the current implementation of the
MXLinear
layer:there is only a single MX quantization step of the output gradient (in
NoopFwToMXBw
).However, following the MX microscaling paper, there should be 4 quantizations happening: two for output gradient (on 2 different axes), one for the activation and one for the weights (different from the forward ones).
Why does it matter: even though not officially confirmed by hardware vendors, it is clear that MX matmuls can only be fully optimized if the quantization axis correspond to the reduction axis for both operands. Hence, running MX backward pass on next gen hardware will require the 4 quantization steps presented above. Changing of axis for the MX quantization result in a different quantization error, meaning that the current implementation is potentially not giving a full picture of what will be MX training on real hardware.
Potential fix: I believe we need a full implementation of forward+backward pass of
blockwise_quantize_linear
function, manually handling the backward pass quantization steps.The text was updated successfully, but these errors were encountered: