diff --git a/xformers/csrc/swiglu/swiglu_packedw.cpp b/xformers/csrc/swiglu/swiglu_packedw.cpp index e70a3a72fe..880cda2b87 100644 --- a/xformers/csrc/swiglu/swiglu_packedw.cpp +++ b/xformers/csrc/swiglu/swiglu_packedw.cpp @@ -211,7 +211,7 @@ at::Tensor swiglu_packedw_cuda( const std::optional b1b2, const at::Tensor w3, const std::optional b3) { - if (x.requires_grad()) { + if (torch::GradMode::is_enabled()) { return SwiGLUPackedWeights::apply(x, w1w2, b1b2, w3, b3); } else { return SwiGLUPackedWeights::forward(