Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Théo Tournier <[email protected]>
  • Loading branch information
LBerth and tourniert authored Sep 9, 2024
1 parent 9c1be91 commit 15450d3
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions mfai/torch/segmentation_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_metrics(self):
}
return torch.nn.ModuleDict(metrics_dict), metrics_kwargs

def forward(self, inputs):
def forward(self, inputs: torch.Tensor):
"""Runs data through the model. Separate from training step."""
if self.channels_last:
inputs = inputs.to(memory_format=torch.channels_last)
Expand All @@ -101,7 +101,7 @@ def forward(self, inputs):
y_hat = self.last_activation(y_hat)
return y_hat

def _shared_forward_step(self, x, y):
def _shared_forward_step(self, x: torch.Tensor, y: torch.Tensor):
"""Computes forward pass and loss for a batch.
Step shared by training, validation and test steps"""
if self.channels_last:
Expand All @@ -124,7 +124,7 @@ def on_train_start(self):
hparams["model"] = self.model.__class__.__name__
self.logger.log_hyperparams(hparams, {"val_loss": 0, "val_f1": 0})

def _shared_epoch_end(self, outputs, label):
def _shared_epoch_end(self, outputs: torch.Tensor, label: torch.Tensor):
"""Computes and logs the averaged loss at the end of an epoch on custom layout.
Step shared by training and validation epochs.
"""
Expand All @@ -143,8 +143,8 @@ def on_train_epoch_end(self):
self._shared_epoch_end(self.training_loss, "train")
self.training_loss.clear() # free memory

def val_plot_step(self, batch_idx, y, y_hat):
"""Plots images on first batch of validation and log them in tensorboard.
def val_plot_step(self, batch_idx: int, y: torch.Tensor, y_hat: torch.Tensor):
"""Plots images on first batch of validation and log them in logger.
Should be overwrited for each specific project, with matplotlib plots."""
if batch_idx == 0:
tb = self.logger.experiment
Expand Down Expand Up @@ -209,20 +209,20 @@ def save_test_metrics_as_csv(self, df: pd.DataFrame) -> None:
print(f"--> Metrics for all samples saved in \033[91m\033[1m{path_csv}\033[0m")

def on_test_epoch_end(self):
"""Logs metrics in tensorboard hparams view, at the end of run."""
"""Logs metrics in logger hparams view, at the end of run."""
df = self.build_metrics_dataframe()
self.save_test_metrics_as_csv(df)
df = df.drop("Name", axis=1)

def last_activation(self, y_hat):
def last_activation(self, y_hat: torch.Tensor):
"""Applies appropriate activation according to task."""
if self.type_segmentation == "multiclass":
y_hat = y_hat.log_softmax(dim=1).exp()
elif self.type_segmentation in ["binary", "multilabel"]:
y_hat = torch.nn.functional.logsigmoid(y_hat).exp()
return y_hat

def probabilities_to_classes(self, y_hat):
def probabilities_to_classes(self, y_hat: torch.Tensor):
"""Transfrom probalistics predictions to discrete classes"""
if self.type_segmentation == "multiclass":
y_hat = y_hat.argmax(dim=1)
Expand Down

0 comments on commit 15450d3

Please sign in to comment.