Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Caching logic to only trigger on the first inference sample #1369

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ def generate(
Dict[str, Any]
] = None, # List of Image prompt tensors for multimodal models
start_pos: int = 0,
skip_cache_setup: bool = False,
Copy link
Contributor

@Gasoonjia Gasoonjia Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok for now, but introducing new inputs into generate function might trigger my nightmare 😣, making it farther away from our target.
I would like to delegate it to model side to suppress the warning mgs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, it's not great. Luckily it's light so we can abstract it easily later on

draft_model: Model,
speculate_k: Optional[int] = 8,
sequential_prefill=True,
Expand All @@ -614,26 +615,27 @@ def generate(
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length)
# set up caches only if first inference
if start_pos == 0:
model = model.to(device=device)
with torch.device(device):
if (
self.is_torchtune_model
or self.model.config.model_type == ModelType.Flamingo
):
# 6404 is one-gpu affordable max_seq_length for single image input
model.setup_caches(
batch_size=1,
dtype=self.dtype,
encoder_max_seq_len=6404,
decoder_max_seq_len=max_seq_length,
)
else:
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
if is_speculative and draft_model is not model:
draft_model.setup_caches(
max_batch_size=1,
max_seq_length=max_seq_length,
)
if not skip_cache_setup:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only change in this block: rest is whitespace

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to directly telling the cache status from model, instead of forwarding a new attribute from outside?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not off the top of my head, but definitely worth baking into our model abstraction in the future

model = model.to(device=device)
with torch.device(device):
if (
self.is_torchtune_model
or self.model.config.model_type == ModelType.Flamingo
):
# 6404 is one-gpu affordable max_seq_length for single image input
model.setup_caches(
batch_size=1,
dtype=self.dtype,
encoder_max_seq_len=6404,
decoder_max_seq_len=max_seq_length,
)
else:
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
if is_speculative and draft_model is not model:
draft_model.setup_caches(
max_batch_size=1,
max_seq_length=max_seq_length,
)
if model.config.model_type == ModelType.Flamingo:
model.reset_caches()

Expand Down Expand Up @@ -1013,6 +1015,7 @@ def chat(
)
for i in range(num_samples):
device_sync(device=self.builder_args.device)
is_first_sample: bool = i == 0
if generator_args.chat_mode:
prompt = input("User: ")
if prompt == "/bye":
Expand All @@ -1038,7 +1041,7 @@ def chat(
]
)
self.system_prompt = None
elif i == 0:
elif is_first_sample:
encoded = self.chat_formatter.encode_dialog_prompt(
[{"role": "user", "content": prompt}]
)
Expand Down Expand Up @@ -1107,6 +1110,7 @@ def callback(x, *, done_generating=False):
top_k=generator_args.top_k,
sequential_prefill=generator_args.sequential_prefill,
start_pos=start_pos,
skip_cache_setup=not is_first_sample,
max_seq_length=max_seq_length,
)
for token_tensor, metrics in generator_func:
Expand All @@ -1116,7 +1120,7 @@ def callback(x, *, done_generating=False):
if metrics is not None:
aggregate_metrics.update(metrics)
yield token_tensor, metrics
jit_compile = (i == 0) and (
jit_compile = is_first_sample and (
generator_args.compile or generator_args.compile_prefill
)
compilation_time = time.perf_counter() - t0
Expand Down
Loading