Skip to content

Commit

Permalink
quant: add tensor parallel support for bitsandbytes (#1052)
Browse files Browse the repository at this point in the history
* quant: add TP support for bitsandbytes

* fix inplace_all_reduce op registrations
  • Loading branch information
AlpinDale authored Dec 27, 2024
1 parent a985143 commit b3f9ab3
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 14 deletions.
6 changes: 0 additions & 6 deletions aphrodite/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,12 +507,6 @@ def verify_with_parallel_config(
"Pipeline parallelism is only supported for the following "
f" architectures: {_PP_SUPPORTED_MODELS}.")

if self.quantization == "bitsandbytes" and (
parallel_config.tensor_parallel_size > 1
or parallel_config.pipeline_parallel_size > 1):
raise ValueError(
"BitsAndBytes quantization with TP/PP is not supported yet.")

if self.quantization == "bitsandbytes" and self.enforce_eager is False:
logger.warning("CUDA graph is not supported on BitAndBytes yet, "
"fallback to the eager mode.")
Expand Down
5 changes: 3 additions & 2 deletions aphrodite/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def _register_group(group: "GroupCoordinator") -> None:
# looks like Python 3.8 does not understand `ReferenceType`
_groups[group.unique_name] = weakref.ref(group) # type: ignore

@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"])
@torch.library.custom_op("aphrodite::inplace_all_reduce",
mutates_args=["tensor"])
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
Expand All @@ -101,7 +102,7 @@ def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
@inplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> None:
return
@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
@torch.library.custom_op("aphrodite::outplace_all_reduce", mutates_args=[])
def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
Expand Down
19 changes: 14 additions & 5 deletions aphrodite/modeling/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,11 @@ def weight_loader(self,
loaded_weight.shape[output_dim], tp_rank, tp_size)
else:
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -957,8 +960,11 @@ def weight_loader(self,
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -1066,6 +1072,7 @@ def __init__(self,
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
Expand All @@ -1080,7 +1087,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)

param_data = param.data
if input_dim is not None:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if input_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[input_dim]
if self.quant_config is None:
start_idx = get_current_tp_rank_partition_offset(
Expand Down
37 changes: 36 additions & 1 deletion aphrodite/modeling/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, SchedulerConfig)
from aphrodite.common.utils import is_pin_memory_available, tensor_progress_bar
from aphrodite.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from aphrodite.modeling.model_loader.tensorizer import (
TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
serialize_aphrodite_model, tensorizer_weights_iterator)
Expand Down Expand Up @@ -665,6 +667,8 @@ def save_model(
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""

# TODO: these module names are for Llama only,
# change so that it works with other models as well
default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj"
Expand Down Expand Up @@ -881,13 +885,38 @@ def _parse_quant_state(param_name: str,
def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
from bitsandbytes.functional import quantize_4bit
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name
for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight")
# weight partitions of different modules occur at
# different dimensions
# TODO: these module names are for Llama only,
# change so that it works with other models as well
if 'down_proj' in weight_name or 'o_proj' in weight_name:
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[...,
start_index:end_index]
else:
total_size = weight_tensor.size(0)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[start_index:end_index,
...]
# bitsandbytes requires data in GPU
loaded_weight = weight_tensor.cuda().data
if weight_sub_tensor.is_cuda:
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.cuda()
# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
if loaded_weight.is_contiguous() is False:
loaded_weight = loaded_weight.contiguous()
with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
Expand Down Expand Up @@ -924,6 +953,12 @@ def _load_weights(self, model_config: ModelConfig,
raise ValueError(
f"BitsAndBytes loader does not support {quant_method} "
"quantization")
# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequant BitsAndBytes models with TP is not supported."
"Please try with PP.")
load_8bit = False
if pre_quant:
load_8bit = quant_config.get('load_in_8bit', False)
Expand Down

0 comments on commit b3f9ab3

Please sign in to comment.