diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index fbcb2d72c1..0b0274906a 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -397,7 +397,7 @@ def forward( else: fp8_context = nullcontext() - with rng_context and fp8_context: + with rng_context, fp8_context: # Forward pass. if self.config.recompute_granularity == 'full' and self.training: hidden_states = self._checkpointed_forward(