Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA, Performance] Add gemm expand triton kernel for multi-LoRA #1728

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 152 additions & 57 deletions python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@


class BaseLayerWithLoRA(nn.Module):
def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
def __init__(self, base_layer, segment_gemm, lora_rank, scaling, gemm_expand=None):
super().__init__()
self.base_layer = base_layer
self.segment_gemm = segment_gemm
self.lora_rank = lora_rank
self.scaling = scaling
self.set_lora = False
self.gemm_expand = gemm_expand

def forward(self, x: torch.Tensor):
return self.base_layer.forward(x)
Expand All @@ -61,17 +62,27 @@ def set_lora_info(self, *args):

class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
self,
base_layer: VocabParallelEmbedding,
segment_gemm,
lora_rank,
scaling,
gemm_expand=None,
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, segment_gemm, lora_rank, scaling, gemm_expand)
self.weight = base_layer.weight


class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
self,
base_layer: ColumnParallelLinear,
segment_gemm,
lora_rank,
scaling,
gemm_expand=None,
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, segment_gemm, lora_rank, scaling, gemm_expand)

def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
# TODO
Expand All @@ -97,16 +108,32 @@ def forward(self, input_: torch.Tensor):

class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
self,
base_layer: MergedColumnParallelLinear,
segment_gemm,
lora_rank,
scaling,
gemm_expand=None,
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, segment_gemm, lora_rank, scaling, gemm_expand)

def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
def set_lora_info(
self,
A_buffer,
B_buffer,
bs,
seg_indptr,
weight_indices,
seg_lens=None,
max_len=None,
):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seg_indptr = seg_indptr
self.seg_lens = seg_lens
self.max_len = max_len
self.weight_indices = weight_indices

def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -118,40 +145,73 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
# FIXME
lora_output = torch.empty_like(base_output)
output_dim = lora_output.shape[-1] // 2
for i in range(2):
left = output_dim * i
right = left + output_dim
lora_output[:, left:right] = self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * i : self.lora_rank * (i + 1)
].contiguous(),
weights=self.B_buffer[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
return base_output + lora_output * self.scaling
output_dim = lora_output.shape[-1]
if self.gemm_expand is not None:
for i in range(2):
self.gemm_expand(
base_output,
lora_a_output,
self.B_buffer[i],
batch_size=self.bs,
seg_lens=self.seg_lens,
seg_start=self.seg_indptr,
weight_indices=self.weight_indices,
max_len=self.max_len,
input_slice_offset=self.lora_rank * i,
output_slice_offset=output_dim * i,
output_add=True,
scaling=self.scaling,
)
return base_output
else:
# FIXME wait for flashinfer segment gemm update
lora_output = torch.empty_like(base_output)
for i in range(2):
left = output_dim * i
right = left + output_dim
lora_output[:, left:right] = self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * i : self.lora_rank * (i + 1)
].contiguous(),
weights=self.B_buffer[i],
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
return base_output + lora_output * self.scaling


class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
def init__(
self,
base_layer: QKVParallelLinear,
segment_gemm,
lora_rank,
scaling,
gemm_expand=None,
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, segment_gemm, lora_rank, scaling, gemm_expand)

def set_lora_info(
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices
self,
A_buffer_qkv,
B_buffer_q,
B_buffer_kv,
bs,
seg_indptr,
weight_indices,
seg_lens=None,
max_len=None,
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
self.B_buffer_q = B_buffer_q
self.B_buffer_kv = B_buffer_kv
self.bs = bs
self.seg_indptr = seg_indptr
self.seg_lens = seg_lens
self.max_len = max_len
self.weight_indices = weight_indices

def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -169,37 +229,60 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor
output_dim_q = self.B_buffer_q.shape[-2]
lora_output[:, :output_dim_q] = self.segment_gemm.run(
x=lora_a_output[:, : self.lora_rank].contiguous(),
weights=self.B_buffer_q,
weights=self.B_buffer_q[0],
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
# kv
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
for i in range(2):
left = output_dim_kv * i
right = left + output_dim_kv
lora_output[:, output_dim_q + left : output_dim_q + right] = (
self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
].contiguous(),
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
output_dim_kv = self.B_buffer_kv.shape[-2]
if self.gemm_expand is not None:
for i in range(2):
self.gemm_expand(
base_output,
lora_a_output,
self.B_buffer_kv[i],
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
seg_lens=self.seg_lens,
seg_start=self.seg_indptr,
weight_indices=self.weight_indices,
max_len=self.max_len,
input_slice_offset=self.lora_rank * (i + 1),
output_slice_offset=output_dim_q + output_dim_kv * i,
output_add=True,
scaling=self.scaling,
)
)
return base_output + lora_output * self.scaling
return base_output
else:
for i in range(2):
left = output_dim_kv * i
right = left + output_dim_kv
lora_output[:, output_dim_q + left : output_dim_q + right] = (
self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
].contiguous(),
weights=self.B_buffer_kv[i],
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
)
return base_output + lora_output * self.scaling


class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
self,
base_layer: RowParallelLinear,
segment_gemm,
lora_rank,
scaling,
gemm_expand=None,
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, segment_gemm, lora_rank, scaling, gemm_expand)

def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
self.set_lora = True
Expand All @@ -220,7 +303,7 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor
)
lora_output = self.segment_gemm.run(
x=lora_output,
weights=self.B_buffer,
weights=self.B_buffer[0],
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
Expand Down Expand Up @@ -264,7 +347,11 @@ def forward(self, input_):


def get_lora_layer(
layer: nn.Module, segment_gemm, lora_rank, scaling
layer: nn.Module,
segment_gemm,
lora_rank,
scaling,
gemm_expand,
) -> BaseLayerWithLoRA:
supported_layer_types = {
# the order matters
Expand All @@ -276,7 +363,9 @@ def get_lora_layer(
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
ret = lora_layer_type(
layer, segment_gemm, lora_rank, scaling, gemm_expand=gemm_expand
)
return ret
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")

Expand Down Expand Up @@ -306,13 +395,14 @@ def offload_from_gpu(self):


class LoRAAdapter(nn.Module):
def __init__(self, uid, config, base_hf_config, load_config):
def __init__(self, uid, config, base_hf_config, load_config, lora_backend):
super().__init__()
self.uid = uid
self.config = config
assert self.config.hf_config["peft_type"].lower() == "lora"
self.base_hf_config = base_hf_config
self.load_config = load_config
self.lora_backend = lora_backend
self.scaling = self.config.lora_alpha / self.config.r

self.layers = nn.ModuleList(
Expand Down Expand Up @@ -383,20 +473,25 @@ def initialize_weights(self):
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
else:
layer.weights[kv_name] = torch.cat(
(
layer.weights[kv_name] = torch.stack(
[
layer.weights[weight_name],
layer.weights[v_name],
),
0,
],
dim=0,
)
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
elif "gate_proj" in weight_name:
up_name = weight_name.replace("gate_proj", "up_proj")
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
layer.weights[gate_up_name] = torch.cat(
(layer.weights[weight_name], layer.weights[up_name]), 0
)
if "lora_A" in weight_name:
layer.weights[gate_up_name] = torch.cat(
(layer.weights[weight_name], layer.weights[up_name]), 0
)
else:
layer.weights[gate_up_name] = torch.stack(
[layer.weights[weight_name], layer.weights[up_name]], dim=0
)
layer.weights.pop(weight_name)
layer.weights.pop(up_name)
Loading