Skip to content

Commit

Permalink
Enable split kernel in bwd pass (#303)
Browse files Browse the repository at this point in the history
* Add fwd and bwd v2

Changes are largely from upstream.

* Split bwd kernel in dq and dk+dv

Only adds the split kernels. They are not enabled yet.

* Pull scalar multiplies out of the loop

* Enable split kernel for bwd pass

* Put back P_SEQ=128 in fwd test

Not used for bwd test

* Address review comments

* Address comments

Conditionally set causal/ splitkernel to False for bwd.

* Add block pointer semantics to bwd pass

This significantly increases perf for bwd, similar to fwd.
  • Loading branch information
vgokhale authored Aug 29, 2023
1 parent b834f42 commit 9cdf3a5
Showing 1 changed file with 334 additions and 53 deletions.
Loading

0 comments on commit 9cdf3a5

Please sign in to comment.