From 31725aa2376009a2ce4e056ef217172ebba6dde8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 3 Sep 2024 09:57:08 +0530 Subject: [PATCH] harmonize changes with https://github.com/huggingface/transformers/pull/33122 --- src/diffusers/models/modeling_utils.py | 54 ++++++--- src/diffusers/pipelines/pipeline_utils.py | 21 ++-- .../quantizers/bitsandbytes/bnb_quantizer.py | 20 +++- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 14 +++ src/diffusers/utils/testing_utils.py | 26 ++++ tests/quantization/bnb/test_4bit.py | 112 +++++++++++++----- 7 files changed, 193 insertions(+), 55 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 23546be09c00..1cb6baae7256 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -47,6 +47,7 @@ deprecate, is_accelerate_available, is_bitsandbytes_available, + is_bitsandbytes_version, is_torch_version, logging, ) @@ -976,27 +977,52 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return model - # Taken from `transformers`. + # Adapted from `transformers`. @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): - # Checks if the model has been loaded in 8-bit + # Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: - raise ValueError( - "Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the" - " model has already been set to the correct devices and cast to the correct `dtype`." - ) - else: - return super().cuda(*args, **kwargs) + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "Calling `cuda()` is not supported for `8-bit` quantized models. " + " Please use the model as it is, since the model has already been set to the correct devices." + ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + return super().cuda(*args, **kwargs) - # Taken from `transformers`. + # Adapted from `transformers`. @wraps(torch.nn.Module.to) def to(self, *args, **kwargs): - # Checks if the model has been loaded in 8-bit + dtype_present_in_args = "dtype" in kwargs + + if not dtype_present_in_args: + for arg in args: + if isinstance(arg, torch.dtype): + dtype_present_in_args = True + break + + # Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: - raise ValueError( - "`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the" - " model has already been set to the correct devices and cast to the correct `dtype`." - ) + if dtype_present_in_args: + raise ValueError( + "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the" + " desired `dtype` by passing the correct `torch_dtype` argument." + ) + + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) return super().to(*args, **kwargs) # Taken from `transformers`. diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index db7953feb569..8537a6a57cd2 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -56,6 +56,7 @@ is_accelerate_version, is_torch_npu_available, is_torch_version, + is_transformers_version, logging, numpy_to_pil, ) @@ -428,19 +429,23 @@ def module_is_offloaded(module): is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for module in modules: _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module) - bit_map = {is_loaded_in_4bit_bnb: "4bit", is_loaded_in_8bit_bnb: "8bit"} + precision = None + precision = "4bit" if is_loaded_in_4bit_bnb else "8bit" if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None: - precision = bit_map[True] logger.warning( f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision. In most cases, it is recommended to not change the precision." ) - if (is_loaded_in_4bit_bnb or is_loaded_in_4bit_bnb) and device is not None: - precision = bit_map[True] + if is_loaded_in_8bit_bnb and device is not None: logger.warning( f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}. In most cases, it is recommended to not change the device." ) + + # This can happen for `transformer` models. CPU placement was added in + # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly. + if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): + module.to(device=device) else: module.to(device, dtype) @@ -449,6 +454,7 @@ def module_is_offloaded(module): and str(device) in ["cpu"] and not silence_dtype_warnings and not is_offloaded + and not is_loaded_in_4bit_bnb ): logger.warning( "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" @@ -1023,16 +1029,13 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t if model is not None and isinstance(model, torch.nn.Module): _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(model) - bit_map = {is_loaded_in_4bit_bnb: "4bit", is_loaded_in_8bit_bnb: "8bit"} - if not isinstance(model, torch.nn.Module): continue # This is because the model would already be placed on a CUDA device. - if is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb: - precision = bit_map[True] + if is_loaded_in_8bit_bnb: # is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb: logger.info( - f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` {precision}." + f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit." ) continue diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 5854c0f84a21..a78e407a02e0 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -32,6 +32,7 @@ is_accelerate_available, is_accelerate_version, is_bitsandbytes_available, + is_bitsandbytes_version, is_torch_available, logging, ) @@ -72,7 +73,7 @@ def validate_environment(self, *args, **kwargs): raise ImportError( "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" ) - if not is_bitsandbytes_available(): + if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"): raise ImportError( "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" ) @@ -319,9 +320,18 @@ def is_trainable(self) -> bool: def _dequantize(self, model): from .utils import dequantize_and_replace + is_model_on_cpu = model.device.type == "cpu" + if is_model_on_cpu: + logger.info( + "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device." + ) + model.to(torch.cuda.current_device()) + model = dequantize_and_replace( model, self.modules_to_not_convert, quantization_config=self.quantization_config ) + if is_model_on_cpu: + model.to("cpu") return model @@ -348,17 +358,17 @@ def __init__(self, quantization_config, **kwargs): if self.quantization_config.llm_int8_skip_modules is not None: self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules - # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4bit->8bit + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4-bit->8-bit def validate_environment(self, *args, **kwargs): if not torch.cuda.is_available(): raise RuntimeError("No GPU found. A GPU is needed for quantization.") if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"): raise ImportError( - "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" + "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" ) - if not is_bitsandbytes_available(): + if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"): raise ImportError( - "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" + "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" ) if kwargs.get("from_flax", False): diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 8bdbb3d62767..c8f64adf3e8a 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -62,6 +62,7 @@ is_accelerate_available, is_accelerate_version, is_bitsandbytes_available, + is_bitsandbytes_version, is_bs4_available, is_flax_available, is_ftfy_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 34cc5fcc8605..8b81b19b8a52 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -740,6 +740,20 @@ def is_peft_version(operation: str, version: str): return compare_versions(parse(_peft_version), operation, version) +def is_bitsandbytes_version(operation: str, version: str): + """ + Args: + Compares the current bitsandbytes version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _bitsandbytes_version: + return False + return compare_versions(parse(_bitsandbytes_version), operation, version) + + def is_k_diffusion_version(operation: str, version: str): """ Args: diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 76f1ba055f4d..1eb35a9c392e 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1,5 +1,6 @@ import functools import importlib +import importlib.metadata import inspect import io import logging @@ -404,6 +405,31 @@ def decorator(test_case): return decorator +def require_bitsandbytes_version_greater(bnb_version): + def decorator(test_case): + correct_bnb_version = is_bitsandbytes_available() and version.parse( + version.parse(importlib.metadata.version("bitsandbytes")).base_version + ) > version.parse(bnb_version) + return unittest.skipUnless( + correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}." + )(test_case) + + return decorator + + +def require_transformers_version_greater(transformers_version): + def decorator(test_case): + correct_transformers_version = is_transformers_available() and version.parse( + version.parse(importlib.metadata.version("transformers")).base_version + ) > version.parse(transformers_version) + return unittest.skipUnless( + correct_transformers_version, + f"test requires transformers backend with the version greater than {transformers_version}", + )(test_case) + + return decorator + + def deprecate_after_peft_backend(test_case): """ Decorator marking a test that will be skipped after PEFT backend diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 83ff34c9db78..6a6e374ffebe 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -18,17 +18,17 @@ import numpy as np -from diffusers import BitsAndBytesConfig, DiffusionPipeline, SD3Transformer2DModel -from diffusers.utils import logging +from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel from diffusers.utils.testing_utils import ( - CaptureLogger, is_bitsandbytes_available, is_torch_available, + is_transformers_available, load_pt, require_accelerate, - require_bitsandbytes, + require_bitsandbytes_version_greater, require_torch, require_torch_gpu, + require_transformers_version_greater, slow, torch_device, ) @@ -41,6 +41,9 @@ def get_some_linear_layer(model): return NotImplementedError("Don't know what layer to retrieve here.") +if is_transformers_available(): + from transformers import T5EncoderModel + if is_torch_available(): import torch @@ -49,7 +52,7 @@ def get_some_linear_layer(model): import bitsandbytes as bnb -@require_bitsandbytes +@require_bitsandbytes_version_greater("0.43.2") @require_accelerate @require_torch @require_torch_gpu @@ -167,33 +170,46 @@ def test_linear_are_4bit(self): # 4-bit parameters are packed in uint8 variables self.assertTrue(module.weight.dtype == torch.uint8) + def test_device_assignment(self): + mem_before = self.model_4bit.get_memory_footprint() + + # Move to CPU + self.model_4bit.to("cpu") + self.assertEqual(self.model_4bit.device.type, "cpu") + self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) + + # Move back to CUDA device + for device in [0, "cuda", "cuda:0", "call()"]: + if device == "call()": + self.model_4bit.cuda(0) + else: + self.model_4bit.to(device) + self.assertEqual(self.model_4bit.device, torch.device(0)) + self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) + self.model_4bit.to("cpu") + def test_device_and_dtype_assignment(self): r""" - Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. + Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error. Checks also if other models are casted correctly. """ with self.assertRaises(ValueError): - # Tries with `str` - self.model_4bit.to("cpu") - - with self.assertRaises(ValueError): - # Tries with a `dtype`` + # Tries with a `dtype` self.model_4bit.to(torch.float16) with self.assertRaises(ValueError): - # Tries with a `device` - self.model_4bit.to(torch.device("cuda:0")) + # Tries with a `device` and `dtype` + self.model_4bit.to(device="cuda:0", dtype=torch.float16) with self.assertRaises(ValueError): - # Tries with a `device` + # Tries with a cast self.model_4bit.float() with self.assertRaises(ValueError): - # Tries with a `device` + # Tries with a cast self.model_4bit.half() # Test if we did not break anything - self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) input_dict_for_transformer = self.get_dummy_inputs() model_inputs = { @@ -214,6 +230,9 @@ def test_device_and_dtype_assignment(self): # Check this does not throw an error _ = self.model_fp16.float() + # Check that this does not throw an error + _ = self.model_fp16.cuda() + def test_bnb_4bit_wrong_config(self): r""" Test whether creating a bnb config with unsupported values leads to errors. @@ -221,18 +240,8 @@ def test_bnb_4bit_wrong_config(self): with self.assertRaises(ValueError): _ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add") - def test_model_cpu_offload_raises_warning(self): - pipeline_4bit = DiffusionPipeline.from_pretrained( - self.model_name, transformer=self.model_4bit, torch_dtype=torch.float16 - ) - logger = logging.get_logger("diffusers.pipelines.pipeline_utils") - logger.setLevel(30) - with CaptureLogger(logger) as cap_logger: - pipeline_4bit.enable_model_cpu_offload() - - self.assertTrue("The module 'SD3Transformer2DModel' has been loaded in `bitsandbytes` 4bit" in cap_logger.out) - +@require_transformers_version_greater("4.44.0") class SlowBnb4BitTests(Base4bitTests): def setUp(self) -> None: nf4_config = BitsAndBytesConfig( @@ -281,6 +290,55 @@ def test_generate_quality_dequantize(self): out_slice = output[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.1216, 0.1387, 0.1584, 0.1152, 0.1318, 0.1282, 0.1062, 0.1226, 0.1228]) + self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) + + # Since we offloaded the `pipeline_4bit.transformer` to CPU (result of `enable_model_cpu_offload()), check + # the following. + self.assertTrue(self.pipeline_4bit.transformer.device.type == "cpu") + # calling it again shouldn't be a problem + _ = self.pipeline_4bit( + prompt=self.prompt, + num_inference_steps=2, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + +@require_transformers_version_greater("4.44.0") +class SlowBnb4BitFluxTests(Base4bitTests): + def setUp(self) -> None: + # TODO: Copy sayakpaul/flux.1-dev-nf4-pkg to testing repo. + model_id = "sayakpaul/flux.1-dev-nf4-pkg" + t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer") + self.pipeline_4bit = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + text_encoder_2=t5_4bit, + transformer=transformer_4bit, + torch_dtype=torch.float16, + ) + self.pipeline_4bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_4bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + # keep the resolution and max tokens to a lower number for faster execution. + output = self.pipeline_4bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + height=256, + width=256, + max_sequence_length=64, + output_type="np", + ).images + + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0583, 0.0586, 0.0632, 0.0815, 0.0813, 0.0947, 0.1040, 0.1145, 0.1265]) self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4))