From 552105ff580bd1a5d10d35bc8ba4d14423518baa Mon Sep 17 00:00:00 2001 From: camillebrianceau Date: Tue, 3 Sep 2024 11:50:49 +0200 Subject: [PATCH] some,minor changes --- clinicadl/trainer/trainer.py | 121 +++++++++++++++++++++-------------- 1 file changed, 73 insertions(+), 48 deletions(-) diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 0c0cb8614..0044c11ae 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -723,7 +723,7 @@ def _train( train_loader: DataLoader, valid_loader: DataLoader, split: int, - network: int = None, + network: Optional[int] = None, resume: bool = False, callbacks: List[Callback] = [], ): @@ -1168,11 +1168,14 @@ def _train_ssdann( _, metrics_train_target, ) = test_da( - self.maps_manager.network_task, - model, - train_target_loader, - criterion, - alpha, + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=train_target_loader, + criterion=criterion, + alpha=alpha, target=True, ) # TO CHECK @@ -1180,11 +1183,14 @@ def _train_ssdann( _, metrics_valid_target, ) = test_da( - self.maps_manager.network_task, - model, - valid_loader, - criterion, - alpha, + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=valid_loader, + criterion=criterion, + alpha=alpha, target=True, ) @@ -1214,21 +1220,27 @@ def _train_ssdann( _, metrics_train_source, ) = test_da( - self.maps_manager.network_task, - model, - train_source_loader, - criterion, - alpha, + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=train_source_loader, + criterion=criterion, + alpha=alpha, ) ( _, metrics_valid_source, ) = test_da( - self.maps_manager.network_task, - model, - valid_source_loader, - criterion, - alpha, + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=valid_source_loader, + criterion=criterion, + alpha=alpha, ) model.train() @@ -1277,22 +1289,28 @@ def _train_ssdann( f"Evaluate source data at the end of the epoch {epoch} with alpha: {alpha}." ) _, metrics_train_source = test_da( - self.maps_manager.network_task, - model, - train_source_loader, - criterion, - alpha, - True, - False, + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=train_source_loader, + criterion=criterion, + alpha=alpha, + target=True, + report_ci=False, ) _, metrics_valid_source = test_da( - self.maps_manager.network_task, - model, - valid_source_loader, - criterion, - alpha, - True, - False, + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=valid_source_loader, + criterion=criterion, + alpha=alpha, + target=True, + report_ci=False, ) log_writer.step( @@ -1313,19 +1331,25 @@ def _train_ssdann( ) _, metrics_train_target = test_da( - self.maps_manager.network_task, - model, - train_target_loader, - criterion, - alpha, + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=train_target_loader, + criterion=criterion, + alpha=alpha, target=True, ) _, metrics_valid_target = test_da( - self.maps_manager.network_task, - model, - valid_loader, - criterion, - alpha, + mode=self.maps_manager.mode, + n_classes=self.maps_manager.n_classes, + metrics_module=self.maps_manager.metrics_module, + network_task=self.maps_manager.network_task, + model=model, + dataloader=valid_loader, + criterion=criterion, + alpha=alpha, target=True, ) @@ -1444,7 +1468,7 @@ def _init_callbacks(self) -> None: def _init_optimizer( self, model: DDP, - split: int = None, + split: Optional[int] = None, resume: bool = False, ) -> torch.optim.Optimizer: """ @@ -1497,7 +1521,8 @@ def _init_profiler(self) -> torch.profiler.profile: Profiler context manager. """ if self.config.optimization.profiler: - from clinicadl.utils.maps_manager.cluster.profiler import ( + # TODO: no more profiler ???? + from clinicadl.utils.cluster.profiler import ( ProfilerActivity, profile, schedule, @@ -1543,7 +1568,7 @@ def _write_weights( state: Dict[str, Any], metrics_dict: Optional[Dict[str, bool]], split: int, - network: int = None, + network: Optional[int] = None, filename: str = "checkpoint.pth.tar", save_all_models: bool = False, ) -> None: