From 04e13d13ae7d5d2053aeb6d068267107e3f4b433 Mon Sep 17 00:00:00 2001 From: jhliu17 Date: Mon, 2 Dec 2024 10:37:38 -0800 Subject: [PATCH] enable use_checkpoint flag for Attention Block --- torchcfm/models/unet/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchcfm/models/unet/unet.py b/torchcfm/models/unet/unet.py index e29df92..205ecab 100644 --- a/torchcfm/models/unet/unet.py +++ b/torchcfm/models/unet/unet.py @@ -270,7 +270,7 @@ def __init__( self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) def _forward(self, x): b, c, *spatial = x.shape