From 6033c6c8198901e7c5aeeaee2a26b0d1b6598df7 Mon Sep 17 00:00:00 2001 From: Yury Parfenov <4665475+warpuv@users.noreply.github.com> Date: Thu, 17 Oct 2024 20:00:43 +0300 Subject: [PATCH] [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp According to the docs (https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) forward() method should not be called directly, apply() method have to be used instead. After removing forward call, activation checkpointing starts working. (alternative variant 2) --- xformers/csrc/swiglu/swiglu_packedw.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/xformers/csrc/swiglu/swiglu_packedw.cpp b/xformers/csrc/swiglu/swiglu_packedw.cpp index 42dae758a..880cda2b8 100644 --- a/xformers/csrc/swiglu/swiglu_packedw.cpp +++ b/xformers/csrc/swiglu/swiglu_packedw.cpp @@ -211,7 +211,12 @@ at::Tensor swiglu_packedw_cuda( const std::optional b1b2, const at::Tensor w3, const std::optional b3) { + if (torch::GradMode::is_enabled()) { return SwiGLUPackedWeights::apply(x, w1w2, b1b2, w3, b3); + } else { + return SwiGLUPackedWeights::forward( + /* ctx */ nullptr, x, w1w2, b1b2, w3, b3); + } } } // namespace