diff --git a/torchchat/generate.py b/torchchat/generate.py index fcae18d87..4a67195fb 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -591,6 +591,7 @@ def generate( Dict[str, Any] ] = None, # List of Image prompt tensors for multimodal models start_pos: int = 0, + skip_cache_setup: bool = False, draft_model: Model, speculate_k: Optional[int] = 8, sequential_prefill=True, @@ -614,26 +615,27 @@ def generate( max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length) # set up caches only if first inference if start_pos == 0: - model = model.to(device=device) - with torch.device(device): - if ( - self.is_torchtune_model - or self.model.config.model_type == ModelType.Flamingo - ): - # 6404 is one-gpu affordable max_seq_length for single image input - model.setup_caches( - batch_size=1, - dtype=self.dtype, - encoder_max_seq_len=6404, - decoder_max_seq_len=max_seq_length, - ) - else: - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - if is_speculative and draft_model is not model: - draft_model.setup_caches( - max_batch_size=1, - max_seq_length=max_seq_length, - ) + if not skip_cache_setup: + model = model.to(device=device) + with torch.device(device): + if ( + self.is_torchtune_model + or self.model.config.model_type == ModelType.Flamingo + ): + # 6404 is one-gpu affordable max_seq_length for single image input + model.setup_caches( + batch_size=1, + dtype=self.dtype, + encoder_max_seq_len=6404, + decoder_max_seq_len=max_seq_length, + ) + else: + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + if is_speculative and draft_model is not model: + draft_model.setup_caches( + max_batch_size=1, + max_seq_length=max_seq_length, + ) if model.config.model_type == ModelType.Flamingo: model.reset_caches() @@ -1013,6 +1015,7 @@ def chat( ) for i in range(num_samples): device_sync(device=self.builder_args.device) + is_first_sample: bool = i == 0 if generator_args.chat_mode: prompt = input("User: ") if prompt == "/bye": @@ -1038,7 +1041,7 @@ def chat( ] ) self.system_prompt = None - elif i == 0: + elif is_first_sample: encoded = self.chat_formatter.encode_dialog_prompt( [{"role": "user", "content": prompt}] ) @@ -1107,6 +1110,7 @@ def callback(x, *, done_generating=False): top_k=generator_args.top_k, sequential_prefill=generator_args.sequential_prefill, start_pos=start_pos, + skip_cache_setup=not is_first_sample, max_seq_length=max_seq_length, ) for token_tensor, metrics in generator_func: @@ -1116,7 +1120,7 @@ def callback(x, *, done_generating=False): if metrics is not None: aggregate_metrics.update(metrics) yield token_tensor, metrics - jit_compile = (i == 0) and ( + jit_compile = is_first_sample and ( generator_args.compile or generator_args.compile_prefill ) compilation_time = time.perf_counter() - t0