From 32cd26d2a74cbc61c88f6ecf2d5b2d2a7b996cb0 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Thu, 12 Dec 2024 00:58:24 +0100 Subject: [PATCH] Add forward action to the training service --- .../test_grpc/test_training_servicer.py | 26 ++++++++++- tiktorch/server/grpc/training_servicer.py | 28 +++++++---- tiktorch/server/session/backend/base.py | 4 +- tiktorch/server/session/backend/supervisor.py | 3 +- tiktorch/server/session/process.py | 3 +- tiktorch/server/session/rpc_interface.py | 4 +- tiktorch/trainer.py | 46 +++++++++++++++++-- 7 files changed, 97 insertions(+), 17 deletions(-) diff --git a/tests/test_server/test_grpc/test_training_servicer.py b/tests/test_server/test_grpc/test_training_servicer.py index f40a8fcc..93e05ba2 100644 --- a/tests/test_server/test_grpc/test_training_servicer.py +++ b/tests/test_server/test_grpc/test_training_servicer.py @@ -8,8 +8,9 @@ import h5py import numpy as np import pytest +import xarray as xr -from tiktorch.converters import pb_state_to_trainer, trainer_state_to_pb +from tiktorch.converters import pb_state_to_trainer, pb_tensor_to_xarray, trainer_state_to_pb, xarray_to_pb_tensor from tiktorch.proto import training_pb2, training_pb2_grpc from tiktorch.server.device_pool import TorchDevicePool from tiktorch.server.grpc import training_servicer @@ -473,6 +474,29 @@ def test_close_trainer_session_twice(self, grpc_stub): grpc_stub.CloseTrainerSession(training_session_id) assert "Unknown session" in excinfo.value.details() + def test_forward(self, grpc_stub): + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + + grpc_stub.Start(training_session_id) + + batch = 5 + in_channels_unet2d = 3 + out_channels_unet2d = 2 + shape = (batch, in_channels_unet2d, 1, 128, 128) + data = np.random.rand(*shape).astype(np.float32) + xarray_data = xr.DataArray(data, dims=("b", "c", "z", "y", "x")) + pb_tensor = xarray_to_pb_tensor(tensor_id="", array=xarray_data) + predict_request = training_pb2.PredictRequest(sessionId=training_session_id, tensors=[pb_tensor]) + + response = grpc_stub.Predict(predict_request) + + predicted_tensors = [pb_tensor_to_xarray(pb_tensor) for pb_tensor in response.tensors] + assert len(predicted_tensors) == 1 + predicted_tensor = predicted_tensors[0] + assert predicted_tensor.dims == ("b", "c", "z", "y", "x") + assert predicted_tensor.shape == (batch, out_channels_unet2d, 1, 128, 128) + def test_close_session(self, grpc_stub): """ Test closing a training session. diff --git a/tiktorch/server/grpc/training_servicer.py b/tiktorch/server/grpc/training_servicer.py index 4dac6b18..86ea1726 100644 --- a/tiktorch/server/grpc/training_servicer.py +++ b/tiktorch/server/grpc/training_servicer.py @@ -5,14 +5,15 @@ from typing import Callable, List import grpc +import torch -from tiktorch.converters import trainer_state_to_pb -from tiktorch.proto import training_pb2, training_pb2_grpc +from tiktorch.converters import pb_tensor_to_numpy, trainer_state_to_pb +from tiktorch.proto import training_pb2, training_pb2_grpc, utils_pb2 from tiktorch.server.device_pool import IDevicePool from tiktorch.server.session.process import start_trainer_process from tiktorch.server.session.rpc_interface import IRPCTrainer from tiktorch.server.session_manager import Session, SessionManager -from tiktorch.trainer import TrainerYamlParser +from tiktorch.trainer import Trainer, TrainerYamlParser logger = logging.getLogger(__name__) @@ -47,7 +48,7 @@ def Init(self, request: training_pb2.TrainingConfig, context): return training_pb2.TrainingSessionId(id=session.id) - def Start(self, request, context): + def Start(self, request: training_pb2.TrainingSessionId, context): session = self._getTrainerSession(context, request.id) session.client.start_training() return training_pb2.Empty() @@ -63,17 +64,28 @@ def Pause(self, request: training_pb2.TrainingSessionId, context): return training_pb2.Empty() def Save(self, request: training_pb2.TrainingSessionId, context): - session = self._getTrainerSession(context, request.modelSessionId) + session = self._getTrainerSession(context, request.id) session.client.save() return training_pb2.Empty() def Export(self, request: training_pb2.TrainingSessionId, context): - session = self._getTrainerSession(context, request.modelSessionId) + session = self._getTrainerSession(context, request.id) session.client.export() return training_pb2.Empty() - def Predict(self, request: training_pb2.TrainingSessionId, context): - raise NotImplementedError + def Predict(self, request: training_pb2.PredictRequest, context): + session = self._getTrainerSession(context, request.sessionId.id) + tensors = [torch.tensor(pb_tensor_to_numpy(pb_tensor)) for pb_tensor in request.tensors] + assert len(tensors) == 1, "We support models with one input" + predictions = session.client.forward(tensors).result() + return training_pb2.PredictResponse(tensors=[self._tensor_to_pb(predictions)]) + + def _tensor_to_pb(self, tensor: torch.Tensor): + dims = Trainer.get_axes_from_tensor(tensor) + shape = [utils_pb2.NamedInt(size=dim, name=i) for i, dim in zip(dims, tensor.shape)] + np_array = tensor.numpy() + proto_tensor = utils_pb2.Tensor(tensorId="", dtype=str(np_array.dtype), shape=shape, buffer=np_array.tobytes()) + return proto_tensor def StreamUpdates(self, request: training_pb2.TrainingSessionId, context): raise NotImplementedError diff --git a/tiktorch/server/session/backend/base.py b/tiktorch/server/session/backend/base.py index 471c2d71..091c64c0 100644 --- a/tiktorch/server/session/backend/base.py +++ b/tiktorch/server/session/backend/base.py @@ -3,7 +3,9 @@ import logging from abc import ABC from concurrent.futures import Future +from typing import List +import torch from bioimageio.core import PredictionPipeline from tiktorch.configkeys import TRAINING, VALIDATION @@ -61,7 +63,7 @@ def __init__(self, trainer: Trainer): supervisor = TrainerSupervisor(trainer) super().__init__(supervisor) - def forward(self, input_tensors): + def forward(self, input_tensors: List[torch.Tensor]): res = Future() self._queue_tasks.send_command(commands.ForwardPass(res, input_tensors)) return res diff --git a/tiktorch/server/session/backend/supervisor.py b/tiktorch/server/session/backend/supervisor.py index fa04e599..a49ae19a 100644 --- a/tiktorch/server/session/backend/supervisor.py +++ b/tiktorch/server/session/backend/supervisor.py @@ -127,8 +127,9 @@ def shutdown(self): def forward(self, input_tensors): self.pause() - self._trainer.forward(input_tensors) + res = self._trainer.forward(input_tensors) self.resume() + return res def save(self): raise NotImplementedError diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index 3a1b949e..81aa6b82 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -7,6 +7,7 @@ from multiprocessing.connection import Connection from typing import List, Optional, Tuple, Type, TypeVar, Union +import torch from bioimageio.core import PredictionPipeline, Tensor, create_prediction_pipeline from bioimageio.spec import InvalidDescr, load_description from bioimageio.spec.model import v0_5 @@ -125,7 +126,7 @@ def init(self, trainer_yaml_config: str): trainer = parser.parse() self._worker = base.TrainerSessionBackend(trainer) - def forward(self, input_tensors) -> Future: + def forward(self, input_tensors: List[torch.Tensor]) -> Future: res = self.worker.forward(input_tensors) return res diff --git a/tiktorch/server/session/rpc_interface.py b/tiktorch/server/session/rpc_interface.py index db714cb5..36146b1e 100644 --- a/tiktorch/server/session/rpc_interface.py +++ b/tiktorch/server/session/rpc_interface.py @@ -1,5 +1,7 @@ from typing import List +import torch + from tiktorch.converters import Sample from tiktorch.rpc import RPCInterface, exposed from tiktorch.rpc.exceptions import Shutdown @@ -56,7 +58,7 @@ def init(self, trainer_yaml_config: str): raise NotImplementedError @exposed - def forward(self, input_tensors: Sample): + def forward(self, input_tensors: List[torch.Tensor]): raise NotImplementedError @exposed diff --git a/tiktorch/trainer.py b/tiktorch/trainer.py index 79b371f9..9f4a5ea2 100644 --- a/tiktorch/trainer.py +++ b/tiktorch/trainer.py @@ -4,14 +4,14 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Generic, List, TypeVar +from typing import Any, Callable, Generic, List, Tuple, TypeVar import torch import yaml from pytorch3dunet.datasets.utils import get_train_loaders from pytorch3dunet.unet3d.losses import get_loss_criterion from pytorch3dunet.unet3d.metrics import get_evaluation_metric -from pytorch3dunet.unet3d.model import get_model +from pytorch3dunet.unet3d.model import ResidualUNet2D, ResidualUNet3D, ResidualUNetSE3D, UNet2D, UNet3D, get_model from pytorch3dunet.unet3d.trainer import UNetTrainer from pytorch3dunet.unet3d.utils import create_lr_scheduler, create_optimizer, get_tensorboard_formatter from torch import nn @@ -96,6 +96,7 @@ class Trainer(UNetTrainer): def __init__( self, model, + device, optimizer, lr_scheduler, loss_criterion, @@ -138,6 +139,7 @@ def __init__( pre_trained=pre_trained, **kwargs, ) + self._device = device self.logs_callbacks: LogsCallbacks = BaseCallbacks() self.should_stop_callbacks: Callbacks = ShouldStopCallbacks() @@ -150,10 +152,45 @@ def train(self): def validate(self): return super().validate() - def forward(self, input_tensors): + def forward(self, input_tensors: List[torch.Tensor]): + """ + Note: + "The 2D U-Net itself uses the standard 2D convolutional + layers instead of 3D convolutions with kernel size (1, 3, 3) for performance reasons." + source: https://github.com/wolny/pytorch-3dunet + + Thus, we drop the z dimension if we have 2d model. + But the input h5 data needs to respect CxDxHxW or DxHxW. + """ + assert len(input_tensors) == 1, "We support models with 1 input" + input_tensor = input_tensors[0] + self.get_axes_from_tensor(input_tensor) self.model.eval() + b, c, z, y, x = input_tensor.shape + if self.is_2d_model() and z != 1: + raise ValueError(f"2d model detected but z != 1 for tensor {input_tensor.shape}") + with torch.no_grad(): - self.model(input_tensors) + if self.is_2d_model(): + input_tensor = input_tensor.squeeze(dim=-3) # b, c, [z], y, x + predictions = self.model(input_tensor.to(self._device)) + predictions = predictions.unsqueeze(dim=-3) # for consistency + else: + predictions = self.model(input_tensor.to(self._device)) + + return predictions + + @staticmethod + def get_axes_from_tensor(tensor: torch.Tensor) -> Tuple[str, ...]: + if tensor.ndim != 5: + raise ValueError(f"Tensor dims should be 5 (b, c, z, y, x) but got {tensor.ndim} dimensions") + return ("b", "c", "z", "y", "x") + + def is_3d_model(self): + return isinstance(self.model, (ResidualUNetSE3D, ResidualUNet3D, UNet3D)) + + def is_2d_model(self): + return isinstance(self.model, (ResidualUNet2D, UNet2D)) def should_stop(self) -> bool: """ @@ -228,6 +265,7 @@ def parse(self) -> Trainer: pre_trained = trainer_config.pop("pre_trained", None) return Trainer( + device=config["device"], model=model, optimizer=optimizer, lr_scheduler=lr_scheduler,