Skip to content

Commit

Permalink
Add int8 Woq for CPU
Browse files Browse the repository at this point in the history
update int4 weight dim

Add CPU profiling
  • Loading branch information
yanbing-j committed Oct 8, 2024
1 parent 222ec25 commit f7f8298
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 5 deletions.
14 changes: 11 additions & 3 deletions mixtral-moe/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.cpp.enable_kernel_profile = True

def device_sync(device):
if "cuda" in device:
Expand Down Expand Up @@ -132,7 +133,7 @@ def encode_tokens(tokenizer, string, bos=True, device='cuda'):
tokens = tokenizer.encode(string)
if bos:
tokens = [tokenizer.bos_id()] + tokens
return torch.tensor(tokens, dtype=torch.int, device=device)
return torch.tensor(tokens, dtype=torch.int, device=args.device)

def _load_model(checkpoint_path, device, precision, use_tp):
with torch.device('meta'):
Expand Down Expand Up @@ -248,8 +249,13 @@ def callback(x):
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
prof = contextlib.nullcontext()
else:
torch.profiler._utils._init_for_cuda_graphs()
prof = torch.profiler.profile()
if device == 'cuda':
torch.profiler._utils._init_for_cuda_graphs()
prof = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], use_cuda=True)
profile_sort = 'self_cuda_time_total'
elif device == 'cpu':
prof = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU])
profile_sort = 'self_cpu_time_total'
with prof:
y = generate(
model,
Expand All @@ -263,6 +269,8 @@ def callback(x):
if i == -1:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
continue
if hasattr(prof, "key_averages"):
print(prof.key_averages().table(sort_by=profile_sort, row_limit=-1))
if hasattr(prof, "export_chrome_trace"):
if use_tp:
prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
Expand Down
21 changes: 20 additions & 1 deletion mixtral-moe/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ def convert_for_runtime(self):
return self.mod


# TODO: This is a workaround to speedup int8 woq performance. Will remove this when
# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release.
def linear_forward_int8(x, weight_int8pack, scales, out_features):
if x.is_cuda:
return F.linear(x, weight_int8pack.to(dtype=x.dtype)) * scales

origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int8pack_mm(x, weight_int8pack, scales)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c


class WeightOnlyBit8Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
Expand All @@ -115,7 +129,12 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
# return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
# TODO: This is a workaround to speedup int8 woq performance. Will remove this when
# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release.
return linear_forward_int8(
input,
self.weight, self.scales, self.out_features)


class ConditionalFeedForwardBit8(nn.Module):
Expand Down
19 changes: 18 additions & 1 deletion quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,18 @@ def convert_for_runtime(self):
replace_linear_weight_only_int8_per_channel(self.mod)
return self.mod

# TODO: This is a workaround to speedup int8 woq performance. Will remove this when
# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release.
def linear_forward_int8(x, weight_int8pack, scales, out_features):
if x.is_cuda:
return F.linear(x, weight_int8pack.to(dtype=x.dtype)) * scales

origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int8pack_mm(x, weight_int8pack, scales)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c

class WeightOnlyInt8Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
Expand All @@ -352,7 +364,12 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
# return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
# TODO: This is a workaround to speedup int8 woq performance. Will remove this when
# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release.
return linear_forward_int8(
input,
self.weight, self.scales, self.out_features)

##### weight only int4 per channel groupwise quantized code ######

Expand Down

0 comments on commit f7f8298

Please sign in to comment.