From 0e7735acfd247bc4b481f214859255aafdb486c7 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Sat, 14 Dec 2024 16:23:39 +0100 Subject: [PATCH] Retain the training state when saving and exporting - If the model was initial paused or running, save after completion retain the state, while temporarily pausing to perform the save. - The export will pause the training if not paused before. --- .../test_grpc/test_training_servicer.py | 49 ++++++++++++++++++- tiktorch/server/session/backend/supervisor.py | 11 +++-- 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/tests/test_server/test_grpc/test_training_servicer.py b/tests/test_server/test_grpc/test_training_servicer.py index f8ba309f..2f0d92ff 100644 --- a/tests/test_server/test_grpc/test_training_servicer.py +++ b/tests/test_server/test_grpc/test_training_servicer.py @@ -542,7 +542,7 @@ def test_forward_while_paused(self, grpc_stub): assert predicted_tensor.dims == ("b", "c", "z", "y", "x") assert predicted_tensor.shape == (batch, out_channels_unet2d, 1, 128, 128) - def test_save(self, grpc_stub): + def test_save_while_running(self, grpc_stub): init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) training_session_id = training_pb2.TrainingSessionId(id=init_response.id) @@ -553,6 +553,7 @@ def test_save(self, grpc_stub): save_request = training_pb2.SaveRequest(sessionId=training_session_id, filePath=str(model_checkpoint_file)) grpc_stub.Save(save_request) assert model_checkpoint_file.exists() + self.assert_state(grpc_stub, training_session_id, TrainerState.RUNNING) # assume stopping training to release devices grpc_stub.CloseTrainerSession(training_session_id) @@ -564,17 +565,61 @@ def test_save(self, grpc_stub): training_session_id = training_pb2.TrainingSessionId(id=init_response.id) grpc_stub.Start(training_session_id) - def test_export(self, grpc_stub): + def test_save_while_paused(self, grpc_stub): init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) training_session_id = training_pb2.TrainingSessionId(id=init_response.id) grpc_stub.Start(training_session_id) + time.sleep(1) + grpc_stub.Pause(training_session_id) + + with tempfile.TemporaryDirectory() as model_checkpoint_dir: + model_checkpoint_file = Path(model_checkpoint_dir) / "model.pth" + save_request = training_pb2.SaveRequest(sessionId=training_session_id, filePath=str(model_checkpoint_file)) + grpc_stub.Save(save_request) + assert model_checkpoint_file.exists() + self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED) + + # assume stopping training to release devices + grpc_stub.CloseTrainerSession(training_session_id) + + # attempt to init a new model with the new checkpoint and start training + init_response = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment(resume=model_checkpoint_file)) + ) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + grpc_stub.Start(training_session_id) + + def test_export_while_running(self, grpc_stub): + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + + grpc_stub.Start(training_session_id) + + with tempfile.TemporaryDirectory() as model_checkpoint_dir: + model_export_file = Path(model_checkpoint_dir) / "bioimageio.zip" + export_request = training_pb2.ExportRequest(sessionId=training_session_id, filePath=str(model_export_file)) + grpc_stub.Export(export_request) + assert model_export_file.exists() + self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED) + + # assume stopping training since model is exported + grpc_stub.CloseTrainerSession(training_session_id) + + def test_export_while_paused(self, grpc_stub): + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + + grpc_stub.Start(training_session_id) + time.sleep(1) + grpc_stub.Pause(training_session_id) with tempfile.TemporaryDirectory() as model_checkpoint_dir: model_export_file = Path(model_checkpoint_dir) / "bioimageio.zip" export_request = training_pb2.ExportRequest(sessionId=training_session_id, filePath=str(model_export_file)) grpc_stub.Export(export_request) assert model_export_file.exists() + self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED) # assume stopping training since model is exported grpc_stub.CloseTrainerSession(training_session_id) diff --git a/tiktorch/server/session/backend/supervisor.py b/tiktorch/server/session/backend/supervisor.py index 8b7043a9..d4d4e45b 100644 --- a/tiktorch/server/session/backend/supervisor.py +++ b/tiktorch/server/session/backend/supervisor.py @@ -136,12 +136,17 @@ def forward(self, input_tensors): return res def save(self, file_path: Path): - self.pause() + init_state = self.get_state() # retain the state after save + if init_state == TrainerState.RUNNING: + self.pause() self._trainer.save_state_dict(file_path) - self.resume() + if init_state == TrainerState.RUNNING: + self.resume() def export(self, file_path: Path): - self.pause() + init_state = self.get_state() + if init_state == TrainerState.RUNNING: + self.pause() self._trainer.export(file_path) def _should_stop(self):