diff --git a/src/brevitas_examples/llm/llm_quant/run_utils.py b/src/brevitas_examples/llm/llm_quant/run_utils.py index b0a2e67f2..bdc1b3a1e 100644 --- a/src/brevitas_examples/llm/llm_quant/run_utils.py +++ b/src/brevitas_examples/llm/llm_quant/run_utils.py @@ -47,7 +47,7 @@ def get_fx(model): return model -def modify_dataloader(model_name_or_path, data): +def modify_dataloader(model_name_or_path, data, dtype): config = AutoConfig.from_pretrained(model_name_or_path) normalized_config_class = NormalizedConfigManager.get_normalized_config_class(config.model_type) @@ -59,8 +59,8 @@ def modify_dataloader(model_name_or_path, data): for sample in data: sample["past_key_values"] = tuple(( - torch.zeros(1, num_heads, 0, head_dim, device=sample["input_ids"].device), - torch.zeros(1, num_heads, 0, head_dim, device=sample["input_ids"].device), + torch.zeros(1, num_heads, 0, head_dim, device=sample["input_ids"].device, dtype=dtype), + torch.zeros(1, num_heads, 0, head_dim, device=sample["input_ids"].device, dtype=dtype), ) for _ in range(num_layers)) return data diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index e09a993b1..dd8272d60 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -295,12 +295,15 @@ def main(): if args.weight_equalization or args.act_equalization == 'fx': model = get_fx(model) - calibration_loader = modify_dataloader(args.model, calibration_loader) - val_data = modify_dataloader(args.model, val_data) + calibration_loader = modify_dataloader(args.model, calibration_loader, dtype=dtype) + val_data = modify_dataloader(args.model, val_data, dtype=dtype) if args.weight_equalization: print("Apply weight equalization...") + # In case of float16 model, we need to offload to account for missing ops + model = offload_model(model) apply_weight_equalization(model) + remove_hooks(model) print("Weight equalization applied.") if args.act_equalization is not None: