Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OptimizedLinear implementation #5355

Merged
merged 13 commits into from
Apr 23, 2024
7 changes: 7 additions & 0 deletions deepspeed/linear/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .optimized_linear import OptimizedLinear
from .config import LoRAConfig, QuantizationConfig
39 changes: 39 additions & 0 deletions deepspeed/linear/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from dataclasses import dataclass


@dataclass
class LoRAConfig:
"""
Configuration settings for LoRAOptimizedLinear.
Attributes:
lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64.
lora_alpha (float): LoRA scaling factor, default is 16.
base_weight_sharding (int): The degree to which the base weights are sharded,
should typically be set to the data-parallel world size to maximize the memory
reduction benefits. Defaults to 1, which means this feature is disabled.
"""
lora_r: int = 64
lora_alpha: float = 16.
base_weight_sharding: int = 1


@dataclass
class QuantizationConfig:
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
"""
Configuration settings for quantization for LoRAOptimizedLinear, QuantizedLinear,
and QuantizedParameter
Attributes:
q_bits (int): The number of bits used for quantization. Default is 8.
mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3.
group_size (int): The size of the group used for quantization. Default is 512.
"""
q_bits: int = 8
mantissa_bits: int = 3
group_size: int = 512
150 changes: 150 additions & 0 deletions deepspeed/linear/optimized_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from deepspeed.accelerator import get_accelerator
import deepspeed.comm as dist

from .config import LoRAConfig, QuantizationConfig
from .quantization import QuantizedParameter, QuantizedLinear


class OptimizedLinear(nn.Module):
"""
Optimized version of nn.Linear that adds features such as:
* LoRA w. base weight sharding
* FP [6,8,12] quantization

Arguments:
input_dim: Required: size of each input sample
output_dim: Required: size of each output sample
bias: Optional: If set to False, the layer will not learn an additive bias. Default: False
lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree
quantization_config: Optional: QuantizationConfig defining quantization features
dtype: Optional: parameter dtype, only supports bfloat16 currently

Returns:
Returns a new nn.Module depending on the input config. Either native
torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear.
"""

def __new__(self,
input_dim: int,
output_dim: int,
bias: bool = False,
lora_config: LoRAConfig = None,
quantization_config: QuantizationConfig = None,
dtype=torch.bfloat16):
tjruwase marked this conversation as resolved.
Show resolved Hide resolved

if quantization_config is not None and not isinstance(quantization_config, QuantizationConfig):
raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}")
if lora_config is not None and not isinstance(lora_config, LoRAConfig):
raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}")

if lora_config is None and quantization_config is None:
# Everything disabled, fall back to normal nn.Linear
self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype)

elif lora_config:
# lora enabled, quantization may or may not be
self = LoRAOptimizedLinear(input_dim=input_dim,
output_dim=output_dim,
bias=bias,
lora_config=lora_config,
quantization_config=quantization_config,
dtype=dtype)

elif quantization_config:
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
# only quantization enabled, no lora
self = QuantizedLinear(input_dim=input_dim,
output_dim=output_dim,
bias=bias,
quantization_config=quantization_config,
dtype=dtype)
return self


class LoRAOptimizedLinear(nn.Module):

def __init__(self,
input_dim: int,
output_dim: int,
bias: bool = False,
lora_config: LoRAConfig = None,
quantization_config: QuantizationConfig = None,
device=None,
dtype=torch.bfloat16):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.bias = bias
self.lora_config = lora_config
self.quantization_config = quantization_config
device = get_accelerator().current_device() if device is None else device
assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config"

self.zero_shards = self.lora_config.base_weight_sharding
self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards)
w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype))
torch.nn.init.xavier_uniform_(w)

if self.quantization_config is not None:
assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization"
self.base_weight = QuantizedParameter(w, quantization_config=quantization_config)
else:
self.base_weight = w

self.base_weight.requires_grad = False

# Use RS lora for now.
self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r)
# Keeping lora weights in bf16 precision for ease of training.
self.lora_weight_1 = nn.Linear(self.input_dim,
self.lora_config.lora_r,
bias=self.bias,
device=device,
dtype=dtype)
self.lora_weight_2 = nn.Linear(self.lora_config.lora_r,
self.output_dim,
bias=self.bias,
device=device,
dtype=dtype)
self.lora_weight_1.weight.requires_grad = True
self.lora_weight_2.weight.requires_grad = True

def full_weight(self):
# This assumes weights are evenly sharded across gpus. which might not be correct.
# in that case, we should flatten before all_gather.
local_weight = self.base_weight.dequantized() if isinstance(self.base_weight,
QuantizedParameter) else self.base_weight
tensor_list = [
torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype)
for _ in range(self.zero_shards)
]
dist.all_gather(tensor_list, local_weight)
weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1))
return weight

def linear_without_F_linear(self, input, weight):
output = torch.mm(input.reshape(-1, input.shape[-1]), weight)
output = output.view(*input.shape[:-1], weight.shape[1])
return output

def forward(self, input_tensor):
# Gather the sharded base weight
if self.zero_shards > 1:
with torch.no_grad():
base_weight = self.full_weight()
elif self.quantization_config:
base_weight = self.base_weight.dequantized()
else:
base_weight = self.base_weight

base_weight_output = F.linear(input_tensor, base_weight)
lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor))
return base_weight_output + self.lora_scaling_factor * lora_output
137 changes: 137 additions & 0 deletions deepspeed/linear/quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import copy
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional

from deepspeed.accelerator import get_accelerator
from deepspeed.ops.fp_quantizer import Quantizer, FP_Quantize
from .config import QuantizationConfig


class QuantizedParameter(nn.Parameter):
"""
Quantized parameter class that implements weight quantization. Weights
are stored in quantized form on GPUs, and can be dequantized on-the-fly when
needed by the model. The weights are actually quantized during any `.to(device)`.
Arguments:
data (Tensor): parameter tensor.
requires_grad (bool, optional): if the parameter requires gradient. Defaults
to False and is not supported to be True. Argument provided only for interface
compatibility with torch.nn.Parameter.
quantization_config (QuantizationConfig, optional):
quantizer (Quantizer, optional): Defaults to FP_Quantize but can be any quantizer
that implements deepspeed.ops.fp_quantizer.Quantizer. This argument is also
required since the quantizer is stashed in the Parameter itself, some models
may clone the Parameter by passing an attribute __dict__. For an example, see
tests/unit/linear/test_quant_param.py::TestQuantParam::test_hf_clone
"""

def __new__(
cls,
data: Optional[torch.Tensor] = None,
requires_grad: bool = False, # quantized weights must be frozen
quantization_config: QuantizationConfig = None,
quantizer: Quantizer = None,
):
if requires_grad:
raise ValueError(f"requires_grad=True is not supported with QuantizedParameter")
if data is None:
data = torch.empty(0)
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.quantization_config = QuantizationConfig() if quantization_config is None else quantization_config
if quantizer is not None:
self.quantizer = quantizer
else:
# if FPQuantizerBuilder is not compatible in this env this init will fail
self.quantizer = FP_Quantize(group_size=self.quantization_config.group_size)
self._ensure_quantized(self)
return self

def _ensure_quantized(self, tensor: torch.Tensor):
# If the tensor is on the accelerator and is not quantized, then quantize it in-place.
if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.int8:
with get_accelerator().stream(get_accelerator().current_stream(tensor.device)):
tensor.data = self.quantizer.quantize(tensor.data,
q_bits=self.quantization_config.q_bits,
q_mantisa_bits=self.quantization_config.mantissa_bits)
assert tensor.dtype == torch.int8

def dequantized(self) -> torch.Tensor:
"""
Return a tensor containing the dequantized weights of this parameter.
"""
if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.int8:
with get_accelerator().stream(get_accelerator().current_stream(self.data.device)):
return self.quantizer.dequantize(self.data,
q_bits=self.quantization_config.q_bits,
q_mantisa_bits=self.quantization_config.mantissa_bits)
return self.data

def __getstate__(self):
state = self.__dict__
state["data"] = self.data
state["quantization_config"] = self.quantization_config
state["requires_grad"] = self.requires_grad
return state

def __setstate__(self, state):
self.quantizer = state["quantizer"]
self.quantization_config = state["quantization_config"]
self.data = state["data"]
self.requires_grad = state["requires_grad"]

def __deepcopy__(self, memo):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
new_instance.quantizer = copy.deepcopy(state["quantizer"])
new_instance.quantization_config = copy.deepcopy(state["quantization_config"])
new_instance.data = copy.deepcopy(state["data"])
return new_instance

def __copy__(self):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
return new_instance

def cuda(self, device=None, non_blocking=False):
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)

def to(self, *args, **kwargs):
"""
Move the parameter to the given device. Then, if the device is a cuda device,
quantize it.
"""
tensor = super().to(*args, **kwargs)
self._ensure_quantized(tensor)
return tensor


class QuantizedLinear(nn.Linear):
"""
Linear layer that implements weight quantization. Parameters
are stored via `QuantizedParameter` and are dequantized on-the-fly during any
forward pass.
"""

def __init__(self,
input_dim: int,
output_dim: int,
bias: bool = False,
quantization_config: QuantizationConfig = None,
dtype=torch.bfloat16):
super().__init__(input_dim, output_dim, bias=bias, dtype=dtype)
assert dtype == torch.bfloat16, "currently only supports bfloat16 dtype"
self.weight = QuantizedParameter(self.weight.data, quantization_config=quantization_config)

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.dequantized(), self.bias)
2 changes: 1 addition & 1 deletion deepspeed/ops/fp_quantizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

# DeepSpeed Team

from .quantize import FP_Quantize
from .quantize import FP_Quantize, Quantizer
33 changes: 30 additions & 3 deletions deepspeed/ops/fp_quantizer/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,47 @@
# DeepSpeed Team

import torch
import abc
from abc import ABC

from deepspeed.ops.op_builder import FPQuantizerBuilder

fp_quant_module = None


class FP_Quantize:
class Quantizer(ABC):
"""
Abstract Quantizer class that implmenents quantize/dequantize methods.
Arguments:
group_size (int, optional): number of values or elements that are grouped
together for the quantization process.
"""

def __init__(self, group_size=512) -> None:
self.group_size = group_size

@abc.abstractmethod
def quantize(self,
input,
q_bits=8,
q_mantisa_bits=3,
stochastic_mode=False,
return_meta_tensor=False) -> torch.Tensor:
...

@abc.abstractmethod
def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:
...


class FP_Quantize(Quantizer):

def __init__(self, group_size=512) -> None:
global fp_quant_module
super().__init__(group_size=group_size)
if fp_quant_module is None:
fp_quant_module = FPQuantizerBuilder().load()

self.group_size = group_size
self.orig_dtype = None

def quantize(self,
Expand Down
Loading
Loading