Skip to content

Commit

Permalink
harmonize changes with huggingface/transformers#33122
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Sep 3, 2024
1 parent abc8607 commit 31725aa
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 55 deletions.
54 changes: 40 additions & 14 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
deprecate,
is_accelerate_available,
is_bitsandbytes_available,
is_bitsandbytes_version,
is_torch_version,
logging,
)
Expand Down Expand Up @@ -976,27 +977,52 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

return model

# Taken from `transformers`.
# Adapted from `transformers`.
@wraps(torch.nn.Module.cuda)
def cuda(self, *args, **kwargs):
# Checks if the model has been loaded in 8-bit
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
"Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the"
" model has already been set to the correct devices and cast to the correct `dtype`."
)
else:
return super().cuda(*args, **kwargs)
if getattr(self, "is_loaded_in_8bit", False):
raise ValueError(
"Calling `cuda()` is not supported for `8-bit` quantized models. "
" Please use the model as it is, since the model has already been set to the correct devices."
)
elif is_bitsandbytes_version("<", "0.43.2"):
raise ValueError(
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
return super().cuda(*args, **kwargs)

# Taken from `transformers`.
# Adapted from `transformers`.
@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs):
# Checks if the model has been loaded in 8-bit
dtype_present_in_args = "dtype" in kwargs

if not dtype_present_in_args:
for arg in args:
if isinstance(arg, torch.dtype):
dtype_present_in_args = True
break

# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
"`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
" model has already been set to the correct devices and cast to the correct `dtype`."
)
if dtype_present_in_args:
raise ValueError(
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
" desired `dtype` by passing the correct `torch_dtype` argument."
)

if getattr(self, "is_loaded_in_8bit", False):
raise ValueError(
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
elif is_bitsandbytes_version("<", "0.43.2"):
raise ValueError(
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
return super().to(*args, **kwargs)

# Taken from `transformers`.
Expand Down
21 changes: 12 additions & 9 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
is_accelerate_version,
is_torch_npu_available,
is_torch_version,
is_transformers_version,
logging,
numpy_to_pil,
)
Expand Down Expand Up @@ -428,19 +429,23 @@ def module_is_offloaded(module):
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
for module in modules:
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
bit_map = {is_loaded_in_4bit_bnb: "4bit", is_loaded_in_8bit_bnb: "8bit"}
precision = None
precision = "4bit" if is_loaded_in_4bit_bnb else "8bit"

if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
precision = bit_map[True]
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision. In most cases, it is recommended to not change the precision."
)

if (is_loaded_in_4bit_bnb or is_loaded_in_4bit_bnb) and device is not None:
precision = bit_map[True]
if is_loaded_in_8bit_bnb and device is not None:
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}. In most cases, it is recommended to not change the device."
)

# This can happen for `transformer` models. CPU placement was added in
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
module.to(device=device)
else:
module.to(device, dtype)

Expand All @@ -449,6 +454,7 @@ def module_is_offloaded(module):
and str(device) in ["cpu"]
and not silence_dtype_warnings
and not is_offloaded
and not is_loaded_in_4bit_bnb
):
logger.warning(
"Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
Expand Down Expand Up @@ -1023,16 +1029,13 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
if model is not None and isinstance(model, torch.nn.Module):
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(model)

bit_map = {is_loaded_in_4bit_bnb: "4bit", is_loaded_in_8bit_bnb: "8bit"}

if not isinstance(model, torch.nn.Module):
continue

# This is because the model would already be placed on a CUDA device.
if is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb:
precision = bit_map[True]
if is_loaded_in_8bit_bnb: # is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb:
logger.info(
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` {precision}."
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
)
continue

Expand Down
20 changes: 15 additions & 5 deletions src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
is_accelerate_available,
is_accelerate_version,
is_bitsandbytes_available,
is_bitsandbytes_version,
is_torch_available,
logging,
)
Expand Down Expand Up @@ -72,7 +73,7 @@ def validate_environment(self, *args, **kwargs):
raise ImportError(
"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
)
if not is_bitsandbytes_available():
if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"):
raise ImportError(
"Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
)
Expand Down Expand Up @@ -319,9 +320,18 @@ def is_trainable(self) -> bool:
def _dequantize(self, model):
from .utils import dequantize_and_replace

is_model_on_cpu = model.device.type == "cpu"
if is_model_on_cpu:
logger.info(
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
)
model.to(torch.cuda.current_device())

model = dequantize_and_replace(
model, self.modules_to_not_convert, quantization_config=self.quantization_config
)
if is_model_on_cpu:
model.to("cpu")
return model


Expand All @@ -348,17 +358,17 @@ def __init__(self, quantization_config, **kwargs):
if self.quantization_config.llm_int8_skip_modules is not None:
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules

# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4bit->8bit
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4-bit->8-bit
def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"):
raise ImportError(
"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
)
if not is_bitsandbytes_available():
if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"):
raise ImportError(
"Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
"Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
)

if kwargs.get("from_flax", False):
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
is_accelerate_available,
is_accelerate_version,
is_bitsandbytes_available,
is_bitsandbytes_version,
is_bs4_available,
is_flax_available,
is_ftfy_available,
Expand Down
14 changes: 14 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,20 @@ def is_peft_version(operation: str, version: str):
return compare_versions(parse(_peft_version), operation, version)


def is_bitsandbytes_version(operation: str, version: str):
"""
Args:
Compares the current bitsandbytes version to a given reference with an operation.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _bitsandbytes_version:
return False
return compare_versions(parse(_bitsandbytes_version), operation, version)


def is_k_diffusion_version(operation: str, version: str):
"""
Args:
Expand Down
26 changes: 26 additions & 0 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import importlib
import importlib.metadata
import inspect
import io
import logging
Expand Down Expand Up @@ -404,6 +405,31 @@ def decorator(test_case):
return decorator


def require_bitsandbytes_version_greater(bnb_version):
def decorator(test_case):
correct_bnb_version = is_bitsandbytes_available() and version.parse(
version.parse(importlib.metadata.version("bitsandbytes")).base_version
) > version.parse(bnb_version)
return unittest.skipUnless(
correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}."
)(test_case)

return decorator


def require_transformers_version_greater(transformers_version):
def decorator(test_case):
correct_transformers_version = is_transformers_available() and version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse(transformers_version)
return unittest.skipUnless(
correct_transformers_version,
f"test requires transformers backend with the version greater than {transformers_version}",
)(test_case)

return decorator


def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend
Expand Down
Loading

0 comments on commit 31725aa

Please sign in to comment.