-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update Caching logic to only trigger on the first inference sample #1369
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1369
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7047d79 with merge base 93f713f (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
max_batch_size=1, | ||
max_seq_length=max_seq_length, | ||
) | ||
if not skip_cache_setup: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only change in this block: rest is whitespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any way to directly telling the cache status from model, instead of forwarding a new attribute from outside?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not off the top of my head, but definitely worth baking into our model abstraction in the future
@@ -591,6 +591,7 @@ def generate( | |||
Dict[str, Any] | |||
] = None, # List of Image prompt tensors for multimodal models | |||
start_pos: int = 0, | |||
skip_cache_setup: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm ok for now, but introducing new inputs into generate function might trigger my nightmare 😣, making it farther away from our target.
I would like to delegate it to model side to suppress the warning mgs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, it's not great. Luckily it's light so we can abstract it easily later on
When the model cache is already set up, there is no need to call
setup_caches
each time a sample is passed in.This is normally fine, but torchtune is noisy (as it should) when setup_cache is unnecessarily called.
This just adds a check for first sample
Warnings that are now missing
Generation after fix (no warning)