-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NotImplementedError: aten.linear.default not implemented when using MXTensor #796
Comments
does this look like a correct implementation? @implements([aten.linear.default])
def mx_mm(aten_op, args, kwargs=None):
a = args[0]
b = args[1]
if len(args) > 2:
c = args[2]
else:
c = None
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
a_hp = a.to_dtype(a._orig_dtype)
b_hp = b.to_dtype(b._orig_dtype)
res = aten_op(a_hp, b_hp, c)
return res |
Would you have a repro of what specifically is not working for you? We do have overrides for |
@vkuzo I did see the I don't know what decides which of |
Agreed. Could you share a repro so we can dig into why you aren't hitting the mm/addmm functions? Adding a linear implementation sounds potentially reasonable, I just wanted to understand in more detail what exactly you are doing to hit this condition. |
@vkuzo I was using MXLinear in quantizing the inference of Llama3.1-8B. Maybe the reason could be that I was calling I finally decided to avoid calling any operation in MX format after all, so I don't have the code I encountered the error with anymore. ...
#Quantizing weights:
orig_dtype = linear_layer.weight.data.dtype
weight_float = linear_layer.weight.data.float()
weight_q = MXTensor.to_mx(weight_float, self.quant_dtype, self.group_size)
linear_layer.weight.data = weight_q.to_dtype(orig_dtype)
...
def forward(self, x):
#Quantizing activations
x_float = x.float() #MXTensor only accepts float32 and bfloat16
x_q = MXTensor.to_mx(x_float, self.elem_dtype, self.block_size)
x_q = x_q.to_dtype(self.weight.dtype)
y = F.linear(x_q, self.weight, self.bias)
y = NoopFwToMXBw.apply(y, self.elem_dtype, self.block_size)
return y |
I see, thanks for that context. Adding an override for |
@vkuzo Created the PR 👍 |
Hey I'm using the MX datatypes. It seems like the aten.linear.default function has not been implemented which causes the linear layers in the attenion layers not work with the MX datatypes.
Can you please implement this function in mx_ops.py?
Thanks!
The text was updated successfully, but these errors were encountered: