Skip to content

Commit

Permalink
move mmgroupquant
Browse files Browse the repository at this point in the history
  • Loading branch information
Xida Ren committed Nov 30, 2023
1 parent dd1c771 commit 42914be
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
10 changes: 10 additions & 0 deletions python/shark_turbine/aot/compiled_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
"CompiledModule",
]

from shark_turbine.transforms.rewriter import Pass # for type annotations

################################################################################
# Data structures
################################################################################
Expand Down Expand Up @@ -472,6 +474,7 @@ def __new__(
context: Optional[Context] = None,
module_op: Optional[Operation] = None,
import_to: Union[ImportPhase, None, str] = "full",
pre_import_passes=List[Pass]
):
import_to = ImportPhase.parse(import_to)
self = super().__new__(cls)
Expand Down Expand Up @@ -538,6 +541,13 @@ def invoke_with_self(*args, **kwargs):
do_export(proc_def)

module_builder.finalize_construct()

# `run_import` transforms module from torch to linalg, and passes like MMGroupQuantRewriterPass need to be run before
module_op = CompiledModule.get_mlir_module(self)
from shark_turbine.transforms.quantization.mm_group_quant import MMGroupQuantRewriterPass
for p in pre_import_passes:
p(module_op).run()

CompiledModule.run_import(self, import_to)
return self

Expand Down
10 changes: 3 additions & 7 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,11 @@ def forward(token0: torch.Tensor, *state0_flat):
return token1, *state1_flat

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = StateUpdateModule(context=Context(), import_to=import_to)
# TODO: Integrate with external parameters to actually be able to run
# TODO: Make more generalizable to be able to quantize with all compile_to options
pre_import_passes = []
if quantization == "int4" and not compile_to == "linalg":
from shark_turbine.transforms.quantization import mm_group_quant

mm_group_quant.MMGroupQuantRewriterPass(
CompiledModule.get_mlir_module(inst).operation
).run()
pre_import_passes.append(mm_group_quant.MMGroupQuantRewriterPass)
inst = StateUpdateModule(context=Context(), import_to=import_to, pre_import_passes=pre_import_passes)
module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
Expand Down

0 comments on commit 42914be

Please sign in to comment.