Skip to content

Commit

Permalink
Faster metal rms norm
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Oct 27, 2024
1 parent 522531d commit ea8b6fe
Show file tree
Hide file tree
Showing 5 changed files with 571 additions and 115 deletions.
122 changes: 88 additions & 34 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const SORT: &str = include_str!("sort.metal");
const TERNARY: &str = include_str!("ternary.metal");
const UNARY: &str = include_str!("unary.metal");
const SDPA: &str = include_str!("scaled_dot_product_attention.metal");
const RMSNORM: &str = include_str!("rms_norm.metal");

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Expand All @@ -44,6 +45,7 @@ pub enum Source {
Ternary,
Unary,
Sdpa,
RmsNorm,
}

pub mod copy2d {
Expand Down Expand Up @@ -235,6 +237,7 @@ impl Kernels {
Source::Ternary => TERNARY,
Source::Unary => UNARY,
Source::Sdpa => SDPA,
Source::RmsNorm => RMSNORM,
Source::Mfa => panic!("Invalid lib"),
}
}
Expand Down Expand Up @@ -721,62 +724,113 @@ pub fn call_last_softmax(
Ok(())
}

#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum RmsNormDType {
BF16,
F16,
F32,
}

#[allow(clippy::too_many_arguments)]
pub fn call_rms_norm(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
elements_to_sum: usize,
ty: RmsNormDType,
eps: f32,
input_shape: &[usize],
input: &Buffer,
input_offset: usize,
alpha_stride: &[usize],
alpha: &Buffer,
alpha_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let axis_size = input_shape[input_shape.len() - 1] as u32;
let n_rows = input_shape.iter().product::<usize>() as u32 / axis_size;

let simd_size = 32;
let rms_n_reads = 4;
let rms_looped_limit = 4096;

let name = match (ty, axis_size > rms_looped_limit) {
(RmsNormDType::F16, false) => "rmsfloat16",
(RmsNormDType::F16, true) => "rms_loopedfloat16",
(RmsNormDType::BF16, false) => "rmsbfloat16",
(RmsNormDType::BF16, true) => "rms_loopedbfloat16",
(RmsNormDType::F32, false) => "rmsfloat32",
(RmsNormDType::F32, true) => "rms_loopedfloat32",
};

let pipeline = kernels.load_pipeline(device, Source::RmsNorm, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);

set_params!(
encoder,
(
length,
elements_to_sum,
(input, input_offset),
output,
(alpha, alpha_offset),
eps
)
);

let out_length = length / elements_to_sum;

let thread_group_count = MTLSize {
width: out_length as u64,
height: 1,
depth: 1,
let max_total_threads_per_threadgroup = pipeline.max_total_threads_per_threadgroup();

let (grid_dims, group_dims) = if axis_size <= rms_looped_limit {
let threadgroup_needed = (axis_size + rms_n_reads - 1) / rms_n_reads;
let simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
let threadgroup_size = simd_size * simds_needed;
assert!(threadgroup_size <= max_total_threads_per_threadgroup as u32);
let n_threads = n_rows * threadgroup_size;
let grid_dims = MTLSize {
width: n_threads as u64,
height: 1,
depth: 1,
};
let group_dims = MTLSize {
width: threadgroup_size as u64,
height: 1,
depth: 1,
};
(grid_dims, group_dims)
} else {
let n_threads = n_rows * max_total_threads_per_threadgroup as u32;
let grid_dims = MTLSize {
width: n_threads as u64,
height: 1,
depth: 1,
};
let group_dims = MTLSize {
width: max_total_threads_per_threadgroup,
height: 1,
depth: 1,
};
(grid_dims, group_dims)
};

let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
elements_to_sum as u64,
)
.next_power_of_two();
let w_stride = alpha_stride[0] as u32;

let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.set_buffer(0, Some(&input), input_offset as NSUInteger);
encoder.set_buffer(1, Some(&alpha), alpha_offset as NSUInteger);
encoder.set_buffer(2, Some(&output), 0);

encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(alpha, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);

encoder.set_bytes(
3,
std::mem::size_of::<f32>() as NSUInteger,
&eps as *const f32 as *const c_void,
);
encoder.set_bytes(
4,
std::mem::size_of::<u32>() as NSUInteger,
&axis_size as *const u32 as *const c_void,
);
encoder.set_bytes(
5,
std::mem::size_of::<u32>() as NSUInteger,
&w_stride as *const u32 as *const c_void,
);

// minimum of 16 bytes
encoder.set_threadgroup_memory_length(0, 16 * 8 as u64);
encoder.set_threadgroup_memory_length(1, simd_size as u64 * std::mem::size_of::<f32>() as u64);
encoder.dispatch_thread_groups(grid_dims, group_dims);
Ok(())
}

Expand Down
71 changes: 0 additions & 71 deletions candle-metal-kernels/src/reduce.metal
Original file line number Diff line number Diff line change
Expand Up @@ -303,56 +303,6 @@ kernel void NAME( \
softmax<T>(src_numel, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory); \
} \

template<typename T>
METAL_FUNC void rmsnorm(
constant size_t & src_numel,
constant size_t & el_to_sum_per_block,
device const T * src,
device T * dst,
device const T * alpha,
constant float & eps,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup float * shared_memory
) {
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;

float tmp = 0;
while (idx < stop_idx) {
tmp = tmp + float(src[idx]) * float(src[idx]);
idx += block_dim;
}
shared_memory[tid] = tmp;

threadgroup_barrier(mem_flags::mem_threadgroup);

for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}

/* wait for shared_memory[0] to be filled */
threadgroup_barrier(mem_flags::mem_threadgroup);

float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps);
float inv_norm = 1.0f / norm;
idx = start_idx + tid;
while (idx < stop_idx) {
float val = float(src[idx]) * inv_norm;
if (alpha != nullptr) {
val *= float(alpha[idx - start_idx]);
}
dst[idx] = T(val);
idx += block_dim;
}
}

template<typename T>
METAL_FUNC void layernorm(
constant size_t & src_numel,
Expand Down Expand Up @@ -412,24 +362,6 @@ METAL_FUNC void layernorm(
}
}

#define RMSNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
device const T *alpha, \
constant float &eps, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = 0; \
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
} \

#define LAYERNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
Expand Down Expand Up @@ -587,8 +519,6 @@ ARGMAX(fast_argmax_u8_strided, uint8_t, 0)

SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
RMSNORM(rmsnorm_f32, float)
RMSNORM(rmsnorm_f16, half)
LAYERNORM(layernorm_f32, float)
LAYERNORM(layernorm_f16, half)
ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
Expand Down Expand Up @@ -626,7 +556,6 @@ REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
RMSNORM(rmsnorm_bf16, bfloat)
LAYERNORM(layernorm_bf16, bfloat)
ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
#endif
Loading

0 comments on commit ea8b6fe

Please sign in to comment.