Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster Metal RmsNorm #34

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 88 additions & 34 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const SORT: &str = include_str!("sort.metal");
const TERNARY: &str = include_str!("ternary.metal");
const UNARY: &str = include_str!("unary.metal");
const SDPA: &str = include_str!("scaled_dot_product_attention.metal");
const RMSNORM: &str = include_str!("rms_norm.metal");

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Expand All @@ -44,6 +45,7 @@ pub enum Source {
Ternary,
Unary,
Sdpa,
RmsNorm,
}

pub mod copy2d {
Expand Down Expand Up @@ -235,6 +237,7 @@ impl Kernels {
Source::Ternary => TERNARY,
Source::Unary => UNARY,
Source::Sdpa => SDPA,
Source::RmsNorm => RMSNORM,
Source::Mfa => panic!("Invalid lib"),
}
}
Expand Down Expand Up @@ -721,62 +724,113 @@ pub fn call_last_softmax(
Ok(())
}

#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum RmsNormDType {
BF16,
F16,
F32,
}

#[allow(clippy::too_many_arguments)]
pub fn call_rms_norm(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
elements_to_sum: usize,
ty: RmsNormDType,
eps: f32,
input_shape: &[usize],
input: &Buffer,
input_offset: usize,
alpha_stride: &[usize],
alpha: &Buffer,
alpha_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let axis_size = input_shape[input_shape.len() - 1] as u32;
let n_rows = input_shape.iter().product::<usize>() as u32 / axis_size;

let simd_size = 32;
let rms_n_reads = 4;
let rms_looped_limit = 4096;

let name = match (ty, axis_size > rms_looped_limit) {
(RmsNormDType::F16, false) => "rmsfloat16",
(RmsNormDType::F16, true) => "rms_loopedfloat16",
(RmsNormDType::BF16, false) => "rmsbfloat16",
(RmsNormDType::BF16, true) => "rms_loopedbfloat16",
(RmsNormDType::F32, false) => "rmsfloat32",
(RmsNormDType::F32, true) => "rms_loopedfloat32",
};

let pipeline = kernels.load_pipeline(device, Source::RmsNorm, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);

set_params!(
encoder,
(
length,
elements_to_sum,
(input, input_offset),
output,
(alpha, alpha_offset),
eps
)
);

let out_length = length / elements_to_sum;

let thread_group_count = MTLSize {
width: out_length as u64,
height: 1,
depth: 1,
let max_total_threads_per_threadgroup = pipeline.max_total_threads_per_threadgroup();

let (grid_dims, group_dims) = if axis_size <= rms_looped_limit {
let threadgroup_needed = (axis_size + rms_n_reads - 1) / rms_n_reads;
let simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
let threadgroup_size = simd_size * simds_needed;
assert!(threadgroup_size <= max_total_threads_per_threadgroup as u32);
let n_threads = n_rows * threadgroup_size;
let grid_dims = MTLSize {
width: n_threads as u64,
height: 1,
depth: 1,
};
let group_dims = MTLSize {
width: threadgroup_size as u64,
height: 1,
depth: 1,
};
(grid_dims, group_dims)
} else {
let n_threads = n_rows * max_total_threads_per_threadgroup as u32;
let grid_dims = MTLSize {
width: n_threads as u64,
height: 1,
depth: 1,
};
let group_dims = MTLSize {
width: max_total_threads_per_threadgroup,
height: 1,
depth: 1,
};
(grid_dims, group_dims)
};

let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
elements_to_sum as u64,
)
.next_power_of_two();
let w_stride = alpha_stride[0] as u32;

let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.set_buffer(0, Some(&input), input_offset as NSUInteger);
encoder.set_buffer(1, Some(&alpha), alpha_offset as NSUInteger);
encoder.set_buffer(2, Some(&output), 0);

encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(alpha, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);

encoder.set_bytes(
3,
std::mem::size_of::<f32>() as NSUInteger,
&eps as *const f32 as *const c_void,
);
encoder.set_bytes(
4,
std::mem::size_of::<u32>() as NSUInteger,
&axis_size as *const u32 as *const c_void,
);
encoder.set_bytes(
5,
std::mem::size_of::<u32>() as NSUInteger,
&w_stride as *const u32 as *const c_void,
);

// minimum of 16 bytes
encoder.set_threadgroup_memory_length(0, 16 * 8 as u64);
encoder.set_threadgroup_memory_length(1, simd_size as u64 * std::mem::size_of::<f32>() as u64);
encoder.dispatch_thread_groups(grid_dims, group_dims);
Ok(())
}

Expand Down
71 changes: 0 additions & 71 deletions candle-metal-kernels/src/reduce.metal
Original file line number Diff line number Diff line change
Expand Up @@ -303,56 +303,6 @@ kernel void NAME( \
softmax<T>(src_numel, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory); \
} \

template<typename T>
METAL_FUNC void rmsnorm(
constant size_t & src_numel,
constant size_t & el_to_sum_per_block,
device const T * src,
device T * dst,
device const T * alpha,
constant float & eps,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup float * shared_memory
) {
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;

float tmp = 0;
while (idx < stop_idx) {
tmp = tmp + float(src[idx]) * float(src[idx]);
idx += block_dim;
}
shared_memory[tid] = tmp;

threadgroup_barrier(mem_flags::mem_threadgroup);

for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}

/* wait for shared_memory[0] to be filled */
threadgroup_barrier(mem_flags::mem_threadgroup);

float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps);
float inv_norm = 1.0f / norm;
idx = start_idx + tid;
while (idx < stop_idx) {
float val = float(src[idx]) * inv_norm;
if (alpha != nullptr) {
val *= float(alpha[idx - start_idx]);
}
dst[idx] = T(val);
idx += block_dim;
}
}

template<typename T>
METAL_FUNC void layernorm(
constant size_t & src_numel,
Expand Down Expand Up @@ -412,24 +362,6 @@ METAL_FUNC void layernorm(
}
}

#define RMSNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
device const T *alpha, \
constant float &eps, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = 0; \
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
} \

#define LAYERNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
Expand Down Expand Up @@ -587,8 +519,6 @@ ARGMAX(fast_argmax_u8_strided, uint8_t, 0)

SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
RMSNORM(rmsnorm_f32, float)
RMSNORM(rmsnorm_f16, half)
LAYERNORM(layernorm_f32, float)
LAYERNORM(layernorm_f16, half)
ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
Expand Down Expand Up @@ -626,7 +556,6 @@ REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
RMSNORM(rmsnorm_bf16, bfloat)
LAYERNORM(layernorm_bf16, bfloat)
ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
#endif
Loading
Loading