Skip to content

Commit

Permalink
[AO][Inductor] Enable WOQ fusion pattern with permute
Browse files Browse the repository at this point in the history
ghstack-source-id: 68ae4d478336c580a4d3fb036a2c27f5439f681f
Pull Request resolved: #135928
  • Loading branch information
leslie-fang-intel committed Sep 13, 2024
1 parent 86335e9 commit 5f4576b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 6 deletions.
25 changes: 21 additions & 4 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2585,19 +2585,36 @@ def forward(self, x):
@skipIfNoDynamoSupport
def test_woq_int8(self):
class M(torch.nn.Module):
def __init__(self, is_permute):
super().__init__()
self.is_permute = is_permute

def forward(self, x, weight, scales):
return torch.nn.functional.linear(x, weight.to(dtype=x.dtype)) * scales
if self.is_permute:
weight = weight.t()
m = torch.mm(
x.reshape(-1, x.shape[-1]),
weight.to(x.dtype),
)
y = m * scales.to(m.dtype)
y = y.reshape(*x.shape[:-1], y.shape[-1])
return y
else:
return (
torch.nn.functional.linear(x, weight.to(dtype=x.dtype)) * scales
)

mod = M().eval()
x_shape = (1, 1, 256)
w_shape = (12, 256)
s_shape = 12
x_strides = [
(256, 256, 1), # linear dispatching to mm
(256, 32, 1), # linear dispatching to bmm
]
for x_stride in x_strides:
is_permutes = [False, True]
for x_stride, is_permute in itertools.product(x_strides, is_permutes):
mod = M(is_permute=is_permute).eval()
x = torch.randn(x_shape, dtype=torch.bfloat16).as_strided(x_shape, x_stride)
w_shape = (12, 256)
w = torch.randint(-128, 127, w_shape, dtype=torch.int8)
s = torch.randn(s_shape, dtype=torch.bfloat16)

Expand Down
12 changes: 10 additions & 2 deletions torch/_inductor/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,19 @@ def _deduce_value(self, node: torch.fx.Node) -> Any:
def is_impure(self, node: torch.fx.node.Node) -> bool:
if (
node.target == torch.ops.prims.convert_element_type.default
and is_const_source(node.args[0], self.lifted_constants) # type: ignore[arg-type]
and (
is_const_source(node.args[0], self.lifted_constants) # type: ignore[arg-type]
or (
isinstance(node.args[0], torch.fx.Node)
and node.args[0].target == torch.ops.aten.permute.default
and is_const_source(node.args[0].args[0], self.lifted_constants) # type: ignore[arg-type]
)
)
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
and node.args[1] == torch.bfloat16
):
# For int8_weight -> dq -> bf16_weight
# Case 1: int8_weight -> dq -> bf16_weight
# Case 2: int8_weight -> permute -> dq -> bf16_weight
return True
if node.target in [
torch.ops.quantized_decomposed.dequantize_per_channel.default,
Expand Down
22 changes: 22 additions & 0 deletions torch/_inductor/fx_passes/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,27 @@ def _register_woq_mm_int8_pattern3():
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)


def _register_woq_mm_int8_pattern4():
_woq_pattern = CallFunction(
aten.mul.Tensor,
CallFunction(
aten.mm.default,
KeywordArg("x"),
CallFunction(
prims.convert_element_type.default,
CallFunction(
aten.permute.default,
KeywordArg("weight"),
Arg(),
),
Arg(),
),
),
KeywordArg("scales"),
)
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)


def _register_quantization_lowerings():
_register_quantization_unary_fusion()
_register_quantization_binary_fusion()
Expand All @@ -1573,6 +1594,7 @@ def _register_woq_lowerings():
_register_woq_mm_int8_pattern1()
_register_woq_mm_int8_pattern2()
_register_woq_mm_int8_pattern3()
_register_woq_mm_int8_pattern4()


def _is_valid_dequant_promotion_pattern(dtype=torch.float32):
Expand Down

0 comments on commit 5f4576b

Please sign in to comment.