Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Generation refactor #1425

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

mattdangerw
Copy link
Member

@mattdangerw mattdangerw commented Feb 6, 2024

  • Move the compiled while loop to the task base class.
  • Move as much common generative code to the task base classes.
  • Expose separate prefill() and decode() functions, which can be overridden from subclass.

This will preserve all high level usages (generate(), compile(sampler="top-k"), etc), and the way to subclass a Sampler. However it will break compat on the way you have to call a sampler--that's kinda the point of the pr. Should be an improvement overall, but definitely a friction there.

Will continue to flesh this out and add a colab demo.

@mattdangerw mattdangerw force-pushed the generation-refactor branch 5 times, most recently from 5dce72d to f96b5f2 Compare February 14, 2024 03:12
mattdangerw added a commit to mattdangerw/keras-hub that referenced this pull request Feb 16, 2024
We will update our samplers in the near future to push the backend
specific compilation details out: keras-team#1425

Also in general, we want our documentation to reflect the main usage of
our classes, which is using them with Seq2SeqLM and CausalLM classes.

So with that in mind, this updates our sampler docs to show the
practical usage of the sampling classes with our modeling classes. For
the base class, we show the main use case of overriding the
`get_next_token()` function.
mattdangerw added a commit to mattdangerw/keras-hub that referenced this pull request Feb 17, 2024
We will update our samplers in the near future to push the backend
specific compilation details out: keras-team#1425

Also in general, we want our documentation to reflect the main usage of
our classes, which is using them with Seq2SeqLM and CausalLM classes.

So with that in mind, this updates our sampler docs to show the
practical usage of the sampling classes with our modeling classes. For
the base class, we show the main use case of overriding the
`get_next_token()` function.
mattdangerw added a commit to mattdangerw/keras-hub that referenced this pull request Feb 20, 2024
We will update our samplers in the near future to push the backend
specific compilation details out: keras-team#1425

Also in general, we want our documentation to reflect the main usage of
our classes, which is using them with Seq2SeqLM and CausalLM classes.

So with that in mind, this updates our sampler docs to show the
practical usage of the sampling classes with our modeling classes. For
the base class, we show the main use case of overriding the
`get_next_token()` function.
mattdangerw added a commit that referenced this pull request Feb 20, 2024
We will update our samplers in the near future to push the backend
specific compilation details out: #1425

Also in general, we want our documentation to reflect the main usage of
our classes, which is using them with Seq2SeqLM and CausalLM classes.

So with that in mind, this updates our sampler docs to show the
practical usage of the sampling classes with our modeling classes. For
the base class, we show the main use case of overriding the
`get_next_token()` function.
abuelnasr0 pushed a commit to abuelnasr0/keras-nlp that referenced this pull request Apr 2, 2024
We will update our samplers in the near future to push the backend
specific compilation details out: keras-team#1425

Also in general, we want our documentation to reflect the main usage of
our classes, which is using them with Seq2SeqLM and CausalLM classes.

So with that in mind, this updates our sampler docs to show the
practical usage of the sampling classes with our modeling classes. For
the base class, we show the main use case of overriding the
`get_next_token()` function.
@@ -373,11 +543,11 @@ def postprocess(x):
inputs = inputs.prefetch(tf.data.AUTOTUNE)
else:
# Fast path for non-dataset, single-batch input.
inputs = [preprocess(x) for x in inputs]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for list inputs correct?

@mattdangerw mattdangerw changed the title Generation refactor [DRAFT] Generation refactor Aug 15, 2024
Tusespifump1o added a commit to Tusespifump1o/keras-nlp that referenced this pull request Aug 26, 2024
We will update our samplers in the near future to push the backend
specific compilation details out: keras-team/keras-hub#1425

Also in general, we want our documentation to reflect the main usage of
our classes, which is using them with Seq2SeqLM and CausalLM classes.

So with that in mind, this updates our sampler docs to show the
practical usage of the sampling classes with our modeling classes. For
the base class, we show the main use case of overriding the
`get_next_token()` function.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants