-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Quantization] Add quantization support for bitsandbytes
#9213
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this ! I see that you used a lot of things from transformers. Do you think it is possible to import these (or inherit) from transformers ? This will help reducing the maintenance. I'm fine also doing that since there are not too many follow-up PR after a quantizer has been added. About the HfQuantizer
class, there are a lot of methods that were created to fit transformers structure. I'm not sure we will need eveyone of these methods in diffusers. Ofc, we can still do a follow-up PR to clean up.
@SunMarc I am guilty as charged but we don’t have transformers as a hard dependency for loading models in Diffusers. Pinging @DN6 to seek his opinion. Update: Chatted with @DN6 as well. We think it's better to redefine inside |
@SunMarc I think this PR is ready for another review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this @sayakpaul !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it makes sense to have this as a separate PR to add a base class because it's hard to understand what methods are needed - we should only introduce a minimum base class and gradually add functionalities as needed
can we have a PR with a minimum example working?
Okay, so, do you want me to add everything needed for bitsandbytes integration in this PR? But do note that this won’t be very different from what we have in transformers. |
@sayakpaul
|
sometimes we can make a feature branch where a bunch of PRs can be merged into before one big honkin' PR is pushed to main at the end. and the pieces are all individually reviewed and can be tested. is this a viable approach for including quantisation? |
Okay I will update this branch. @yiyixuxu |
cc @MekkCyber for visibility |
Just a few considerations for the quantization design. I would say the initial design should start loading/inference at just the model level and then proceed to add functionality (pipeline level loading etc). The feature needs to perform the following functions
At the moment, the most common ask seems to be the ability to load models into GPU using the FP8 dtype and run inference in a supported dtype by dynamically upcasting the necessary layers. NF4 is another format that's gaining attention. So perhaps we should focus on this first. This mostly applies to the DiT models but large models like CogVideo might also benefit with this approach. Some example quantized versions of models that have been doing the rounds
To cover these initial cases, we can rely on Quanto (FP8) and BitsandBytes (NF4). Example API: from diffusers import FluxPipeline, FluxTransformer2DModel, DiffusersQuantoConfig
# Load model in FP8 with Quanto and perform compute in configured dtype.
quantization_config = DiffusersQuantoConfig(weights="float8", compute_dtype=torch.bfloat16)
FluxTransformer2DModel.from_pretrained("<either diffusers format or quanto format weights>", quantization_config=quantization_config)
pipe = FluxPipeline.from_pretrained("...", transformer=transformer) The quantization config should probably take the following arguments
I think initially we can rely on the dynamic upcasting operations performed by Quanto and BnB under the hood to start and then expand on them if needed. Some other considerations
|
This PR will be at the model-level itself. And we should not add multiple backends in a single PR. This PR aims to add Concretely, I would like to stick to the outline of the changes laid out in #9174 (along with anything related) for this PR.
I won't advocate doing all of that in a single PR because it makes things very hard to review. We would rather want to move faster with something more minimal, confirming their effectiveness.
Well, note that if the underlying LoRA wasn't trained with the base quantization precision, it might not perform as expected.
Please note that |
Sounds good to me. For this PR lets do
|
Very insightful comments, @yiyixuxu! I think I have resolved them all. LMK. |
} | ||
|
||
|
||
class DiffusersAutoQuantizationConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this is similar to transformers, but I think the DiffusersAutoQuantConfig class is probably not needed.
This is just a simple mapping to a specific quantization config object. The from_pretrained
method in the AutoQuantizer is just wrapping the AutoConfig from_pretrained
.
I think we can just move these methods/logic directly into the AutoQuantizer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is not a must-have, could do this in a follow-up PR.
Hi folks! Thanks for working on this. I was able to run the following script on this branch and generate images on my 8 gigs VRAM laptop from diffusers import FluxPipeline, FluxTransformer2DModel
from transformers import T5EncoderModel
import torch
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
flush()
ckpt_id = "black-forest-labs/FLUX.1-dev"
ckpt_4bit_id = "sayakpaul/flux.1-dev-nf4-pkg"
prompt = "a billboard on highway with 'FLUX under 8' written on it"
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
ckpt_4bit_id,
subfolder="text_encoder_2",
)
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder_2=text_encoder_2_4bit,
transformer=None,
vae=None,
torch_dtype=torch.float16,
)
pipeline.enable_model_cpu_offload()
with torch.no_grad():
print("Encoding prompts.")
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=256
)
pipeline = pipeline.to("cpu")
del pipeline
flush()
transformer_4bit = FluxTransformer2DModel.from_pretrained(ckpt_4bit_id, subfolder="transformer")
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
transformer=transformer_4bit,
torch_dtype=torch.float16,
)
pipeline.enable_model_cpu_offload()
print("Running denoising.")
height, width = 512, 768
images = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=50,
guidance_scale=5.5,
height=height,
width=width,
output_type="pil",
).images
images[0].save("output.png") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's merge this!
I asked @DN6 to open a follow-up PR for this #9213 (comment),
PR merge contingent on #9720. |
|
||
|
||
@dataclass | ||
class BitsAndBytesConfig(QuantizationConfigMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something to consider. Let's assume you want to use a quantized transformer model in your code. With this naming, you would always need to set up imports in the following way.
from transformers import BitsAndBytesConfig
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
Not a huge issue. Just giving a heads up incase you want to consider renaming the config to something like DiffusersBitsAndBytesConfig
set_module_kwargs["dtype"] = dtype | ||
|
||
# bnb params are flattened. | ||
if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this situation, aren't we skipping parameter shape checks for bnb loaded weights entirely? What happens when one attempts to load bnb weights but the flattened shape is incorrect?
Perhaps we add a check_quantized_param_shape
method to the DiffusersQuantizer base class. And in the BnBQuantizer we can check if the shape matches the rule here:
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/18e827d666fa2b70a12d539ccedc17aa51b2c97c/bitsandbytes/functional.py#L816
if not is_quantized or ( | ||
not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device) | ||
): | ||
if accepts_dtype: | ||
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) | ||
else: | ||
set_module_tensor_to_device(model, param_name, device, value=param) | ||
else: | ||
set_module_tensor_to_device(model, param_name, device, value=param) | ||
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small nit. IMO this is a bit more readable
if is_quantized or hf_quantizer.check_quantized_param(
model, param, param_name, state_dict, param_device=device
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
"""adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization""" | ||
return max_memory | ||
|
||
def check_quantized_param( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO check_is_quantized_param
or check_if_quantized_param
more explicitly conveys what this method does.
|
||
|
||
class BnB4BitBasicTests(Base4bitTests): | ||
def setUp(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would clear cache on setup as well.
It would be useful to rename
|
Yeah I think the documentation should reflect this. I guess this is safe to do @SunMarc? |
Yeah we should do that, would you like to update this @Ednaordinary ? We should also do it in transformers when it gets merged. |
Sure, @SunMarc. I'll make a PR when I'm able. Should I refactor the parameter name and include a deprecation notice, or just include a note in the docs? |
What does this PR do?
Come back later.
bitsandbytes
)bitsandbytes
)bitsandbytes
from_pretrained()
at theModelMixin
level and related changessave_pretrained()
Notes
QuantizationLoaderMixin
in [Quantization] bring quantization to diffusers core #9174, I realized that is not an approach we can take because loading and saving a quantized model is very much baked into the arguments ofModelMixin.save_pretrained()
andModelMixin.from_pretrained()
. It is deeply entangled.device_map
, because for a pipeline, multiple device_maps can get ugly. This will be dealt with in a follow-up PR by @SunMarc and myself.No-frills code snippets
Serialization
Serialized checkpoint: https://huggingface.co/sayakpaul/flux.1-dev-nf4-with-bnb-integration.
NF4 checkpoints of Flux transformer and T5: https://huggingface.co/sayakpaul/flux.1-dev-nf4-pkg (has Colab Notebooks, too).
Inference