From 49eae2da2150a42f0242649e3af2be5496a20c13 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Wed, 8 Nov 2023 21:40:52 +0000 Subject: [PATCH] [Fix] Fix `topi.rms_norm` with float32 upscale --- include/tvm/topi/nn/rms_norm.h | 17 ++++++++++------- python/tvm/topi/testing/rms_norm_python.py | 9 +++++---- tests/python/topi/python/test_topi_rms_norm.py | 14 ++++++-------- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h index 55dac39b718e..ba2f7e49ac98 100644 --- a/include/tvm/topi/nn/rms_norm.h +++ b/include/tvm/topi/nn/rms_norm.h @@ -54,15 +54,18 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Arraydtype : data_type; ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type"; - auto square = multiply(data, data); + const auto& data_fp32 = cast(data, DataType::Float(32)); + const auto& weight_fp32 = cast(weight, DataType::Float(32)); + + auto square = multiply(data_fp32, data_fp32); auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true); - auto ndim = data->shape.size(); + auto ndim = data_fp32->shape.size(); ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); - auto reduce_extent = make_const(data->dtype, 1); + auto reduce_extent = make_const(data_fp32->dtype, 1); for (int i : real_axis) { - reduce_extent *= data->shape[i]; + reduce_extent *= data_fp32->shape[i]; } auto rms_norm_func = [&](const Array& indices) { Array reduce_indices, non_reduce_indices; @@ -74,12 +77,12 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Arrayshape, rms_norm_func, name, tag); - return rms_norm; + auto rms_norm = tvm::te::compute(data_fp32->shape, rms_norm_func, name, tag); + return cast(rms_norm, data_type); } } // namespace nn diff --git a/python/tvm/topi/testing/rms_norm_python.py b/python/tvm/topi/testing/rms_norm_python.py index 7fad5d57ce10..651f6f884309 100644 --- a/python/tvm/topi/testing/rms_norm_python.py +++ b/python/tvm/topi/testing/rms_norm_python.py @@ -19,7 +19,7 @@ import numpy as np -def rms_norm_python(data, weight, bias, axis, epsilon=1e-5): +def rms_norm_python(data, weight, axis, epsilon=1e-5): """Root mean square normalization operator in Python. Parameters @@ -44,8 +44,9 @@ def rms_norm_python(data, weight, bias, axis, epsilon=1e-5): result : np.ndarray N-D with shape (d_0, d_1, ..., d_{N-1}) """ + dtype = data.dtype + data = data.astype("float32") + weight = weight.astype("float32") square_mean = np.mean(np.square(data), axis, keepdims=True) result = data * weight / np.sqrt(square_mean + epsilon) - if bias is not None: - result += bias - return result + return result.astype(dtype) diff --git a/tests/python/topi/python/test_topi_rms_norm.py b/tests/python/topi/python/test_topi_rms_norm.py index 35a1485afa6b..c8c1b8795f2d 100644 --- a/tests/python/topi/python/test_topi_rms_norm.py +++ b/tests/python/topi/python/test_topi_rms_norm.py @@ -34,7 +34,8 @@ # only test on llvm because schedule is missing @tvm.testing.parametrize_targets("llvm") @pytest.mark.parametrize( - "shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 16)], (1,))] + "shape,axis", + [([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 16)], (1,)), ([2, 8192], (1,))], ) @pytest.mark.parametrize("dtype", ["float32", "float16"]) def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-3, atol=1e-4): @@ -42,25 +43,22 @@ def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-3, ato scale_shape_te = [shape_te[dim] for dim in axis] data = te.placeholder(shape_te, dtype=dtype, name="data") weight = te.placeholder(scale_shape_te, dtype=dtype, name="weight") - bias = te.placeholder(scale_shape_te, dtype=dtype, name="weight") - B = topi.nn.rms_norm(data, weight, bias, axis, episilon) + B = topi.nn.rms_norm(data, weight, axis, episilon) shape_np = [v[1] if isinstance(v, tuple) else v for v in shape] scale_shape_np = [shape_np[dim] for dim in axis] data_np = np.random.uniform(size=shape_np).astype(dtype) weight_np = np.random.uniform(size=scale_shape_np).astype(dtype) - bias_np = np.random.uniform(size=scale_shape_np).astype(dtype) - b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, bias_np, axis, episilon) + b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, axis, episilon) with tvm.target.Target(target): s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule) s = s_func([B]) data_tvm = tvm.nd.array(data_np, dev) weight_tvm = tvm.nd.array(weight_np, dev) - bias_tvm = tvm.nd.array(bias_np, dev) b_tvm = tvm.nd.array(np.zeros(shape_np, dtype=dtype), dev) - f = tvm.build(s, [data, weight, bias, B], target) - f(data_tvm, weight_tvm, bias_tvm, b_tvm) + f = tvm.build(s, [data, weight, B], target) + f(data_tvm, weight_tvm, b_tvm) tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)