From 0c8aba8fdc2e93b9b93d4a61502259ecfe18d0c9 Mon Sep 17 00:00:00 2001 From: Zain Merchant Date: Tue, 20 Aug 2024 07:25:29 -0700 Subject: [PATCH] Changed pointer variable names for clarity for SwiGLU (#46) ## Summary Added _ptr to pointer variables for clarity in SwiGLU as we have done in the other files. ## Testing Done - [ x] run `make checkstyle` to ensure code style --- src/liger_kernel/ops/swiglu.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/liger_kernel/ops/swiglu.py b/src/liger_kernel/ops/swiglu.py index 39e309478..d83625be4 100644 --- a/src/liger_kernel/ops/swiglu.py +++ b/src/liger_kernel/ops/swiglu.py @@ -12,43 +12,43 @@ def silu(x): @triton.jit def _swiglu_forward_kernel( - a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr + a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr ): program_id = tl.program_id(0) # locate start index - a += program_id * stride - b += program_id * stride - c += program_id * stride + a_ptr += program_id * stride + b_ptr += program_id * stride + c_ptr += program_id * stride col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < n_cols # sigmoid requires type float32 - a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) - b_row = tl.load(b + col_offsets, mask=mask, other=0) + a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0) c_row = silu(a_row) * b_row - tl.store(c + col_offsets, c_row, mask=mask) + tl.store(c_ptr + col_offsets, c_row, mask=mask) @triton.jit def _swiglu_backward_kernel( - dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr + dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr ): program_id = tl.program_id(0) # locate start index - dc += program_id * stride - a += program_id * stride - b += program_id * stride + dc_ptr += program_id * stride + a_ptr += program_id * stride + b_ptr += program_id * stride col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < n_cols - dc_row = tl.load(dc + col_offsets, mask=mask, other=0) + dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0) # sigmoid requires type float32 - a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) - b_row = tl.load(b + col_offsets, mask=mask, other=0) + a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0) # recomputation to save memory sig_a = tl.sigmoid(a_row) @@ -56,8 +56,8 @@ def _swiglu_backward_kernel( db_row = dc_row * silu_a da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row - tl.store(a + col_offsets, da_row, mask=mask) - tl.store(b + col_offsets, db_row, mask=mask) + tl.store(a_ptr + col_offsets, da_row, mask=mask) + tl.store(b_ptr + col_offsets, db_row, mask=mask) class LigerSiLUMulFunction(torch.autograd.Function):