Skip to content

Commit

Permalink
Add some fast Metal MLX SDPA kernels (#32)
Browse files Browse the repository at this point in the history
* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format
  • Loading branch information
EricLBuehler authored Oct 26, 2024
1 parent 1f8a28a commit 522531d
Show file tree
Hide file tree
Showing 4 changed files with 1,986 additions and 0 deletions.
339 changes: 339 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const REDUCE: &str = include_str!("reduce.metal");
const SORT: &str = include_str!("sort.metal");
const TERNARY: &str = include_str!("ternary.metal");
const UNARY: &str = include_str!("unary.metal");
const SDPA: &str = include_str!("scaled_dot_product_attention.metal");

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Expand All @@ -42,6 +43,7 @@ pub enum Source {
Sort,
Ternary,
Unary,
Sdpa,
}

pub mod copy2d {
Expand Down Expand Up @@ -173,6 +175,17 @@ pub enum MetalKernelError {
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
#[error("Sdpa {variation} head size was {got}, expectd {expected:?}")]
SdpaHeadSizeMismatch {
variation: &'static str,
got: usize,
expected: Vec<usize>,
},
#[error("Sdpa {variation} got dtype {got:?}")]
SdpaHeadDTypeMismatch {
variation: &'static str,
got: SdpaDType,
},
}

impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
Expand Down Expand Up @@ -221,6 +234,7 @@ impl Kernels {
Source::Sort => SORT,
Source::Ternary => TERNARY,
Source::Unary => UNARY,
Source::Sdpa => SDPA,
Source::Mfa => panic!("Invalid lib"),
}
}
Expand Down Expand Up @@ -1641,6 +1655,331 @@ pub fn call_gemm(
Ok(())
}

#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum SdpaDType {
BF16,
F16,
F32,
}

/// SDPA full is supported when:
/// - q head dim == 64, 128
/// - no mask
/// - q heads == kv heads
/// - final type != bf16 (TODO maybe just template this kernel too?)
/// - q,k,v are contiguous
#[allow(clippy::too_many_arguments)]
pub fn call_sdpa_full(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
q_offset: usize,
q_shape: &[usize],
q_buffer: &Buffer,
k_offset: usize,
k_shape: &[usize],
k_buffer: &Buffer,
v_offset: usize,
v_buffer: &Buffer,
output: &Buffer,
alpha: f32,
softcapping: f32,
itype: SdpaDType,
) -> Result<(), MetalKernelError> {
#[derive(Debug)]
#[repr(C)]
struct MLXFastAttentionParams {
m: i32,
n: i32,
k: i32,

ldq: i32, // ldq == ldo
ldk: i32,
ldv: i32,
lds: i32,
ldo: i32,

tiles_n: i32,
tiles_m: i32,

batch_stride_q: i32,
batch_stride_k: i32,
batch_stride_v: i32,
batch_stride_o: i32,

swizzle_log: i32,
gemm_n_iterations_aligned: i32,
gemm_k_iterations_aligned: i32,
gemm_sv_m_block_iterations: i32,

batch_ndim: i32,
alpha: f32,
softcapping: f32,
}

let bk = q_shape.last().unwrap();

const BN: usize = 16;
const BM: usize = 16;
const WM: usize = 2;
const WN: usize = 2;

let name = match (bk, itype) {
(32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half",
(64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half",
(96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half",
(128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half",
(256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half",
(32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float",
(64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float",
(96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float",
(128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float",
(256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float",
(other, SdpaDType::F16 | SdpaDType::F32) => {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "full",
got: *other,
expected: vec![32, 64, 96, 128, 256],
})
}
(_, SdpaDType::BF16) => {
return Err(MetalKernelError::SdpaHeadDTypeMismatch {
variation: "full",
got: SdpaDType::BF16,
})
}
};

let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);

// q = (bs, qhead, seq, hidden)
// k/v = (bs, kv_head, seq, hidden)

let _hidden = q_shape[q_shape.len() - 1];
let qseq = q_shape[q_shape.len() - 2];
let _qheads = q_shape[q_shape.len() - 3];

let _kvseq = k_shape[k_shape.len() - 2];
let _nq_heads = q_shape[1];
let _nkv_heads = k_shape[1];

let m = q_shape[q_shape.len() - 2];
let n = m;
let k = q_shape[q_shape.len() - 1];
let bs_out = q_shape[0] * q_shape[1];

let batch_shape = [q_shape[0] * q_shape[1]];
let dk = q_shape[q_shape.len() - 1];
let ldq = dk;
let ldk = dk;
let ldv = dk;
let lds = BN;
let ldo = dk;

let tn = 1;
let tm = (m + BM - 1) / BM;

let b_stride_q = dk * qseq;
let b_stride_k = dk * qseq;
let b_stride_v = dk * qseq;
let b_stride_o = dk * qseq;
let swizzle_log = 0;
let gemm_n_iterations_aligned = (n + BN - 1) / BN;
let gemm_k_iterations_aligned = (k + bk - 1) / bk;
let gemm_sv_m_block_iterations = (m + BM - 1) / BM;
let batch_ndim = batch_shape.len();

let alpha = if softcapping != 1. {
alpha / softcapping
} else {
alpha
};

let params = MLXFastAttentionParams {
m: m as i32,
n: n as i32,
k: k as i32,
ldq: ldq as i32,
ldk: ldk as i32,
ldv: ldv as i32,
lds: lds as i32,
ldo: ldo as i32,
tiles_n: tn,
tiles_m: tm as i32,
batch_stride_q: b_stride_q as i32,
batch_stride_k: b_stride_k as i32,
batch_stride_v: b_stride_v as i32,
batch_stride_o: b_stride_o as i32,
swizzle_log,
gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32,
gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32,
gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32,
batch_ndim: batch_ndim as i32,
alpha,
softcapping,
};
let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o];

encoder.set_buffer(0, Some(&q_buffer), q_offset as NSUInteger);
encoder.set_buffer(1, Some(&k_buffer), k_offset as NSUInteger);
encoder.set_buffer(2, Some(&v_buffer), v_offset as NSUInteger);
encoder.set_buffer(3, Some(&output), 0);

encoder.set_bytes(
4,
std::mem::size_of::<MLXFastAttentionParams>() as u64,
&params as *const MLXFastAttentionParams as *const c_void,
);
encoder.set_bytes(
6,
(std::mem::size_of::<i32>() * batch_shape.len()) as u64,
batch_shape.as_ptr() as *const i32 as *const c_void,
);
encoder.set_bytes(
7,
(std::mem::size_of::<usize>() * batch_strides.len()) as u64,
batch_strides.as_ptr() as *const c_void,
);

let grid_dims = MTLSize {
width: 1,
height: tm as u64,
depth: bs_out as u64,
};
let group_dims = MTLSize {
width: 32,
height: WM as u64,
depth: WN as u64,
};
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_dims, group_dims);
Ok(())
}

/// SDPA full is supported when:
/// - q head dim == 64, 96, 128
/// - no mask
/// - q,k,v are contiguous
#[allow(clippy::too_many_arguments)]
pub fn call_sdpa_vector(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
q_offset: usize,
q_shape: &[usize],
q_buffer: &Buffer,
k_offset: usize,
k_shape: &[usize],
k_stride: &[usize],
k_buffer: &Buffer,
v_offset: usize,
v_buffer: &Buffer,
output: &Buffer,
alpha: f32,
softcapping: f32,
itype: SdpaDType,
) -> Result<(), MetalKernelError> {
let bk = q_shape.last().unwrap();

let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
let n = k_shape[2] as i32;
let b = (q_shape[0] * q_shape[1]) as i32;
let stride = k_stride[1];

let name = match (bk, itype) {
(32, SdpaDType::F16) => "sdpa_vector_float16_t_32",
(64, SdpaDType::F16) => "sdpa_vector_float16_t_64",
(96, SdpaDType::F16) => "sdpa_vector_float16_t_96",
(128, SdpaDType::F16) => "sdpa_vector_float16_t_128",
(256, SdpaDType::F16) => "sdpa_vector_float16_t_256",
(32, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_32",
(64, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_64",
(96, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_96",
(128, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_128",
(256, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_256",
(32, SdpaDType::F32) => "sdpa_vector_float_32",
(64, SdpaDType::F32) => "sdpa_vector_float_64",
(96, SdpaDType::F32) => "sdpa_vector_float_96",
(128, SdpaDType::F32) => "sdpa_vector_float_128",
(256, SdpaDType::F32) => "sdpa_vector_float_256",
(other, _) => {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "vector",
got: *other,
expected: vec![32, 64, 96, 128, 256],
})
}
};

let alpha = if softcapping != 1. {
alpha / softcapping
} else {
alpha
};

let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);

// q = (bs, qhead, seq, hidden)
// k/v = (bs, kv_head, kv_seq, hidden)

encoder.set_buffer(0, Some(&q_buffer), q_offset as NSUInteger);
encoder.set_buffer(1, Some(&k_buffer), k_offset as NSUInteger);
encoder.set_buffer(2, Some(&v_buffer), v_offset as NSUInteger);
encoder.set_buffer(3, Some(&output), 0);

encoder.set_bytes(
4,
std::mem::size_of::<i32>() as u64,
&gqa_factor as *const i32 as *const c_void,
);
encoder.set_bytes(
5,
std::mem::size_of::<i32>() as u64,
&n as *const i32 as *const c_void,
);
encoder.set_bytes(
6,
std::mem::size_of::<usize>() as u64,
&stride as *const usize as *const c_void,
);
encoder.set_bytes(
7,
std::mem::size_of::<f32>() as u64,
&alpha as *const f32 as *const c_void,
);
encoder.set_bytes(
8,
std::mem::size_of::<f32>() as u64,
&softcapping as *const f32 as *const c_void,
);

let grid_dims = MTLSize {
width: 1,
height: b as u64,
depth: 1 as u64,
};
let group_dims = MTLSize {
width: 1024,
height: 1,
depth: 1,
};
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_dims, group_dims);
Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn call_im2col1d_strided(
device: &Device,
Expand Down
Loading

0 comments on commit 522531d

Please sign in to comment.