From 116c5c2f84ccb16f5f2e93e7f2ba49d2997d012c Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 12 Nov 2024 16:14:32 -0800 Subject: [PATCH 1/2] Only set up during the first sample --- torchchat/generate.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index dd423b58a..172eaebbe 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -591,6 +591,7 @@ def generate( Dict[str, Any] ] = None, # List of Image prompt tensors for multimodal models start_pos: int = 0, + skip_cache_setup: bool = False, draft_model: Model, speculate_k: Optional[int] = 8, sequential_prefill=True, @@ -613,7 +614,7 @@ def generate( prompt_length = prompt.size(0) max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length) # set up caches only if first inference - if start_pos == 0: + if start_pos == 0 and not skip_cache_setup: model = model.to(device=device) with torch.device(device): if ( @@ -1020,6 +1021,7 @@ def chat( ) for i in range(num_samples): device_sync(device=self.builder_args.device) + is_first_sample: bool = i == 0 if generator_args.chat_mode: prompt = input("User: ") if prompt == "/bye": @@ -1045,7 +1047,7 @@ def chat( ] ) self.system_prompt = None - elif i == 0: + elif is_first_sample: encoded = self.chat_formatter.encode_dialog_prompt( [{"role": "user", "content": prompt}] ) @@ -1116,6 +1118,7 @@ def callback(x, *, done_generating=False): top_k=generator_args.top_k, sequential_prefill=generator_args.sequential_prefill, start_pos=start_pos, + skip_cache_setup=not is_first_sample, max_seq_length=max_seq_length, ) for token_tensor, metrics in generator_func: @@ -1125,7 +1128,7 @@ def callback(x, *, done_generating=False): if metrics is not None: aggregate_metrics.update(metrics) yield token_tensor, metrics - jit_compile = (i == 0) and ( + jit_compile = is_first_sample and ( generator_args.compile or generator_args.compile_prefill ) compilation_time = time.perf_counter() - t0 From 0163f61402f45510a7e88278a3490bc15d5b3fc9 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 12 Nov 2024 18:54:45 -0800 Subject: [PATCH 2/2] Cleaner --- torchchat/generate.py | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 172eaebbe..d6dc54281 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -614,27 +614,28 @@ def generate( prompt_length = prompt.size(0) max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length) # set up caches only if first inference - if start_pos == 0 and not skip_cache_setup: - model = model.to(device=device) - with torch.device(device): - if ( - self.is_torchtune_model - or self.model.config.model_type == ModelType.Flamingo - ): - # 6404 is one-gpu affordable max_seq_length for single image input - model.setup_caches( - batch_size=1, - dtype=self.dtype, - encoder_max_seq_len=6404, - decoder_max_seq_len=max_seq_length, - ) - else: - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - if is_speculative and draft_model is not model: - draft_model.setup_caches( - max_batch_size=1, - max_seq_length=max_seq_length, - ) + if start_pos == 0: + if not skip_cache_setup: + model = model.to(device=device) + with torch.device(device): + if ( + self.is_torchtune_model + or self.model.config.model_type == ModelType.Flamingo + ): + # 6404 is one-gpu affordable max_seq_length for single image input + model.setup_caches( + batch_size=1, + dtype=self.dtype, + encoder_max_seq_len=6404, + decoder_max_seq_len=max_seq_length, + ) + else: + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + if is_speculative and draft_model is not model: + draft_model.setup_caches( + max_batch_size=1, + max_seq_length=max_seq_length, + ) if model.config.model_type == ModelType.Flamingo: model.reset_caches()