Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse MoE #116

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mttl/models/modifiers/sm_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

from mttl.logging import logger
from mttl.models.modifiers.base import Modifier
from mttl.models.modifiers.sm_config import SparseMaskConfig
from mttl.models.modifiers.sparse_utils.sparse_linear import MaskedLinear, SparseLinear
from mttl.models.modifiers.sparse_utils.utils import (
from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig
from mttl.models.modifiers.sparsity.sparse_linear import MaskedLinear, SparseLinear
from mttl.models.modifiers.sparsity.sparse_utils.utils import (
get_2d_indices_from_csr_matrix,
get_top_k_sparcity,
scipy_csr_to_torch_csr,
Expand Down
17 changes: 4 additions & 13 deletions mttl/models/modifiers/sparse_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,19 @@
from dataclasses import dataclass
from typing import Union

import numpy as np
import torch
from scipy.sparse import csr_matrix
from torch import nn
from triton.ops.blocksparse.matmul import dsd_lut, sdd_lut

from mttl.logging import logger
from mttl.models.modifiers.base import Modifier, ModifierConfig
from mttl.models.modifiers.sm_config import SparseMaskConfig
from mttl.models.modifiers.sm_updater import MaskUpdater
from mttl.models.modifiers.sparse_utils.sparse_linear import (
from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig
from mttl.models.modifiers.sparsity.mask_updater import MaskUpdater
from mttl.models.modifiers.sparsity.sparse_linear import (
MaskedLinear,
ScatteredSparseLinearModule,
SparseLinear,
SparseLinearConfig,
)
from mttl.models.modifiers.sparse_utils.utils import (
get_2d_indices_from_csr_matrix,
get_top_k_sparcity,
scipy_csr_to_torch_csr,
torch_csr_to_scipy_csr,
)


class SparseMaskAdapter(Modifier):
Expand Down Expand Up @@ -55,7 +46,7 @@ def __init__(

def forward(self, input):
if self.maks_update_mode and self.training:
return self.mask_updater(self.sparse_layer, input)
return self.mask_updater(input)
return self.sparse_layer(input)

def prepare_for_mask_update(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass

from mttl.models.modifiers.sparse_utils.sparse_linear import SparseLinearConfig
from mttl.models.modifiers.sparsity.sparse_linear import SparseLinearConfig


@dataclass
Expand Down
200 changes: 0 additions & 200 deletions mttl/models/modifiers/sparse_utils/profile_block_sparsity.py

This file was deleted.

Loading