Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NotImplementedError: aten.linear.default not implemented when using MXTensor #796

Open
Ali-Flt opened this issue Sep 3, 2024 · 7 comments · May be fixed by #806
Open

NotImplementedError: aten.linear.default not implemented when using MXTensor #796

Ali-Flt opened this issue Sep 3, 2024 · 7 comments · May be fixed by #806
Labels

Comments

@Ali-Flt
Copy link

Ali-Flt commented Sep 3, 2024

Hey I'm using the MX datatypes. It seems like the aten.linear.default function has not been implemented which causes the linear layers in the attenion layers not work with the MX datatypes.

Can you please implement this function in mx_ops.py?
Thanks!

@Ali-Flt
Copy link
Author

Ali-Flt commented Sep 3, 2024

does this look like a correct implementation?

@implements([aten.linear.default])
def mx_mm(aten_op, args, kwargs=None):
    a = args[0]
    b = args[1]
    if len(args) > 2:
        c = args[2]
    else:
        c = None
    assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
    a_hp = a.to_dtype(a._orig_dtype)
    b_hp = b.to_dtype(b._orig_dtype)
    res = aten_op(a_hp, b_hp, c)
    return res

@msaroufim msaroufim added the mx label Sep 3, 2024
@vkuzo
Copy link
Contributor

vkuzo commented Sep 3, 2024

Would you have a repro of what specifically is not working for you?

We do have overrides for aten.mm and aten.addmm here: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/mx_ops.py, they are called into from the __torch_dispatch__ extension point.

@Ali-Flt
Copy link
Author

Ali-Flt commented Sep 4, 2024

@vkuzo I did see the aten.mm and aten.addmm implementations. But for some reason in my case when F.linear() is called in MXLinear, aten.linear.default is used instead of atten.addmm.

I don't know what decides which of aten.addmm or aten.linear getting called but nonetheless, having a atten.linear.default implementation should be an easy fix right?

@vkuzo
Copy link
Contributor

vkuzo commented Sep 4, 2024

I don't know what decides which of aten.addmm or aten.linear getting called

Agreed. Could you share a repro so we can dig into why you aren't hitting the mm/addmm functions? Adding a linear implementation sounds potentially reasonable, I just wanted to understand in more detail what exactly you are doing to hit this condition.

@Ali-Flt
Copy link
Author

Ali-Flt commented Sep 4, 2024

@vkuzo I was using MXLinear in quantizing the inference of Llama3.1-8B. Maybe the reason could be that I was calling F.linear in a torch.no_grad() context?

I finally decided to avoid calling any operation in MX format after all, so I don't have the code I encountered the error with anymore.
I do the quantization this way now (Please note that I don't care about the memory optimization when quantizing. I just want to incorporate the quantization errors. Hence, I bring weights and activations back to their original dtypes):

...
#Quantizing weights:
orig_dtype = linear_layer.weight.data.dtype
weight_float = linear_layer.weight.data.float()
weight_q = MXTensor.to_mx(weight_float, self.quant_dtype, self.group_size)
linear_layer.weight.data = weight_q.to_dtype(orig_dtype)
...
    def forward(self, x):
        #Quantizing activations
        x_float = x.float() #MXTensor only accepts float32 and bfloat16
        x_q = MXTensor.to_mx(x_float, self.elem_dtype, self.block_size)
        x_q = x_q.to_dtype(self.weight.dtype)
        y = F.linear(x_q, self.weight, self.bias)
        y = NoopFwToMXBw.apply(y, self.elem_dtype, self.block_size)
        return y

@vkuzo
Copy link
Contributor

vkuzo commented Sep 4, 2024

I see, thanks for that context. Adding an override for linear sgtm, let me know if you are interested in putting up a PR, otherwise we can take care of it. Thanks for the report!

@Ali-Flt Ali-Flt linked a pull request Sep 4, 2024 that will close this issue
@Ali-Flt
Copy link
Author

Ali-Flt commented Sep 4, 2024

@vkuzo Created the PR 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants