Skip to content

Commit

Permalink
Add preprocessing and postprocessing to forward method
Browse files Browse the repository at this point in the history
  • Loading branch information
thodkatz committed Dec 20, 2024
1 parent b6d63c7 commit 50b0944
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
1 change: 0 additions & 1 deletion tiktorch/server/grpc/training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def Export(self, request: training_pb2.ExportRequest, context):
session.client.export(Path(request.filePath))
return utils_pb2.Empty()

Check warning on line 79 in tiktorch/server/grpc/training_servicer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/server/grpc/training_servicer.py#L77-L79

Added lines #L77 - L79 were not covered by tests


def Predict(self, request: training_pb2.PredictRequest, context):
session = self._getTrainerSession(context, request.sessionId.id)
tensors = [torch.tensor(pb_tensor_to_numpy(pb_tensor)) for pb_tensor in request.tensors]
Expand Down
54 changes: 54 additions & 0 deletions tiktorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import yaml
from pytorch3dunet.augment.transforms import Compose, Normalize, Standardize, ToTensor
from pytorch3dunet.datasets.utils import get_train_loaders
from pytorch3dunet.unet3d.losses import get_loss_criterion
from pytorch3dunet.unet3d.metrics import get_evaluation_metric
Expand Down Expand Up @@ -97,6 +98,8 @@ def __init__(
self,
model,
device,
in_channels,
out_channels,
optimizer,
lr_scheduler,
loss_criterion,
Expand Down Expand Up @@ -139,6 +142,8 @@ def __init__(
pre_trained=pre_trained,
**kwargs,
)
self._in_channels = in_channels
self._out_channels = out_channels
self._device = device
self.logs_callbacks: LogsCallbacks = BaseCallbacks()
self.should_stop_callbacks: Callbacks = ShouldStopCallbacks()

Check warning on line 149 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L145-L149

Added lines #L145 - L149 were not covered by tests
Expand Down Expand Up @@ -170,16 +175,60 @@ def forward(self, input_tensors: List[torch.Tensor]):
if self.is_2d_model() and z != 1:
raise ValueError(f"2d model detected but z != 1 for tensor {input_tensor.shape}")

Check warning on line 176 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L170-L176

Added lines #L170 - L176 were not covered by tests

# todo: normalization need to be consistent with the training one (it should be retrieved by the config)
preprocessor = Compose([Standardize(), ToTensor(expand_dims=True)])
input_tensor = self._apply_transformation(compose=preprocessor, tensor=input_tensor)

Check warning on line 180 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L179-L180

Added lines #L179 - L180 were not covered by tests

def apply_final_activation(input_tensors) -> torch.Tensor:
if self.model.final_activation is not None:
return self.model.final_activation(input_tensors)
return input_tensors

Check warning on line 185 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L182-L185

Added lines #L182 - L185 were not covered by tests

with torch.no_grad():
if self.is_2d_model():
input_tensor = input_tensor.squeeze(dim=-3) # b, c, [z], y, x
predictions = self.model(input_tensor.to(self._device))
predictions = apply_final_activation(predictions)
predictions = predictions.unsqueeze(dim=-3) # for consistency

Check warning on line 192 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L187-L192

Added lines #L187 - L192 were not covered by tests
else:
predictions = self.model(input_tensor.to(self._device))
predictions = apply_final_activation(predictions)

Check warning on line 195 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L194-L195

Added lines #L194 - L195 were not covered by tests

predictions = predictions.cpu()

Check warning on line 197 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L197

Added line #L197 was not covered by tests

# this needs to be exposed as well
# currently we scale the features from 0 - 1 (consistent scale for rendering across channels)
postprocessor = Compose([Normalize(norm01=True), ToTensor(expand_dims=True)])
predictions = self._apply_transformation(compose=postprocessor, tensor=predictions)
return predictions

Check warning on line 203 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L201-L203

Added lines #L201 - L203 were not covered by tests

def _apply_transformation(self, compose: Compose, tensor: torch.Tensor) -> torch.Tensor:
"""
To apply transformations pytorch 3d unet requires shape of DxHxW or CxDxHxW
"""
b, c, z, y, x = tensor.shape
non_batch_tensors = []
for batch_idx in range(b):

Check warning on line 211 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L209-L211

Added lines #L209 - L211 were not covered by tests
# drop batch
non_batch_tensor = tensor[batch_idx, :]

Check warning on line 213 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L213

Added line #L213 was not covered by tests

# drop channel dim if single channel
dropped_channel = non_batch_tensor.squeeze(dim=-4) if self.is_input_single_channel() else non_batch_tensor

Check warning on line 216 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L216

Added line #L216 was not covered by tests

# adds channel back with the`expand_dims`
transformed_tensor = compose(dropped_channel.detach().cpu().numpy())

Check warning on line 219 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L219

Added line #L219 was not covered by tests

non_batch_tensors.append(transformed_tensor)

Check warning on line 221 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L221

Added line #L221 was not covered by tests

# add batch dim again
return torch.stack(non_batch_tensors, dim=0)

Check warning on line 224 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L224

Added line #L224 was not covered by tests

def is_input_single_channel(self) -> bool:
return self._in_channels == 1

Check warning on line 227 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L227

Added line #L227 was not covered by tests

def is_output_single_channel(self) -> bool:
return self._out_channels == 1

Check warning on line 230 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L230

Added line #L230 was not covered by tests

@staticmethod
def get_axes_from_tensor(tensor: torch.Tensor) -> Tuple[str, ...]:
if tensor.ndim != 5:
Expand Down Expand Up @@ -238,6 +287,9 @@ def parse(self) -> Trainer:

model = get_model(config["model"])

Check warning on line 288 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L288

Added line #L288 was not covered by tests

in_channels = config["model"]["in_channels"]
out_channels = config["model"]["out_channels"]

Check warning on line 291 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L290-L291

Added lines #L290 - L291 were not covered by tests

if torch.cuda.device_count() > 1 and not config["device"] == "cpu":
model = nn.DataParallel(model)
if torch.cuda.is_available() and not config["device"] == "cpu":
Expand Down Expand Up @@ -266,6 +318,8 @@ def parse(self) -> Trainer:

return Trainer(

Check warning on line 319 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L319

Added line #L319 was not covered by tests
device=config["device"],
in_channels=in_channels,
out_channels=out_channels,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
Expand Down

0 comments on commit 50b0944

Please sign in to comment.