-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
In-place operations in triton kernel might result in incorrect gradient calculations #272
Comments
should we adopt the second solution since the first one introduces quite a lot of overhead? also, can you elaborate under which case will this behavior happen? |
Consider the following forward graph: graph TD
A[input] -->|a| B[exp]
B -->|b| C[liger_ce]
C -->|loss| ouput
to calculate gradients of exp layer, which is
Normally, we take the least computations/memory option, 2. in this case. graph TD
A[input] -->|a| B["exp <br> saved tensors: b (v0)"]
B -->|b| C[liger_ce]
C -->|loss| ouput
After a complete forward pass from input graph TD
A[input] <-->|dx * grad_ce = b' * grad_ce| B["exp <br> saved tensors: b' (v0)<br>(changed by liger_ce)"]
B <-->|grad_ce| C[liger_ce]
C <-->|loss| ouput
Notice that in forward pass we stored the gradients of Replacing tl;dr Why no error?The reason why it doesn't raise the error is because triton kernel doesn't bump the version when doing inplace op, so it's still v0 when computing gradients in backward. If we do inplace outside of kernel by calling torch function, version can be correctly updated. graph TD
A[input] <-->|"dx * grad_output <br>= b' * grad_output"| B["exp <br> saved tensors: b' (v1)<br>(changed by inplace op)"]
B <-->|grad_output| C["torch's inplace op"]
C <-->|something| something
Thus, the error can be detected. |
We can keep pointers of gradients when designing a kernel, and add a boolean argument to autograd.function for users to decide whether storing gradients inplace or not. If False, we can allocate new memory and pass it to kernel. E.g. Liger-Kernel/src/liger_kernel/ops/jsd.py Lines 64 to 77 in ff6650b
If True, we can just pass the existing tensor that we want to perform in-place storing. E.g. X_ptr and dX_ptr as below:Liger-Kernel/src/liger_kernel/ops/fused_linear_jsd.py Lines 75 to 88 in ff6650b
Above examples show that we can design a kernel which looks "out-place" but still can achieve "in-place" storing.
Since the trivial solution introduces quite a lot of overhead, we can just do it only in the first pass as a in-place correctness checker. A possible implementation could be like this: @triton.jit
def _kernel(
x_ptr, # input tensor
y_ptr, # output tensor
dx_ptr, # gradients of input
...
):
... # do something
def forward(_input, inplace: bool, ...):
... # do something
if inplace:
dx = _input
if first_pass: # I haven't come up with a good way to detect first pass or not
_input.add_(0)
else:
dx = tensor.zeros_like(_input)
_kernel[(...)](
x_ptr=_input,
y_ptr=output,
dx_ptr=dx,
...
)
return output |
🐛 Describe the bug
#254 #262 (comments)
PyTorch’s autograd system records operations on tensors to construct a computational graph, which is used for computing gradients. When an in-place operation is performed on a tensor, the autograd system needs to ensure that the computational graph reflects the modified values.
https://pytorch.org/docs/stable/autograd.html#in-place-correctness-checks
Each tensor in PyTorch has an internal version counter that is incremented every time an in-place operation is performed.
https://github.com/pytorch/pytorch/blob/190e09d8b6a13f789b143f0fbd1325f924550967/c10/core/TensorImpl.h#L382
Since we don't explicitly call pytorch in-place operations, the version counter doesn't change when we do in-place operations in triton kernels, i.e., pytorch's "In-place correctness checks" mechanism won't work properly and show no error to user.
Reproduce
Solution
One trivial solution is performing a no-op like inplace operation, such as
.add_(0)
and.mul_(1)
, to explicitly declare we have changed the tensor values in-place, then the errors will be raised.With this approach, I suggest adding a
inplace=True/False
parameter to those functions involved in-place operations, so users can set it to False (using extra tensors) when getting errors.Versions
Environment Report:
Operating System: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version: 3.10.12
PyTorch version: 2.4.1+cu121
CUDA version: 12.1
Triton version: 3.0.0
Transformers version: 4.45.0
The text was updated successfully, but these errors were encountered: