diff --git a/python/turbine_models/custom_models/llm_app.py b/python/turbine_models/custom_models/llm_app.py index bfbd090cc..33ba2e1c9 100644 --- a/python/turbine_models/custom_models/llm_app.py +++ b/python/turbine_models/custom_models/llm_app.py @@ -92,7 +92,7 @@ def generate(self, input_ids): # Because we have stored the res in KV-cache. token_len = input_ids.shape[-1] if self.init_cache: - input_ids = input_ids[:, self.prev_token_len:] + input_ids = input_ids[:, self.prev_token_len :] inputs = [ireert.asdevicearray(self.runner.config.device, input_ids)] if self.first_input or not self.init_cache: s = time.time()