Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable split kernel in bwd pass (#303)
* Add fwd and bwd v2 Changes are largely from upstream. * Split bwd kernel in dq and dk+dv Only adds the split kernels. They are not enabled yet. * Pull scalar multiplies out of the loop * Enable split kernel for bwd pass * Put back P_SEQ=128 in fwd test Not used for bwd test * Address review comments * Address comments Conditionally set causal/ splitkernel to False for bwd. * Add block pointer semantics to bwd pass This significantly increases perf for bwd, similar to fwd.
- Loading branch information