Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve HF device handling #359

Merged
merged 10 commits into from
Nov 13, 2023
Merged

Improve HF device handling #359

merged 10 commits into from
Nov 13, 2023

Conversation

rmitsch
Copy link
Collaborator

@rmitsch rmitsch commented Nov 9, 2023

Description

Improve HF device handling:

  • Consider device/device_map conflicts.
  • Always move inputs to model device for models using AutoModelForCausalLM.

Context/motivation: #324 (reply in thread), #324. This PR should help with both.

Corresponding documentation PR

-

Types of change

Checklist

  • I confirm that I have the right to submit this contribution under the project's MIT license.
  • I ran all tests in tests and usage_examples/tests, and all new and existing tests passed. This includes
    • all external tests (i. e. pytest ran with --external)
    • all tests requiring a GPU (i. e. pytest ran with --gpu)
  • My changes don't require a change to the documentation, or if they do, I've added all required information.

@rmitsch rmitsch added bug Something isn't working feat/model Feature: models labels Nov 9, 2023
@rmitsch rmitsch marked this pull request as draft November 9, 2023 16:08
@rmitsch rmitsch added the Test GPU Run GPU tests label Nov 10, 2023
@rmitsch rmitsch marked this pull request as ready for review November 10, 2023 09:28
@adrianeboyd
Copy link
Contributor

I do think it's a good idea to test with accelerate, at least in the future. It looks like that torch bug will be fixed in torch v2.1.1.

@rmitsch
Copy link
Collaborator Author

rmitsch commented Nov 13, 2023

I do think it's a good idea to test with accelerate, at least in the future. It looks like that torch bug will be fixed in torch v2.1.1.

I'll do that in a follow-up PR as I want this to be included in the upcoming v0.6.3 release (release today or tomorrow).

@svlandeg
Copy link
Member

There's no documentation update necessary after this?

@rmitsch
Copy link
Collaborator Author

rmitsch commented Nov 13, 2023

There's no documentation update necessary after this?

I wouldn't think so - this is part bugfix and part the ability to process torch_dtype in the config, which should have been there anyway, as we claim all HF args can be set and are passed on to the HF model. I'd say describing that in the release notes should be sufficient, but we can make this explicit in the docs as well.

@rmitsch rmitsch merged commit 7687d44 into main Nov 13, 2023
11 checks passed
@svlandeg svlandeg deleted the fix/hf-device-handling branch November 13, 2023 16:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working feat/model Feature: models Test GPU Run GPU tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants