Skip to content

Commit

Permalink
[Fix] Fix topi.rms_norm with float32 upscale
Browse files Browse the repository at this point in the history
  • Loading branch information
cyx-6 committed Nov 8, 2023
1 parent 9100a8e commit 49eae2d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 19 deletions.
17 changes: 10 additions & 7 deletions include/tvm/topi/nn/rms_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,18 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array<Int
const auto& weight_type = weight.defined() ? weight->dtype : data_type;
ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type";

auto square = multiply(data, data);
const auto& data_fp32 = cast(data, DataType::Float(32));
const auto& weight_fp32 = cast(weight, DataType::Float(32));

auto square = multiply(data_fp32, data_fp32);
auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true);

auto ndim = data->shape.size();
auto ndim = data_fp32->shape.size();
ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto reduce_extent = make_const(data->dtype, 1);
auto reduce_extent = make_const(data_fp32->dtype, 1);
for (int i : real_axis) {
reduce_extent *= data->shape[i];
reduce_extent *= data_fp32->shape[i];
}
auto rms_norm_func = [&](const Array<Var>& indices) {
Array<Var> reduce_indices, non_reduce_indices;
Expand All @@ -74,12 +77,12 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array<Int
}
}
auto output =
data(indices) * weight(reduce_indices) *
data_fp32(indices) * weight_fp32(reduce_indices) *
tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon));
return output;
};
auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag);
return rms_norm;
auto rms_norm = tvm::te::compute(data_fp32->shape, rms_norm_func, name, tag);
return cast(rms_norm, data_type);
}

} // namespace nn
Expand Down
9 changes: 5 additions & 4 deletions python/tvm/topi/testing/rms_norm_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np


def rms_norm_python(data, weight, bias, axis, epsilon=1e-5):
def rms_norm_python(data, weight, axis, epsilon=1e-5):
"""Root mean square normalization operator in Python.
Parameters
Expand All @@ -44,8 +44,9 @@ def rms_norm_python(data, weight, bias, axis, epsilon=1e-5):
result : np.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
dtype = data.dtype
data = data.astype("float32")
weight = weight.astype("float32")
square_mean = np.mean(np.square(data), axis, keepdims=True)
result = data * weight / np.sqrt(square_mean + epsilon)
if bias is not None:
result += bias
return result
return result.astype(dtype)
14 changes: 6 additions & 8 deletions tests/python/topi/python/test_topi_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,31 @@
# only test on llvm because schedule is missing
@tvm.testing.parametrize_targets("llvm")
@pytest.mark.parametrize(
"shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 16)], (1,))]
"shape,axis",
[([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 16)], (1,)), ([2, 8192], (1,))],
)
@pytest.mark.parametrize("dtype", ["float32", "float16"])
def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-3, atol=1e-4):
shape_te = [te.var(v[0]) if isinstance(v, tuple) else v for v in shape]
scale_shape_te = [shape_te[dim] for dim in axis]
data = te.placeholder(shape_te, dtype=dtype, name="data")
weight = te.placeholder(scale_shape_te, dtype=dtype, name="weight")
bias = te.placeholder(scale_shape_te, dtype=dtype, name="weight")
B = topi.nn.rms_norm(data, weight, bias, axis, episilon)
B = topi.nn.rms_norm(data, weight, axis, episilon)

shape_np = [v[1] if isinstance(v, tuple) else v for v in shape]
scale_shape_np = [shape_np[dim] for dim in axis]
data_np = np.random.uniform(size=shape_np).astype(dtype)
weight_np = np.random.uniform(size=scale_shape_np).astype(dtype)
bias_np = np.random.uniform(size=scale_shape_np).astype(dtype)
b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, bias_np, axis, episilon)
b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, axis, episilon)

with tvm.target.Target(target):
s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule)
s = s_func([B])
data_tvm = tvm.nd.array(data_np, dev)
weight_tvm = tvm.nd.array(weight_np, dev)
bias_tvm = tvm.nd.array(bias_np, dev)
b_tvm = tvm.nd.array(np.zeros(shape_np, dtype=dtype), dev)
f = tvm.build(s, [data, weight, bias, B], target)
f(data_tvm, weight_tvm, bias_tvm, b_tvm)
f = tvm.build(s, [data, weight, B], target)
f(data_tvm, weight_tvm, b_tvm)
tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)


Expand Down

0 comments on commit 49eae2d

Please sign in to comment.