diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 973e0ba9a..21a535601 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -163,8 +163,6 @@ def main( temperature: float = 0.8, checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), quantization: Optional[str] = None, - calibration_limit: int = 10, - calibration_seq_length: int = 256, kv_cache_quantization: bool = False, cache_size: Optional[int] = None, linear_causal_mask: bool=False, @@ -254,16 +252,16 @@ def main( quant_dtype = getattr(torch, quant_dtype, torch.uint8) model=model.to(device) # get calibration data - insert_awq_observer_(model, calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size) + insert_awq_observer_(model, 1, 256, quant_dtype=quant_dtype, group_size=group_size) TransformerEvalWrapper( model=model.to(device), tokenizer=tokenizer, - max_seq_length=calibration_seq_length, + max_seq_length=256, input_prep_func=prepare_inputs_for_model, device=device, ).run_eval( tasks=['wikitext'], - limit=calibration_limit, + limit=1, ) is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) use_hqq = "hqq" in quantization @@ -477,8 +475,6 @@ def callback(x): +'embed-int8wo' ) ) - parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples") - parser.add_argument("--calibration_seq_length", type=int, default=256, help="Sequence length for calibration") parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') @@ -494,5 +490,5 @@ def callback(x): args = parser.parse_args() main( args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.calibration_limit, args.calibration_seq_length, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result + args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result )