Skip to content

Commit

Permalink
some,minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Sep 3, 2024
1 parent bb61475 commit 552105f
Showing 1 changed file with 73 additions and 48 deletions.
121 changes: 73 additions & 48 deletions clinicadl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def _train(
train_loader: DataLoader,
valid_loader: DataLoader,
split: int,
network: int = None,
network: Optional[int] = None,
resume: bool = False,
callbacks: List[Callback] = [],
):
Expand Down Expand Up @@ -1168,23 +1168,29 @@ def _train_ssdann(
_,
metrics_train_target,
) = test_da(
self.maps_manager.network_task,
model,
train_target_loader,
criterion,
alpha,
mode=self.maps_manager.mode,
n_classes=self.maps_manager.n_classes,
metrics_module=self.maps_manager.metrics_module,
network_task=self.maps_manager.network_task,
model=model,
dataloader=train_target_loader,
criterion=criterion,
alpha=alpha,
target=True,
) # TO CHECK

(
_,
metrics_valid_target,
) = test_da(
self.maps_manager.network_task,
model,
valid_loader,
criterion,
alpha,
mode=self.maps_manager.mode,
n_classes=self.maps_manager.n_classes,
metrics_module=self.maps_manager.metrics_module,
network_task=self.maps_manager.network_task,
model=model,
dataloader=valid_loader,
criterion=criterion,
alpha=alpha,
target=True,
)

Expand Down Expand Up @@ -1214,21 +1220,27 @@ def _train_ssdann(
_,
metrics_train_source,
) = test_da(
self.maps_manager.network_task,
model,
train_source_loader,
criterion,
alpha,
mode=self.maps_manager.mode,
n_classes=self.maps_manager.n_classes,
metrics_module=self.maps_manager.metrics_module,
network_task=self.maps_manager.network_task,
model=model,
dataloader=train_source_loader,
criterion=criterion,
alpha=alpha,
)
(
_,
metrics_valid_source,
) = test_da(
self.maps_manager.network_task,
model,
valid_source_loader,
criterion,
alpha,
mode=self.maps_manager.mode,
n_classes=self.maps_manager.n_classes,
metrics_module=self.maps_manager.metrics_module,
network_task=self.maps_manager.network_task,
model=model,
dataloader=valid_source_loader,
criterion=criterion,
alpha=alpha,
)

model.train()
Expand Down Expand Up @@ -1277,22 +1289,28 @@ def _train_ssdann(
f"Evaluate source data at the end of the epoch {epoch} with alpha: {alpha}."
)
_, metrics_train_source = test_da(
self.maps_manager.network_task,
model,
train_source_loader,
criterion,
alpha,
True,
False,
mode=self.maps_manager.mode,
n_classes=self.maps_manager.n_classes,
metrics_module=self.maps_manager.metrics_module,
network_task=self.maps_manager.network_task,
model=model,
dataloader=train_source_loader,
criterion=criterion,
alpha=alpha,
target=True,
report_ci=False,
)
_, metrics_valid_source = test_da(
self.maps_manager.network_task,
model,
valid_source_loader,
criterion,
alpha,
True,
False,
mode=self.maps_manager.mode,
n_classes=self.maps_manager.n_classes,
metrics_module=self.maps_manager.metrics_module,
network_task=self.maps_manager.network_task,
model=model,
dataloader=valid_source_loader,
criterion=criterion,
alpha=alpha,
target=True,
report_ci=False,
)

log_writer.step(
Expand All @@ -1313,19 +1331,25 @@ def _train_ssdann(
)

_, metrics_train_target = test_da(
self.maps_manager.network_task,
model,
train_target_loader,
criterion,
alpha,
mode=self.maps_manager.mode,
n_classes=self.maps_manager.n_classes,
metrics_module=self.maps_manager.metrics_module,
network_task=self.maps_manager.network_task,
model=model,
dataloader=train_target_loader,
criterion=criterion,
alpha=alpha,
target=True,
)
_, metrics_valid_target = test_da(
self.maps_manager.network_task,
model,
valid_loader,
criterion,
alpha,
mode=self.maps_manager.mode,
n_classes=self.maps_manager.n_classes,
metrics_module=self.maps_manager.metrics_module,
network_task=self.maps_manager.network_task,
model=model,
dataloader=valid_loader,
criterion=criterion,
alpha=alpha,
target=True,
)

Expand Down Expand Up @@ -1444,7 +1468,7 @@ def _init_callbacks(self) -> None:
def _init_optimizer(
self,
model: DDP,
split: int = None,
split: Optional[int] = None,
resume: bool = False,
) -> torch.optim.Optimizer:
"""
Expand Down Expand Up @@ -1497,7 +1521,8 @@ def _init_profiler(self) -> torch.profiler.profile:
Profiler context manager.
"""
if self.config.optimization.profiler:
from clinicadl.utils.maps_manager.cluster.profiler import (
# TODO: no more profiler ????
from clinicadl.utils.cluster.profiler import (
ProfilerActivity,
profile,
schedule,
Expand Down Expand Up @@ -1543,7 +1568,7 @@ def _write_weights(
state: Dict[str, Any],
metrics_dict: Optional[Dict[str, bool]],
split: int,
network: int = None,
network: Optional[int] = None,
filename: str = "checkpoint.pth.tar",
save_all_models: bool = False,
) -> None:
Expand Down

0 comments on commit 552105f

Please sign in to comment.