-
Notifications
You must be signed in to change notification settings - Fork 840
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speculative sampling #1410
Speculative sampling #1410
Conversation
fix update past key values for target model
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@haim-barad please check the CI failures as it seems some of the are quite easy to fix |
@@ -0,0 +1,492 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #31. res = target_model(x_draft_target_input, attention_mask=torch.ones(x_draft_target_input.size(), dtype=torch.long), use_cache=False)
Why use_cache=False? You should use target_past_kv for target_model inference based on x_draft_target_input
:
if target_past_kv is None:x_draft_target_input = torch.cat((x, x_draft), dim=1)
else:
x_draft_target_input = x_draft
Reply via ReviewNB
@@ -0,0 +1,492 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -0,0 +1,492 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #46. if np.random.random() < min(1, (q_item / p_item)): # accepted
q_item
is logits, p_item
is probability after softmax. You should apply softmax to q
first.
Reply via ReviewNB
@@ -0,0 +1,492 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #58. target_past_kv = target_new_past_kv
You should update target_past_kv
based on the number of accepted tokens.
Reply via ReviewNB
Code_check failure is another notebook. Spelling errors are ignoring the tag (as noted in your instructions). Treon (ubuntu for 3.8 and 3.9) fails due to timeout for GPU... but I clearly have CPU as a default device. |
@haim-barad, for spell check you need to add unknown words reported by tools into vocabulary. You can find more info about that contributing guide |
I had added the required entries and now I see the latest push created some conflicts, I apologize for this, but can someone with permissions do a merge? |
code check is failing for 263-latent-consistency-models-image-generation, not this notebook. |
Both code_check and docker_treon are failing on other notebooks, not because of this PR. Please approve. Thanks. |
There's still a small link to fix. I see this was renumbered to 266. There's a "open in Collab" button in the readme - the link needs to be 266 instead of 265. Can one of you fix it to avoid the review process? |
Works with DollyV2 for machines with at least 64GB local memory. By default, we use GPT2 for smaller machines.
Leveraging speculative sampling with KV caching, this notebook's code will generate text using standard autoregressive sampling and speculative sampling and compare the times needed to generate N tokens. By default, N is set to 100. Another parameter K is set to 5 but can be adjusted to determine the number of candidate tokens to generate from the smaller draft model.
This code uses gradio to provide an interface labeling the models used and handling the input.