Skip to content

Commit

Permalink
Changed pointer variable names for clarity for SwiGLU (#46)
Browse files Browse the repository at this point in the history
## Summary
Added _ptr to pointer variables for clarity in SwiGLU as we have done in
the other files.

## Testing Done
- [ x] run `make checkstyle` to ensure code style
  • Loading branch information
zain-merchant authored Aug 20, 2024
1 parent b46485b commit 0c8aba8
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions src/liger_kernel/ops/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,52 @@ def silu(x):

@triton.jit
def _swiglu_forward_kernel(
a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
):
program_id = tl.program_id(0)

# locate start index
a += program_id * stride
b += program_id * stride
c += program_id * stride
a_ptr += program_id * stride
b_ptr += program_id * stride
c_ptr += program_id * stride

col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols

# sigmoid requires type float32
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b + col_offsets, mask=mask, other=0)
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
c_row = silu(a_row) * b_row
tl.store(c + col_offsets, c_row, mask=mask)
tl.store(c_ptr + col_offsets, c_row, mask=mask)


@triton.jit
def _swiglu_backward_kernel(
dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
):
program_id = tl.program_id(0)

# locate start index
dc += program_id * stride
a += program_id * stride
b += program_id * stride
dc_ptr += program_id * stride
a_ptr += program_id * stride
b_ptr += program_id * stride

col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols

dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
# sigmoid requires type float32
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b + col_offsets, mask=mask, other=0)
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)

# recomputation to save memory
sig_a = tl.sigmoid(a_row)
silu_a = a_row * sig_a
db_row = dc_row * silu_a
da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row

tl.store(a + col_offsets, da_row, mask=mask)
tl.store(b + col_offsets, db_row, mask=mask)
tl.store(a_ptr + col_offsets, da_row, mask=mask)
tl.store(b_ptr + col_offsets, db_row, mask=mask)


class LigerSiLUMulFunction(torch.autograd.Function):
Expand Down

0 comments on commit 0c8aba8

Please sign in to comment.