From ea8b6fe90c97bbd9b28fad583a2ff0fc3d0ff091 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 27 Oct 2024 14:45:52 -0400 Subject: [PATCH] Faster metal rms norm --- candle-metal-kernels/src/lib.rs | 122 +++++-- candle-metal-kernels/src/reduce.metal | 71 ---- candle-metal-kernels/src/rms_norm.metal | 443 ++++++++++++++++++++++++ candle-nn/src/ops.rs | 26 +- candle-nn/tests/layer_norm.rs | 24 ++ 5 files changed, 571 insertions(+), 115 deletions(-) create mode 100644 candle-metal-kernels/src/rms_norm.metal diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 1f53e710d7..27361fdc24 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -26,6 +26,7 @@ const SORT: &str = include_str!("sort.metal"); const TERNARY: &str = include_str!("ternary.metal"); const UNARY: &str = include_str!("unary.metal"); const SDPA: &str = include_str!("scaled_dot_product_attention.metal"); +const RMSNORM: &str = include_str!("rms_norm.metal"); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { @@ -44,6 +45,7 @@ pub enum Source { Ternary, Unary, Sdpa, + RmsNorm, } pub mod copy2d { @@ -235,6 +237,7 @@ impl Kernels { Source::Ternary => TERNARY, Source::Unary => UNARY, Source::Sdpa => SDPA, + Source::RmsNorm => RMSNORM, Source::Mfa => panic!("Invalid lib"), } } @@ -721,62 +724,113 @@ pub fn call_last_softmax( Ok(()) } +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +pub enum RmsNormDType { + BF16, + F16, + F32, +} + #[allow(clippy::too_many_arguments)] pub fn call_rms_norm( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, - kernel_name: &'static str, - length: usize, - elements_to_sum: usize, + ty: RmsNormDType, eps: f32, + input_shape: &[usize], input: &Buffer, input_offset: usize, + alpha_stride: &[usize], alpha: &Buffer, alpha_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let axis_size = input_shape[input_shape.len() - 1] as u32; + let n_rows = input_shape.iter().product::() as u32 / axis_size; + + let simd_size = 32; + let rms_n_reads = 4; + let rms_looped_limit = 4096; + + let name = match (ty, axis_size > rms_looped_limit) { + (RmsNormDType::F16, false) => "rmsfloat16", + (RmsNormDType::F16, true) => "rms_loopedfloat16", + (RmsNormDType::BF16, false) => "rmsbfloat16", + (RmsNormDType::BF16, true) => "rms_loopedbfloat16", + (RmsNormDType::F32, false) => "rmsfloat32", + (RmsNormDType::F32, true) => "rms_loopedfloat32", + }; + + let pipeline = kernels.load_pipeline(device, Source::RmsNorm, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); - set_params!( - encoder, - ( - length, - elements_to_sum, - (input, input_offset), - output, - (alpha, alpha_offset), - eps - ) - ); - - let out_length = length / elements_to_sum; - - let thread_group_count = MTLSize { - width: out_length as u64, - height: 1, - depth: 1, + let max_total_threads_per_threadgroup = pipeline.max_total_threads_per_threadgroup(); + + let (grid_dims, group_dims) = if axis_size <= rms_looped_limit { + let threadgroup_needed = (axis_size + rms_n_reads - 1) / rms_n_reads; + let simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; + let threadgroup_size = simd_size * simds_needed; + assert!(threadgroup_size <= max_total_threads_per_threadgroup as u32); + let n_threads = n_rows * threadgroup_size; + let grid_dims = MTLSize { + width: n_threads as u64, + height: 1, + depth: 1, + }; + let group_dims = MTLSize { + width: threadgroup_size as u64, + height: 1, + depth: 1, + }; + (grid_dims, group_dims) + } else { + let n_threads = n_rows * max_total_threads_per_threadgroup as u32; + let grid_dims = MTLSize { + width: n_threads as u64, + height: 1, + depth: 1, + }; + let group_dims = MTLSize { + width: max_total_threads_per_threadgroup, + height: 1, + depth: 1, + }; + (grid_dims, group_dims) }; - let width = std::cmp::min( - pipeline.max_total_threads_per_threadgroup(), - elements_to_sum as u64, - ) - .next_power_of_two(); + let w_stride = alpha_stride[0] as u32; - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; + encoder.set_buffer(0, Some(&input), input_offset as NSUInteger); + encoder.set_buffer(1, Some(&alpha), alpha_offset as NSUInteger); + encoder.set_buffer(2, Some(&output), 0); encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(alpha, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); - encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64); - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + + encoder.set_bytes( + 3, + std::mem::size_of::() as NSUInteger, + &eps as *const f32 as *const c_void, + ); + encoder.set_bytes( + 4, + std::mem::size_of::() as NSUInteger, + &axis_size as *const u32 as *const c_void, + ); + encoder.set_bytes( + 5, + std::mem::size_of::() as NSUInteger, + &w_stride as *const u32 as *const c_void, + ); + + // minimum of 16 bytes + encoder.set_threadgroup_memory_length(0, 16 * 8 as u64); + encoder.set_threadgroup_memory_length(1, simd_size as u64 * std::mem::size_of::() as u64); + encoder.dispatch_thread_groups(grid_dims, group_dims); Ok(()) } diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 56ef56f7e0..f73b6c98f4 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -303,56 +303,6 @@ kernel void NAME( \ softmax(src_numel, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory); \ } \ -template -METAL_FUNC void rmsnorm( - constant size_t & src_numel, - constant size_t & el_to_sum_per_block, - device const T * src, - device T * dst, - device const T * alpha, - constant float & eps, - uint id, - uint tid, - uint dst_id, - uint block_dim, - threadgroup float * shared_memory -) { - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); - size_t idx = start_idx + tid; - - float tmp = 0; - while (idx < stop_idx) { - tmp = tmp + float(src[idx]) * float(src[idx]); - idx += block_dim; - } - shared_memory[tid] = tmp; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (uint s = block_dim / 2; s > 0; s >>= 1) { - if (tid < s) { - shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - /* wait for shared_memory[0] to be filled */ - threadgroup_barrier(mem_flags::mem_threadgroup); - - float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps); - float inv_norm = 1.0f / norm; - idx = start_idx + tid; - while (idx < stop_idx) { - float val = float(src[idx]) * inv_norm; - if (alpha != nullptr) { - val *= float(alpha[idx - start_idx]); - } - dst[idx] = T(val); - idx += block_dim; - } -} - template METAL_FUNC void layernorm( constant size_t & src_numel, @@ -412,24 +362,6 @@ METAL_FUNC void layernorm( } } -#define RMSNORM(NAME, T) \ -kernel void NAME( \ - constant size_t &src_numel, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device T *dst, \ - device const T *alpha, \ - constant float &eps, \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - threadgroup float shared_memory[THREADGROUP_SIZE]; \ - shared_memory[tid] = 0; \ - rmsnorm(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \ -} \ - #define LAYERNORM(NAME, T) \ kernel void NAME( \ constant size_t &src_numel, \ @@ -587,8 +519,6 @@ ARGMAX(fast_argmax_u8_strided, uint8_t, 0) SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) -RMSNORM(rmsnorm_f32, float) -RMSNORM(rmsnorm_f16, half) LAYERNORM(layernorm_f32, float) LAYERNORM(layernorm_f16, half) ROPE(rope_f32, rope_i_f32, rope_thd_f32, float) @@ -626,7 +556,6 @@ REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF) ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) SOFTMAX(softmax_bf16, bfloat) -RMSNORM(rmsnorm_bf16, bfloat) LAYERNORM(layernorm_bf16, bfloat) ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat) #endif diff --git a/candle-metal-kernels/src/rms_norm.metal b/candle-metal-kernels/src/rms_norm.metal new file mode 100644 index 0000000000..776190bd4a --- /dev/null +++ b/candle-metal-kernels/src/rms_norm.metal @@ -0,0 +1,443 @@ +// Copyright © 2024 Apple Inc. + +#include +#include + +static constant constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; +static constant constexpr int REDUCE_N_READS = 4; +static constant constexpr int REDUCE_N_WRITES = 4; +static constant constexpr int SOFTMAX_N_READS = 4; +static constant constexpr int RMS_N_READS = 4; +static constant constexpr int RMS_LOOPED_LIMIT = 4096; + +using namespace metal; + +template +[[kernel]] void rms_single_row( + const device T* x, + const device T* w, + device T* out, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + threadgroup float* local_inv_mean [[threadgroup(0)]], + threadgroup float* local_sums [[threadgroup(1)]], + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + float acc = 0; + x += gid * size_t(axis_size) + lid * N_READS; + w += w_stride * lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = x[i]; + acc += xi * xi; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + float xi = x[i]; + acc += xi * xi; + } + } + } + acc = simd_sum(acc); + // Initialize shared memory + if (simd_group_id == 0) { + local_sums[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write simd accumulations into shared memory + if (simd_lane_id == 0) { + local_sums[simd_group_id] = acc; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Accumulate over simd groups + if (simd_group_id == 0) { + acc = simd_sum(local_sums[simd_lane_id]); + if (simd_lane_id == 0) { + local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write the outputs + out += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + out[i] = w[w_stride * i] * static_cast(x[i] * local_inv_mean[0]); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + out[i] = w[w_stride * i] * static_cast(x[i] * local_inv_mean[0]); + } + } + } +} + +template +[[kernel]] void rms_looped( + const device T* x, + const device T* w, + device T* out, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + threadgroup float* local_inv_mean [[threadgroup(0)]], + threadgroup float* local_sums [[threadgroup(1)]], + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + float acc = 0; + x += gid * size_t(axis_size) + lid * N_READS; + w += w_stride * lid * N_READS; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = x[i + r]; + acc += xi * xi; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float xi = x[i + r]; + acc += xi * xi; + } + } + } + } + acc = simd_sum(acc); + // Initialize shared memory + if (simd_group_id == 0) { + local_sums[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write simd accumulations into shared memory + if (simd_lane_id == 0) { + local_sums[simd_group_id] = acc; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Accumulate over simd groups + if (simd_group_id == 0) { + acc = simd_sum(local_sums[simd_lane_id]); + if (simd_lane_id == 0) { + local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write the outputs + out += gid * size_t(axis_size) + lid * N_READS; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + out[r + i] = w[w_stride * (i + r)] * + static_cast(x[r + i] * local_inv_mean[0]); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + out[r + i] = w[w_stride * (i + r)] * + static_cast(x[r + i] * local_inv_mean[0]); + } + } + } + } +} + +template +[[kernel]] void vjp_rms_single_row( + const device T* x, + const device T* w, + const device T* g, + device T* gx, + device T* gw, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + // Advance the input pointers + x += gid * size_t(axis_size) + lid * N_READS; + g += gid * size_t(axis_size) + lid * N_READS; + w += w_stride * lid * N_READS; + + // Allocate registers for the computation and accumulators + float thread_x[N_READS]; + float thread_w[N_READS]; + float thread_g[N_READS]; + float sumx2 = 0; + float sumgwx = 0; + + // Allocate shared memory to implement the reduction + constexpr int SIMD_SIZE = 32; + threadgroup float local_sumx2[SIMD_SIZE]; + threadgroup float local_sumgwx[SIMD_SIZE]; + threadgroup float local_normalizer[1]; + threadgroup float local_meangwx[1]; + + // Read and accumulate locally + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + thread_x[i] = x[i]; + thread_w[i] = w[w_stride * i]; + thread_g[i] = g[i]; + + sumx2 += thread_x[i] * thread_x[i]; + sumgwx += thread_x[i] * thread_w[i] * thread_g[i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + thread_x[i] = x[i]; + thread_w[i] = w[w_stride * i]; + thread_g[i] = g[i]; + + sumx2 += thread_x[i] * thread_x[i]; + sumgwx += thread_x[i] * thread_w[i] * thread_g[i]; + } + } + } + + // Accumulate across threads + sumx2 = simd_sum(sumx2); + sumgwx = simd_sum(sumgwx); + if (simd_group_id == 0) { + local_sumx2[simd_lane_id] = 0; + local_sumgwx[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_lane_id == 0) { + local_sumx2[simd_group_id] = sumx2; + local_sumgwx[simd_group_id] = sumgwx; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + sumx2 = simd_sum(local_sumx2[simd_lane_id]); + sumgwx = simd_sum(local_sumgwx[simd_lane_id]); + if (simd_lane_id == 0) { + local_meangwx[0] = sumgwx / axis_size; + local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float meangwx = local_meangwx[0]; + float normalizer = local_normalizer[0]; + float normalizer3 = normalizer * normalizer * normalizer; + + // Write the outputs + gx += gid * size_t(axis_size) + lid * N_READS; + gw += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + gx[i] = static_cast( + thread_g[i] * thread_w[i] * normalizer - + thread_x[i] * meangwx * normalizer3); + gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + gx[i] = static_cast( + thread_g[i] * thread_w[i] * normalizer - + thread_x[i] * meangwx * normalizer3); + gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); + } + } + } +} + +template +[[kernel]] void vjp_rms_looped( + const device T* x, + const device T* w, + const device T* g, + device T* gx, + device T* gw, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + // Advance the input pointers + x += gid * size_t(axis_size) + lid * N_READS; + g += gid * size_t(axis_size) + lid * N_READS; + w += w_stride * lid * N_READS; + + // Allocate registers for the accumulators + float sumx2 = 0; + float sumgwx = 0; + + // Allocate shared memory to implement the reduction + constexpr int SIMD_SIZE = 32; + threadgroup float local_sumx2[SIMD_SIZE]; + threadgroup float local_sumgwx[SIMD_SIZE]; + threadgroup float local_normalizer[1]; + threadgroup float local_meangwx[1]; + + // Read and accumulate locally + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = x[i + r]; + float wi = w[w_stride * (i + r)]; + float gi = g[i + r]; + + sumx2 += xi * xi; + sumgwx += xi * wi * gi; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float xi = x[i + r]; + float wi = w[w_stride * (i + r)]; + float gi = g[i + r]; + + sumx2 += xi * xi; + sumgwx += xi * wi * gi; + } + } + } + } + + // Accumulate across threads + sumx2 = simd_sum(sumx2); + sumgwx = simd_sum(sumgwx); + if (simd_group_id == 0) { + local_sumx2[simd_lane_id] = 0; + local_sumgwx[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_lane_id == 0) { + local_sumx2[simd_group_id] = sumx2; + local_sumgwx[simd_group_id] = sumgwx; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + sumx2 = simd_sum(local_sumx2[simd_lane_id]); + sumgwx = simd_sum(local_sumgwx[simd_lane_id]); + if (simd_lane_id == 0) { + local_meangwx[0] = sumgwx / axis_size; + local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float meangwx = local_meangwx[0]; + float normalizer = local_normalizer[0]; + float normalizer3 = normalizer * normalizer * normalizer; + + // Write the outputs + gx += gid * size_t(axis_size) + lid * N_READS; + gw += gid * size_t(axis_size) + lid * N_READS; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = x[i + r]; + float wi = w[w_stride * (i + r)]; + float gi = g[i + r]; + + gx[i + r] = + static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); + gw[i + r] = static_cast(gi * xi * normalizer); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float xi = x[i + r]; + float wi = w[w_stride * (i + r)]; + float gi = g[i + r]; + + gx[i + r] = + static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); + gw[i + r] = static_cast(gi * xi * normalizer); + } + } + } + } +} + +// clang-format off +#define instantiate_rms_single_row(name, itype) \ + template [[host_name("rms" #name)]] [[kernel]] void \ + rms_single_row( \ + const device itype* x, \ + const device itype* w, \ + device itype* out, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + threadgroup float* local_inv_mean [[threadgroup(0)]], \ + threadgroup float* local_sums [[threadgroup(1)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ + \ + template [[host_name("vjp_rms" #name)]] [[kernel]] void \ + vjp_rms_single_row( \ + const device itype* x, \ + const device itype* w, \ + const device itype* g, \ + device itype* gx, \ + device itype* gw, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_rms_looped(name, itype) \ + template [[host_name("rms_looped" #name)]] [[kernel]] void \ + rms_looped( \ + const device itype* x, \ + const device itype* w, \ + device itype* out, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + threadgroup float* local_inv_mean [[threadgroup(0)]], \ + threadgroup float* local_sums [[threadgroup(1)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ + \ + template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \ + vjp_rms_looped( \ + const device itype* x, \ + const device itype* w, \ + const device itype* g, \ + device itype* gx, \ + device itype* gw, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_rms(name, itype) \ + instantiate_rms_single_row(name, itype) \ + instantiate_rms_looped(name, itype) + +instantiate_rms(float32, float) +instantiate_rms(float16, half) +instantiate_rms(bfloat16, bfloat) // clang-format on diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index b37c778c88..da6f127c3a 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -577,13 +577,13 @@ impl candle::CustomOp2 for RmsNorm { l2: &Layout, ) -> Result<(candle::MetalStorage, Shape)> { use candle::backend::BackendStorage; + use candle_metal_kernels::RmsNormDType; + let device = s1.device(); let command_buffer = device.command_buffer()?; let kernels = device.kernels(); - let name = match (s1.dtype(), s2.dtype()) { - (DType::F32, DType::F32) => "rmsnorm_f32", - (DType::F16, DType::F16) => "rmsnorm_f16", - (DType::BF16, DType::BF16) => "rmsnorm_bf16", + match (s1.dtype(), s2.dtype()) { + (DType::F32, DType::F32) | (DType::F16, DType::F16) | (DType::BF16, DType::BF16) => (), (dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"), }; @@ -591,21 +591,27 @@ impl candle::CustomOp2 for RmsNorm { candle::bail!("Non contiguous rmsnorm is not implemented"); } - let last_dim = l1.dims()[l1.shape().rank() - 1]; + let ty: RmsNormDType = match s1.dtype() { + DType::BF16 => RmsNormDType::BF16, + DType::F16 => RmsNormDType::F16, + DType::F32 => RmsNormDType::F32, + _ => unreachable!(), + }; + let elem_count = l1.shape().elem_count(); let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?; candle_metal_kernels::call_rms_norm( device.metal_device(), &command_buffer, kernels, - name, - elem_count, - last_dim, + ty, self.eps, + l1.dims(), s1.buffer(), - l1.start_offset() * s1.dtype().size_in_bytes(), + l1.start_offset(), + l2.stride(), s2.buffer(), - l2.start_offset() * s2.dtype().size_in_bytes(), + l2.start_offset(), &output, ) .map_err(candle::Error::wrap)?; diff --git a/candle-nn/tests/layer_norm.rs b/candle-nn/tests/layer_norm.rs index 30f598b329..8186fc82a9 100644 --- a/candle-nn/tests/layer_norm.rs +++ b/candle-nn/tests/layer_norm.rs @@ -53,3 +53,27 @@ fn layer_norm() -> Result<()> { ); Ok(()) } + +#[test] +fn rms_norm() -> Result<()> { + #[cfg(not(feature = "metal"))] + let dev = Device::cuda_if_available(0)?; + #[cfg(feature = "metal")] + let dev = Device::new_metal(0)?; + + const DIM: usize = 4096; + + let data = Tensor::randn(0f32, 1f32, (32, DIM), &dev)?; + let w = Tensor::randn(0f32, 1f32, (DIM,), &dev)?; + + let fused = candle_nn::ops::rms_norm(&data, &w, 1e-6)?.clone(); + let truth = candle_nn::ops::rms_norm_slow(&data, &w, 1e-6)?; + + let error: f32 = ((&truth - &fused)?.abs()? / &truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.008, "{}", error); + + Ok(()) +}