diff --git a/mace/tools/train.py b/mace/tools/train.py index 62013e9d..87a958f9 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -244,12 +244,12 @@ def train( if valid_loss >= lowest_loss: patience_counter += 1 - if patience_counter >= patience and epoch < swa.start: + if patience_counter >= patience and (swa.start is not None and epoch < swa.start): logging.info( f"Stopping optimization after {patience_counter} epochs without improvement and starting swa" ) epoch = swa.start - elif patience_counter >= patience and epoch >= swa.start: + elif patience_counter >= patience and (swa.start is None or epoch >= swa.start): logging.info( f"Stopping optimization after {patience_counter} epochs without improvement" )