Skip to content

Commit

Permalink
Add forward action to the training service
Browse files Browse the repository at this point in the history
  • Loading branch information
thodkatz committed Dec 20, 2024
1 parent 27b3923 commit 39672ca
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 14 deletions.
26 changes: 25 additions & 1 deletion tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import h5py
import numpy as np
import pytest
import xarray as xr

from tiktorch.converters import pb_state_to_trainer, trainer_state_to_pb
from tiktorch.converters import pb_state_to_trainer, pb_tensor_to_xarray, trainer_state_to_pb, xarray_to_pb_tensor
from tiktorch.proto import training_pb2, training_pb2_grpc, utils_pb2
from tiktorch.server.device_pool import TorchDevicePool
from tiktorch.server.grpc import training_servicer
Expand Down Expand Up @@ -473,6 +474,29 @@ def test_close_trainer_session_twice(self, grpc_stub):
grpc_stub.CloseTrainerSession(training_session_id)
assert "Unknown session" in excinfo.value.details()

def test_forward(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)

grpc_stub.Start(training_session_id)

batch = 5
in_channels_unet2d = 3
out_channels_unet2d = 2
shape = (batch, in_channels_unet2d, 1, 128, 128)
data = np.random.rand(*shape).astype(np.float32)
xarray_data = xr.DataArray(data, dims=("b", "c", "z", "y", "x"))
pb_tensor = xarray_to_pb_tensor(tensor_id="", array=xarray_data)
predict_request = training_pb2.PredictRequest(sessionId=training_session_id, tensors=[pb_tensor])

response = grpc_stub.Predict(predict_request)

predicted_tensors = [pb_tensor_to_xarray(pb_tensor) for pb_tensor in response.tensors]
assert len(predicted_tensors) == 1
predicted_tensor = predicted_tensors[0]
assert predicted_tensor.dims == ("b", "c", "z", "y", "x")
assert predicted_tensor.shape == (batch, out_channels_unet2d, 1, 128, 128)

def test_close_session(self, grpc_stub):
"""
Test closing a training session.
Expand Down
23 changes: 18 additions & 5 deletions tiktorch/server/grpc/training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
from typing import Callable, List

import grpc
import torch

from tiktorch.converters import trainer_state_to_pb
from tiktorch.converters import pb_tensor_to_numpy, trainer_state_to_pb
from tiktorch.proto import training_pb2, training_pb2_grpc, utils_pb2
from tiktorch.server.device_pool import IDevicePool
from tiktorch.server.grpc.utils_servicer import list_devices
from tiktorch.server.session.process import start_trainer_process
from tiktorch.server.session.rpc_interface import IRPCTrainer
from tiktorch.server.session_manager import Session, SessionManager
from tiktorch.trainer import TrainerYamlParser
from tiktorch.trainer import Trainer, TrainerYamlParser

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -52,7 +53,7 @@ def Init(self, request: training_pb2.TrainingConfig, context):

return training_pb2.TrainingSessionId(id=session.id)

def Start(self, request, context):
def Start(self, request: training_pb2.TrainingSessionId, context):
session = self._getTrainerSession(context, request.id)
session.client.start_training()
return utils_pb2.Empty()
Expand All @@ -77,8 +78,20 @@ def Export(self, request: training_pb2.ExportRequest, context):
session.client.export(Path(request.filePath))
return utils_pb2.Empty()

def Predict(self, request: training_pb2.TrainingSessionId, context):
raise NotImplementedError

def Predict(self, request: training_pb2.PredictRequest, context):
session = self._getTrainerSession(context, request.sessionId.id)
tensors = [torch.tensor(pb_tensor_to_numpy(pb_tensor)) for pb_tensor in request.tensors]
assert len(tensors) == 1, "We support models with one input"
predictions = session.client.forward(tensors).result()
return training_pb2.PredictResponse(tensors=[self._tensor_to_pb(predictions)])

def _tensor_to_pb(self, tensor: torch.Tensor):
dims = Trainer.get_axes_from_tensor(tensor)
shape = [utils_pb2.NamedInt(size=dim, name=i) for i, dim in zip(dims, tensor.shape)]
np_array = tensor.numpy()
proto_tensor = utils_pb2.Tensor(tensorId="", dtype=str(np_array.dtype), shape=shape, buffer=np_array.tobytes())
return proto_tensor

def StreamUpdates(self, request: training_pb2.TrainingSessionId, context):
raise NotImplementedError
Expand Down
4 changes: 3 additions & 1 deletion tiktorch/server/session/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import logging
from abc import ABC
from concurrent.futures import Future
from typing import List

import torch
from bioimageio.core import PredictionPipeline

from tiktorch.configkeys import TRAINING, VALIDATION
Expand Down Expand Up @@ -61,7 +63,7 @@ def __init__(self, trainer: Trainer):
supervisor = TrainerSupervisor(trainer)
super().__init__(supervisor)

def forward(self, input_tensors):
def forward(self, input_tensors: List[torch.Tensor]):
res = Future()
self._queue_tasks.send_command(commands.ForwardPass(res, input_tensors))
return res
Expand Down
3 changes: 2 additions & 1 deletion tiktorch/server/session/backend/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ def shutdown(self):

def forward(self, input_tensors):
self.pause()
self._trainer.forward(input_tensors)
res = self._trainer.forward(input_tensors)
self.resume()
return res

def save(self):
raise NotImplementedError
Expand Down
3 changes: 2 additions & 1 deletion tiktorch/server/session/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from multiprocessing.connection import Connection
from typing import List, Optional, Tuple, Type, TypeVar, Union

import torch
from bioimageio.core import PredictionPipeline, Tensor, create_prediction_pipeline
from bioimageio.spec import InvalidDescr, load_description
from bioimageio.spec.model import v0_5
Expand Down Expand Up @@ -125,7 +126,7 @@ def init(self, trainer_yaml_config: str):
trainer = parser.parse()
self._worker = base.TrainerSessionBackend(trainer)

def forward(self, input_tensors) -> Future:
def forward(self, input_tensors: List[torch.Tensor]) -> Future:
res = self.worker.forward(input_tensors)
return res

Expand Down
4 changes: 3 additions & 1 deletion tiktorch/server/session/rpc_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List

import torch

from tiktorch.converters import Sample
from tiktorch.rpc import RPCInterface, exposed
from tiktorch.rpc.exceptions import Shutdown
Expand Down Expand Up @@ -56,7 +58,7 @@ def init(self, trainer_yaml_config: str):
raise NotImplementedError

@exposed
def forward(self, input_tensors: Sample):
def forward(self, input_tensors: List[torch.Tensor]):
raise NotImplementedError

@exposed
Expand Down
46 changes: 42 additions & 4 deletions tiktorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Generic, List, TypeVar
from typing import Any, Callable, Generic, List, Tuple, TypeVar

import torch
import yaml
from pytorch3dunet.datasets.utils import get_train_loaders
from pytorch3dunet.unet3d.losses import get_loss_criterion
from pytorch3dunet.unet3d.metrics import get_evaluation_metric
from pytorch3dunet.unet3d.model import get_model
from pytorch3dunet.unet3d.model import ResidualUNet2D, ResidualUNet3D, ResidualUNetSE3D, UNet2D, UNet3D, get_model
from pytorch3dunet.unet3d.trainer import UNetTrainer
from pytorch3dunet.unet3d.utils import create_lr_scheduler, create_optimizer, get_tensorboard_formatter
from torch import nn
Expand Down Expand Up @@ -96,6 +96,7 @@ class Trainer(UNetTrainer):
def __init__(
self,
model,
device,
optimizer,
lr_scheduler,
loss_criterion,
Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(
pre_trained=pre_trained,
**kwargs,
)
self._device = device
self.logs_callbacks: LogsCallbacks = BaseCallbacks()
self.should_stop_callbacks: Callbacks = ShouldStopCallbacks()

Expand All @@ -150,10 +152,45 @@ def train(self):
def validate(self):
return super().validate()

def forward(self, input_tensors):
def forward(self, input_tensors: List[torch.Tensor]):
"""
Note:
"The 2D U-Net itself uses the standard 2D convolutional
layers instead of 3D convolutions with kernel size (1, 3, 3) for performance reasons."
source: https://github.com/wolny/pytorch-3dunet
Thus, we drop the z dimension if we have 2d model.
But the input h5 data needs to respect CxDxHxW or DxHxW.
"""
assert len(input_tensors) == 1, "We support models with 1 input"
input_tensor = input_tensors[0]
self.get_axes_from_tensor(input_tensor)
self.model.eval()
b, c, z, y, x = input_tensor.shape
if self.is_2d_model() and z != 1:
raise ValueError(f"2d model detected but z != 1 for tensor {input_tensor.shape}")

with torch.no_grad():
self.model(input_tensors)
if self.is_2d_model():
input_tensor = input_tensor.squeeze(dim=-3) # b, c, [z], y, x
predictions = self.model(input_tensor.to(self._device))
predictions = predictions.unsqueeze(dim=-3) # for consistency
else:
predictions = self.model(input_tensor.to(self._device))

return predictions

@staticmethod
def get_axes_from_tensor(tensor: torch.Tensor) -> Tuple[str, ...]:
if tensor.ndim != 5:
raise ValueError(f"Tensor dims should be 5 (b, c, z, y, x) but got {tensor.ndim} dimensions")
return ("b", "c", "z", "y", "x")

def is_3d_model(self):
return isinstance(self.model, (ResidualUNetSE3D, ResidualUNet3D, UNet3D))

def is_2d_model(self):
return isinstance(self.model, (ResidualUNet2D, UNet2D))

def should_stop(self) -> bool:
"""
Expand Down Expand Up @@ -228,6 +265,7 @@ def parse(self) -> Trainer:
pre_trained = trainer_config.pop("pre_trained", None)

return Trainer(
device=config["device"],
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
Expand Down

0 comments on commit 39672ca

Please sign in to comment.