Skip to content

Commit

Permalink
Add save and export to the training service
Browse files Browse the repository at this point in the history
  • Loading branch information
thodkatz committed Dec 14, 2024
1 parent 619ef5f commit 30eda0d
Show file tree
Hide file tree
Showing 11 changed files with 343 additions and 64 deletions.
12 changes: 10 additions & 2 deletions proto/training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ service Training {

rpc GetLogs(TrainingSessionId) returns (GetLogsResponse) {}

rpc Save(TrainingSessionId) returns (Empty) {}
rpc Save(SaveRequest) returns (Empty) {}

rpc Export(TrainingSessionId) returns (Empty) {}
rpc Export(ExportRequest) returns (Empty) {}

rpc Predict(PredictRequest) returns (PredictResponse) {}

Expand Down Expand Up @@ -58,7 +58,15 @@ message GetLogsResponse {
repeated Logs logs = 1;
}

message SaveRequest {
TrainingSessionId sessionId = 1;
string filePath = 2;
}

message ExportRequest {
TrainingSessionId sessionId = 1;
string filePath = 2;
}

message PredictRequest {
repeated Tensor tensors = 1;
Expand Down
69 changes: 60 additions & 9 deletions tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import threading
import time
from pathlib import Path
from typing import Callable
from typing import Callable, Optional

import grpc
import h5py
Expand Down Expand Up @@ -41,8 +41,11 @@ def grpc_stub_cls():
return training_pb2_grpc.TrainingStub


def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: str = "cpu"):
return f"""
def unet2d_config_path(
checkpoint_dir: Path, train_data_dir: str, val_data_path: str, resume: Optional[str] = None, device: str = "cpu"
):
# todo: upsampling makes model torchscript incompatible
base = f"""
device: {device} # Use CPU for faster test execution, change to 'cuda' if GPU is available and necessary
model:
name: UNet2D
Expand All @@ -53,13 +56,14 @@ def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: st
num_groups: 4
final_sigmoid: false
is_segmentation: true
upsample: default
trainer:
checkpoint_dir: {checkpoint_dir}
resume: null
validate_after_iters: 2
validate_after_iters: 250
log_after_iters: 2
max_num_epochs: 1000
max_num_iterations: 10000
max_num_epochs: 10000
max_num_iterations: 100000
eval_score_higher_is_better: True
optimizer:
learning_rate: 0.0002
Expand Down Expand Up @@ -149,6 +153,9 @@ def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: st
- name: ToTensor
expand_dims: false
"""
if resume:
return f"resume: {resume}{base}"
return base


def create_random_dataset(shape, channel_per_class):
Expand All @@ -171,15 +178,22 @@ def create_random_dataset(shape, channel_per_class):
return tmp.name


def prepare_unet2d_test_environment(device: str = "cpu") -> str:
def prepare_unet2d_test_environment(resume: Optional[str] = None, device: str = "cpu") -> str:
checkpoint_dir = Path(tempfile.mkdtemp())

shape = (3, 1, 128, 128)
in_channel = 3
z = 1 # 2d
y = 128
x = 128
shape = (in_channel, z, y, x)
binary_loss = False
train = create_random_dataset(shape, binary_loss)
val = create_random_dataset(shape, binary_loss)

return unet2d_config_path(checkpoint_dir=checkpoint_dir, train_data_dir=train, val_data_path=val, device=device)
config = unet2d_config_path(
resume=resume, checkpoint_dir=checkpoint_dir, train_data_dir=train, val_data_path=val, device=device
)
return config


class TestTrainingServicer:
Expand Down Expand Up @@ -528,6 +542,43 @@ def test_forward_while_paused(self, grpc_stub):
assert predicted_tensor.dims == ("b", "c", "z", "y", "x")
assert predicted_tensor.shape == (batch, out_channels_unet2d, 1, 128, 128)

def test_save(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)

grpc_stub.Start(training_session_id)

with tempfile.TemporaryDirectory() as model_checkpoint_dir:
model_checkpoint_file = Path(model_checkpoint_dir) / "model.pth"
save_request = training_pb2.SaveRequest(sessionId=training_session_id, filePath=str(model_checkpoint_file))
grpc_stub.Save(save_request)
assert model_checkpoint_file.exists()

# assume stopping training to release devices
grpc_stub.CloseTrainerSession(training_session_id)

# attempt to init a new model with the new checkpoint and start training
init_response = grpc_stub.Init(
training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment(resume=model_checkpoint_file))
)
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
grpc_stub.Start(training_session_id)

def test_export(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)

grpc_stub.Start(training_session_id)

with tempfile.TemporaryDirectory() as model_checkpoint_dir:
model_export_file = Path(model_checkpoint_dir) / "bioimageio.zip"
export_request = training_pb2.ExportRequest(sessionId=training_session_id, filePath=str(model_export_file))
grpc_stub.Export(export_request)
assert model_export_file.exists()

# assume stopping training since model is exported
grpc_stub.CloseTrainerSession(training_session_id)

def test_close_session(self, grpc_stub):
"""
Test closing a training session.
Expand Down
38 changes: 21 additions & 17 deletions tiktorch/proto/training_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions tiktorch/proto/training_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ def __init__(self, channel):
_registered_method=True)
self.Save = channel.unary_unary(
'/training.Training/Save',
request_serializer=training__pb2.TrainingSessionId.SerializeToString,
request_serializer=training__pb2.SaveRequest.SerializeToString,
response_deserializer=training__pb2.Empty.FromString,
_registered_method=True)
self.Export = channel.unary_unary(
'/training.Training/Export',
request_serializer=training__pb2.TrainingSessionId.SerializeToString,
request_serializer=training__pb2.ExportRequest.SerializeToString,
response_deserializer=training__pb2.Empty.FromString,
_registered_method=True)
self.Predict = channel.unary_unary(
Expand Down Expand Up @@ -195,12 +195,12 @@ def add_TrainingServicer_to_server(servicer, server):
),
'Save': grpc.unary_unary_rpc_method_handler(
servicer.Save,
request_deserializer=training__pb2.TrainingSessionId.FromString,
request_deserializer=training__pb2.SaveRequest.FromString,
response_serializer=training__pb2.Empty.SerializeToString,
),
'Export': grpc.unary_unary_rpc_method_handler(
servicer.Export,
request_deserializer=training__pb2.TrainingSessionId.FromString,
request_deserializer=training__pb2.ExportRequest.FromString,
response_serializer=training__pb2.Empty.SerializeToString,
),
'Predict': grpc.unary_unary_rpc_method_handler(
Expand Down Expand Up @@ -406,7 +406,7 @@ def Save(request,
request,
target,
'/training.Training/Save',
training__pb2.TrainingSessionId.SerializeToString,
training__pb2.SaveRequest.SerializeToString,
training__pb2.Empty.FromString,
options,
channel_credentials,
Expand All @@ -433,7 +433,7 @@ def Export(request,
request,
target,
'/training.Training/Export',
training__pb2.TrainingSessionId.SerializeToString,
training__pb2.ExportRequest.SerializeToString,
training__pb2.Empty.FromString,
options,
channel_credentials,
Expand Down
13 changes: 7 additions & 6 deletions tiktorch/server/grpc/training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import queue
from pathlib import Path
from typing import Callable, List

import grpc
Expand Down Expand Up @@ -63,14 +64,14 @@ def Pause(self, request: training_pb2.TrainingSessionId, context):
session.client.pause_training()
return training_pb2.Empty()

def Save(self, request: training_pb2.TrainingSessionId, context):
session = self._getTrainerSession(context, request.id)
session.client.save()
def Save(self, request: training_pb2.SaveRequest, context):
session = self._getTrainerSession(context, request.sessionId.id)
session.client.save(Path(request.filePath))
return training_pb2.Empty()

def Export(self, request: training_pb2.TrainingSessionId, context):
session = self._getTrainerSession(context, request.id)
session.client.export()
def Export(self, request: training_pb2.ExportRequest, context):
session = self._getTrainerSession(context, request.sessionId.id)
session.client.export(Path(request.filePath))
return training_pb2.Empty()

def Predict(self, request: training_pb2.PredictRequest, context):
Expand Down
13 changes: 9 additions & 4 deletions tiktorch/server/session/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from abc import ABC
from concurrent.futures import Future
from pathlib import Path
from typing import List

import torch
Expand Down Expand Up @@ -83,11 +84,15 @@ def start_training(self) -> None:
self._queue_tasks.send_command(start_cmd.awaitable)
start_cmd.awaitable.wait()

def save(self) -> None:
raise NotImplementedError
def save(self, file_path: Path) -> None:
save_cmd = commands.SaveTrainingCmd(file_path)
self._queue_tasks.send_command(save_cmd.awaitable)
save_cmd.awaitable.wait()

def export(self) -> None:
raise NotImplementedError
def export(self, file_path: Path) -> None:
export_cmd = commands.ExportTrainingCmd(file_path)
self._queue_tasks.send_command(export_cmd.awaitable)
export_cmd.awaitable.wait()

def get_state(self) -> TrainerState:
return self._supervisor.get_state()
19 changes: 19 additions & 0 deletions tiktorch/server/session/backend/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import threading
import typing
from dataclasses import dataclass, field
from pathlib import Path
from typing import Generic, Type, TypeVar

from tiktorch.trainer import TrainerAction, TrainerState
Expand Down Expand Up @@ -131,6 +132,24 @@ def execute(self, ctx: Context) -> None:
pass


class ExportTrainingCmd(ICommand):
def __init__(self, file_path: Path):
super().__init__()
self._file_path = file_path

def execute(self, ctx: Context[TrainerSupervisor]) -> None:
ctx.session.export(self._file_path)


class SaveTrainingCmd(ICommand):
def __init__(self, file_path: Path):
super().__init__()
self._file_path = file_path

def execute(self, ctx: Context[TrainerSupervisor]) -> None:
ctx.session.save(self._file_path)


class ShutdownWithTeardownCmd(ShutdownCmd):
def execute(self, ctx: Context[Supervisors]) -> None:
ctx.session.shutdown()
Expand Down
12 changes: 8 additions & 4 deletions tiktorch/server/session/backend/supervisor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import threading
from pathlib import Path
from typing import Generic, Set, TypeVar, Union

from bioimageio.core import PredictionPipeline, Sample
Expand Down Expand Up @@ -134,11 +135,14 @@ def forward(self, input_tensors):
self.resume()
return res

def save(self):
raise NotImplementedError
def save(self, file_path: Path):
self.pause()
self._trainer.save_state_dict(file_path)
self.resume()

def export(self):
raise NotImplementedError
def export(self, file_path: Path):
self.pause()
self._trainer.export(file_path)

def _should_stop(self):
return self._pause_triggered
Expand Down
Loading

0 comments on commit 30eda0d

Please sign in to comment.