diff --git a/xformers/csrc/swiglu/swiglu_packedw.cpp b/xformers/csrc/swiglu/swiglu_packedw.cpp index 42dae758a..880cda2b8 100644 --- a/xformers/csrc/swiglu/swiglu_packedw.cpp +++ b/xformers/csrc/swiglu/swiglu_packedw.cpp @@ -211,7 +211,12 @@ at::Tensor swiglu_packedw_cuda( const std::optional b1b2, const at::Tensor w3, const std::optional b3) { + if (torch::GradMode::is_enabled()) { return SwiGLUPackedWeights::apply(x, w1w2, b1b2, w3, b3); + } else { + return SwiGLUPackedWeights::forward( + /* ctx */ nullptr, x, w1w2, b1b2, w3, b3); + } } } // namespace