diff --git a/monai/handlers/validation_handler.py b/monai/handlers/validation_handler.py index 19183fb4e2..89c7715f42 100644 --- a/monai/handlers/validation_handler.py +++ b/monai/handlers/validation_handler.py @@ -31,7 +31,9 @@ class ValidationHandler: """ - def __init__(self, interval: int, validator: Evaluator | None = None, epoch_level: bool = True) -> None: + def __init__( + self, interval: int, validator: Evaluator | None = None, epoch_level: bool = True, exec_at_start: bool = False + ) -> None: """ Args: interval: do validation every N epochs or every N iterations during training. @@ -39,6 +41,9 @@ def __init__(self, interval: int, validator: Evaluator | None = None, epoch_leve if None, should call `set_validator()` before training. epoch_level: execute validation every N epochs or N iterations. `True` is epoch level, `False` is iteration level. + exec_at_start: whether to execute a validation first when starting the training. + default to `False`. It can be useful especially for some transfer-learning cases + to validate the initial model before training. Raises: TypeError: When ``validator`` is not a ``monai.engines.evaluator.Evaluator``. @@ -49,6 +54,7 @@ def __init__(self, interval: int, validator: Evaluator | None = None, epoch_leve self.validator = validator self.interval = interval self.epoch_level = epoch_level + self.exec_at_start = exec_at_start def set_validator(self, validator: Evaluator) -> None: """ @@ -67,6 +73,8 @@ def attach(self, engine: Engine) -> None: engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.interval), self) else: engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.interval), self) + if self.exec_at_start: + engine.add_event_handler(Events.STARTED, self) def __call__(self, engine: Engine) -> None: """ diff --git a/tests/test_handler_validation.py b/tests/test_handler_validation.py index deb762917d..23fcf5e75c 100644 --- a/tests/test_handler_validation.py +++ b/tests/test_handler_validation.py @@ -39,8 +39,11 @@ def _train_func(engine, batch): # set up testing handler val_data_loader = torch.utils.data.DataLoader(Dataset(data)) evaluator = TestEvaluator(torch.device("cpu:0"), val_data_loader) - saver = ValidationHandler(interval=2, validator=evaluator) - saver.attach(engine) + ValidationHandler(interval=2, validator=evaluator, exec_at_start=True).attach(engine) + # test execution at start + engine.run(data, max_epochs=1) + self.assertEqual(evaluator.state.max_epochs, 0) + self.assertEqual(evaluator.state.epoch_length, 8) engine.run(data, max_epochs=5) self.assertEqual(evaluator.state.max_epochs, 4)