Skip to content

Commit

Permalink
Fix (examples/llm): offload for weight equalization
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 12, 2024
1 parent 007c0dc commit 6ce11e2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/brevitas_examples/llm/llm_quant/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_fx(model):
return model


def modify_dataloader(model_name_or_path, data):
def modify_dataloader(model_name_or_path, data, dtype):
config = AutoConfig.from_pretrained(model_name_or_path)

normalized_config_class = NormalizedConfigManager.get_normalized_config_class(config.model_type)
Expand All @@ -59,8 +59,8 @@ def modify_dataloader(model_name_or_path, data):

for sample in data:
sample["past_key_values"] = tuple((
torch.zeros(1, num_heads, 0, head_dim, device=sample["input_ids"].device),
torch.zeros(1, num_heads, 0, head_dim, device=sample["input_ids"].device),
torch.zeros(1, num_heads, 0, head_dim, device=sample["input_ids"].device, dtype=dtype),
torch.zeros(1, num_heads, 0, head_dim, device=sample["input_ids"].device, dtype=dtype),
) for _ in range(num_layers))
return data

Expand Down
7 changes: 5 additions & 2 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,15 @@ def main():

if args.weight_equalization or args.act_equalization == 'fx':
model = get_fx(model)
calibration_loader = modify_dataloader(args.model, calibration_loader)
val_data = modify_dataloader(args.model, val_data)
calibration_loader = modify_dataloader(args.model, calibration_loader, dtype=dtype)
val_data = modify_dataloader(args.model, val_data, dtype=dtype)

if args.weight_equalization:
print("Apply weight equalization...")
# In case of float16 model, we need to offload to account for missing ops
model = offload_model(model)
apply_weight_equalization(model)
remove_hooks(model)
print("Weight equalization applied.")

if args.act_equalization is not None:
Expand Down

0 comments on commit 6ce11e2

Please sign in to comment.