forked from microsoft/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
OptimizedLinear implementation (microsoft#5355)
Optimized version of `nn.Linear` that adds features such as: * LoRA w. base weight sharding * FP [6,8,12] quantization Depends on microsoft#5336 being merged first Co-authored-by: @rajhans Co-authored-by: @aurickq --------- Co-authored-by: Rajhans Samdani <[email protected]> Co-authored-by: Jeff Rasley <[email protected]>
- Loading branch information
Showing
8 changed files
with
550 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from .optimized_linear import OptimizedLinear | ||
from .config import LoRAConfig, QuantizationConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class LoRAConfig: | ||
""" | ||
Configuration settings for LoRAOptimizedLinear. | ||
Attributes: | ||
lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64. | ||
lora_alpha (float): LoRA scaling factor, default is 16. | ||
base_weight_sharding (int): The degree to which the base weights are sharded, | ||
should typically be set to the data-parallel world size to maximize the memory | ||
reduction benefits. Defaults to 1, which means this feature is disabled. | ||
""" | ||
lora_r: int = 64 | ||
lora_alpha: float = 16. | ||
base_weight_sharding: int = 1 | ||
|
||
|
||
@dataclass | ||
class QuantizationConfig: | ||
""" | ||
Configuration settings for quantization for LoRAOptimizedLinear, QuantizedLinear, | ||
and QuantizedParameter | ||
Attributes: | ||
q_bits (int): The number of bits used for quantization. Default is 8. | ||
mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3. | ||
group_size (int): The size of the group used for quantization. Default is 512. | ||
""" | ||
q_bits: int = 8 | ||
mantissa_bits: int = 3 | ||
group_size: int = 512 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import torch | ||
import math | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from dataclasses import is_dataclass | ||
from deepspeed.accelerator import get_accelerator | ||
import deepspeed.comm as dist | ||
|
||
from .config import LoRAConfig, QuantizationConfig | ||
from .quantization import QuantizedParameter, QuantizedLinear | ||
|
||
|
||
class OptimizedLinear(nn.Module): | ||
""" | ||
Optimized version of nn.Linear that adds features such as: | ||
* LoRA w. base weight sharding | ||
* FP [6,8,12] quantization | ||
Arguments: | ||
input_dim: Required: size of each input sample | ||
output_dim: Required: size of each output sample | ||
bias: Optional: If set to False, the layer will not learn an additive bias. Default: False | ||
lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree | ||
quantization_config: Optional: QuantizationConfig defining quantization features | ||
dtype: Optional: parameter dtype, only supports bfloat16 currently | ||
Returns: | ||
Returns a new nn.Module depending on the input config. Either native | ||
torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear. | ||
""" | ||
|
||
def __new__(self, | ||
input_dim: int, | ||
output_dim: int, | ||
bias: bool = False, | ||
lora_config: LoRAConfig = None, | ||
quantization_config: QuantizationConfig = None, | ||
dtype=torch.bfloat16): | ||
|
||
if quantization_config is not None and not is_dataclass(quantization_config): | ||
raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}") | ||
if lora_config is not None and not is_dataclass(lora_config): | ||
raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}") | ||
if lora_config is None and quantization_config is None: | ||
# Everything disabled, fall back to normal nn.Linear | ||
self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype) | ||
|
||
elif lora_config: | ||
# lora enabled, quantization may or may not be | ||
self = LoRAOptimizedLinear(input_dim=input_dim, | ||
output_dim=output_dim, | ||
bias=bias, | ||
lora_config=lora_config, | ||
quantization_config=quantization_config, | ||
dtype=dtype) | ||
|
||
elif quantization_config: | ||
# only quantization enabled, no lora | ||
self = QuantizedLinear(input_dim=input_dim, | ||
output_dim=output_dim, | ||
bias=bias, | ||
quantization_config=quantization_config, | ||
dtype=dtype) | ||
return self | ||
|
||
|
||
class LoRAOptimizedLinear(nn.Module): | ||
|
||
def __init__(self, | ||
input_dim: int, | ||
output_dim: int, | ||
bias: bool = False, | ||
lora_config: LoRAConfig = None, | ||
quantization_config: QuantizationConfig = None, | ||
device=None, | ||
dtype=torch.bfloat16): | ||
super().__init__() | ||
self.input_dim = input_dim | ||
self.output_dim = output_dim | ||
self.bias = bias | ||
self.lora_config = lora_config | ||
self.quantization_config = quantization_config | ||
device = get_accelerator().current_device() if device is None else device | ||
assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config" | ||
|
||
self.zero_shards = self.lora_config.base_weight_sharding | ||
self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards) | ||
w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype)) | ||
torch.nn.init.xavier_uniform_(w) | ||
|
||
if self.quantization_config is not None: | ||
assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization" | ||
self.base_weight = QuantizedParameter(w, quantization_config=quantization_config) | ||
else: | ||
self.base_weight = w | ||
|
||
self.base_weight.requires_grad = False | ||
|
||
# Use RS lora for now. | ||
self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r) | ||
# Keeping lora weights in bf16 precision for ease of training. | ||
self.lora_weight_1 = nn.Linear(self.input_dim, | ||
self.lora_config.lora_r, | ||
bias=self.bias, | ||
device=device, | ||
dtype=dtype) | ||
self.lora_weight_2 = nn.Linear(self.lora_config.lora_r, | ||
self.output_dim, | ||
bias=self.bias, | ||
device=device, | ||
dtype=dtype) | ||
self.lora_weight_1.weight.requires_grad = True | ||
self.lora_weight_2.weight.requires_grad = True | ||
|
||
def full_weight(self): | ||
# This assumes weights are evenly sharded across gpus. which might not be correct. | ||
# in that case, we should flatten before all_gather. | ||
local_weight = self.base_weight.dequantized() if isinstance(self.base_weight, | ||
QuantizedParameter) else self.base_weight | ||
tensor_list = [ | ||
torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype) | ||
for _ in range(self.zero_shards) | ||
] | ||
dist.all_gather(tensor_list, local_weight) | ||
weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1)) | ||
return weight | ||
|
||
def linear_without_F_linear(self, input, weight): | ||
output = torch.mm(input.reshape(-1, input.shape[-1]), weight) | ||
output = output.view(*input.shape[:-1], weight.shape[1]) | ||
return output | ||
|
||
def forward(self, input_tensor): | ||
# Gather the sharded base weight | ||
if self.zero_shards > 1: | ||
with torch.no_grad(): | ||
base_weight = self.full_weight() | ||
elif self.quantization_config: | ||
base_weight = self.base_weight.dequantized() | ||
else: | ||
base_weight = self.base_weight | ||
|
||
base_weight_output = F.linear(input_tensor, base_weight) | ||
lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor)) | ||
return base_weight_output + self.lora_scaling_factor * lora_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import copy | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from typing import Optional | ||
|
||
from deepspeed.accelerator import get_accelerator | ||
from deepspeed.ops.fp_quantizer import Quantizer, FP_Quantize | ||
from .config import QuantizationConfig | ||
|
||
|
||
class QuantizedParameter(nn.Parameter): | ||
""" | ||
Quantized parameter class that implements weight quantization. Weights | ||
are stored in quantized form on GPUs, and can be dequantized on-the-fly when | ||
needed by the model. The weights are actually quantized during any `.to(device)`. | ||
Arguments: | ||
data (Tensor): parameter tensor. | ||
requires_grad (bool, optional): if the parameter requires gradient. Defaults | ||
to False and is not supported to be True. Argument provided only for interface | ||
compatibility with torch.nn.Parameter. | ||
quantization_config (QuantizationConfig, optional): | ||
quantizer (Quantizer, optional): Defaults to FP_Quantize but can be any quantizer | ||
that implements deepspeed.ops.fp_quantizer.Quantizer. This argument is also | ||
required since the quantizer is stashed in the Parameter itself, some models | ||
may clone the Parameter by passing an attribute __dict__. For an example, see | ||
tests/unit/linear/test_quant_param.py::TestQuantParam::test_hf_clone | ||
""" | ||
|
||
def __new__( | ||
cls, | ||
data: Optional[torch.Tensor] = None, | ||
requires_grad: bool = False, # quantized weights must be frozen | ||
quantization_config: QuantizationConfig = None, | ||
quantizer: Quantizer = None, | ||
): | ||
if requires_grad: | ||
raise ValueError(f"requires_grad=True is not supported with QuantizedParameter") | ||
if data is None: | ||
data = torch.empty(0) | ||
self = torch.Tensor._make_subclass(cls, data, requires_grad) | ||
self.quantization_config = QuantizationConfig() if quantization_config is None else quantization_config | ||
if quantizer is not None: | ||
self.quantizer = quantizer | ||
else: | ||
# if FPQuantizerBuilder is not compatible in this env this init will fail | ||
self.quantizer = FP_Quantize(group_size=self.quantization_config.group_size) | ||
self._ensure_quantized(self) | ||
return self | ||
|
||
def _ensure_quantized(self, tensor: torch.Tensor): | ||
# If the tensor is on the accelerator and is not quantized, then quantize it in-place. | ||
if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.int8: | ||
with get_accelerator().stream(get_accelerator().current_stream(tensor.device)): | ||
tensor.data = self.quantizer.quantize(tensor.data, | ||
q_bits=self.quantization_config.q_bits, | ||
q_mantisa_bits=self.quantization_config.mantissa_bits) | ||
assert tensor.dtype == torch.int8 | ||
|
||
def dequantized(self) -> torch.Tensor: | ||
""" | ||
Return a tensor containing the dequantized weights of this parameter. | ||
""" | ||
if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.int8: | ||
with get_accelerator().stream(get_accelerator().current_stream(self.data.device)): | ||
return self.quantizer.dequantize(self.data, | ||
q_bits=self.quantization_config.q_bits, | ||
q_mantisa_bits=self.quantization_config.mantissa_bits) | ||
return self.data | ||
|
||
def __getstate__(self): | ||
state = self.__dict__ | ||
state["data"] = self.data | ||
state["quantization_config"] = self.quantization_config | ||
state["requires_grad"] = self.requires_grad | ||
return state | ||
|
||
def __setstate__(self, state): | ||
self.quantizer = state["quantizer"] | ||
self.quantization_config = state["quantization_config"] | ||
self.data = state["data"] | ||
self.requires_grad = state["requires_grad"] | ||
|
||
def __deepcopy__(self, memo): | ||
new_instance = type(self).__new__(type(self)) | ||
state = self.__getstate__() | ||
new_instance.__setstate__(state) | ||
new_instance.quantizer = copy.deepcopy(state["quantizer"]) | ||
new_instance.quantization_config = copy.deepcopy(state["quantization_config"]) | ||
new_instance.data = copy.deepcopy(state["data"]) | ||
return new_instance | ||
|
||
def __copy__(self): | ||
new_instance = type(self).__new__(type(self)) | ||
state = self.__getstate__() | ||
new_instance.__setstate__(state) | ||
return new_instance | ||
|
||
def cuda(self, device=None, non_blocking=False): | ||
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) | ||
|
||
def to(self, *args, **kwargs): | ||
""" | ||
Move the parameter to the given device. Then, if the device is a cuda device, | ||
quantize it. | ||
""" | ||
tensor = super().to(*args, **kwargs) | ||
self._ensure_quantized(tensor) | ||
return tensor | ||
|
||
|
||
class QuantizedLinear(nn.Linear): | ||
""" | ||
Linear layer that implements weight quantization. Parameters | ||
are stored via `QuantizedParameter` and are dequantized on-the-fly during any | ||
forward pass. | ||
""" | ||
|
||
def __init__(self, | ||
input_dim: int, | ||
output_dim: int, | ||
bias: bool = False, | ||
quantization_config: QuantizationConfig = None, | ||
dtype=torch.bfloat16): | ||
super().__init__(input_dim, output_dim, bias=bias, dtype=dtype) | ||
assert dtype == torch.bfloat16, "currently only supports bfloat16 dtype" | ||
self.weight = QuantizedParameter(self.weight.data, quantization_config=quantization_config) | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
return F.linear(input, self.weight.dequantized(), self.bias) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,4 @@ | |
|
||
# DeepSpeed Team | ||
|
||
from .quantize import FP_Quantize | ||
from .quantize import FP_Quantize, Quantizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.