-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
remove to restriction for 4-bit model #33122
Merged
Merged
Changes from 1 commit
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
08f9c93
remove to restiction for 4-bit model
SunMarc bb12e88
Update src/transformers/modeling_utils.py
SunMarc d064b48
bitsandbytes: prevent dtype casting while allowing device movement wi…
matthewdouglas 22f6088
quality fix
matthewdouglas 462ac2c
Improve warning message for .to() and .cuda() on bnb quantized models
matthewdouglas File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2861,46 +2861,56 @@ def get_memory_footprint(self, return_buffers=True): | |
def cuda(self, *args, **kwargs): | ||
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: | ||
raise ValueError("`.cuda` is not supported for HQQ-quantized models.") | ||
# Checks if the model has been loaded in 8-bit | ||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB | ||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: | ||
raise ValueError( | ||
"Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the" | ||
" model has already been set to the correct devices and casted to the correct `dtype`." | ||
) | ||
if getattr(self, "is_loaded_in_8bit", False): | ||
raise ValueError( | ||
"Calling `cuda()` is not supported for `8-bit` quantized models. Please use the model as it is, since the" | ||
" model has already been set to the correct devices and casted to the correct `dtype`." | ||
) | ||
elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"): | ||
raise ValueError( | ||
"Calling `cuda()` is not supported for `4-bit` quantized models. Please use the model as it is, since the" | ||
" model has already been set to the correct devices and casted to the correct `dtype`. " | ||
"However, if you still want to move the model, you need to install bitsandbytes >= 0.43.2 " | ||
) | ||
else: | ||
return super().cuda(*args, **kwargs) | ||
|
||
@wraps(torch.nn.Module.to) | ||
def to(self, *args, **kwargs): | ||
# For BNB/GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours. | ||
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`. | ||
dtype_present_in_args = "dtype" in kwargs | ||
|
||
if not dtype_present_in_args: | ||
for arg in args: | ||
if isinstance(arg, torch.dtype): | ||
dtype_present_in_args = True | ||
break | ||
|
||
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: | ||
raise ValueError("`.to` is not supported for HQQ-quantized models.") | ||
# Checks if the model has been loaded in 8-bit | ||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB | ||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: | ||
if getattr(self, "is_loaded_in_4bit", False): | ||
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.0"): | ||
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SunMarc I've bumped this to 0.43.2 since that's when bitsandbytes-foundation/bitsandbytes#1279 was landed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, thanks for updating the PR ! |
||
raise ValueError( | ||
"`.to` is not supported for `4-bit`. Please use the model as it is, since the" | ||
" model has already been set to the correct devices and casted to the correct `dtype`. " | ||
"However, if you still want to move the model, you need to install bitsandbytes >= 0.43.0 " | ||
"However, if you still want to move the model, you need to install bitsandbytes >= 0.43.2 " | ||
) | ||
elif dtype_present_in_args: | ||
raise ValueError( | ||
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the" | ||
" desired `dtype` by passing the correct `torch_dtype` argument." | ||
) | ||
else: | ||
raise ValueError( | ||
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" | ||
" model has already been set to the correct devices and casted to the correct `dtype`." | ||
) | ||
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ: | ||
# For GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours. | ||
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`. | ||
dtype_present_in_args = False | ||
|
||
if "dtype" not in kwargs: | ||
for arg in args: | ||
if isinstance(arg, torch.dtype): | ||
dtype_present_in_args = True | ||
break | ||
else: | ||
dtype_present_in_args = True | ||
|
||
if dtype_present_in_args: | ||
raise ValueError( | ||
"You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The warning isn't super clear to me in terms of what the user should or should not do; should they install the new version or should they just let the model there? I'd try to clarify this a bit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good feedback, thanks! Updated. I think in most cases the user would be using
.cuda()
without realizing it is already on a GPU so I put the currentmodel.device
in the message. That should help inform on whether they really meant to move it somewhere else and need to upgrade.