Skip to content

Commit

Permalink
Retain the training state when saving and exporting
Browse files Browse the repository at this point in the history
- If the model was initial paused or running, save after completion
  retain the state, while temporarily pausing to perform the save.
- The export will pause the training if not paused before.
  • Loading branch information
thodkatz committed Dec 14, 2024
1 parent 30eda0d commit 0e7735a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
49 changes: 47 additions & 2 deletions tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def test_forward_while_paused(self, grpc_stub):
assert predicted_tensor.dims == ("b", "c", "z", "y", "x")
assert predicted_tensor.shape == (batch, out_channels_unet2d, 1, 128, 128)

def test_save(self, grpc_stub):
def test_save_while_running(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)

Expand All @@ -553,6 +553,7 @@ def test_save(self, grpc_stub):
save_request = training_pb2.SaveRequest(sessionId=training_session_id, filePath=str(model_checkpoint_file))
grpc_stub.Save(save_request)
assert model_checkpoint_file.exists()
self.assert_state(grpc_stub, training_session_id, TrainerState.RUNNING)

# assume stopping training to release devices
grpc_stub.CloseTrainerSession(training_session_id)
Expand All @@ -564,17 +565,61 @@ def test_save(self, grpc_stub):
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
grpc_stub.Start(training_session_id)

def test_export(self, grpc_stub):
def test_save_while_paused(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)

grpc_stub.Start(training_session_id)
time.sleep(1)
grpc_stub.Pause(training_session_id)

with tempfile.TemporaryDirectory() as model_checkpoint_dir:
model_checkpoint_file = Path(model_checkpoint_dir) / "model.pth"
save_request = training_pb2.SaveRequest(sessionId=training_session_id, filePath=str(model_checkpoint_file))
grpc_stub.Save(save_request)
assert model_checkpoint_file.exists()
self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED)

# assume stopping training to release devices
grpc_stub.CloseTrainerSession(training_session_id)

# attempt to init a new model with the new checkpoint and start training
init_response = grpc_stub.Init(
training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment(resume=model_checkpoint_file))
)
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
grpc_stub.Start(training_session_id)

def test_export_while_running(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)

grpc_stub.Start(training_session_id)

with tempfile.TemporaryDirectory() as model_checkpoint_dir:
model_export_file = Path(model_checkpoint_dir) / "bioimageio.zip"
export_request = training_pb2.ExportRequest(sessionId=training_session_id, filePath=str(model_export_file))
grpc_stub.Export(export_request)
assert model_export_file.exists()
self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED)

# assume stopping training since model is exported
grpc_stub.CloseTrainerSession(training_session_id)

def test_export_while_paused(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)

grpc_stub.Start(training_session_id)
time.sleep(1)
grpc_stub.Pause(training_session_id)

with tempfile.TemporaryDirectory() as model_checkpoint_dir:
model_export_file = Path(model_checkpoint_dir) / "bioimageio.zip"
export_request = training_pb2.ExportRequest(sessionId=training_session_id, filePath=str(model_export_file))
grpc_stub.Export(export_request)
assert model_export_file.exists()
self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED)

# assume stopping training since model is exported
grpc_stub.CloseTrainerSession(training_session_id)
Expand Down
11 changes: 8 additions & 3 deletions tiktorch/server/session/backend/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,17 @@ def forward(self, input_tensors):
return res

def save(self, file_path: Path):
self.pause()
init_state = self.get_state() # retain the state after save
if init_state == TrainerState.RUNNING:
self.pause()
self._trainer.save_state_dict(file_path)
self.resume()
if init_state == TrainerState.RUNNING:
self.resume()

def export(self, file_path: Path):
self.pause()
init_state = self.get_state()
if init_state == TrainerState.RUNNING:
self.pause()
self._trainer.export(file_path)

def _should_stop(self):
Expand Down

0 comments on commit 0e7735a

Please sign in to comment.