From b6d63c78dba8f743097affac680185951942a889 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Sat, 14 Dec 2024 16:00:24 +0100 Subject: [PATCH] Retain the initial state of training when forward If the training is running or paused, the forward, will retain the state after completion. But it requires to pause so we can release memory and do the forward pass. --- .../test_grpc/test_training_servicer.py | 33 ++++++++++++++++++- tiktorch/server/session/backend/supervisor.py | 7 ++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/tests/test_server/test_grpc/test_training_servicer.py b/tests/test_server/test_grpc/test_training_servicer.py index 9a4e7907..0dcb9107 100644 --- a/tests/test_server/test_grpc/test_training_servicer.py +++ b/tests/test_server/test_grpc/test_training_servicer.py @@ -474,7 +474,7 @@ def test_close_trainer_session_twice(self, grpc_stub): grpc_stub.CloseTrainerSession(training_session_id) assert "Unknown session" in excinfo.value.details() - def test_forward(self, grpc_stub): + def test_forward_while_running(self, grpc_stub): init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) training_session_id = training_pb2.TrainingSessionId(id=init_response.id) @@ -491,6 +491,37 @@ def test_forward(self, grpc_stub): response = grpc_stub.Predict(predict_request) + # assert that predict command has retained the init state (e.g. RUNNING) + self.assert_state(grpc_stub, training_session_id, TrainerState.RUNNING) + + predicted_tensors = [pb_tensor_to_xarray(pb_tensor) for pb_tensor in response.tensors] + assert len(predicted_tensors) == 1 + predicted_tensor = predicted_tensors[0] + assert predicted_tensor.dims == ("b", "c", "z", "y", "x") + assert predicted_tensor.shape == (batch, out_channels_unet2d, 1, 128, 128) + + def test_forward_while_paused(self, grpc_stub): + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + + grpc_stub.Start(training_session_id) + + batch = 5 + in_channels_unet2d = 3 + out_channels_unet2d = 2 + shape = (batch, in_channels_unet2d, 1, 128, 128) + data = np.random.rand(*shape).astype(np.float32) + xarray_data = xr.DataArray(data, dims=("b", "c", "z", "y", "x")) + pb_tensor = xarray_to_pb_tensor(tensor_id="", array=xarray_data) + predict_request = training_pb2.PredictRequest(sessionId=training_session_id, tensors=[pb_tensor]) + + grpc_stub.Pause(training_session_id) + + response = grpc_stub.Predict(predict_request) + + # assert that predict command has retained the init state (e.g. PAUSED) + self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED) + predicted_tensors = [pb_tensor_to_xarray(pb_tensor) for pb_tensor in response.tensors] assert len(predicted_tensors) == 1 predicted_tensor = predicted_tensors[0] diff --git a/tiktorch/server/session/backend/supervisor.py b/tiktorch/server/session/backend/supervisor.py index a49ae19a..5bb07e2a 100644 --- a/tiktorch/server/session/backend/supervisor.py +++ b/tiktorch/server/session/backend/supervisor.py @@ -126,9 +126,12 @@ def shutdown(self): self._session_thread.join() def forward(self, input_tensors): - self.pause() + init_state = self.get_state() # retain the state after forward + if init_state == TrainerState.RUNNING: + self.pause() res = self._trainer.forward(input_tensors) - self.resume() + if init_state == TrainerState.RUNNING: + self.resume() return res def save(self):