Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: MXLinear backward pass implementation #1501

Open
balancap opened this issue Jan 6, 2025 · 1 comment
Open

Bug: MXLinear backward pass implementation #1501

balancap opened this issue Jan 6, 2025 · 1 comment
Labels

Comments

@balancap
Copy link

balancap commented Jan 6, 2025

In the current implementation of the MXLinear layer:

def forward(self, x):
       x_mx = MXTensor.to_mx(x, self.elem_dtype, self.block_size)
       w_mx = MXTensor.to_mx(self.weight, self.elem_dtype, self.block_size)
       y = F.linear(x_mx, w_mx, self.bias)
       y = NoopFwToMXBw.apply(y, self.elem_dtype, self.block_size)
       return y

there is only a single MX quantization step of the output gradient (in NoopFwToMXBw).

However, following the MX microscaling paper, there should be 4 quantizations happening: two for output gradient (on 2 different axes), one for the activation and one for the weights (different from the forward ones).
microscaling-fwd-bwd

Why does it matter: even though not officially confirmed by hardware vendors, it is clear that MX matmuls can only be fully optimized if the quantization axis correspond to the reduction axis for both operands. Hence, running MX backward pass on next gen hardware will require the 4 quantization steps presented above. Changing of axis for the MX quantization result in a different quantization error, meaning that the current implementation is potentially not giving a full picture of what will be MX training on real hardware.

Potential fix: I believe we need a full implementation of forward+backward pass of blockwise_quantize_linear function, manually handling the backward pass quantization steps.

@vkuzo
Copy link
Contributor

vkuzo commented Jan 7, 2025

Hi @balancap , #932 is also related. This is a KP and I agree with your issue summary. I also agree that we'll have to have a torch.autograd.Function which specifies how the quantization is done for all the 3 gemms in fwd/bwd.

As soon as the first hardware vendor releases the official specs, we on the torchao core team are going to sprint on updating this code to match the final hardware support. We would welcome contributions if anyone is looking to do this update earlier!

@vkuzo vkuzo added the mx label Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants