diff --git a/deepspeed/linear/__init__.py b/deepspeed/linear/__init__.py new file mode 100644 index 000000000000..a27f1c3eaee7 --- /dev/null +++ b/deepspeed/linear/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .optimized_linear import OptimizedLinear +from .config import LoRAConfig, QuantizationConfig diff --git a/deepspeed/linear/config.py b/deepspeed/linear/config.py new file mode 100644 index 000000000000..ae9050a3c92b --- /dev/null +++ b/deepspeed/linear/config.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from dataclasses import dataclass + + +@dataclass +class LoRAConfig: + """ + Configuration settings for LoRAOptimizedLinear. + + Attributes: + lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64. + lora_alpha (float): LoRA scaling factor, default is 16. + base_weight_sharding (int): The degree to which the base weights are sharded, + should typically be set to the data-parallel world size to maximize the memory + reduction benefits. Defaults to 1, which means this feature is disabled. + """ + lora_r: int = 64 + lora_alpha: float = 16. + base_weight_sharding: int = 1 + + +@dataclass +class QuantizationConfig: + """ + Configuration settings for quantization for LoRAOptimizedLinear, QuantizedLinear, + and QuantizedParameter + + Attributes: + q_bits (int): The number of bits used for quantization. Default is 8. + mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3. + group_size (int): The size of the group used for quantization. Default is 512. + """ + q_bits: int = 8 + mantissa_bits: int = 3 + group_size: int = 512 diff --git a/deepspeed/linear/optimized_linear.py b/deepspeed/linear/optimized_linear.py new file mode 100644 index 000000000000..138bd493ffc7 --- /dev/null +++ b/deepspeed/linear/optimized_linear.py @@ -0,0 +1,150 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import is_dataclass +from deepspeed.accelerator import get_accelerator +import deepspeed.comm as dist + +from .config import LoRAConfig, QuantizationConfig +from .quantization import QuantizedParameter, QuantizedLinear + + +class OptimizedLinear(nn.Module): + """ + Optimized version of nn.Linear that adds features such as: + * LoRA w. base weight sharding + * FP [6,8,12] quantization + + Arguments: + input_dim: Required: size of each input sample + output_dim: Required: size of each output sample + bias: Optional: If set to False, the layer will not learn an additive bias. Default: False + lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree + quantization_config: Optional: QuantizationConfig defining quantization features + dtype: Optional: parameter dtype, only supports bfloat16 currently + + Returns: + Returns a new nn.Module depending on the input config. Either native + torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear. + """ + + def __new__(self, + input_dim: int, + output_dim: int, + bias: bool = False, + lora_config: LoRAConfig = None, + quantization_config: QuantizationConfig = None, + dtype=torch.bfloat16): + + if quantization_config is not None and not is_dataclass(quantization_config): + raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}") + if lora_config is not None and not is_dataclass(lora_config): + raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}") + if lora_config is None and quantization_config is None: + # Everything disabled, fall back to normal nn.Linear + self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype) + + elif lora_config: + # lora enabled, quantization may or may not be + self = LoRAOptimizedLinear(input_dim=input_dim, + output_dim=output_dim, + bias=bias, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=dtype) + + elif quantization_config: + # only quantization enabled, no lora + self = QuantizedLinear(input_dim=input_dim, + output_dim=output_dim, + bias=bias, + quantization_config=quantization_config, + dtype=dtype) + return self + + +class LoRAOptimizedLinear(nn.Module): + + def __init__(self, + input_dim: int, + output_dim: int, + bias: bool = False, + lora_config: LoRAConfig = None, + quantization_config: QuantizationConfig = None, + device=None, + dtype=torch.bfloat16): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.bias = bias + self.lora_config = lora_config + self.quantization_config = quantization_config + device = get_accelerator().current_device() if device is None else device + assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config" + + self.zero_shards = self.lora_config.base_weight_sharding + self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards) + w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype)) + torch.nn.init.xavier_uniform_(w) + + if self.quantization_config is not None: + assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization" + self.base_weight = QuantizedParameter(w, quantization_config=quantization_config) + else: + self.base_weight = w + + self.base_weight.requires_grad = False + + # Use RS lora for now. + self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r) + # Keeping lora weights in bf16 precision for ease of training. + self.lora_weight_1 = nn.Linear(self.input_dim, + self.lora_config.lora_r, + bias=self.bias, + device=device, + dtype=dtype) + self.lora_weight_2 = nn.Linear(self.lora_config.lora_r, + self.output_dim, + bias=self.bias, + device=device, + dtype=dtype) + self.lora_weight_1.weight.requires_grad = True + self.lora_weight_2.weight.requires_grad = True + + def full_weight(self): + # This assumes weights are evenly sharded across gpus. which might not be correct. + # in that case, we should flatten before all_gather. + local_weight = self.base_weight.dequantized() if isinstance(self.base_weight, + QuantizedParameter) else self.base_weight + tensor_list = [ + torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype) + for _ in range(self.zero_shards) + ] + dist.all_gather(tensor_list, local_weight) + weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1)) + return weight + + def linear_without_F_linear(self, input, weight): + output = torch.mm(input.reshape(-1, input.shape[-1]), weight) + output = output.view(*input.shape[:-1], weight.shape[1]) + return output + + def forward(self, input_tensor): + # Gather the sharded base weight + if self.zero_shards > 1: + with torch.no_grad(): + base_weight = self.full_weight() + elif self.quantization_config: + base_weight = self.base_weight.dequantized() + else: + base_weight = self.base_weight + + base_weight_output = F.linear(input_tensor, base_weight) + lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor)) + return base_weight_output + self.lora_scaling_factor * lora_output diff --git a/deepspeed/linear/quantization.py b/deepspeed/linear/quantization.py new file mode 100644 index 000000000000..f5343af45fb8 --- /dev/null +++ b/deepspeed/linear/quantization.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional + +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.fp_quantizer import Quantizer, FP_Quantize +from .config import QuantizationConfig + + +class QuantizedParameter(nn.Parameter): + """ + Quantized parameter class that implements weight quantization. Weights + are stored in quantized form on GPUs, and can be dequantized on-the-fly when + needed by the model. The weights are actually quantized during any `.to(device)`. + + Arguments: + data (Tensor): parameter tensor. + requires_grad (bool, optional): if the parameter requires gradient. Defaults + to False and is not supported to be True. Argument provided only for interface + compatibility with torch.nn.Parameter. + quantization_config (QuantizationConfig, optional): + quantizer (Quantizer, optional): Defaults to FP_Quantize but can be any quantizer + that implements deepspeed.ops.fp_quantizer.Quantizer. This argument is also + required since the quantizer is stashed in the Parameter itself, some models + may clone the Parameter by passing an attribute __dict__. For an example, see + tests/unit/linear/test_quant_param.py::TestQuantParam::test_hf_clone + """ + + def __new__( + cls, + data: Optional[torch.Tensor] = None, + requires_grad: bool = False, # quantized weights must be frozen + quantization_config: QuantizationConfig = None, + quantizer: Quantizer = None, + ): + if requires_grad: + raise ValueError(f"requires_grad=True is not supported with QuantizedParameter") + if data is None: + data = torch.empty(0) + self = torch.Tensor._make_subclass(cls, data, requires_grad) + self.quantization_config = QuantizationConfig() if quantization_config is None else quantization_config + if quantizer is not None: + self.quantizer = quantizer + else: + # if FPQuantizerBuilder is not compatible in this env this init will fail + self.quantizer = FP_Quantize(group_size=self.quantization_config.group_size) + self._ensure_quantized(self) + return self + + def _ensure_quantized(self, tensor: torch.Tensor): + # If the tensor is on the accelerator and is not quantized, then quantize it in-place. + if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.int8: + with get_accelerator().stream(get_accelerator().current_stream(tensor.device)): + tensor.data = self.quantizer.quantize(tensor.data, + q_bits=self.quantization_config.q_bits, + q_mantisa_bits=self.quantization_config.mantissa_bits) + assert tensor.dtype == torch.int8 + + def dequantized(self) -> torch.Tensor: + """ + Return a tensor containing the dequantized weights of this parameter. + """ + if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.int8: + with get_accelerator().stream(get_accelerator().current_stream(self.data.device)): + return self.quantizer.dequantize(self.data, + q_bits=self.quantization_config.q_bits, + q_mantisa_bits=self.quantization_config.mantissa_bits) + return self.data + + def __getstate__(self): + state = self.__dict__ + state["data"] = self.data + state["quantization_config"] = self.quantization_config + state["requires_grad"] = self.requires_grad + return state + + def __setstate__(self, state): + self.quantizer = state["quantizer"] + self.quantization_config = state["quantization_config"] + self.data = state["data"] + self.requires_grad = state["requires_grad"] + + def __deepcopy__(self, memo): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + new_instance.quantizer = copy.deepcopy(state["quantizer"]) + new_instance.quantization_config = copy.deepcopy(state["quantization_config"]) + new_instance.data = copy.deepcopy(state["data"]) + return new_instance + + def __copy__(self): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + return new_instance + + def cuda(self, device=None, non_blocking=False): + return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) + + def to(self, *args, **kwargs): + """ + Move the parameter to the given device. Then, if the device is a cuda device, + quantize it. + """ + tensor = super().to(*args, **kwargs) + self._ensure_quantized(tensor) + return tensor + + +class QuantizedLinear(nn.Linear): + """ + Linear layer that implements weight quantization. Parameters + are stored via `QuantizedParameter` and are dequantized on-the-fly during any + forward pass. + """ + + def __init__(self, + input_dim: int, + output_dim: int, + bias: bool = False, + quantization_config: QuantizationConfig = None, + dtype=torch.bfloat16): + super().__init__(input_dim, output_dim, bias=bias, dtype=dtype) + assert dtype == torch.bfloat16, "currently only supports bfloat16 dtype" + self.weight = QuantizedParameter(self.weight.data, quantization_config=quantization_config) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.dequantized(), self.bias) diff --git a/deepspeed/ops/fp_quantizer/__init__.py b/deepspeed/ops/fp_quantizer/__init__.py index 5575f3567185..995bbae4aeaf 100644 --- a/deepspeed/ops/fp_quantizer/__init__.py +++ b/deepspeed/ops/fp_quantizer/__init__.py @@ -3,4 +3,4 @@ # DeepSpeed Team -from .quantize import FP_Quantize +from .quantize import FP_Quantize, Quantizer diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py index 0d4bf7bc6db1..f8435bda16c1 100644 --- a/deepspeed/ops/fp_quantizer/quantize.py +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -4,20 +4,47 @@ # DeepSpeed Team import torch +import abc +from abc import ABC from deepspeed.ops.op_builder import FPQuantizerBuilder fp_quant_module = None -class FP_Quantize: +class Quantizer(ABC): + """ + Abstract Quantizer class that implmenents quantize/dequantize methods. + + Arguments: + group_size (int, optional): number of values or elements that are grouped + together for the quantization process. + """ + + def __init__(self, group_size=512) -> None: + self.group_size = group_size + + @abc.abstractmethod + def quantize(self, + input, + q_bits=8, + q_mantisa_bits=3, + stochastic_mode=False, + return_meta_tensor=False) -> torch.Tensor: + ... + + @abc.abstractmethod + def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor: + ... + + +class FP_Quantize(Quantizer): def __init__(self, group_size=512) -> None: global fp_quant_module + super().__init__(group_size=group_size) if fp_quant_module is None: fp_quant_module = FPQuantizerBuilder().load() - - self.group_size = group_size self.orig_dtype = None def quantize(self, diff --git a/tests/unit/linear/test_linear.py b/tests/unit/linear/test_linear.py new file mode 100644 index 000000000000..ccd26b4cd726 --- /dev/null +++ b/tests/unit/linear/test_linear.py @@ -0,0 +1,128 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +import deepspeed.comm as dist + +from deepspeed.accelerator import get_accelerator +from deepspeed.linear import OptimizedLinear, LoRAConfig, QuantizationConfig +from unit.common import DistributedTest + +from deepspeed.ops.op_builder import FPQuantizerBuilder + +if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: + pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) + + +class TestBasicLinear(DistributedTest): + world_size = 2 + + def test(self): + lora_config = None + quantization_config = None + + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 1 # Number of samples in a batch + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + + dummy_input = torch.rand(batch_size, input_features, dtype=torch.bfloat16) + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) + + +@pytest.mark.parametrize("base_weight_sharding", [1, 2]) +class TestLoRALinear(DistributedTest): + world_size = 2 + + def test(self, base_weight_sharding): + rank = dist.get_rank() + lora_config = None + quantization_config = None + + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 5 # Number of samples in a batch + + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=base_weight_sharding) + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + device = get_accelerator().current_device_name() + linear_layer = linear_layer.to(device) + if rank == 0: + for n, p in linear_layer.named_parameters(): + print(f"{n}, {p.shape}") + + dummy_input = torch.rand(batch_size, input_features, device=device, dtype=torch.bfloat16) + + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) + + +@pytest.mark.parametrize("q_bits", [8, 6]) +class TestQuantLinear(DistributedTest): + world_size = 2 + + def test(self, q_bits): + rank = dist.get_rank() + lora_config = None + + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 5 # Number of samples in a batch + + lora_config = None + quantization_config = QuantizationConfig(q_bits=q_bits) + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + device = get_accelerator().current_device_name() + linear_layer = linear_layer.to(device) + dummy_input = torch.rand([batch_size, input_features], device=device, dtype=torch.bfloat16) + + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) + + +@pytest.mark.parametrize("base_weight_sharding", [1, 2], ids=['bws1', 'bws2']) +@pytest.mark.parametrize("q_bits", [8, 6], ids=['qbit8', 'qbit6']) +class TestOptimizedLinear(DistributedTest): + world_size = 2 + + def test(self, base_weight_sharding, q_bits): + rank = dist.get_rank() + lora_config = None + + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 5 # Number of samples in a batch + + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=base_weight_sharding) + quantization_config = QuantizationConfig(q_bits=q_bits) + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + device = get_accelerator().current_device_name() + linear_layer = linear_layer.to(device) + dummy_input = torch.rand([batch_size, input_features], device=device, dtype=torch.bfloat16) + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) diff --git a/tests/unit/linear/test_quant_param.py b/tests/unit/linear/test_quant_param.py new file mode 100644 index 000000000000..9479b3cba8a0 --- /dev/null +++ b/tests/unit/linear/test_quant_param.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed + +from deepspeed.accelerator import get_accelerator +from deepspeed.linear.quantization import QuantizedParameter +from deepspeed.linear.config import QuantizationConfig + +from deepspeed.ops.op_builder import FPQuantizerBuilder + +from unit.common import DistributedTest + +if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: + pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) + + +class TestQuantParam(DistributedTest): + world_size = 1 + + @pytest.mark.parametrize('dtype', [torch.half, torch.float]) + def test_unsupported_dtypes(self, dtype): + device = get_accelerator().current_device_name() + data = torch.rand(5, 5, device='cpu', dtype=dtype) + qp = QuantizedParameter(data) + with pytest.raises(AssertionError): + qp.to(device) + + def test_requires_grad(self): + data = torch.rand(5, 5, dtype=torch.bfloat16) + with pytest.raises(ValueError): + QuantizedParameter(data, requires_grad=True) + + def test_move_to_accelerator(self): + device = get_accelerator().current_device() + data = torch.rand(5, 5, device='cpu', dtype=torch.bfloat16) + qp = QuantizedParameter(data) + assert qp.device == torch.device('cpu') + qp = qp.to(get_accelerator().current_device_name()) + assert qp.device == torch.device(device) + assert qp.dtype == torch.int8 + + def test_hf_clone(self): + device = get_accelerator().current_device_name() + data = torch.rand(5, 5, device=device, dtype=torch.bfloat16) + + quantization_config = QuantizationConfig(q_bits=6) + qp = QuantizedParameter(data, quantization_config=quantization_config) + + # should be able to clone parameter via dict, HF expects this to work + qp_copy = QuantizedParameter(qp.data, **qp.__dict__) + + assert all(qp.data == qp_copy.data) + assert qp.quantization_config == qp_copy.quantization_config