-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add generation caching in TextEnvironment and fix bugs in TextEnvironment #2556
base: main
Are you sure you want to change the base?
Add generation caching in TextEnvironment and fix bugs in TextEnvironment #2556
Conversation
I would be very grateful for a review by: |
6a87c8d
to
3f57ee9
Compare
3f57ee9
to
ede7e81
Compare
I was unable to execute the pre-commit hook, so I manually ran the linter. |
Thanks for the PR! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Just to be sure, as I'm unfamiliar with their implementation: The trl Trainers like PPO should not try to back propagate through the generated tokens, right? |
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
The CI failing for Python 3.9 seems unrelated to this PR. |
Yes that's correct. The backprop is done on the output of a forward pass |
@qgallouedec Could you run the precommit to fix the linting issues? I haven't gotten it to work. |
I'm still working on adding some more tests and cleaning up the code a bit. |
This PR mainly affects the TextEnvironment class and adds caching in between generation calls, in order to not have to recompute all previous activations when generating the next segment. This is mainly intended for use cases where many tool calls are performed sequentially and thus the activations for the (possibly quite large) system prompt would have to be calculated at each step. For stability, caching is optional.
Bug fixes:
This issue also addresses two bugs I encountered:
I fixed the bug and also added a check at generation time to ensure, that the padded inputs also do not exceed max length.
RE testing:
I only made sure, that the tests in tests/test_environments.py were completing.
Using
make test
some tests were failing and the tests were taking a long time to run. However, the only tests, which call TextEnvironment seem to be in test_environments.py, so the rest should be unaffected as far as I know. Nevertheless, I would be grateful, if somebody else could run all the tests before merging. I suspect, that my environment may not be ideally configured. Is testing automated via a CI?