Skip to content

Commit

Permalink
Retain the initial state of training when forward
Browse files Browse the repository at this point in the history
If the training is running or paused, the forward, will retain the state
after completion. But it requires to pause so we can release memory and
do the forward pass.
  • Loading branch information
thodkatz committed Dec 14, 2024
1 parent 32cd26d commit 619ef5f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
33 changes: 32 additions & 1 deletion tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def test_close_trainer_session_twice(self, grpc_stub):
grpc_stub.CloseTrainerSession(training_session_id)
assert "Unknown session" in excinfo.value.details()

def test_forward(self, grpc_stub):
def test_forward_while_running(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)

Expand All @@ -491,6 +491,37 @@ def test_forward(self, grpc_stub):

response = grpc_stub.Predict(predict_request)

# assert that predict command has retained the init state (e.g. RUNNING)
self.assert_state(grpc_stub, training_session_id, TrainerState.RUNNING)

predicted_tensors = [pb_tensor_to_xarray(pb_tensor) for pb_tensor in response.tensors]
assert len(predicted_tensors) == 1
predicted_tensor = predicted_tensors[0]
assert predicted_tensor.dims == ("b", "c", "z", "y", "x")
assert predicted_tensor.shape == (batch, out_channels_unet2d, 1, 128, 128)

def test_forward_while_paused(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)

grpc_stub.Start(training_session_id)

batch = 5
in_channels_unet2d = 3
out_channels_unet2d = 2
shape = (batch, in_channels_unet2d, 1, 128, 128)
data = np.random.rand(*shape).astype(np.float32)
xarray_data = xr.DataArray(data, dims=("b", "c", "z", "y", "x"))
pb_tensor = xarray_to_pb_tensor(tensor_id="", array=xarray_data)
predict_request = training_pb2.PredictRequest(sessionId=training_session_id, tensors=[pb_tensor])

grpc_stub.Pause(training_session_id)

response = grpc_stub.Predict(predict_request)

# assert that predict command has retained the init state (e.g. PAUSED)
self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED)

predicted_tensors = [pb_tensor_to_xarray(pb_tensor) for pb_tensor in response.tensors]
assert len(predicted_tensors) == 1
predicted_tensor = predicted_tensors[0]
Expand Down
7 changes: 5 additions & 2 deletions tiktorch/server/session/backend/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,12 @@ def shutdown(self):
self._session_thread.join()

def forward(self, input_tensors):
self.pause()
init_state = self.get_state() # retain the state after forward
if init_state == TrainerState.RUNNING:
self.pause()
res = self._trainer.forward(input_tensors)
self.resume()
if init_state == TrainerState.RUNNING:
self.resume()
return res

def save(self):
Expand Down

0 comments on commit 619ef5f

Please sign in to comment.