diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 7ed69e1e9bf8..457ad6e11ba3 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1402,20 +1402,23 @@ def callback(self, pre, post, node_map): return ethosu_fc -class MatMulRewriter(DFPatternCallback): - """Legalize matrix multiplication to an NPU operator""" +class MatrixMultiplicationRewriter(DFPatternCallback): + """Legalize matrix multiplication with two tensors into sequence of NPU operators""" - def __init__(self): + def __init__( + self, + params_class: Type, + pattern: CallPattern, + ): super().__init__(require_type=True) - self.pattern = ( - wildcard().has_attr({"Composite": ethosu_patterns.MatMulParams.composite_name}) - )(wildcard(), wildcard()) + self.pattern = pattern + self.params_class = params_class def callback(self, pre, post, node_map): - params = ethosu_patterns.MatMulParams(post.op.body) + params = self.params_class(post.op.body) ifm = post.args[0] ifm2 = post.args[1] - lut = relay.const([], dtype="int8") + lut = relay.const([], dtype=params.ifm.dtype) activation_map = {"clip": "CLIP"} if params.activation: activation = activation_map[params.activation.op.name] @@ -1471,7 +1474,7 @@ def callback(self, pre, post, node_map): rounding_mode="NATURAL", ) - # Convert tensor dtype from int32 to int8 + # Convert tensor dtype from int32 to output dtype scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="int32"), dtype="int32") reduce_sum = ethosu_ops.ethosu_binary_elementwise( ifm=reduce_sum, @@ -1487,7 +1490,7 @@ def callback(self, pre, post, node_map): ifm_channels=1, ifm2_channels=1, reversed_operands=False, - ofm_dtype="int8", + ofm_dtype=params.ofm.dtype, ) res_columns.append(reduce_sum) @@ -1497,6 +1500,32 @@ def callback(self, pre, post, node_map): return relay.reshape(concat, params.ofm.shape) +class MatMulRewriter(MatrixMultiplicationRewriter): + """Convert ethos-u.matmul composite function to sequence of NPU operators""" + + def __init__(self): + super().__init__( + params_class=ethosu_patterns.MatMulParams, + pattern=( + wildcard().has_attr({"Composite": ethosu_patterns.MatMulParams.composite_name}) + )(wildcard(), wildcard()), + ) + + +class MatMulFixedPointRewriter(MatrixMultiplicationRewriter): + """Convert ethos-u.matmul_fixed_point composite function to sequence of NPU operators""" + + def __init__(self): + super().__init__( + params_class=ethosu_patterns.MatMulFixedPointParams, + pattern=( + wildcard().has_attr( + {"Composite": ethosu_patterns.MatMulFixedPointParams.composite_name} + ) + )(wildcard(), wildcard()), + ) + + class PadRewriter(DFPatternCallback): """Convert ethos-u.pad2d composite function to ethosu_depthwise_conv2d operator""" @@ -1644,6 +1673,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function: PartitionedSplitRewriter(), FullyConnectedRewriter(), MatMulRewriter(), + MatMulFixedPointRewriter(), SplitRewriter(), ChannelPadRewriter(), Conv2DRewriter(), diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 73cee9d0cd23..f24538242cf9 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1917,7 +1917,7 @@ def is_valid(self) -> bool: Checks whether matrix multiplication has compatible attributes with HW """ - if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]): + if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8, np.int16]): return False if not len(self.ifm.shape) == 2: return False @@ -1938,6 +1938,75 @@ def matmul_pattern(): return optional_clip +class MatMulFixedPointParams: + """ + This class will parse a call to an ethos-u.matmul_fixed_point composite + function and extract the parameter information. + """ + + composite_name = "ethos-u.matmul_fixed_point" + + @requires_vela + def __init__(self, func_body): + from tvm.relay.backend.contrib.ethosu.util import QDenseArgs + + dense_fixed_point = func_body.args[0] + dense = dense_fixed_point.args[0] + # fixed_point_multiply relay operation uses multiplier with 31 fractional bits + # so to determine the size of the fraction use the formula: 31 - shift + self.fraction_size = 31 - dense_fixed_point.attrs.shift + fract_scale = tvm.relay.Constant(tvm.nd.array(np.array(1 / 2**self.fraction_size))) + fract_zero_point = tvm.relay.Constant(tvm.nd.array(np.array(0, dtype="int32"))) + + self.activation = None + self.weights = TensorParams( + dense.args[QDenseArgs.WEIGHTS.value].args[0].args[0], + None, + fract_scale, + fract_zero_point, + ) + self.ifm = TensorParams( + dense.args[QDenseArgs.IFM.value].args[0].args[0], + None, + fract_scale, + fract_zero_point, + ) + self.ofm = TensorParams( + func_body, + None, + fract_scale, + fract_zero_point, + ) + + def is_valid(self) -> bool: + """ + Checks whether matrix multiplication has compatible attributes with HW + """ + + if self.fraction_size < 0 or self.fraction_size > 16: + return False + if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int16]): + return False + if not len(self.ifm.shape) == 2: + return False + if not len(self.ofm.shape) == 2: + return False + # The weights must be transposed + if self.ifm.shape[1] != self.weights.shape[1]: + return False + return True + + +def matmul_fixed_point_pattern(): + ifm = is_op("cast")(wildcard()) + ifm2 = is_op("cast")(wildcard()) + ifm = is_op("fixed_point_multiply")(ifm) + ifm2 = is_op("fixed_point_multiply")(ifm2) + dense = is_op("nn.dense")(ifm, ifm2) + dense = is_op("fixed_point_multiply")(dense) + return is_op("cast")(dense) + + class HardSwishParams: """ This class will parse a call to a ethos-u.hard_swish composite function @@ -2228,6 +2297,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal matmul_pattern(), lambda pat: MatMulParams(pat).is_valid(), ), + ( + MatMulFixedPointParams.composite_name, + matmul_fixed_point_pattern(), + lambda pat: MatMulFixedPointParams(pat).is_valid(), + ), ( MaxPool2DParams.composite_name, qnn_maxpool2d_pattern(), diff --git a/src/relay/op/contrib/ethosu/identity.cc b/src/relay/op/contrib/ethosu/identity.cc index f808e8c21902..9ec6c6f42ce0 100644 --- a/src/relay/op/contrib/ethosu/identity.cc +++ b/src/relay/op/contrib/ethosu/identity.cc @@ -46,7 +46,8 @@ bool EthosuIdentityRel(const Array& types, int num_inputs, const Attrs& at const String operator_name = "ethosu_identity"; - CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm"); + CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8), DataType::Int(16)}, + operator_name, "ifm"); if (ifm->shape.size() > 4) { reporter->GetDiagCtx().EmitFatal( diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index f69c114cabd1..afcf27bb4517 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1624,5 +1624,56 @@ def subtract_sigmoid_function(lhs, rhs): ) +@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"]) +@pytest.mark.parametrize( + "ifm_shape,ofm_channels,fract_size,tolerance", + [[(1, 16), 8, 15, 0.001], [(2, 8), 16, 14, 0.001], [(4, 8), 16, 12, 0.001]], +) +def test_ethosu_matmul_fixed_point(accel_type, ifm_shape, ofm_channels, fract_size, tolerance): + np.random.seed(0) + dtype = "int16" + weights_shape = (ofm_channels, ifm_shape[1]) + + def create_model(): + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=weights_shape, dtype=dtype) + ifm_fixed_point = relay.cast(ifm, "int32") + ifm2_fixed_point = relay.cast(ifm2, "int32") + ifm_fixed_point = relay.fixed_point_multiply(ifm_fixed_point, 2**31 - 1, 0) + ifm2_fixed_point = relay.fixed_point_multiply(ifm2_fixed_point, 2**31 - 1, 0) + dense = relay.nn.dense(ifm_fixed_point, ifm2_fixed_point) + dense = relay.fixed_point_multiply(dense, 1, 16) + dense = relay.cast(dense, dtype) + return tvm.IRModule.from_expr(relay.Function([ifm, ifm2], dense)) + + def convert_to_fixed_point(arr, fract_size): + fract_fact = 0b1 << fract_size + return np.array(arr * fract_fact, dtype=np.int16) + + cpu_mod = create_model() + ethosu_mod = partition_for_ethosu(cpu_mod) + + input_data = { + "ifm": np.random.uniform(-0.5, 0.5, size=ifm_shape), + "ifm2": np.random.uniform(-0.5, 0.5, size=weights_shape), + } + input_data = { + "ifm": convert_to_fixed_point(input_data["ifm"], fract_size), + "ifm2": convert_to_fixed_point(input_data["ifm2"], fract_size), + } + output_data = generate_ref_data(cpu_mod, input_data) + output_data = {"output": output_data["output"].astype("int16")} + tolerance = convert_to_fixed_point(tolerance, fract_size) + + infra.compare_ethosu_with_reference( + ethosu_mod, + input_data, + output_data, + accel_type, + enable_cascader=False, + output_tolerance=tolerance, + ) + + if __name__ == "__main__": tvm.testing.main()