From 50b0944cd369c3a7d0da5836abe85d8cadd30d6f Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Thu, 19 Dec 2024 01:40:45 +0100 Subject: [PATCH] Add preprocessing and postprocessing to forward method --- tiktorch/server/grpc/training_servicer.py | 1 - tiktorch/trainer.py | 54 +++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/tiktorch/server/grpc/training_servicer.py b/tiktorch/server/grpc/training_servicer.py index c262e76b..c44a0ad0 100644 --- a/tiktorch/server/grpc/training_servicer.py +++ b/tiktorch/server/grpc/training_servicer.py @@ -78,7 +78,6 @@ def Export(self, request: training_pb2.ExportRequest, context): session.client.export(Path(request.filePath)) return utils_pb2.Empty() - def Predict(self, request: training_pb2.PredictRequest, context): session = self._getTrainerSession(context, request.sessionId.id) tensors = [torch.tensor(pb_tensor_to_numpy(pb_tensor)) for pb_tensor in request.tensors] diff --git a/tiktorch/trainer.py b/tiktorch/trainer.py index 9f4a5ea2..e7afee51 100644 --- a/tiktorch/trainer.py +++ b/tiktorch/trainer.py @@ -8,6 +8,7 @@ import torch import yaml +from pytorch3dunet.augment.transforms import Compose, Normalize, Standardize, ToTensor from pytorch3dunet.datasets.utils import get_train_loaders from pytorch3dunet.unet3d.losses import get_loss_criterion from pytorch3dunet.unet3d.metrics import get_evaluation_metric @@ -97,6 +98,8 @@ def __init__( self, model, device, + in_channels, + out_channels, optimizer, lr_scheduler, loss_criterion, @@ -139,6 +142,8 @@ def __init__( pre_trained=pre_trained, **kwargs, ) + self._in_channels = in_channels + self._out_channels = out_channels self._device = device self.logs_callbacks: LogsCallbacks = BaseCallbacks() self.should_stop_callbacks: Callbacks = ShouldStopCallbacks() @@ -170,16 +175,60 @@ def forward(self, input_tensors: List[torch.Tensor]): if self.is_2d_model() and z != 1: raise ValueError(f"2d model detected but z != 1 for tensor {input_tensor.shape}") + # todo: normalization need to be consistent with the training one (it should be retrieved by the config) + preprocessor = Compose([Standardize(), ToTensor(expand_dims=True)]) + input_tensor = self._apply_transformation(compose=preprocessor, tensor=input_tensor) + + def apply_final_activation(input_tensors) -> torch.Tensor: + if self.model.final_activation is not None: + return self.model.final_activation(input_tensors) + return input_tensors + with torch.no_grad(): if self.is_2d_model(): input_tensor = input_tensor.squeeze(dim=-3) # b, c, [z], y, x predictions = self.model(input_tensor.to(self._device)) + predictions = apply_final_activation(predictions) predictions = predictions.unsqueeze(dim=-3) # for consistency else: predictions = self.model(input_tensor.to(self._device)) + predictions = apply_final_activation(predictions) + predictions = predictions.cpu() + + # this needs to be exposed as well + # currently we scale the features from 0 - 1 (consistent scale for rendering across channels) + postprocessor = Compose([Normalize(norm01=True), ToTensor(expand_dims=True)]) + predictions = self._apply_transformation(compose=postprocessor, tensor=predictions) return predictions + def _apply_transformation(self, compose: Compose, tensor: torch.Tensor) -> torch.Tensor: + """ + To apply transformations pytorch 3d unet requires shape of DxHxW or CxDxHxW + """ + b, c, z, y, x = tensor.shape + non_batch_tensors = [] + for batch_idx in range(b): + # drop batch + non_batch_tensor = tensor[batch_idx, :] + + # drop channel dim if single channel + dropped_channel = non_batch_tensor.squeeze(dim=-4) if self.is_input_single_channel() else non_batch_tensor + + # adds channel back with the`expand_dims` + transformed_tensor = compose(dropped_channel.detach().cpu().numpy()) + + non_batch_tensors.append(transformed_tensor) + + # add batch dim again + return torch.stack(non_batch_tensors, dim=0) + + def is_input_single_channel(self) -> bool: + return self._in_channels == 1 + + def is_output_single_channel(self) -> bool: + return self._out_channels == 1 + @staticmethod def get_axes_from_tensor(tensor: torch.Tensor) -> Tuple[str, ...]: if tensor.ndim != 5: @@ -238,6 +287,9 @@ def parse(self) -> Trainer: model = get_model(config["model"]) + in_channels = config["model"]["in_channels"] + out_channels = config["model"]["out_channels"] + if torch.cuda.device_count() > 1 and not config["device"] == "cpu": model = nn.DataParallel(model) if torch.cuda.is_available() and not config["device"] == "cpu": @@ -266,6 +318,8 @@ def parse(self) -> Trainer: return Trainer( device=config["device"], + in_channels=in_channels, + out_channels=out_channels, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler,