diff --git a/tests/test_server/test_grpc/test_training_servicer.py b/tests/test_server/test_grpc/test_training_servicer.py index 93e05ba2..9d8dc26f 100644 --- a/tests/test_server/test_grpc/test_training_servicer.py +++ b/tests/test_server/test_grpc/test_training_servicer.py @@ -474,7 +474,7 @@ def test_close_trainer_session_twice(self, grpc_stub): grpc_stub.CloseTrainerSession(training_session_id) assert "Unknown session" in excinfo.value.details() - def test_forward(self, grpc_stub): + def test_forward_while_running(self, grpc_stub): init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) training_session_id = training_pb2.TrainingSessionId(id=init_response.id) @@ -491,6 +491,37 @@ def test_forward(self, grpc_stub): response = grpc_stub.Predict(predict_request) + # assert that predict command has retained the init state (e.g. RUNNING) + self.assert_state(grpc_stub, training_session_id, TrainerState.RUNNING) + + predicted_tensors = [pb_tensor_to_xarray(pb_tensor) for pb_tensor in response.tensors] + assert len(predicted_tensors) == 1 + predicted_tensor = predicted_tensors[0] + assert predicted_tensor.dims == ("b", "c", "z", "y", "x") + assert predicted_tensor.shape == (batch, out_channels_unet2d, 1, 128, 128) + + def test_forward_while_paused(self, grpc_stub): + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + + grpc_stub.Start(training_session_id) + + batch = 5 + in_channels_unet2d = 3 + out_channels_unet2d = 2 + shape = (batch, in_channels_unet2d, 1, 128, 128) + data = np.random.rand(*shape).astype(np.float32) + xarray_data = xr.DataArray(data, dims=("b", "c", "z", "y", "x")) + pb_tensor = xarray_to_pb_tensor(tensor_id="", array=xarray_data) + predict_request = training_pb2.PredictRequest(sessionId=training_session_id, tensors=[pb_tensor]) + + grpc_stub.Pause(training_session_id) + + response = grpc_stub.Predict(predict_request) + + # assert that predict command has retained the init state (e.g. PAUSED) + self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED) + predicted_tensors = [pb_tensor_to_xarray(pb_tensor) for pb_tensor in response.tensors] assert len(predicted_tensors) == 1 predicted_tensor = predicted_tensors[0] diff --git a/tiktorch/server/session/backend/supervisor.py b/tiktorch/server/session/backend/supervisor.py index a49ae19a..5bb07e2a 100644 --- a/tiktorch/server/session/backend/supervisor.py +++ b/tiktorch/server/session/backend/supervisor.py @@ -126,9 +126,12 @@ def shutdown(self): self._session_thread.join() def forward(self, input_tensors): - self.pause() + init_state = self.get_state() # retain the state after forward + if init_state == TrainerState.RUNNING: + self.pause() res = self._trainer.forward(input_tensors) - self.resume() + if init_state == TrainerState.RUNNING: + self.resume() return res def save(self):