Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
yanbing-j committed Sep 18, 2024
1 parent 45e75e7 commit acdc197
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,14 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
assert inner_k_tiles in [2, 4, 8]

@torch.no_grad()
def create_quantized_state_dict(self):
def create_quantized_state_dict(self, use_cuda = True):
device="cpu"
if use_cuda:
if torch.cuda.is_available():
device="cuda"
else:
print(f"Warning: CUDA not available, running CPU")

cur_state_dict = self.mod.state_dict()
for fqn, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear):
Expand All @@ -425,7 +432,7 @@ def create_quantized_state_dict(self):
"and that groupsize and inner_k_tiles*16 evenly divide into it")
continue
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
weight.to(torch.bfloat16), self.groupsize, self.inner_k_tiles
weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles
)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')
Expand Down

0 comments on commit acdc197

Please sign in to comment.