Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename brevitas quant custom op #693

Merged
merged 1 commit into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ def matmul_rhs_group_quant(
raise ValueError("Input shapes not supported.")


brevitas_lib = torch.library.Library("brevitas", "DEF")
brevitas_lib = torch.library.Library("quant", "DEF")
brevitas_lib.define(
"matmul_rhs_group_quant(Tensor lhs, Tensor rhs, Tensor rhs_scale, Tensor rhs_zero_point, int rhs_bit_width, int rhs_group_size) -> Tensor"
)
brevitas_lib.impl("matmul_rhs_group_quant", matmul_rhs_group_quant)


def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
Expand All @@ -72,20 +72,20 @@ def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rh
raise ValueError("Input shapes not supported.")


def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype


def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return
# yapf: enable

brevitas_matmul_rhs_group_quant_library = [
brevitas〇matmul_rhs_group_quant〡shape,
brevitas〇matmul_rhs_group_quant〡dtype,
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
quant〇matmul_rhs_group_quant〡shape,
quant〇matmul_rhs_group_quant〡dtype,
quant〇matmul_rhs_group_quant〡has_value_semantics]

if __name__ == '__main__':

Expand All @@ -100,7 +100,7 @@ def forward(
rhs: torch.Tensor,
rhs_scale: torch.Tensor,
rhs_zero_point: torch.Tensor):
return torch.ops.brevitas.matmul_rhs_group_quant(
return torch.ops.quant.matmul_rhs_group_quant(
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width=8, rhs_group_size=128)

mod = CustomOpExampleModule()
Expand All @@ -109,6 +109,6 @@ def forward(
module = torch_mlir.compile(
mod, (torch.ones(3, 4), torch.ones(5, 4), torch.ones(1), torch.ones(1)),
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library)
print(module)
12 changes: 5 additions & 7 deletions src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@

# Due a tracing issue this annotation needs to be
# in the same module (== file) from which make_fx is called
# We also can't directly annotate torch.ops.brevitas.matmul_rhs_group_quant
# We also can't directly annotate torch.ops.quant.matmul_rhs_group_quant
# and so we trace a placeholder first and then replace it post tracing
@wrap(visible_to_make_fx=True)
def matmul_rhs_group_quant_placeholder(*args, **kwargs):
return torch.ops.brevitas.matmul_rhs_group_quant(*args, **kwargs)
return torch.ops.quant.matmul_rhs_group_quant(*args, **kwargs)


class LinearWeightBlockQuantHandlerFwd(LinearWeightBlockQuantHandler):
Expand Down Expand Up @@ -261,9 +261,7 @@ def transform_fx(fx_g):

transform_fx(fx_g)
replace_call_fn_target(
fx_g,
src=matmul_rhs_group_quant_placeholder,
target=torch.ops.brevitas.matmul_rhs_group_quant)
fx_g, src=matmul_rhs_group_quant_placeholder, target=torch.ops.quant.matmul_rhs_group_quant)

fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
Expand Down Expand Up @@ -319,7 +317,7 @@ def compile_to_vmfb(inputs, layers, export_context_manager, export_class, is_fir
module = torch_mlir.compile(
ts_g, (hidden_states_placeholder, inputs[1], inputs[2]),
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False)
Expand All @@ -342,7 +340,7 @@ def compile_to_vmfb(inputs, layers, export_context_manager, export_class, is_fir
pkv0_placeholder,
pkv1_placeholder),
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False)
Expand Down
8 changes: 4 additions & 4 deletions src/brevitas_examples/llm/test_linear_mlir_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@

# Due a tracing issue this annotation needs to be
# in the same module (== file) from which make_fx is called
# We also can't directly annotate torch.ops.brevitas.matmul_rhs_group_quant
# We also can't directly annotate torch.ops.quant.matmul_rhs_group_quant
# and so we trace a placeholder first and then replace it post tracing
@wrap(visible_to_make_fx=True)
def matmul_rhs_group_quant_placeholder(*args, **kwargs):
return torch.ops.brevitas.matmul_rhs_group_quant(*args, **kwargs)
return torch.ops.quant.matmul_rhs_group_quant(*args, **kwargs)


class LinearWeightBlockQuantHandlerFwd(LinearWeightBlockQuantHandler):
Expand Down Expand Up @@ -84,7 +84,7 @@ def quantize_and_export(args):
replace_call_fn_target(
traced_model,
src=matmul_rhs_group_quant_placeholder,
target=torch.ops.brevitas.matmul_rhs_group_quant)
target=torch.ops.quant.matmul_rhs_group_quant)

# print the output graph
print(traced_model.graph)
Expand All @@ -93,7 +93,7 @@ def quantize_and_export(args):
traced_model,
torch.randn(2, 128),
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=True,
verbose=False)
Expand Down
Loading