Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
metascroy committed May 9, 2024
1 parent cf8f910 commit 4a7a40c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 33 deletions.
21 changes: 9 additions & 12 deletions _custom_linear/_custom_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,24 @@
import torch.nn as nn
from typing import Optional
torch.ops.load_library("_custom_linear/build/libcustom_linear.dylib")
from .quantize import group_quantize_tensor_symmetric, convert_to_qc4w

class _CustomLinear(nn.Module):
def _prepare(self) -> None:
self.weight.requires_grad = False
if self.bias:
self.bias.requires_grad = False

# self.packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(self.weight, self.bias)
self.packed_weight_bias = torch.ops.torchchat.prepack.default(self.weight, self.bias, None, None)

def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> None:
super().__init__()
self.weight = weight
self.bias = bias
self._prepare()
assert bias is None

self.group_size = 8
w_int, s, z = group_quantize_tensor_symmetric(self.weight, self.group_size, torch.float32)
w_packed = convert_to_qc4w(w_int)
self.prepacked = torch.ops.torchchat.prepack.default(w_packed, s)

def forward(self, x):
if x.dtype != torch.float32:
raise RuntimeError(f"x has dtype {x.dtype}, expected float32")
# return torch.ops.prepacked.linear_clamp_run(x, self.packed_weight_bias)
return torch.ops.torchchat.run.default(x, self.packed_weight_bias)
assert x.shape[0] == 1
return torch.ops.torchchat.run.default(self.prepacked, x.squeeze(0)).unsqueeze(0)

def _replace_linear_with_custom_linear(module: nn.Module):
for name, child in module.named_children():
Expand Down
21 changes: 0 additions & 21 deletions _custom_linear/custom_linear.h

This file was deleted.

0 comments on commit 4a7a40c

Please sign in to comment.