diff --git a/examples/grpc_client.py b/examples/grpc_client.py index f0fe9ce5..f64ebd5b 100644 --- a/examples/grpc_client.py +++ b/examples/grpc_client.py @@ -1,12 +1,12 @@ import grpc -from tiktorch.proto import inference_pb2, inference_pb2_grpc +from tiktorch.proto import inference_pb2_grpc, utils_pb2 def run(): with grpc.insecure_channel("127.0.0.1:5567") as channel: stub = inference_pb2_grpc.InferenceStub(channel) - response = stub.ListDevices(inference_pb2.Empty()) + response = stub.ListDevices(utils_pb2.Empty()) print(response) diff --git a/proto/inference.proto b/proto/inference.proto index f2a95159..3d10929b 100644 --- a/proto/inference.proto +++ b/proto/inference.proto @@ -1,5 +1,10 @@ syntax = "proto3"; +package inference; + +import "utils.proto"; + + service Inference { rpc CreateModelSession(CreateModelSessionRequest) returns (ModelSession) {} @@ -14,15 +19,6 @@ service Inference { rpc Predict(PredictRequest) returns (PredictResponse) {} } -message Device { - enum Status { - AVAILABLE = 0; - IN_USE = 1; - } - - string id = 1; - Status status = 2; -} message CreateDatasetDescriptionRequest { string modelSessionId = 1; @@ -76,26 +72,6 @@ message LogEntry { string content = 3; } -message Devices { - repeated Device devices = 1; -} - -message NamedInt { - uint32 size = 1; - string name = 2; -} - -message NamedFloat { - float size = 1; - string name = 2; -} - -message Tensor { - bytes buffer = 1; - string dtype = 2; - string tensorId = 3; - repeated NamedInt shape = 4; -} message PredictRequest { string modelSessionId = 1; @@ -107,7 +83,6 @@ message PredictResponse { repeated Tensor tensors = 1; } -message Empty {} service FlightControl { rpc Ping(Empty) returns (Empty) {} diff --git a/proto/training.proto b/proto/training.proto new file mode 100644 index 00000000..496a6eaa --- /dev/null +++ b/proto/training.proto @@ -0,0 +1,95 @@ +syntax = "proto3"; + +package training; + +import "utils.proto"; + + + +service Training { + rpc ListDevices(Empty) returns (Devices) {} + + rpc Init(TrainingConfig) returns (TrainingSessionId) {} + + rpc Start(TrainingSessionId) returns (Empty) {} + + rpc Resume(TrainingSessionId) returns (Empty) {} + + rpc Pause(TrainingSessionId) returns (Empty) {} + + rpc StreamUpdates(TrainingSessionId) returns (stream StreamUpdateResponse) {} + + rpc GetLogs(TrainingSessionId) returns (GetLogsResponse) {} + + rpc Save(TrainingSessionId) returns (Empty) {} + + rpc Export(TrainingSessionId) returns (Empty) {} + + rpc Predict(PredictRequest) returns (PredictResponse) {} + + rpc GetStatus(TrainingSessionId) returns (GetStatusResponse) {} + + rpc CloseTrainerSession(TrainingSessionId) returns (Empty) {} +} + +message TrainingSessionId { + string id = 1; +} + +message Logs { + enum ModelPhase { + Train = 0; + Eval = 1; + } + ModelPhase mode = 1; + double eval_score = 2; + double loss = 3; + uint32 iteration = 4; +} + + +message StreamUpdateResponse { + uint32 best_model_idx = 1; + Logs logs = 2; +} + + +message GetLogsResponse { + repeated Logs logs = 1; +} + + + +message PredictRequest { + repeated Tensor tensors = 1; + TrainingSessionId sessionId = 2; +} + + +message PredictResponse { + repeated Tensor tensors = 1; +} + +message ValidationResponse { + double validation_score_average = 1; +} + +message GetStatusResponse { + enum State { + Idle = 0; + Running = 1; + Paused = 2; + Failed = 3; + Finished = 4; + } + State state = 1; +} + + +message GetCurrentBestModelIdxResponse { + uint32 id = 1; +} + +message TrainingConfig { + string yaml_content = 1; +} diff --git a/proto/utils.proto b/proto/utils.proto new file mode 100644 index 00000000..cb24d3e3 --- /dev/null +++ b/proto/utils.proto @@ -0,0 +1,34 @@ +syntax = "proto3"; + +message Empty {} + +message NamedInt { + uint32 size = 1; + string name = 2; +} + +message NamedFloat { + float size = 1; + string name = 2; +} + +message Tensor { + bytes buffer = 1; + string dtype = 2; + string tensorId = 3; + repeated NamedInt shape = 4; +} + +message Device { + enum Status { + AVAILABLE = 0; + IN_USE = 1; + } + + string id = 1; + Status status = 2; +} + +message Devices { + repeated Device devices = 1; +} \ No newline at end of file diff --git a/pytest.ini b/pytest.ini index 9370a7ea..0e0eea50 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,7 +1,7 @@ [pytest] python_files = test_*.py addopts = - --timeout 10 + --timeout 60 -v -s --color=yes diff --git a/setup.cfg b/setup.cfg index aecec9f7..50d7fac8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,5 +12,5 @@ max-line-length = 120 [flake8] max-line-length = 120 -ignore=E203 +ignore=E203,W503 exclude = tiktorch/proto/*,vendor diff --git a/tests/test_converters.py b/tests/test_converters.py index cf5a51d2..26a68d9c 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -13,7 +13,7 @@ xarray_to_pb_tensor, xr_tensors_to_sample, ) -from tiktorch.proto import inference_pb2 +from tiktorch.proto import utils_pb2 def _numpy_to_pb_tensor(arr, tensor_id: str = "dummy_tensor_name"): @@ -21,7 +21,7 @@ def _numpy_to_pb_tensor(arr, tensor_id: str = "dummy_tensor_name"): Makes sure that tensor was serialized/deserialized """ tensor = numpy_to_pb_tensor(tensor_id, arr) - parsed = inference_pb2.Tensor() + parsed = utils_pb2.Tensor() parsed.ParseFromString(tensor.SerializeToString()) return parsed @@ -31,7 +31,7 @@ def to_pb_tensor(tensor_id: str, arr: xr.DataArray): Makes sure that tensor was serialized/deserialized """ tensor = xarray_to_pb_tensor(tensor_id, arr) - parsed = inference_pb2.Tensor() + parsed = utils_pb2.Tensor() parsed.ParseFromString(tensor.SerializeToString()) return parsed @@ -40,7 +40,7 @@ class TestNumpyToPBTensor: def test_should_serialize_to_tensor_type(self): arr = np.arange(9) tensor = _numpy_to_pb_tensor(arr) - assert isinstance(tensor, inference_pb2.Tensor) + assert isinstance(tensor, utils_pb2.Tensor) @pytest.mark.parametrize("np_dtype,dtype_str", [(np.int64, "int64"), (np.uint8, "uint8"), (np.float32, "float32")]) def test_should_have_dtype_as_str(self, np_dtype, dtype_str): @@ -65,12 +65,12 @@ def test_should_have_serialized_bytes(self): class TestPBTensorToNumpy: def test_should_raise_on_empty_dtype(self): - tensor = inference_pb2.Tensor(dtype="", shape=[inference_pb2.NamedInt(size=1), inference_pb2.NamedInt(size=2)]) + tensor = utils_pb2.Tensor(dtype="", shape=[utils_pb2.NamedInt(size=1), utils_pb2.NamedInt(size=2)]) with pytest.raises(ValueError): pb_tensor_to_numpy(tensor) def test_should_raise_on_empty_shape(self): - tensor = inference_pb2.Tensor(dtype="int64", shape=[]) + tensor = utils_pb2.Tensor(dtype="int64", shape=[]) with pytest.raises(ValueError): pb_tensor_to_numpy(tensor) @@ -109,7 +109,7 @@ class TestXarrayToPBTensor: def test_should_serialize_to_tensor_type(self): xarr = xr.DataArray(np.arange(8).reshape((2, 4)), dims=("x", "y")) pb_tensor = to_pb_tensor("input0", xarr) - assert isinstance(pb_tensor, inference_pb2.Tensor) + assert isinstance(pb_tensor, utils_pb2.Tensor) assert len(pb_tensor.shape) == 2 dim1 = pb_tensor.shape[0] dim2 = pb_tensor.shape[1] @@ -137,12 +137,12 @@ def test_should_have_serialized_bytes(self): class TestPBTensorToXarray: def test_should_raise_on_empty_dtype(self): - tensor = inference_pb2.Tensor(dtype="", shape=[inference_pb2.NamedInt(size=1), inference_pb2.NamedInt(size=2)]) + tensor = utils_pb2.Tensor(dtype="", shape=[utils_pb2.NamedInt(size=1), utils_pb2.NamedInt(size=2)]) with pytest.raises(ValueError): pb_tensor_to_xarray(tensor) def test_should_raise_on_empty_shape(self): - tensor = inference_pb2.Tensor(dtype="int64", shape=[]) + tensor = utils_pb2.Tensor(dtype="int64", shape=[]) with pytest.raises(ValueError): pb_tensor_to_xarray(tensor) @@ -178,19 +178,19 @@ def test_should_same_data(self, shape): class TestSample: def test_pb_tensors_to_sample(self): arr_1 = np.arange(32 * 32, dtype=np.int64).reshape(32, 32) - tensor_1 = inference_pb2.Tensor( + tensor_1 = utils_pb2.Tensor( dtype="int64", tensorId="input1", buffer=bytes(arr_1), - shape=[inference_pb2.NamedInt(name="x", size=32), inference_pb2.NamedInt(name="y", size=32)], + shape=[utils_pb2.NamedInt(name="x", size=32), utils_pb2.NamedInt(name="y", size=32)], ) arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64) - tensor_2 = inference_pb2.Tensor( + tensor_2 = utils_pb2.Tensor( dtype="int64", tensorId="input2", buffer=bytes(arr_2), - shape=[inference_pb2.NamedInt(name="x", size=64), inference_pb2.NamedInt(name="y", size=64)], + shape=[utils_pb2.NamedInt(name="x", size=64), utils_pb2.NamedInt(name="y", size=64)], ) sample = pb_tensors_to_sample([tensor_1, tensor_2]) @@ -218,17 +218,17 @@ def test_sample_to_pb_tensors(self): tensors_ids = ["input1", "input2"] sample = xr_tensors_to_sample(tensors_ids, [tensor_1, tensor_2]) - pb_tensor_1 = inference_pb2.Tensor( + pb_tensor_1 = utils_pb2.Tensor( dtype="int64", tensorId="input1", buffer=bytes(arr_1), - shape=[inference_pb2.NamedInt(name="x", size=32), inference_pb2.NamedInt(name="y", size=32)], + shape=[utils_pb2.NamedInt(name="x", size=32), utils_pb2.NamedInt(name="y", size=32)], ) - pb_tensor_2 = inference_pb2.Tensor( + pb_tensor_2 = utils_pb2.Tensor( dtype="int64", tensorId="input2", buffer=bytes(arr_2), - shape=[inference_pb2.NamedInt(name="x", size=64), inference_pb2.NamedInt(name="y", size=64)], + shape=[utils_pb2.NamedInt(name="x", size=64), utils_pb2.NamedInt(name="y", size=64)], ) expected_tensors = [pb_tensor_1, pb_tensor_2] diff --git a/tests/test_server/test_grpc/test_fligh_control_servicer.py b/tests/test_server/test_grpc/test_fligh_control_servicer.py index 52f32c41..363987e2 100644 --- a/tests/test_server/test_grpc/test_fligh_control_servicer.py +++ b/tests/test_server/test_grpc/test_fligh_control_servicer.py @@ -29,10 +29,10 @@ def _pinger(): pinger_thread.start() assert not evt.is_set() - assert not evt.wait(timeout=0.2) + assert not evt.wait(timeout=1) stop_pinger.set() - assert evt.wait(timeout=0.2) + assert evt.wait(timeout=1) def test_shutdown_timeout_0_means_no_watchdog(): diff --git a/tests/test_server/test_grpc/test_inference_servicer.py b/tests/test_server/test_grpc/test_inference_servicer.py index b1de0213..52827193 100644 --- a/tests/test_server/test_grpc/test_inference_servicer.py +++ b/tests/test_server/test_grpc/test_inference_servicer.py @@ -9,7 +9,7 @@ from tiktorch import converters from tiktorch.converters import pb_tensor_to_xarray -from tiktorch.proto import inference_pb2, inference_pb2_grpc +from tiktorch.proto import inference_pb2, inference_pb2_grpc, utils_pb2 from tiktorch.server.data_store import DataStore from tiktorch.server.device_pool import TorchDevicePool from tiktorch.server.grpc import inference_servicer @@ -101,13 +101,13 @@ def test_model_init_failed_close_session(self, bioimage_model_explicit_add_one_s class TestDeviceManagement: def test_list_devices(self, grpc_stub): - resp = grpc_stub.ListDevices(inference_pb2.Empty()) + resp = grpc_stub.ListDevices(utils_pb2.Empty()) device_by_id = {d.id: d for d in resp.devices} assert "cpu" in device_by_id - assert inference_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status + assert utils_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status def _query_devices(self, grpc_stub): - dev_resp = grpc_stub.ListDevices(inference_pb2.Empty()) + dev_resp = grpc_stub.ListDevices(utils_pb2.Empty()) device_by_id = {d.id: d for d in dev_resp.devices} return device_by_id @@ -121,19 +121,19 @@ def test_if_model_create_fails_devices_are_released(self, grpc_stub): device_by_id = self._query_devices(grpc_stub) assert "cpu" in device_by_id - assert inference_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status + assert utils_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status def test_use_device(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5): model_bytes = bioimage_model_explicit_add_one_siso_v5 device_by_id = self._query_devices(grpc_stub) assert "cpu" in device_by_id - assert inference_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status + assert utils_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status grpc_stub.CreateModelSession(valid_model_request(model_bytes, device_ids=["cpu"])) device_by_id = self._query_devices(grpc_stub) assert "cpu" in device_by_id - assert inference_pb2.Device.Status.IN_USE == device_by_id["cpu"].status + assert utils_pb2.Device.Status.IN_USE == device_by_id["cpu"].status def test_using_same_device_fails(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5): model_bytes = bioimage_model_explicit_add_one_siso_v5 @@ -147,20 +147,20 @@ def test_closing_session_releases_devices(self, grpc_stub, bioimage_model_explic device_by_id = self._query_devices(grpc_stub) assert "cpu" in device_by_id - assert inference_pb2.Device.Status.IN_USE == device_by_id["cpu"].status + assert utils_pb2.Device.Status.IN_USE == device_by_id["cpu"].status grpc_stub.CloseModelSession(model) device_by_id_after_close = self._query_devices(grpc_stub) assert "cpu" in device_by_id_after_close - assert inference_pb2.Device.Status.AVAILABLE == device_by_id_after_close["cpu"].status + assert utils_pb2.Device.Status.AVAILABLE == device_by_id_after_close["cpu"].status class TestGetLogs: def test_returns_ack_message(self, bioimage_model_explicit_add_one_siso_v5, grpc_stub): model_bytes = bioimage_model_explicit_add_one_siso_v5 grpc_stub.CreateModelSession(valid_model_request(model_bytes)) - resp = grpc_stub.GetLogs(inference_pb2.Empty()) + resp = grpc_stub.GetLogs(utils_pb2.Empty()) record = next(resp) assert inference_pb2.LogEntry.Level.INFO == record.level assert "Sending model logs" == record.content diff --git a/tests/test_server/test_grpc/test_init.py b/tests/test_server/test_grpc/test_init.py index af7bcf8a..0c8b57a9 100644 --- a/tests/test_server/test_grpc/test_init.py +++ b/tests/test_server/test_grpc/test_init.py @@ -4,8 +4,8 @@ import grpc -from tiktorch.proto.inference_pb2 import Empty from tiktorch.proto.inference_pb2_grpc import FlightControlStub +from tiktorch.proto.utils_pb2 import Empty from tiktorch.server.grpc import serve from tiktorch.utils import wait diff --git a/tests/test_server/test_grpc/test_training_servicer.py b/tests/test_server/test_grpc/test_training_servicer.py new file mode 100644 index 00000000..2f054a5f --- /dev/null +++ b/tests/test_server/test_grpc/test_training_servicer.py @@ -0,0 +1,507 @@ +import tempfile +import threading +import time +from pathlib import Path +from typing import Callable + +import grpc +import h5py +import numpy as np +import pytest + +from tiktorch.converters import pb_state_to_trainer, trainer_state_to_pb +from tiktorch.proto import training_pb2, training_pb2_grpc, utils_pb2 +from tiktorch.server.device_pool import TorchDevicePool +from tiktorch.server.grpc import training_servicer +from tiktorch.server.session.backend.base import TrainerSessionBackend +from tiktorch.server.session.process import TrainerSessionProcess +from tiktorch.server.session_manager import SessionManager +from tiktorch.trainer import ShouldStopCallbacks, Trainer, TrainerState + + +@pytest.fixture(scope="module") +def grpc_add_to_server(): + return training_pb2_grpc.add_TrainingServicer_to_server + + +@pytest.fixture(scope="module") +def grpc_servicer(): + return training_servicer.TrainingServicer(TorchDevicePool(), SessionManager()) + + +@pytest.fixture(autouse=True) +def clean(grpc_servicer): + yield + grpc_servicer.close_all_sessions() + + +@pytest.fixture(scope="module") +def grpc_stub_cls(): + return training_pb2_grpc.TrainingStub + + +def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: str = "cpu"): + return f""" +device: {device} # Use CPU for faster test execution, change to 'cuda' if GPU is available and necessary +model: + name: UNet2D + in_channels: 3 + out_channels: 2 + layer_order: gcr + f_maps: 16 + num_groups: 4 + final_sigmoid: false + is_segmentation: true +trainer: + checkpoint_dir: {checkpoint_dir} + resume: null + validate_after_iters: 2 + log_after_iters: 2 + max_num_epochs: 1000 + max_num_iterations: 10000 + eval_score_higher_is_better: True +optimizer: + learning_rate: 0.0002 + weight_decay: 0.00001 +loss: + name: CrossEntropyLoss +eval_metric: + name: MeanIoU + ignore_index: null +lr_scheduler: + name: MultiStepLR + milestones: [2, 3] + gamma: 0.5 +loaders: + dataset: StandardHDF5Dataset + batch_size: 1 + num_workers: 1 + raw_internal_path: raw + label_internal_path: label + weight_internal_path: null + train: + file_paths: + - {train_data_dir} + + slice_builder: + name: SliceBuilder + patch_shape: [1, 64, 64] + stride_shape: [1, 64, 64] + skip_shape_check: true + + transformer: + raw: + - name: Standardize + - name: RandomFlip + - name: RandomRotate90 + - name: RandomRotate + axes: [[2, 1]] + angle_spectrum: 30 + mode: reflect + - name: ElasticDeformation + execution_probability: 1.0 + spline_order: 3 + - name: AdditiveGaussianNoise + execution_probability: 1.0 + - name: AdditivePoissonNoise + execution_probability: 1.0 + - name: ToTensor + expand_dims: true + label: + - name: RandomFlip + - name: RandomRotate90 + - name: RandomRotate + axes: [[2, 1]] + angle_spectrum: 30 + mode: reflect + - name: ElasticDeformation + execution_probability: 1.0 + spline_order: 0 + - name: ToTensor + # do not expand dims for cross-entropy loss + expand_dims: false + # cross-entropy loss requires target to be of type 'long' + dtype: 'int64' + weight: + - name: ToTensor + expand_dims: false + val: + file_paths: + - {val_data_path} + + slice_builder: + name: SliceBuilder + patch_shape: [1, 64, 64] + stride_shape: [1, 64, 64] + skip_shape_check: true + + transformer: + raw: + - name: Standardize + - name: ToTensor + expand_dims: true + label: + - name: ToTensor + expand_dims: false + dtype: 'int64' + weight: + - name: ToTensor + expand_dims: false +""" + + +def create_random_dataset(shape, channel_per_class): + tmp = tempfile.NamedTemporaryFile(delete=False) + + with h5py.File(tmp.name, "w") as f: + l_shape = w_shape = shape + # make sure that label and weight tensors are 3D + if len(shape) == 4: + l_shape = shape[1:] + w_shape = shape[1:] + + if channel_per_class: + l_shape = (2,) + l_shape + + f.create_dataset("raw", data=np.random.rand(*shape), dtype=np.float32) + f.create_dataset("label", data=np.random.randint(0, 2, l_shape), dtype=np.int64) + f.create_dataset("weight_map", data=np.random.rand(*w_shape), dtype=np.float32) + + return tmp.name + + +def prepare_unet2d_test_environment(device: str = "cpu") -> str: + checkpoint_dir = Path(tempfile.mkdtemp()) + + shape = (3, 1, 128, 128) + binary_loss = False + train = create_random_dataset(shape, binary_loss) + val = create_random_dataset(shape, binary_loss) + + return unet2d_config_path(checkpoint_dir=checkpoint_dir, train_data_dir=train, val_data_path=val, device=device) + + +class TestTrainingServicer: + def assert_state(self, grpc_stub, training_session_id: str, state_to_check: TrainerState): + response = grpc_stub.GetStatus(training_session_id) + assert response.state == trainer_state_to_pb[state_to_check] + + def poll_for_state_grpc(self, grpc_stub, session_id, expected_state: TrainerState, timeout=3, poll_interval=0.1): + def get_status(*args): + return pb_state_to_trainer[grpc_stub.GetStatus(session_id).state] + + self.poll_for_state(get_status, expected_state, timeout, poll_interval) + + def poll_for_state(self, get_status: Callable, expected_state: TrainerState, timeout=3, poll_interval=0.1): + start_time = time.time() + + while True: + current_state = get_status() + + if current_state == expected_state: + return current_state + + if time.time() - start_time > timeout: + pytest.fail(f"Timeout: State did not transition to {expected_state} within {timeout} seconds.") + + time.sleep(poll_interval) + + def test_init_success(self, grpc_stub): + """ + Test that a session initializes successfully with valid YAML. + """ + response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + assert response.id is not None, "Failed to initialize training session" + + def test_init_invalid_yaml(self, grpc_stub): + """ + Test that initializing with invalid YAML raises an error. + """ + invalid_yaml = "invalid_yaml_content: {unbalanced_braces" + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=invalid_yaml)) + assert "expected ',' or '}', but got" in excinfo.value.details() + + def test_init_failed_then_devices_are_released(self, grpc_stub): + invalid_yaml = """ + device: cpu + unknown: 42 + """ + with pytest.raises(grpc.RpcError): + grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=invalid_yaml)) + + # attempt to init with the same device + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + assert init_response.id is not None + + def test_start_training_success(self): + """ + Test starting training after successful initialization. + """ + trainer_is_called = threading.Event() + + class MockedNominalTrainer(Trainer): + def __init__(self): + self.num_epochs = 0 + self.max_num_epochs = 10 + self.num_iterations = 0 + self.max_num_iterations = 100 + self.should_stop_callbacks = ShouldStopCallbacks() + + def fit(self): + print("Training has started") + trainer_is_called.set() + + class MockedTrainerSessionBackend(TrainerSessionProcess): + def init(self, trainer_yaml_config: str = ""): + self._worker = TrainerSessionBackend(MockedNominalTrainer()) + + backend = MockedTrainerSessionBackend() + backend.init() + backend.start_training() + trainer_is_called.wait(timeout=5) + backend.shutdown() + + def test_concurrent_state_transitions(self, grpc_stub): + """ + Test concurrent calls to Start, Pause, and Resume to ensure no deadlocks or race conditions. + + The test should exit gracefully without hanging processes or threads. + """ + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + + threads = [] + for _ in range(2): + threads.append(threading.Thread(target=lambda: grpc_stub.Start(training_session_id))) + threads.append(threading.Thread(target=lambda: grpc_stub.Pause(training_session_id))) + threads.append(threading.Thread(target=lambda: grpc_stub.Resume(training_session_id))) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + def test_queueing_multiple_commands(self, grpc_stub): + def assert_state(state_to_check): + self.assert_state(grpc_stub, training_session_id, state_to_check) + + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + + grpc_stub.Start(training_session_id) + assert_state(TrainerState.RUNNING) + + for _ in range(3): + grpc_stub.Pause(training_session_id) + assert_state(TrainerState.PAUSED) + + grpc_stub.Resume(training_session_id) + assert_state(TrainerState.RUNNING) + + def test_error_handling_on_invalid_state_transitions_after_training_started(self, grpc_stub): + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + + # Attempt to start again while already running + grpc_stub.Start(training_session_id) + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.Start(training_session_id) + assert "Invalid state transition: TrainerState.RUNNING -> TrainerState.RUNNING" in excinfo.value.details() + + # Attempt to pause again while already paused + grpc_stub.Pause(training_session_id) + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.Pause(training_session_id) + assert "Invalid state transition: TrainerState.PAUSED -> TrainerState.PAUSED" in excinfo.value.details() + + # Attempt to resume again while already resumed + grpc_stub.Resume(training_session_id) + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.Resume(training_session_id) + assert "Invalid state transition: TrainerState.RUNNING -> TrainerState.RUNNING" in excinfo.value.details() + + def test_error_handling_on_invalid_state_transitions_before_training_started(self, grpc_stub): + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + + # Attempt to resume before start + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.Resume(training_session_id) + assert "Invalid state transition: TrainerState.IDLE -> TrainerState.RUNNING" in excinfo.value.details() + + # Attempt to pause before start + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.Pause(training_session_id) + assert "Invalid state transition: TrainerState.IDLE -> TrainerState.PAUSED" in excinfo.value.details() + + def test_start_training_without_init(self, grpc_stub): + """ + Test starting training without initializing a session. + """ + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.Start(utils_pb2.Empty()) + assert excinfo.value.code() == grpc.StatusCode.FAILED_PRECONDITION + assert "trainer-session with id doesn't exist" in excinfo.value.details() + + def test_recover_training_failed(self): + class MockedExceptionTrainer: + def __init__(self): + self.should_stop_callbacks = ShouldStopCallbacks() + + def fit(self): + raise Exception("mocked exception") + + class MockedNominalTrainer: + def __init__(self): + self.num_epochs = 0 + self.max_num_epochs = 10 + self.num_iterations = 0 + self.max_num_iterations = 100 + self.should_stop_callbacks = ShouldStopCallbacks() + + def fit(self): + for epoch in range(self.max_num_epochs): + self.num_epochs += 1 + + class MockedTrainerSessionBackend(TrainerSessionProcess): + def init(self, trainer_yaml_config: str): + if trainer_yaml_config == "nominal": + self._worker = TrainerSessionBackend(MockedNominalTrainer()) + elif trainer_yaml_config == "exception": + self._worker = TrainerSessionBackend(MockedExceptionTrainer()) + else: + # simulate user creating model that raises an exception, + # and then adjusts the config for a nominal run + raise AssertionError + + backend = MockedTrainerSessionBackend() + backend.init("exception") + backend.start_training() + + # client detects that state is failed, closes the session and starts a new one + self.poll_for_state(backend.get_state, expected_state=TrainerState.FAILED) + + backend.shutdown() + + backend.init("nominal") + backend.start_training() + self.poll_for_state(backend.get_state, expected_state=TrainerState.FINISHED) + backend.shutdown() + + def test_perform_operations_after_training_failed(self): + def assert_error(func, expected_message: str): + with pytest.raises(Exception) as excinfo: + func() + assert expected_message in str(excinfo.value) + + class MockedExceptionTrainer: + def __init__(self): + self.should_stop_callbacks = ShouldStopCallbacks() + + def fit(self): + raise Exception("mocked exception") + + class MockedTrainerSessionBackend(TrainerSessionProcess): + def init(self, trainer_yaml_config: str = ""): + self._worker = TrainerSessionBackend(MockedExceptionTrainer()) + + backend = MockedTrainerSessionBackend() + backend.init() + backend.start_training() + + start_thread = threading.Thread(target=backend.start_training) + start_thread.start() + + pause_thread = threading.Thread( + target=lambda: assert_error( + backend.pause_training, + "Invalid state transition: TrainerState.FAILED -> TrainerState.PAUSED", + ) + ) + pause_thread.start() + + resume_thread = threading.Thread( + target=lambda: assert_error( + backend.pause_training, + "Invalid state transition: TrainerState.FAILED -> TrainerState.RUNNING", + ) + ) + resume_thread.start() + + start_thread.join() + pause_thread.join() + resume_thread.join() + backend.shutdown() + + def test_graceful_shutdown_after_init(self, grpc_stub): + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + grpc_stub.CloseTrainerSession(training_session_id) + + def test_graceful_shutdown_after_start(self, grpc_stub): + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + grpc_stub.Start(training_session_id) + grpc_stub.CloseTrainerSession(training_session_id) + + def test_graceful_shutdown_after_pause(self, grpc_stub): + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + grpc_stub.Start(training_session_id) + grpc_stub.Pause(training_session_id) + grpc_stub.CloseTrainerSession(training_session_id) + + def test_graceful_shutdown_after_resume(self, grpc_stub): + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + grpc_stub.Start(training_session_id) + grpc_stub.Pause(training_session_id) + grpc_stub.Resume(training_session_id) + grpc_stub.CloseTrainerSession(training_session_id) + + def test_close_trainer_session_twice(self, grpc_stub): + # Attempt to close the session twice + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + grpc_stub.CloseTrainerSession(training_session_id) + + # The second attempt should raise an error + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.CloseTrainerSession(training_session_id) + assert "Unknown session" in excinfo.value.details() + + def test_close_session(self, grpc_stub): + """ + Test closing a training session. + """ + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + grpc_stub.CloseTrainerSession(training_session_id) + + # attempt to perform an operation while session is closed + operations = [grpc_stub.Start, grpc_stub.Pause, grpc_stub.Resume] + for operation in operations: + with pytest.raises(grpc.RpcError) as excinfo: + operation(training_session_id) + assert "doesn't exist" in excinfo.value.details() + + def test_multiple_sessions(self, grpc_stub): + response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + assert response.id is not None + + response = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment(device="gpu")) + ) + assert response.id is not None diff --git a/tests/test_server/test_training/test_training.py b/tests/test_server/test_training/test_training.py deleted file mode 100644 index 3492d6f4..00000000 --- a/tests/test_server/test_training/test_training.py +++ /dev/null @@ -1,126 +0,0 @@ -import threading -import time -from concurrent.futures import Future - -import numpy as np -import pytest -import xarray as xr - -from tiktorch.server.session import State -from tiktorch.server.session.backend import commands -from tiktorch.server.session.backend.supervisor import Supervisor -from tiktorch.utils import wait - - -class TestExemplumSupervisor: - class DummyCmd(commands.ICommand): - def execute(self, ctx): - pass - - class DummyExemplum: - def __init__(self): - self.iteration_count = 0 - self.max_num_iterations = 0 - self._break_cb = None - self._devs = [] - - def set_break_callback(self, cb): - self._break_cb = cb - - def predict_sample_without_blocking(self, input_tensors): - return [xr.DataArray(np.array([42]), dims=("x",))] - - def set_max_num_iterations(self, val): - self.max_num_iterations = val - - def stop_training(self, max_num_iterations=None, max_num_epochs=None): - return self._break_cb and self._break_cb() or self.iteration_count >= self.max_num_iterations - - def train(self): - while not self.stop_training(): - self.iteration_count += 1 - time.sleep(0.01) - - @pytest.fixture - def exemplum(self): - return self.DummyExemplum() - - @pytest.fixture - def supervisor(self, exemplum): - return Supervisor(exemplum) - - @pytest.fixture - def worker_thread(self, supervisor): - t = threading.Thread(target=supervisor.run) - t.start() - yield t - supervisor.send_command(commands.StopCmd()) - t.join() - - def test_not_running_worker_has_stopped_status(self, supervisor): - assert State.Stopped == supervisor.state - - def test_started_worker_has_idle_status(self, supervisor, worker_thread): - cmd = self.DummyCmd().awaitable - supervisor.send_command(cmd) - cmd.wait() - - assert State.Paused == supervisor.state - - def test_resuming_transitions_to_idle_with_no_devices(self, supervisor, worker_thread): - cmd = commands.ResumeCmd().awaitable - supervisor.send_command(cmd) - cmd.wait() - - assert State.Idle == supervisor.state - - def test_transition_to_running(self, supervisor, worker_thread): - cmd = commands.ResumeCmd() - supervisor.send_command(cmd) - - add_work = commands.SetMaxNumIterations(2).awaitable - supervisor.send_command(add_work) - add_work.wait() - - assert supervisor.state == State.Running - - def test_exception_during_train_should_transition_to_paused(self, supervisor, worker_thread, exemplum): - train_called = threading.Event() - train_proceed = threading.Event() - - def _exc(): - train_called.set() - train_proceed.wait() - raise Exception() - - exemplum.train = _exc - - cmd = commands.ResumeCmd() - supervisor.send_command(cmd) - - assert supervisor.state == State.Paused - add_work = commands.SetMaxNumIterations(2).awaitable - supervisor.send_command(add_work) - add_work.wait() - - train_called.wait() - wait(lambda: supervisor.state == State.Running, max_wait=1) - train_proceed.set() - wait(lambda: supervisor.state == State.Paused, max_wait=1) - - def test_finished_training_should_transition_to_paused(self, supervisor, worker_thread, exemplum): - cmd = commands.ResumeCmd() - supervisor.send_command(cmd) - - add_work = commands.SetMaxNumIterations(2).awaitable - supervisor.send_command(add_work) - add_work.wait() - assert supervisor.state == State.Running - time.sleep(0.1) # FIXME: Find a better way to wait for pause event with timeout - assert supervisor.state == State.Idle - - def test_forward(self, supervisor, worker_thread, exemplum): - fut = Future() - forward_cmd = commands.ForwardPass(fut, [xr.DataArray(np.array([1]), dims=("x",))]) - supervisor.send_command(forward_cmd) - assert [42] == fut.result() diff --git a/tests/test_server/test_training/test_worker/test_commands.py b/tests/test_server/test_training/test_worker/test_commands.py index f8307fa7..38ad4289 100644 --- a/tests/test_server/test_training/test_worker/test_commands.py +++ b/tests/test_server/test_training/test_worker/test_commands.py @@ -9,17 +9,17 @@ class TestCommandQueue: def test_stop_command_has_higher_priorityj(self): cmd_queue = cmds.CommandPriorityQueue() - stop_cmd = cmds.StopCmd() - cmd_queue.put_nowait(cmds.ResumeCmd()) + stop_cmd = cmds.ShutdownCmd() + cmd_queue.put_nowait(cmds.ResumeTrainingCmd()) cmd_queue.put_nowait(stop_cmd) - cmd_queue.put_nowait(cmds.PauseCmd()) + cmd_queue.put_nowait(cmds.PauseTrainingCmd()) received_cmd = cmd_queue.get_nowait() assert stop_cmd is received_cmd def test_queue_order_is_stable(self): cmd_queue = cmds.CommandPriorityQueue() - stop_cmds = [cmds.StopCmd() for _ in range(100)] + stop_cmds = [cmds.ShutdownCmd() for _ in range(100)] for cmd in stop_cmds: cmd_queue.put_nowait(cmd) diff --git a/tiktorch/converters.py b/tiktorch/converters.py index 3159c05c..016b7b3c 100644 --- a/tiktorch/converters.py +++ b/tiktorch/converters.py @@ -7,10 +7,21 @@ from bioimageio.core import Sample, Tensor from bioimageio.spec.model.v0_5 import TensorId -from tiktorch.proto import inference_pb2 +from tiktorch.proto import inference_pb2, training_pb2, utils_pb2 +from tiktorch.trainer import TrainerState +trainer_state_to_pb = { + TrainerState.IDLE: training_pb2.GetStatusResponse.State.Idle, + TrainerState.RUNNING: training_pb2.GetStatusResponse.State.Running, + TrainerState.PAUSED: training_pb2.GetStatusResponse.State.Paused, + TrainerState.FAILED: training_pb2.GetStatusResponse.State.Failed, + TrainerState.FINISHED: training_pb2.GetStatusResponse.State.Finished, +} -def pb_tensors_to_sample(pb_tensors: List[inference_pb2.Tensor]) -> Sample: +pb_state_to_trainer = {value: key for key, value in trainer_state_to_pb.items()} + + +def pb_tensors_to_sample(pb_tensors: List[utils_pb2.Tensor]) -> Sample: return Sample( members={TensorId(tensor.tensorId): Tensor.from_xarray(pb_tensor_to_xarray(tensor)) for tensor in pb_tensors}, id=None, @@ -30,21 +41,21 @@ def xr_tensors_to_sample(tensor_ids: List[str], tensors_data: List[xr.DataArray] ) -def sample_to_pb_tensors(sample: Sample) -> List[inference_pb2.Tensor]: +def sample_to_pb_tensors(sample: Sample) -> List[utils_pb2.Tensor]: return [xarray_to_pb_tensor(tensor_id, res_tensor.data) for tensor_id, res_tensor in sample.members.items()] -def numpy_to_pb_tensor(tensor_id: str, array: np.ndarray, axistags=None) -> inference_pb2.Tensor: +def numpy_to_pb_tensor(tensor_id: str, array: np.ndarray, axistags=None) -> utils_pb2.Tensor: if axistags: - shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, axistags)] + shape = [utils_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, axistags)] else: - shape = [inference_pb2.NamedInt(size=dim) for dim in array.shape] - return inference_pb2.Tensor(tensorId=tensor_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array)) + shape = [utils_pb2.NamedInt(size=dim) for dim in array.shape] + return utils_pb2.Tensor(tensorId=tensor_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array)) -def xarray_to_pb_tensor(tensor_id: str, array: xr.DataArray) -> inference_pb2.Tensor: - shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, array.dims)] - return inference_pb2.Tensor(tensorId=tensor_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array.data)) +def xarray_to_pb_tensor(tensor_id: str, array: xr.DataArray) -> utils_pb2.Tensor: + shape = [utils_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, array.dims)] + return utils_pb2.Tensor(tensorId=tensor_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array.data)) def name_int_tuples_to_pb_NamedInts(name_int_tuples) -> inference_pb2.NamedInts: @@ -59,7 +70,7 @@ def name_float_tuples_to_pb_NamedFloats(name_float_tuples) -> inference_pb2.Name ) -def pb_tensor_to_xarray(tensor: inference_pb2.Tensor) -> inference_pb2.Tensor: +def pb_tensor_to_xarray(tensor: utils_pb2.Tensor) -> xr.DataArray: if not tensor.dtype: raise ValueError("Tensor dtype is not specified") @@ -71,7 +82,7 @@ def pb_tensor_to_xarray(tensor: inference_pb2.Tensor) -> inference_pb2.Tensor: return xr.DataArray(data, dims=[d.name for d in tensor.shape]) -def pb_tensor_to_numpy(tensor: inference_pb2.Tensor) -> np.ndarray: +def pb_tensor_to_numpy(tensor: utils_pb2.Tensor) -> np.ndarray: if not tensor.dtype: raise ValueError("Tensor dtype is not specified") diff --git a/tiktorch/proto/data_store_pb2.py b/tiktorch/proto/data_store_pb2.py index 26ac233f..87c21ef4 100644 --- a/tiktorch/proto/data_store_pb2.py +++ b/tiktorch/proto/data_store_pb2.py @@ -2,10 +2,9 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # source: data_store.proto """Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database # @@protoc_insertion_point(imports) @@ -16,49 +15,8 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x64\x61ta_store.proto\":\n\x0eUploadResponse\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04size\x18\x02 \x01(\r\x12\x0e\n\x06sha256\x18\x03 \x01(\t\"\x1a\n\nUploadInfo\x12\x0c\n\x04size\x18\x01 \x01(\r\"J\n\rUploadRequest\x12\x1b\n\x04info\x18\x01 \x01(\x0b\x32\x0b.UploadInfoH\x00\x12\x11\n\x07\x63ontent\x18\x02 \x01(\x0cH\x00\x42\t\n\x07payload\"!\n\rRemoveRequest\x12\x10\n\x08uploadId\x18\x01 \x01(\t\"\x10\n\x0eRemoveResponse2g\n\tDataStore\x12-\n\x06Upload\x12\x0e.UploadRequest\x1a\x0f.UploadResponse\"\x00(\x01\x12+\n\x06Remove\x12\x0e.RemoveRequest\x1a\x0f.RemoveResponse\"\x00\x62\x06proto3') - - -_UPLOADRESPONSE = DESCRIPTOR.message_types_by_name['UploadResponse'] -_UPLOADINFO = DESCRIPTOR.message_types_by_name['UploadInfo'] -_UPLOADREQUEST = DESCRIPTOR.message_types_by_name['UploadRequest'] -_REMOVEREQUEST = DESCRIPTOR.message_types_by_name['RemoveRequest'] -_REMOVERESPONSE = DESCRIPTOR.message_types_by_name['RemoveResponse'] -UploadResponse = _reflection.GeneratedProtocolMessageType('UploadResponse', (_message.Message,), { - 'DESCRIPTOR' : _UPLOADRESPONSE, - '__module__' : 'data_store_pb2' - # @@protoc_insertion_point(class_scope:UploadResponse) - }) -_sym_db.RegisterMessage(UploadResponse) - -UploadInfo = _reflection.GeneratedProtocolMessageType('UploadInfo', (_message.Message,), { - 'DESCRIPTOR' : _UPLOADINFO, - '__module__' : 'data_store_pb2' - # @@protoc_insertion_point(class_scope:UploadInfo) - }) -_sym_db.RegisterMessage(UploadInfo) - -UploadRequest = _reflection.GeneratedProtocolMessageType('UploadRequest', (_message.Message,), { - 'DESCRIPTOR' : _UPLOADREQUEST, - '__module__' : 'data_store_pb2' - # @@protoc_insertion_point(class_scope:UploadRequest) - }) -_sym_db.RegisterMessage(UploadRequest) - -RemoveRequest = _reflection.GeneratedProtocolMessageType('RemoveRequest', (_message.Message,), { - 'DESCRIPTOR' : _REMOVEREQUEST, - '__module__' : 'data_store_pb2' - # @@protoc_insertion_point(class_scope:RemoveRequest) - }) -_sym_db.RegisterMessage(RemoveRequest) - -RemoveResponse = _reflection.GeneratedProtocolMessageType('RemoveResponse', (_message.Message,), { - 'DESCRIPTOR' : _REMOVERESPONSE, - '__module__' : 'data_store_pb2' - # @@protoc_insertion_point(class_scope:RemoveResponse) - }) -_sym_db.RegisterMessage(RemoveResponse) - -_DATASTORE = DESCRIPTOR.services_by_name['DataStore'] +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'data_store_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None diff --git a/tiktorch/proto/inference_pb2.py b/tiktorch/proto/inference_pb2.py index c300c315..be9fbde2 100644 --- a/tiktorch/proto/inference_pb2.py +++ b/tiktorch/proto/inference_pb2.py @@ -2,195 +2,49 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # source: inference.proto """Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() +from . import utils_pb2 as utils__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0finference.proto\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"i\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12\x1b\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x1a\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\"\x9e\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12\x1e\n\x05level\x18\x02 \x01(\x0e\x32\x0f.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Device\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"S\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x10\n\x08tensorId\x18\x03 \x01(\t\x12\x18\n\x05shape\x18\x04 \x03(\x0b\x32\t.NamedInt\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"\x07\n\x05\x45mpty2\xc6\x02\n\tInference\x12\x41\n\x12\x43reateModelSession\x12\x1a.CreateModelSessionRequest\x1a\r.ModelSession\"\x00\x12,\n\x11\x43loseModelSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12S\n\x18\x43reateDatasetDescription\x12 .CreateDatasetDescriptionRequest\x1a\x13.DatasetDescription\"\x00\x12 \n\x07GetLogs\x12\x06.Empty\x1a\t.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0finference.proto\x12\tinference\x1a\x0butils.proto\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"s\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12%\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x0f.inference.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x1a\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\"\xa8\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12(\n\x05level\x18\x02 \x01(\x0e\x32\x19.inference.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor2\x96\x03\n\tInference\x12U\n\x12\x43reateModelSession\x12$.inference.CreateModelSessionRequest\x1a\x17.inference.ModelSession\"\x00\x12\x36\n\x11\x43loseModelSession\x12\x17.inference.ModelSession\x1a\x06.Empty\"\x00\x12g\n\x18\x43reateDatasetDescription\x12*.inference.CreateDatasetDescriptionRequest\x1a\x1d.inference.DatasetDescription\"\x00\x12*\n\x07GetLogs\x12\x06.Empty\x1a\x13.inference.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12\x42\n\x07Predict\x12\x19.inference.PredictRequest\x1a\x1a.inference.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3') - - -_DEVICE = DESCRIPTOR.message_types_by_name['Device'] -_CREATEDATASETDESCRIPTIONREQUEST = DESCRIPTOR.message_types_by_name['CreateDatasetDescriptionRequest'] -_DATASETDESCRIPTION = DESCRIPTOR.message_types_by_name['DatasetDescription'] -_BLOB = DESCRIPTOR.message_types_by_name['Blob'] -_CREATEMODELSESSIONREQUEST = DESCRIPTOR.message_types_by_name['CreateModelSessionRequest'] -_NAMEDINTS = DESCRIPTOR.message_types_by_name['NamedInts'] -_NAMEDFLOATS = DESCRIPTOR.message_types_by_name['NamedFloats'] -_MODELSESSION = DESCRIPTOR.message_types_by_name['ModelSession'] -_LOGENTRY = DESCRIPTOR.message_types_by_name['LogEntry'] -_DEVICES = DESCRIPTOR.message_types_by_name['Devices'] -_NAMEDINT = DESCRIPTOR.message_types_by_name['NamedInt'] -_NAMEDFLOAT = DESCRIPTOR.message_types_by_name['NamedFloat'] -_TENSOR = DESCRIPTOR.message_types_by_name['Tensor'] -_PREDICTREQUEST = DESCRIPTOR.message_types_by_name['PredictRequest'] -_PREDICTRESPONSE = DESCRIPTOR.message_types_by_name['PredictResponse'] -_EMPTY = DESCRIPTOR.message_types_by_name['Empty'] -_DEVICE_STATUS = _DEVICE.enum_types_by_name['Status'] -_LOGENTRY_LEVEL = _LOGENTRY.enum_types_by_name['Level'] -Device = _reflection.GeneratedProtocolMessageType('Device', (_message.Message,), { - 'DESCRIPTOR' : _DEVICE, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:Device) - }) -_sym_db.RegisterMessage(Device) - -CreateDatasetDescriptionRequest = _reflection.GeneratedProtocolMessageType('CreateDatasetDescriptionRequest', (_message.Message,), { - 'DESCRIPTOR' : _CREATEDATASETDESCRIPTIONREQUEST, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:CreateDatasetDescriptionRequest) - }) -_sym_db.RegisterMessage(CreateDatasetDescriptionRequest) - -DatasetDescription = _reflection.GeneratedProtocolMessageType('DatasetDescription', (_message.Message,), { - 'DESCRIPTOR' : _DATASETDESCRIPTION, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:DatasetDescription) - }) -_sym_db.RegisterMessage(DatasetDescription) - -Blob = _reflection.GeneratedProtocolMessageType('Blob', (_message.Message,), { - 'DESCRIPTOR' : _BLOB, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:Blob) - }) -_sym_db.RegisterMessage(Blob) - -CreateModelSessionRequest = _reflection.GeneratedProtocolMessageType('CreateModelSessionRequest', (_message.Message,), { - 'DESCRIPTOR' : _CREATEMODELSESSIONREQUEST, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:CreateModelSessionRequest) - }) -_sym_db.RegisterMessage(CreateModelSessionRequest) - -NamedInts = _reflection.GeneratedProtocolMessageType('NamedInts', (_message.Message,), { - 'DESCRIPTOR' : _NAMEDINTS, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:NamedInts) - }) -_sym_db.RegisterMessage(NamedInts) - -NamedFloats = _reflection.GeneratedProtocolMessageType('NamedFloats', (_message.Message,), { - 'DESCRIPTOR' : _NAMEDFLOATS, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:NamedFloats) - }) -_sym_db.RegisterMessage(NamedFloats) - -ModelSession = _reflection.GeneratedProtocolMessageType('ModelSession', (_message.Message,), { - 'DESCRIPTOR' : _MODELSESSION, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:ModelSession) - }) -_sym_db.RegisterMessage(ModelSession) - -LogEntry = _reflection.GeneratedProtocolMessageType('LogEntry', (_message.Message,), { - 'DESCRIPTOR' : _LOGENTRY, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:LogEntry) - }) -_sym_db.RegisterMessage(LogEntry) - -Devices = _reflection.GeneratedProtocolMessageType('Devices', (_message.Message,), { - 'DESCRIPTOR' : _DEVICES, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:Devices) - }) -_sym_db.RegisterMessage(Devices) - -NamedInt = _reflection.GeneratedProtocolMessageType('NamedInt', (_message.Message,), { - 'DESCRIPTOR' : _NAMEDINT, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:NamedInt) - }) -_sym_db.RegisterMessage(NamedInt) - -NamedFloat = _reflection.GeneratedProtocolMessageType('NamedFloat', (_message.Message,), { - 'DESCRIPTOR' : _NAMEDFLOAT, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:NamedFloat) - }) -_sym_db.RegisterMessage(NamedFloat) - -Tensor = _reflection.GeneratedProtocolMessageType('Tensor', (_message.Message,), { - 'DESCRIPTOR' : _TENSOR, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:Tensor) - }) -_sym_db.RegisterMessage(Tensor) - -PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), { - 'DESCRIPTOR' : _PREDICTREQUEST, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:PredictRequest) - }) -_sym_db.RegisterMessage(PredictRequest) - -PredictResponse = _reflection.GeneratedProtocolMessageType('PredictResponse', (_message.Message,), { - 'DESCRIPTOR' : _PREDICTRESPONSE, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:PredictResponse) - }) -_sym_db.RegisterMessage(PredictResponse) - -Empty = _reflection.GeneratedProtocolMessageType('Empty', (_message.Message,), { - 'DESCRIPTOR' : _EMPTY, - '__module__' : 'inference_pb2' - # @@protoc_insertion_point(class_scope:Empty) - }) -_sym_db.RegisterMessage(Empty) - -_INFERENCE = DESCRIPTOR.services_by_name['Inference'] -_FLIGHTCONTROL = DESCRIPTOR.services_by_name['FlightControl'] +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'inference_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _DEVICE._serialized_start=19 - _DEVICE._serialized_end=108 - _DEVICE_STATUS._serialized_start=73 - _DEVICE_STATUS._serialized_end=108 - _CREATEDATASETDESCRIPTIONREQUEST._serialized_start=110 - _CREATEDATASETDESCRIPTIONREQUEST._serialized_end=197 - _DATASETDESCRIPTION._serialized_start=199 - _DATASETDESCRIPTION._serialized_end=231 - _BLOB._serialized_start=233 - _BLOB._serialized_end=272 - _CREATEMODELSESSIONREQUEST._serialized_start=274 - _CREATEMODELSESSIONREQUEST._serialized_end=379 - _NAMEDINTS._serialized_start=381 - _NAMEDINTS._serialized_end=422 - _NAMEDFLOATS._serialized_start=424 - _NAMEDFLOATS._serialized_end=471 - _MODELSESSION._serialized_start=473 - _MODELSESSION._serialized_end=499 - _LOGENTRY._serialized_start=502 - _LOGENTRY._serialized_end=660 - _LOGENTRY_LEVEL._serialized_start=582 - _LOGENTRY_LEVEL._serialized_end=660 - _DEVICES._serialized_start=662 - _DEVICES._serialized_end=697 - _NAMEDINT._serialized_start=699 - _NAMEDINT._serialized_end=737 - _NAMEDFLOAT._serialized_start=739 - _NAMEDFLOAT._serialized_end=779 - _TENSOR._serialized_start=781 - _TENSOR._serialized_end=864 - _PREDICTREQUEST._serialized_start=866 - _PREDICTREQUEST._serialized_end=951 - _PREDICTRESPONSE._serialized_start=953 - _PREDICTRESPONSE._serialized_end=996 - _EMPTY._serialized_start=998 - _EMPTY._serialized_end=1005 - _INFERENCE._serialized_start=1008 - _INFERENCE._serialized_end=1334 - _FLIGHTCONTROL._serialized_start=1336 - _FLIGHTCONTROL._serialized_end=1407 + _CREATEDATASETDESCRIPTIONREQUEST._serialized_start=43 + _CREATEDATASETDESCRIPTIONREQUEST._serialized_end=130 + _DATASETDESCRIPTION._serialized_start=132 + _DATASETDESCRIPTION._serialized_end=164 + _BLOB._serialized_start=166 + _BLOB._serialized_end=205 + _CREATEMODELSESSIONREQUEST._serialized_start=207 + _CREATEMODELSESSIONREQUEST._serialized_end=322 + _NAMEDINTS._serialized_start=324 + _NAMEDINTS._serialized_end=365 + _NAMEDFLOATS._serialized_start=367 + _NAMEDFLOATS._serialized_end=414 + _MODELSESSION._serialized_start=416 + _MODELSESSION._serialized_end=442 + _LOGENTRY._serialized_start=445 + _LOGENTRY._serialized_end=613 + _LOGENTRY_LEVEL._serialized_start=535 + _LOGENTRY_LEVEL._serialized_end=613 + _PREDICTREQUEST._serialized_start=615 + _PREDICTREQUEST._serialized_end=700 + _PREDICTRESPONSE._serialized_start=702 + _PREDICTRESPONSE._serialized_end=745 + _INFERENCE._serialized_start=748 + _INFERENCE._serialized_end=1154 + _FLIGHTCONTROL._serialized_start=1156 + _FLIGHTCONTROL._serialized_end=1227 # @@protoc_insertion_point(module_scope) diff --git a/tiktorch/proto/inference_pb2_grpc.py b/tiktorch/proto/inference_pb2_grpc.py index f983d4c5..b49f2e42 100644 --- a/tiktorch/proto/inference_pb2_grpc.py +++ b/tiktorch/proto/inference_pb2_grpc.py @@ -3,6 +3,7 @@ import grpc from . import inference_pb2 as inference__pb2 +from . import utils_pb2 as utils__pb2 class InferenceStub(object): @@ -15,32 +16,32 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.CreateModelSession = channel.unary_unary( - '/Inference/CreateModelSession', + '/inference.Inference/CreateModelSession', request_serializer=inference__pb2.CreateModelSessionRequest.SerializeToString, response_deserializer=inference__pb2.ModelSession.FromString, ) self.CloseModelSession = channel.unary_unary( - '/Inference/CloseModelSession', + '/inference.Inference/CloseModelSession', request_serializer=inference__pb2.ModelSession.SerializeToString, - response_deserializer=inference__pb2.Empty.FromString, + response_deserializer=utils__pb2.Empty.FromString, ) self.CreateDatasetDescription = channel.unary_unary( - '/Inference/CreateDatasetDescription', + '/inference.Inference/CreateDatasetDescription', request_serializer=inference__pb2.CreateDatasetDescriptionRequest.SerializeToString, response_deserializer=inference__pb2.DatasetDescription.FromString, ) self.GetLogs = channel.unary_stream( - '/Inference/GetLogs', - request_serializer=inference__pb2.Empty.SerializeToString, + '/inference.Inference/GetLogs', + request_serializer=utils__pb2.Empty.SerializeToString, response_deserializer=inference__pb2.LogEntry.FromString, ) self.ListDevices = channel.unary_unary( - '/Inference/ListDevices', - request_serializer=inference__pb2.Empty.SerializeToString, - response_deserializer=inference__pb2.Devices.FromString, + '/inference.Inference/ListDevices', + request_serializer=utils__pb2.Empty.SerializeToString, + response_deserializer=utils__pb2.Devices.FromString, ) self.Predict = channel.unary_unary( - '/Inference/Predict', + '/inference.Inference/Predict', request_serializer=inference__pb2.PredictRequest.SerializeToString, response_deserializer=inference__pb2.PredictResponse.FromString, ) @@ -96,7 +97,7 @@ def add_InferenceServicer_to_server(servicer, server): 'CloseModelSession': grpc.unary_unary_rpc_method_handler( servicer.CloseModelSession, request_deserializer=inference__pb2.ModelSession.FromString, - response_serializer=inference__pb2.Empty.SerializeToString, + response_serializer=utils__pb2.Empty.SerializeToString, ), 'CreateDatasetDescription': grpc.unary_unary_rpc_method_handler( servicer.CreateDatasetDescription, @@ -105,13 +106,13 @@ def add_InferenceServicer_to_server(servicer, server): ), 'GetLogs': grpc.unary_stream_rpc_method_handler( servicer.GetLogs, - request_deserializer=inference__pb2.Empty.FromString, + request_deserializer=utils__pb2.Empty.FromString, response_serializer=inference__pb2.LogEntry.SerializeToString, ), 'ListDevices': grpc.unary_unary_rpc_method_handler( servicer.ListDevices, - request_deserializer=inference__pb2.Empty.FromString, - response_serializer=inference__pb2.Devices.SerializeToString, + request_deserializer=utils__pb2.Empty.FromString, + response_serializer=utils__pb2.Devices.SerializeToString, ), 'Predict': grpc.unary_unary_rpc_method_handler( servicer.Predict, @@ -120,7 +121,7 @@ def add_InferenceServicer_to_server(servicer, server): ), } generic_handler = grpc.method_handlers_generic_handler( - 'Inference', rpc_method_handlers) + 'inference.Inference', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) @@ -139,7 +140,7 @@ def CreateModelSession(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/Inference/CreateModelSession', + return grpc.experimental.unary_unary(request, target, '/inference.Inference/CreateModelSession', inference__pb2.CreateModelSessionRequest.SerializeToString, inference__pb2.ModelSession.FromString, options, channel_credentials, @@ -156,9 +157,9 @@ def CloseModelSession(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/Inference/CloseModelSession', + return grpc.experimental.unary_unary(request, target, '/inference.Inference/CloseModelSession', inference__pb2.ModelSession.SerializeToString, - inference__pb2.Empty.FromString, + utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -173,7 +174,7 @@ def CreateDatasetDescription(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/Inference/CreateDatasetDescription', + return grpc.experimental.unary_unary(request, target, '/inference.Inference/CreateDatasetDescription', inference__pb2.CreateDatasetDescriptionRequest.SerializeToString, inference__pb2.DatasetDescription.FromString, options, channel_credentials, @@ -190,8 +191,8 @@ def GetLogs(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_stream(request, target, '/Inference/GetLogs', - inference__pb2.Empty.SerializeToString, + return grpc.experimental.unary_stream(request, target, '/inference.Inference/GetLogs', + utils__pb2.Empty.SerializeToString, inference__pb2.LogEntry.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -207,9 +208,9 @@ def ListDevices(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/Inference/ListDevices', - inference__pb2.Empty.SerializeToString, - inference__pb2.Devices.FromString, + return grpc.experimental.unary_unary(request, target, '/inference.Inference/ListDevices', + utils__pb2.Empty.SerializeToString, + utils__pb2.Devices.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -224,7 +225,7 @@ def Predict(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/Inference/Predict', + return grpc.experimental.unary_unary(request, target, '/inference.Inference/Predict', inference__pb2.PredictRequest.SerializeToString, inference__pb2.PredictResponse.FromString, options, channel_credentials, @@ -241,14 +242,14 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Ping = channel.unary_unary( - '/FlightControl/Ping', - request_serializer=inference__pb2.Empty.SerializeToString, - response_deserializer=inference__pb2.Empty.FromString, + '/inference.FlightControl/Ping', + request_serializer=utils__pb2.Empty.SerializeToString, + response_deserializer=utils__pb2.Empty.FromString, ) self.Shutdown = channel.unary_unary( - '/FlightControl/Shutdown', - request_serializer=inference__pb2.Empty.SerializeToString, - response_deserializer=inference__pb2.Empty.FromString, + '/inference.FlightControl/Shutdown', + request_serializer=utils__pb2.Empty.SerializeToString, + response_deserializer=utils__pb2.Empty.FromString, ) @@ -272,17 +273,17 @@ def add_FlightControlServicer_to_server(servicer, server): rpc_method_handlers = { 'Ping': grpc.unary_unary_rpc_method_handler( servicer.Ping, - request_deserializer=inference__pb2.Empty.FromString, - response_serializer=inference__pb2.Empty.SerializeToString, + request_deserializer=utils__pb2.Empty.FromString, + response_serializer=utils__pb2.Empty.SerializeToString, ), 'Shutdown': grpc.unary_unary_rpc_method_handler( servicer.Shutdown, - request_deserializer=inference__pb2.Empty.FromString, - response_serializer=inference__pb2.Empty.SerializeToString, + request_deserializer=utils__pb2.Empty.FromString, + response_serializer=utils__pb2.Empty.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( - 'FlightControl', rpc_method_handlers) + 'inference.FlightControl', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) @@ -301,9 +302,9 @@ def Ping(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/FlightControl/Ping', - inference__pb2.Empty.SerializeToString, - inference__pb2.Empty.FromString, + return grpc.experimental.unary_unary(request, target, '/inference.FlightControl/Ping', + utils__pb2.Empty.SerializeToString, + utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -318,8 +319,8 @@ def Shutdown(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/FlightControl/Shutdown', - inference__pb2.Empty.SerializeToString, - inference__pb2.Empty.FromString, + return grpc.experimental.unary_unary(request, target, '/inference.FlightControl/Shutdown', + utils__pb2.Empty.SerializeToString, + utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/tiktorch/proto/training_pb2.py b/tiktorch/proto/training_pb2.py new file mode 100644 index 00000000..70d93933 --- /dev/null +++ b/tiktorch/proto/training_pb2.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: training.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from . import utils_pb2 as utils__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0etraining.proto\x12\x08training\x1a\x0butils.proto\"\x1f\n\x11TrainingSessionId\x12\n\n\x02id\x18\x01 \x01(\t\"\x87\x01\n\x04Logs\x12\'\n\x04mode\x18\x01 \x01(\x0e\x32\x19.training.Logs.ModelPhase\x12\x12\n\neval_score\x18\x02 \x01(\x01\x12\x0c\n\x04loss\x18\x03 \x01(\x01\x12\x11\n\titeration\x18\x04 \x01(\r\"!\n\nModelPhase\x12\t\n\x05Train\x10\x00\x12\x08\n\x04\x45val\x10\x01\"L\n\x14StreamUpdateResponse\x12\x16\n\x0e\x62\x65st_model_idx\x18\x01 \x01(\r\x12\x1c\n\x04logs\x18\x02 \x01(\x0b\x32\x0e.training.Logs\"/\n\x0fGetLogsResponse\x12\x1c\n\x04logs\x18\x01 \x03(\x0b\x32\x0e.training.Logs\"Z\n\x0ePredictRequest\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\x12.\n\tsessionId\x18\x02 \x01(\x0b\x32\x1b.training.TrainingSessionId\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"6\n\x12ValidationResponse\x12 \n\x18validation_score_average\x18\x01 \x01(\x01\"\x8b\x01\n\x11GetStatusResponse\x12\x30\n\x05state\x18\x01 \x01(\x0e\x32!.training.GetStatusResponse.State\"D\n\x05State\x12\x08\n\x04Idle\x10\x00\x12\x0b\n\x07Running\x10\x01\x12\n\n\x06Paused\x10\x02\x12\n\n\x06\x46\x61iled\x10\x03\x12\x0c\n\x08\x46inished\x10\x04\",\n\x1eGetCurrentBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\r\"&\n\x0eTrainingConfig\x12\x14\n\x0cyaml_content\x18\x01 \x01(\t2\xbf\x05\n\x08Training\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12?\n\x04Init\x12\x18.training.TrainingConfig\x1a\x1b.training.TrainingSessionId\"\x00\x12.\n\x05Start\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12/\n\x06Resume\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12.\n\x05Pause\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12P\n\rStreamUpdates\x12\x1b.training.TrainingSessionId\x1a\x1e.training.StreamUpdateResponse\"\x00\x30\x01\x12\x43\n\x07GetLogs\x12\x1b.training.TrainingSessionId\x1a\x19.training.GetLogsResponse\"\x00\x12-\n\x04Save\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12/\n\x06\x45xport\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12@\n\x07Predict\x12\x18.training.PredictRequest\x1a\x19.training.PredictResponse\"\x00\x12G\n\tGetStatus\x12\x1b.training.TrainingSessionId\x1a\x1b.training.GetStatusResponse\"\x00\x12<\n\x13\x43loseTrainerSession\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'training_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _TRAININGSESSIONID._serialized_start=41 + _TRAININGSESSIONID._serialized_end=72 + _LOGS._serialized_start=75 + _LOGS._serialized_end=210 + _LOGS_MODELPHASE._serialized_start=177 + _LOGS_MODELPHASE._serialized_end=210 + _STREAMUPDATERESPONSE._serialized_start=212 + _STREAMUPDATERESPONSE._serialized_end=288 + _GETLOGSRESPONSE._serialized_start=290 + _GETLOGSRESPONSE._serialized_end=337 + _PREDICTREQUEST._serialized_start=339 + _PREDICTREQUEST._serialized_end=429 + _PREDICTRESPONSE._serialized_start=431 + _PREDICTRESPONSE._serialized_end=474 + _VALIDATIONRESPONSE._serialized_start=476 + _VALIDATIONRESPONSE._serialized_end=530 + _GETSTATUSRESPONSE._serialized_start=533 + _GETSTATUSRESPONSE._serialized_end=672 + _GETSTATUSRESPONSE_STATE._serialized_start=604 + _GETSTATUSRESPONSE_STATE._serialized_end=672 + _GETCURRENTBESTMODELIDXRESPONSE._serialized_start=674 + _GETCURRENTBESTMODELIDXRESPONSE._serialized_end=718 + _TRAININGCONFIG._serialized_start=720 + _TRAININGCONFIG._serialized_end=758 + _TRAINING._serialized_start=761 + _TRAINING._serialized_end=1464 +# @@protoc_insertion_point(module_scope) diff --git a/tiktorch/proto/training_pb2_grpc.py b/tiktorch/proto/training_pb2_grpc.py new file mode 100644 index 00000000..79bf33df --- /dev/null +++ b/tiktorch/proto/training_pb2_grpc.py @@ -0,0 +1,430 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from . import training_pb2 as training__pb2 +from . import utils_pb2 as utils__pb2 + + +class TrainingStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.ListDevices = channel.unary_unary( + '/training.Training/ListDevices', + request_serializer=utils__pb2.Empty.SerializeToString, + response_deserializer=utils__pb2.Devices.FromString, + ) + self.Init = channel.unary_unary( + '/training.Training/Init', + request_serializer=training__pb2.TrainingConfig.SerializeToString, + response_deserializer=training__pb2.TrainingSessionId.FromString, + ) + self.Start = channel.unary_unary( + '/training.Training/Start', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=utils__pb2.Empty.FromString, + ) + self.Resume = channel.unary_unary( + '/training.Training/Resume', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=utils__pb2.Empty.FromString, + ) + self.Pause = channel.unary_unary( + '/training.Training/Pause', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=utils__pb2.Empty.FromString, + ) + self.StreamUpdates = channel.unary_stream( + '/training.Training/StreamUpdates', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=training__pb2.StreamUpdateResponse.FromString, + ) + self.GetLogs = channel.unary_unary( + '/training.Training/GetLogs', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=training__pb2.GetLogsResponse.FromString, + ) + self.Save = channel.unary_unary( + '/training.Training/Save', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=utils__pb2.Empty.FromString, + ) + self.Export = channel.unary_unary( + '/training.Training/Export', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=utils__pb2.Empty.FromString, + ) + self.Predict = channel.unary_unary( + '/training.Training/Predict', + request_serializer=training__pb2.PredictRequest.SerializeToString, + response_deserializer=training__pb2.PredictResponse.FromString, + ) + self.GetStatus = channel.unary_unary( + '/training.Training/GetStatus', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=training__pb2.GetStatusResponse.FromString, + ) + self.CloseTrainerSession = channel.unary_unary( + '/training.Training/CloseTrainerSession', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=utils__pb2.Empty.FromString, + ) + + +class TrainingServicer(object): + """Missing associated documentation comment in .proto file.""" + + def ListDevices(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Init(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Start(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Resume(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Pause(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def StreamUpdates(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetLogs(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Save(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Export(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Predict(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetStatus(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CloseTrainerSession(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_TrainingServicer_to_server(servicer, server): + rpc_method_handlers = { + 'ListDevices': grpc.unary_unary_rpc_method_handler( + servicer.ListDevices, + request_deserializer=utils__pb2.Empty.FromString, + response_serializer=utils__pb2.Devices.SerializeToString, + ), + 'Init': grpc.unary_unary_rpc_method_handler( + servicer.Init, + request_deserializer=training__pb2.TrainingConfig.FromString, + response_serializer=training__pb2.TrainingSessionId.SerializeToString, + ), + 'Start': grpc.unary_unary_rpc_method_handler( + servicer.Start, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=utils__pb2.Empty.SerializeToString, + ), + 'Resume': grpc.unary_unary_rpc_method_handler( + servicer.Resume, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=utils__pb2.Empty.SerializeToString, + ), + 'Pause': grpc.unary_unary_rpc_method_handler( + servicer.Pause, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=utils__pb2.Empty.SerializeToString, + ), + 'StreamUpdates': grpc.unary_stream_rpc_method_handler( + servicer.StreamUpdates, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=training__pb2.StreamUpdateResponse.SerializeToString, + ), + 'GetLogs': grpc.unary_unary_rpc_method_handler( + servicer.GetLogs, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=training__pb2.GetLogsResponse.SerializeToString, + ), + 'Save': grpc.unary_unary_rpc_method_handler( + servicer.Save, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=utils__pb2.Empty.SerializeToString, + ), + 'Export': grpc.unary_unary_rpc_method_handler( + servicer.Export, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=utils__pb2.Empty.SerializeToString, + ), + 'Predict': grpc.unary_unary_rpc_method_handler( + servicer.Predict, + request_deserializer=training__pb2.PredictRequest.FromString, + response_serializer=training__pb2.PredictResponse.SerializeToString, + ), + 'GetStatus': grpc.unary_unary_rpc_method_handler( + servicer.GetStatus, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=training__pb2.GetStatusResponse.SerializeToString, + ), + 'CloseTrainerSession': grpc.unary_unary_rpc_method_handler( + servicer.CloseTrainerSession, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=utils__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'training.Training', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Training(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def ListDevices(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/ListDevices', + utils__pb2.Empty.SerializeToString, + utils__pb2.Devices.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Init(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Init', + training__pb2.TrainingConfig.SerializeToString, + training__pb2.TrainingSessionId.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Start(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Start', + training__pb2.TrainingSessionId.SerializeToString, + utils__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Resume(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Resume', + training__pb2.TrainingSessionId.SerializeToString, + utils__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Pause(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Pause', + training__pb2.TrainingSessionId.SerializeToString, + utils__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def StreamUpdates(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/training.Training/StreamUpdates', + training__pb2.TrainingSessionId.SerializeToString, + training__pb2.StreamUpdateResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetLogs(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/GetLogs', + training__pb2.TrainingSessionId.SerializeToString, + training__pb2.GetLogsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Save(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Save', + training__pb2.TrainingSessionId.SerializeToString, + utils__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Export(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Export', + training__pb2.TrainingSessionId.SerializeToString, + utils__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Predict(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Predict', + training__pb2.PredictRequest.SerializeToString, + training__pb2.PredictResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetStatus(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/GetStatus', + training__pb2.TrainingSessionId.SerializeToString, + training__pb2.GetStatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CloseTrainerSession(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/CloseTrainerSession', + training__pb2.TrainingSessionId.SerializeToString, + utils__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/tiktorch/proto/utils_pb2.py b/tiktorch/proto/utils_pb2.py new file mode 100644 index 00000000..c0709d97 --- /dev/null +++ b/tiktorch/proto/utils_pb2.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: utils.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0butils.proto\"\x07\n\x05\x45mpty\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"S\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x10\n\x08tensorId\x18\x03 \x01(\t\x12\x18\n\x05shape\x18\x04 \x03(\x0b\x32\t.NamedInt\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Deviceb\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'utils_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _EMPTY._serialized_start=15 + _EMPTY._serialized_end=22 + _NAMEDINT._serialized_start=24 + _NAMEDINT._serialized_end=62 + _NAMEDFLOAT._serialized_start=64 + _NAMEDFLOAT._serialized_end=104 + _TENSOR._serialized_start=106 + _TENSOR._serialized_end=189 + _DEVICE._serialized_start=191 + _DEVICE._serialized_end=280 + _DEVICE_STATUS._serialized_start=245 + _DEVICE_STATUS._serialized_end=280 + _DEVICES._serialized_start=282 + _DEVICES._serialized_end=317 +# @@protoc_insertion_point(module_scope) diff --git a/tiktorch/proto/utils_pb2_grpc.py b/tiktorch/proto/utils_pb2_grpc.py new file mode 100644 index 00000000..2daafffe --- /dev/null +++ b/tiktorch/proto/utils_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/tiktorch/rpc/interface.py b/tiktorch/rpc/interface.py index d47543a5..f00f98b1 100644 --- a/tiktorch/rpc/interface.py +++ b/tiktorch/rpc/interface.py @@ -6,6 +6,7 @@ class RPCInterfaceMeta(type): def __new__(mcls, name, bases, namespace, **kwargs): cls = super().__new__(mcls, name, bases, namespace, **kwargs) + exposed = {name for name, value in namespace.items() if getattr(value, "__exposed__", False)} for base in bases: diff --git a/tiktorch/server/grpc/__init__.py b/tiktorch/server/grpc/__init__.py index a2132a51..21aaf93c 100644 --- a/tiktorch/server/grpc/__init__.py +++ b/tiktorch/server/grpc/__init__.py @@ -6,7 +6,7 @@ import grpc -from tiktorch.proto import data_store_pb2_grpc, inference_pb2_grpc +from tiktorch.proto import data_store_pb2_grpc, inference_pb2_grpc, training_pb2_grpc from tiktorch.server.data_store import DataStore from tiktorch.server.device_pool import IDevicePool, TorchDevicePool from tiktorch.server.session_manager import SessionManager @@ -14,6 +14,7 @@ from .data_store_servicer import DataStoreServicer from .flight_control_servicer import FlightControlServicer from .inference_servicer import InferenceServicer +from .training_servicer import TrainingServicer def _print_available_devices(device_pool: IDevicePool) -> None: @@ -51,16 +52,20 @@ def serve(host, port, *, connection_file_path: Optional[str] = None, kill_timeou ) data_store = DataStore() - device_pool = TorchDevicePool() + inference_svc = InferenceServicer(device_pool, SessionManager(), data_store) fligh_svc = FlightControlServicer(done_evt=done_evt, kill_timeout=kill_timeout) data_svc = DataStoreServicer(data_store) + + training_svc = TrainingServicer(device_pool=device_pool, session_manager=SessionManager()) + _print_available_devices(device_pool) inference_pb2_grpc.add_InferenceServicer_to_server(inference_svc, server) inference_pb2_grpc.add_FlightControlServicer_to_server(fligh_svc, server) data_store_pb2_grpc.add_DataStoreServicer_to_server(data_svc, server) + training_pb2_grpc.add_TrainingServicer_to_server(training_svc, server) acquired_port = server.add_insecure_port(f"{host}:{port}") print() diff --git a/tiktorch/server/grpc/flight_control_servicer.py b/tiktorch/server/grpc/flight_control_servicer.py index b1c41cd0..4ac2866e 100644 --- a/tiktorch/server/grpc/flight_control_servicer.py +++ b/tiktorch/server/grpc/flight_control_servicer.py @@ -3,7 +3,7 @@ import time from typing import Optional -from tiktorch.proto import inference_pb2, inference_pb2_grpc +from tiktorch.proto import inference_pb2_grpc, utils_pb2 logger = logging.getLogger(__name__) @@ -43,11 +43,11 @@ def _run_watchdog(): watchdog_thread.start() return watchdog_thread - def Ping(self, request: inference_pb2.Empty, context) -> inference_pb2.Empty: + def Ping(self, request: utils_pb2.Empty, context) -> utils_pb2.Empty: self.__last_ping = time.time() - return inference_pb2.Empty() + return utils_pb2.Empty() - def Shutdown(self, request: inference_pb2.Empty, context) -> inference_pb2.Empty: + def Shutdown(self, request: utils_pb2.Empty, context) -> utils_pb2.Empty: if self.__done_evt: self.__done_evt.set() - return inference_pb2.Empty() + return utils_pb2.Empty() diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index 230674ee..4318fa50 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -3,10 +3,11 @@ import grpc from tiktorch.converters import pb_tensors_to_sample, sample_to_pb_tensors -from tiktorch.proto import inference_pb2, inference_pb2_grpc +from tiktorch.proto import inference_pb2, inference_pb2_grpc, utils_pb2 from tiktorch.rpc.mp import BioModelClient from tiktorch.server.data_store import IDataStore -from tiktorch.server.device_pool import DeviceStatus, IDevicePool +from tiktorch.server.device_pool import IDevicePool +from tiktorch.server.grpc.utils_servicer import list_devices from tiktorch.server.session.process import InputSampleValidator, start_model_session_process from tiktorch.server.session_manager import Session, SessionManager @@ -55,9 +56,9 @@ def CreateDatasetDescription( id = session.client.api.create_dataset_description(mean=request.mean, stddev=request.stddev) return inference_pb2.DatasetDescription(id=id) - def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> inference_pb2.Empty: + def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> utils_pb2.Empty: self.__session_manager.close_session(request.id) - return inference_pb2.Empty() + return utils_pb2.Empty() def close_all_sessions(self): """ @@ -68,25 +69,13 @@ def close_all_sessions(self): self.__session_manager.close_all_sessions() assert len(self.__device_pool.list_reserved_devices()) == 0 - def GetLogs(self, request: inference_pb2.Empty, context): + def GetLogs(self, request: utils_pb2.Empty, context): yield inference_pb2.LogEntry( timestamp=int(time.time()), level=inference_pb2.LogEntry.Level.INFO, content="Sending model logs" ) - def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.Devices: - devices = self.__device_pool.list_devices() - pb_devices = [] - for dev in devices: - if dev.status == DeviceStatus.AVAILABLE: - pb_status = inference_pb2.Device.Status.AVAILABLE - elif dev.status == DeviceStatus.IN_USE: - pb_status = inference_pb2.Device.Status.IN_USE - else: - raise ValueError(f"Unknown status value {dev.status}") - - pb_devices.append(inference_pb2.Device(id=dev.id, status=pb_status)) - - return inference_pb2.Devices(devices=pb_devices) + def ListDevices(self, request: utils_pb2.Empty, context) -> utils_pb2.Devices: + return list_devices(self.__device_pool) def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse: session = self._getModelSession(context, request.modelSessionId) diff --git a/tiktorch/server/grpc/training_servicer.py b/tiktorch/server/grpc/training_servicer.py new file mode 100644 index 00000000..9488a592 --- /dev/null +++ b/tiktorch/server/grpc/training_servicer.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import logging +import queue +from pathlib import Path +from typing import Callable, List + +import grpc + +from tiktorch.converters import trainer_state_to_pb +from tiktorch.proto import training_pb2, training_pb2_grpc, utils_pb2 +from tiktorch.server.device_pool import IDevicePool +from tiktorch.server.grpc.utils_servicer import list_devices +from tiktorch.server.session.process import start_trainer_process +from tiktorch.server.session.rpc_interface import IRPCTrainer +from tiktorch.server.session_manager import Session, SessionManager +from tiktorch.trainer import TrainerYamlParser + +logger = logging.getLogger(__name__) + + +class TrainingServicer(training_pb2_grpc.TrainingServicer): + def __init__( + self, + device_pool: IDevicePool, + session_manager: SessionManager[IRPCTrainer], + ) -> None: + self._device_pool = device_pool + self._logs_queue_stream = queue.Queue() + self._should_stop_callbacks: List[Callable] = [] + self._session_manager = session_manager + + def ListDevices(self, request: utils_pb2.Empty, context) -> utils_pb2.Devices: + return list_devices(self._device_pool) + + def Init(self, request: training_pb2.TrainingConfig, context): + parser = TrainerYamlParser(request.yaml_content) + device = parser.get_device() + + _, client = start_trainer_process() + session = self._session_manager.create_session(client) + session.on_close(client.shutdown) + + lease = self._device_pool.lease([device]) + session.on_close(lease.terminate) + + try: + client.init(request.yaml_content) + except Exception as e: + self._session_manager.close_session(session.id) + raise e + + return training_pb2.TrainingSessionId(id=session.id) + + def Start(self, request, context): + session = self._getTrainerSession(context, request.id) + session.client.start_training() + return utils_pb2.Empty() + + def Resume(self, request, context): + session = self._getTrainerSession(context, request.id) + session.client.resume_training() + return utils_pb2.Empty() + + def Pause(self, request: training_pb2.TrainingSessionId, context): + session = self._getTrainerSession(context, request.id) + session.client.pause_training() + return utils_pb2.Empty() + + def Save(self, request: training_pb2.SaveRequest, context): + session = self._getTrainerSession(context, request.sessionId.id) + session.client.save(Path(request.filePath)) + return utils_pb2.Empty() + + def Export(self, request: training_pb2.ExportRequest, context): + session = self._getTrainerSession(context, request.sessionId.id) + session.client.export(Path(request.filePath)) + return utils_pb2.Empty() + + def Predict(self, request: training_pb2.TrainingSessionId, context): + raise NotImplementedError + + def StreamUpdates(self, request: training_pb2.TrainingSessionId, context): + raise NotImplementedError + + def GetLogs(self, request: training_pb2.TrainingSessionId, context): + raise NotImplementedError + + def GetStatus(self, request: training_pb2.TrainingSessionId, context): + session = self._getTrainerSession(context, request.id) + state = session.client.get_state() + return training_pb2.GetStatusResponse(state=trainer_state_to_pb[state]) + + def CloseTrainerSession(self, request: training_pb2.TrainingSessionId, context) -> training_pb2.Empty: + self._session_manager.close_session(request.id) + return utils_pb2.Empty() + + def close_all_sessions(self): + self._session_manager.close_all_sessions() + + def _getTrainerSession(self, context, trainer_session_id: str) -> Session[IRPCTrainer]: + session = self._session_manager.get(trainer_session_id) + + if session is None: + context.abort( + grpc.StatusCode.FAILED_PRECONDITION, f"trainer-session with id {trainer_session_id} doesn't exist" + ) + + return session diff --git a/tiktorch/server/grpc/utils_servicer.py b/tiktorch/server/grpc/utils_servicer.py new file mode 100644 index 00000000..bb23b40c --- /dev/null +++ b/tiktorch/server/grpc/utils_servicer.py @@ -0,0 +1,18 @@ +from tiktorch.proto import utils_pb2 +from tiktorch.server.device_pool import DeviceStatus, IDevicePool + + +def list_devices(device_pool: IDevicePool) -> utils_pb2.Devices: + devices = device_pool.list_devices() + pb_devices = [] + for dev in devices: + if dev.status == DeviceStatus.AVAILABLE: + pb_status = utils_pb2.Device.Status.AVAILABLE + elif dev.status == DeviceStatus.IN_USE: + pb_status = utils_pb2.Device.Status.IN_USE + else: + raise ValueError(f"Unknown status value {dev.status}") + + pb_devices.append(utils_pb2.Device(id=dev.id, status=pb_status)) + + return utils_pb2.Devices(devices=pb_devices) diff --git a/tiktorch/server/session/backend/base.py b/tiktorch/server/session/backend/base.py index eab3ea44..471c2d71 100644 --- a/tiktorch/server/session/backend/base.py +++ b/tiktorch/server/session/backend/base.py @@ -1,60 +1,91 @@ from __future__ import annotations import logging -import threading -import typing +from abc import ABC from concurrent.futures import Future from bioimageio.core import PredictionPipeline from tiktorch.configkeys import TRAINING, VALIDATION -from tiktorch.server.session import types -from tiktorch.server.session.backend import commands, supervisor +from tiktorch.server.session.backend import commands +from tiktorch.server.session.backend.supervisor import BioModelSupervisor, QueueTasks, TrainerState, TrainerSupervisor from tiktorch.tiktypes import TikTensorBatch +from tiktorch.trainer import Trainer logger = logging.getLogger(__name__) -class SessionBackend: +class SessionBackend(ABC): + def __init__(self, supervisor): + self._supervisor = supervisor + self._queue_tasks = QueueTasks(supervisor) + self._queue_tasks.start() + + def shutdown(self): + self._queue_tasks.shutdown() + logger.debug("Shutdown complete") + + +class BioModelSessionBackend(SessionBackend): + """Session backend for bioimageio models + + Currently used only for inference. + """ + def __init__(self, pipeline: PredictionPipeline): - self._supervisor = supervisor.Supervisor(pipeline) - self._supervisor_thread = threading.Thread(target=self._supervisor.run, name="ModelThread") - self._supervisor_thread.start() + supervisor = BioModelSupervisor(pipeline) + super().__init__(supervisor) def update_dataset(self, name: str, *, data: TikTensorBatch, labels: TikTensorBatch) -> None: assert name in (TRAINING, VALIDATION), f"{name} not in ({TRAINING}, {VALIDATION})" update_cmd = commands.UpdateDatasetCmd(name, raw_data=data, labels=labels) - self._supervisor.send_command(update_cmd) + self._queue_tasks.send_command(update_cmd) def set_max_num_iterations(self, num: int) -> None: - self._supervisor.send_command(commands.SetMaxNumIterations(num)) + self._queue_tasks.send_command(commands.SetMaxNumIterations(num)) def forward(self, input_tensors): res = Future() - self._supervisor.send_command(commands.ForwardPass(res, input_tensors)) + self._queue_tasks.send_command(commands.ForwardPass(res, input_tensors)) return res - def shutdown(self) -> None: - logger.debug("Shutting down...") - stop_cmd = commands.StopCmd() - self._supervisor.send_command(stop_cmd.awaitable) - stop_cmd.awaitable.wait() +class TrainerSessionBackend(SessionBackend): + """Session backend for training - self._supervisor_thread.join() + Currently, supports only custom unet models decoupled from bioimageio models + """ - logger.debug("Shutdown complete") + def __init__(self, trainer: Trainer): + self._trainer = trainer + supervisor = TrainerSupervisor(trainer) + super().__init__(supervisor) + + def forward(self, input_tensors): + res = Future() + self._queue_tasks.send_command(commands.ForwardPass(res, input_tensors)) + return res def resume_training(self) -> None: - resume_cmd = commands.ResumeCmd() - self._supervisor.send_command(resume_cmd.awaitable) + resume_cmd = commands.ResumeTrainingCmd() + self._queue_tasks.send_command(resume_cmd.awaitable) resume_cmd.awaitable.wait() def pause_training(self) -> None: - self._supervisor.send_command(commands.PauseCmd()) + pause_cmd = commands.PauseTrainingCmd() + self._queue_tasks.send_command(pause_cmd.awaitable) + pause_cmd.awaitable.wait() + + def start_training(self) -> None: + start_cmd = commands.StartTrainingCmd() + self._queue_tasks.send_command(start_cmd.awaitable) + start_cmd.awaitable.wait() + + def save(self) -> None: + raise NotImplementedError - def get_idle(self) -> bool: - return self._supervisor.state == types.State.Paused + def export(self) -> None: + raise NotImplementedError - def on_idle(self, callback: typing.Callable[[], None]) -> None: - self._supervisor.on_idle(callback) + def get_state(self) -> TrainerState: + return self._supervisor.get_state() diff --git a/tiktorch/server/session/backend/commands.py b/tiktorch/server/session/backend/commands.py index 8943b325..f19a45d2 100644 --- a/tiktorch/server/session/backend/commands.py +++ b/tiktorch/server/session/backend/commands.py @@ -6,11 +6,12 @@ import threading import typing from dataclasses import dataclass, field +from typing import Generic, Type, TypeVar -from tiktorch.server.session import types +from tiktorch.trainer import TrainerAction, TrainerState if typing.TYPE_CHECKING: - from tiktorch.server.session.backend.supervisor import Supervisor + from tiktorch.server.session.backend.supervisor import BioModelSupervisor, Supervisors, TrainerSupervisor # from tiktorch.server.datasets import DynamicDataset @@ -20,27 +21,39 @@ __all__ = [ "ICommand", "AwaitableCommand", - "PauseCmd", - "ResumeCmd", - "StopCmd", + "StartTrainingCmd", + "PauseTrainingCmd", + "ResumeTrainingCmd", + "ShutdownWithTeardownCmd", + "SetResumeStateTrainingCmd", + "SetPauseStateTrainingCmd", + "SetStartStateTrainingCmd", "UpdateDatasetCmd", "SetMaxNumIterations", ] +SupervisorType = TypeVar("SupervisorType") -class Context: + +class Context(Generic[SupervisorType]): """ Command execution context Contains modifiable entities as attributes """ - def __init__(self, *, supervisor: Supervisor) -> None: + def __init__(self, *, supervisor: SupervisorType) -> None: self.session = supervisor class ICommand: __awaitable = None + def __init__(self, is_termination_signal: bool = False): + self._is_termination_signal = is_termination_signal + + def is_stop(self): + return self._is_termination_signal + @property def awaitable(self): if not self.__awaitable: @@ -51,18 +64,33 @@ def awaitable(self): def execute(self, ctx: Context) -> None: raise NotImplementedError() + def is_command(self, command_to_check: Type[ICommand]): + """ + Identify the command even if it is wrapped as an awaitable one + """ + if isinstance(self, AwaitableCommand): + return isinstance(self._cmd, command_to_check) + else: + return isinstance(self, command_to_check) + class AwaitableCommand(ICommand): def __init__(self, cmd: ICommand): self._cmd = cmd self._done_evt = threading.Event() + self._exception: Exception | None = None # Store the exception + super().__init__(is_termination_signal=self._cmd.is_stop()) def wait(self): self._done_evt.wait() + if self._exception is not None: + raise self._exception def execute(self, ctx: Context) -> None: try: self._cmd.execute(ctx) + except Exception as e: + self._exception = e finally: self._done_evt.set() @@ -70,55 +98,118 @@ def __repr__(self): return f"Awaitable {self._cmd!r}" -class PauseCmd(ICommand): - def execute(self, ctx: Context) -> None: - ctx.session.transition_to(types.State.Paused) +class PauseTrainingCmd(ICommand): + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.pause() -class ResumeCmd(ICommand): - def execute(self, ctx: Context) -> None: - ctx.session.transition_to(types.State.Running) +class ResumeTrainingCmd(ICommand): + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.resume() + + +class SetStartStateTrainingCmd(ICommand): + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.transition_to_state(new_state=TrainerState.RUNNING, trainer_action=TrainerAction.START) -class StopCmd(ICommand): +class SetPauseStateTrainingCmd(ICommand): + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.transition_to_state(new_state=TrainerState.PAUSED, trainer_action=TrainerAction.PAUSE) + + +class SetResumeStateTrainingCmd(ICommand): + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.transition_to_state(new_state=TrainerState.RUNNING, trainer_action=TrainerAction.RESUME) + + +class ShutdownCmd(ICommand): + def __init__(self): + super().__init__(is_termination_signal=True) + def execute(self, ctx: Context) -> None: - ctx.session.transition_to(types.State.Stopped) + pass + + +class ShutdownWithTeardownCmd(ShutdownCmd): + def execute(self, ctx: Context[Supervisors]) -> None: + ctx.session.shutdown() + + +class StartTrainingCmd(ICommand): + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.start() class UpdateDatasetCmd(ICommand): def __init__(self, name, *, raw_data, labels): + super().__init__() self._name = name self._raw_data = raw_data self._labels = labels - def execute(self, ctx: Context) -> None: + def execute(self, ctx: Context[BioModelSupervisor]) -> None: logger.warning("Not Implemented") + ctx.session.update_dataset() # dataset = ctx.exemplum.get_dataset(self._name) # dataset.update(self._raw_data, self._labels) class SetMaxNumIterations(ICommand): def __init__(self, num_iterations: int) -> None: + super().__init__() self._num_iterations = num_iterations - def execute(self, ctx: Context) -> None: + def execute(self, ctx: Context[BioModelSupervisor]) -> None: ctx.session.set_max_num_iterations(self._num_iterations) class ForwardPass(ICommand): def __init__(self, future, input_tensors): + super().__init__() self._input_tensors = input_tensors self._future = future - def execute(self, ctx: Context) -> None: + def execute(self, ctx: Context[Supervisors]) -> None: try: self._future.set_result(ctx.session.forward(self._input_tensors)) except Exception as e: self._future.set_exception(e) +class CommandPriorityQueueUtils: + """ + Utility for managing and processing commands in a priority queue. + """ + + def __init__(self) -> None: + self.queue = CommandPriorityQueue() + + def send_command(self, cmd: ICommand) -> None: + if not isinstance(cmd, ICommand): + raise ValueError(f"Expected instance of ICommand got {cmd}") + + logger.debug("Sending command %s", cmd) + self.queue.put(cmd) + + def process_commands(self, session): + cmd: ICommand = self.queue.get() + ctx = Context(supervisor=session) + logger.debug("Executing %s", cmd) + + try: + cmd.execute(ctx) + except Exception as e: + logger.exception(f"Failed to execute %s with exception {e}", cmd) + finally: + self.queue.task_done() + logger.debug(f"Finished executing {cmd}") + + return cmd.is_stop() + + class CommandPriorityQueue(queue.PriorityQueue): - COMMAND_PRIORITIES = {StopCmd: 0} + COMMAND_PRIORITIES = {ShutdownWithTeardownCmd: 0} @dataclass(order=True) class _PrioritizedItem: @@ -129,7 +220,10 @@ class _PrioritizedItem: @classmethod def _make_queue_item(cls, cmd: ICommand): - priority = cls.COMMAND_PRIORITIES.get(type(cmd), 999) + if cmd.is_stop(): + priority = 0 + else: + priority = cls.COMMAND_PRIORITIES.get(type(cmd), 999) return cls._PrioritizedItem((priority, next(cls.__counter)), cmd) def put(self, cmd: ICommand, block=True, timeout=None) -> None: diff --git a/tiktorch/server/session/backend/supervisor.py b/tiktorch/server/session/backend/supervisor.py index a5594e08..fa04e599 100644 --- a/tiktorch/server/session/backend/supervisor.py +++ b/tiktorch/server/session/backend/supervisor.py @@ -1,137 +1,239 @@ -from __future__ import annotations - import logging -import queue +import threading +from typing import Generic, Set, TypeVar, Union from bioimageio.core import PredictionPipeline, Sample -from tiktorch.server.session import types from tiktorch.server.session.backend import commands +from tiktorch.server.session.backend.commands import CommandPriorityQueueUtils, ShutdownWithTeardownCmd +from tiktorch.trainer import BaseCallbacks, ErrorCallbacks, Trainer, TrainerAction, TrainerState logger = logging.getLogger(__name__) -class Supervisor: - def __init__(self, pipeline: PredictionPipeline) -> None: - self._state = types.State.Stopped +def requires_queue_alive(func): + def wrapper(self, *args, **kwargs): + if not self._session_thread.is_alive(): + raise RuntimeError("Training hasn't started") + func(self, *args, **kwargs) - self._command_queue = commands.CommandPriorityQueue() - self._pipeline = pipeline - # self._pipeline.set_break_callback(self.has_commands) - self._idle_callbacks = [] + return wrapper - def send_command(self, cmd: commands.ICommand) -> None: - if not isinstance(cmd, commands.ICommand): - raise ValueError(f"Expected instance of ICommand got {cmd}") - logger.debug("Sending command %s", cmd) - self._command_queue.put(cmd) +class StateTransitionError(Exception): + def __init__(self, current_state: TrainerState, transitioning_state: TrainerState, valid_states: Set[TrainerState]): + super().__init__( + f"Invalid state transition: {current_state} -> {transitioning_state}. Valids are {valid_states}" + ) + self.current_state = current_state + self.transitioning_state = transitioning_state + self.valid_states = valid_states + + def __reduce__(self): + return ( + self.__class__, + (self.current_state, self.transitioning_state, self.valid_states), + ) + - @property - def state(self): +class TrainerSupervisor: + """Training supervisor for custom models supported by the 'Trainer' interface. + + Monitoring the training thread and its status. + """ + + def __init__(self, trainer: Trainer) -> None: + super().__init__() + self._trainer = trainer + self._trainer.should_stop_callbacks.register(self._should_stop) + self._state = TrainerState.IDLE + self._pause_triggered = False + self._session_thread = threading.Thread(target=self._start_session, name="SessionThread") + self._command_queue_utils = CommandPriorityQueueUtils() + self.training_error_callbacks: ErrorCallbacks = BaseCallbacks() + + def get_state(self) -> TrainerState: + logger.debug(f"Get state called {self._state}") return self._state - def has_commands(self): - return not self._command_queue.empty() + def start(self): + self._check_transition_to_start() + self._session_thread.start() + self._pause_triggered = False + start_cmd = commands.SetStartStateTrainingCmd() + self._command_queue_utils.send_command(start_cmd.awaitable) + start_cmd.awaitable.wait() + + def _start_session(self): + logger.info("Starting session worker") + try: + while True: + if self._command_queue_utils.process_commands(self): + break + + if self._state == TrainerState.RUNNING: + self._fit() + except Exception as e: + logger.exception(f"Uncaught exception in session worker. Exception: {e}") + finally: + logger.info("Stopped session worker") + + def _fit(self): + try: + self._trainer.fit() + except Exception as e: + logger.exception(f"Training error: {e}") + self.training_error_callbacks(e) + self._state = TrainerState.FAILED + return + + if self.is_training_finished(): + logger.info(f"Training has finished: {self._get_num_iterations_epochs()} ") + self._state = TrainerState.FINISHED + + def is_training_finished(self): + return ( + self._trainer.num_epochs == self._trainer.max_num_epochs + or self._trainer.num_iterations == self._trainer.max_num_iterations + ) or self._trainer.should_stop_model_criteria() + + def _get_num_iterations_epochs(self) -> str: + iterations = f"Iterations[{self._trainer.num_iterations}/{self._trainer.max_num_iterations}]" + epochs = f"Epochs[{self._trainer.num_epochs}/{self._trainer.max_num_epochs}]" + return f"{iterations}, {epochs}" + + def resume(self): + self._check_transition_to_resume() + self._pause_triggered = False + resume_cmd = commands.SetResumeStateTrainingCmd() + self._command_queue_utils.send_command(resume_cmd.awaitable) + resume_cmd.awaitable.wait() # make sure that the state has actually changed (acknowledge) + logger.info(f"Resume training: {self._get_num_iterations_epochs()}") + + def pause(self): + self._check_transition_to_pause() + self._pause_triggered = True + pause_cmd = commands.SetPauseStateTrainingCmd() + self._command_queue_utils.send_command(pause_cmd.awaitable) + pause_cmd.awaitable.wait() # make sure that the state has actually changed (acknowledge) + + def shutdown(self): + if not self._session_thread.is_alive(): + # nothing to do if session thread not alive + return + self._pause_triggered = True + self._command_queue_utils.send_command(commands.ShutdownCmd()) + self._session_thread.join() + + def forward(self, input_tensors): + self.pause() + self._trainer.forward(input_tensors) + self.resume() + + def save(self): + raise NotImplementedError + + def export(self): + raise NotImplementedError + + def _should_stop(self): + return self._pause_triggered + + def transition_to_state(self, new_state: TrainerState, trainer_action: TrainerAction): + """ + Should be used via the ICommands to monitor the state of the training + """ + if trainer_action == TrainerAction.START: + self._check_transition_to_start() + elif trainer_action == TrainerAction.PAUSE: + self._check_transition_to_pause() + elif trainer_action == TrainerAction.RESUME: + self._check_transition_to_resume() + logger.info(f"State transition: {self._state} -> {new_state}") + self._state = new_state + + def _check_transition_to_start(self): + return self._check_transition_to_state(TrainerState.RUNNING, {TrainerState.IDLE}) + + def _check_transition_to_pause(self): + return self._check_transition_to_state(TrainerState.PAUSED, {TrainerState.RUNNING}) + + def _check_transition_to_resume(self): + return self._check_transition_to_state(TrainerState.RUNNING, {TrainerState.PAUSED}) + + def _check_transition_to_state(self, new_state: TrainerState, valid_states: Set[TrainerState]): + if self._state not in valid_states: + raise StateTransitionError( + current_state=self._state, transitioning_state=new_state, valid_states=valid_states + ) + - def has_work(self): - return self._pipeline.max_num_iterations and self._pipeline.max_num_iterations > self._pipeline.iteration_count +class BioModelSupervisor: + """Supervisor for bioimageio models + + Currently used only for inference. + + Allows to serialize and offload commands by multiple threads requests. + """ + + def __init__(self, pipeline: PredictionPipeline) -> None: + super().__init__() + self._pipeline = pipeline def forward(self, sample: Sample): results = self._pipeline.predict_sample_without_blocking(sample) return results - def transition_to(self, new_state: types.State) -> None: - logger.debug("Attempting transition to state %s", new_state) - self._state = new_state - self._update_state() - - def set_max_num_iterations(self, num: int): - self._pipeline.set_max_num_iterations(num) - self._update_state() - - def on_idle(self, callback): - self._idle_callbacks.append(callback) - self._notify_idle() - - def _notify_idle(self): - if self._state in (types.State.Idle, types.State.Paused): - idle_cbs = self._idle_callbacks - self._idle_callbacks = [] - for cb in idle_cbs: - try: - cb() - except Exception: - logger.exception("Exception during idle callback") - - def run(self): + def set_max_num_iterations(self, num_iterations: int): + raise NotImplementedError + + def update_dataset(self): + raise NotImplementedError + + def shutdown(self): + pass + + +Supervisors = Union[BioModelSupervisor, TrainerSupervisor] +SupervisorTypeVar = TypeVar("SupervisorTypeVar", bound=Supervisors) + + +class QueueTasks(Generic[SupervisorTypeVar]): + """ + A task queue manager for processing commands with a supervisor. + + Serializes multiple async requests wrapped as commands. + """ + + def __init__(self, supervisor: SupervisorTypeVar) -> None: + self._command_queue = CommandPriorityQueueUtils() + self._supervisor = supervisor + self._thread = threading.Thread(target=self._run, name="QueueTasksWorker") + + def start(self): + self._thread.start() + + def _run(self): logger.info("Starting session worker") try: - self._run() - except Exception: - logger.exception("Uncaught exception in session worker") + while True: + if self._command_queue.process_commands(self._supervisor): + break + except Exception as e: + logger.exception(f"Uncaught exception in session worker {e}") finally: logger.info("Stopped session worker") - def _run(self): - self._set_state(types.State.Paused) - - while True: - self._process_commands() - - if self.state == types.State.Stopped: - break - - elif self._state == types.State.Idle or self._state == types.State.Paused: - with self._command_queue.not_empty: - self._command_queue.not_empty.wait() - - elif self._state == types.State.Running: - self._train() - self._update_state() - - def _process_commands(self): - while not self._command_queue.empty(): - try: - cmd = self._command_queue.get_nowait() - logger.debug("Executing %s", cmd) - ctx = commands.Context(supervisor=self) - - try: - cmd.execute(ctx) - except Exception: - logger.exception("Failed to execute %s", cmd) - finally: - self._command_queue.task_done() - - except queue.Empty: - pass - - def _train(self): - logger.info( - "Start session for %d iterations", self._pipeline.max_num_iterations - self._pipeline.iteration_count - ) - try: - self._pipeline.train() - except Exception: - logger.error("Exception during session training. Pausing...", exc_info=True) - # FIXME: Should we use PauseCmd here? Maybe we should only know about ICommand on this level. - self.send_command(commands.PauseCmd()) - - self._update_state() - - def _update_state(self): - if self._state == types.State.Running: - should_idle = not self.has_work() - if should_idle: - self._set_state(types.State.Idle) - - elif self._state == types.State.Idle: - should_run = self.has_work() - if should_run: - self._set_state(types.State.Running) - - def _set_state(self, new_state: types.State) -> None: - self._state = new_state - self._notify_idle() - logger.debug("Set new state %s", self._state) + def send_command(self, command: commands.ICommand): + self._command_queue.send_command(command) + + def shutdown(self): + if not self._thread.is_alive(): + logger.debug("Worker thread isn't alive") + return + logger.debug("Shutting down...") + stop_cmd = ShutdownWithTeardownCmd() + self.send_command(stop_cmd.awaitable) + stop_cmd.awaitable.wait() + logger.debug("Shutdown complete") + self._thread.join() diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index f289697e..3a1b949e 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -1,3 +1,4 @@ +import logging import multiprocessing as _mp import pathlib import tempfile @@ -17,8 +18,11 @@ from tiktorch.rpc.mp import BioModelClient, MPServer from ...converters import Sample +from ...trainer import TrainerYamlParser from .backend import base -from .rpc_interface import IRPCModelSession +from .rpc_interface import IRPCModelSession, IRPCTrainer + +logger = logging.getLogger(__name__) class InputSampleValidator: @@ -77,11 +81,11 @@ class ModelSessionProcess(IRPCModelSession): def __init__(self) -> None: super().__init__() self._datasets = {} - self._worker: Optional[base.SessionBackend] = None + self._worker: Optional[base.BioModelSessionBackend] = None def init(self, model_bytes: bytes, devices: List[str]): prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_bytes, devices) - self._worker = base.SessionBackend(prediction_pipeline) + self._worker = base.BioModelSessionBackend(prediction_pipeline) def forward(self, sample: Sample) -> Future: res = self.worker.forward(sample) @@ -99,11 +103,56 @@ def shutdown(self) -> Shutdown: return Shutdown() @property - def worker(self) -> base.SessionBackend: + def worker(self) -> base.BioModelSessionBackend: + if self._worker is None: + raise ValueError("Server isn't initialized") + return self._worker + + +class TrainerSessionProcess(IRPCTrainer): + def __init__(self): + self._worker: Optional[base.TrainerSessionBackend] = None + + @property + def worker(self) -> base.TrainerSessionBackend: if self._worker is None: raise ValueError("Server isn't initialized") return self._worker + def init(self, trainer_yaml_config: str): + parser = TrainerYamlParser(trainer_yaml_config) + logger.debug(f"Config file {trainer_yaml_config}") + trainer = parser.parse() + self._worker = base.TrainerSessionBackend(trainer) + + def forward(self, input_tensors) -> Future: + res = self.worker.forward(input_tensors) + return res + + def resume_training(self): + self.worker.resume_training() + + def start_training(self): + self.worker.start_training() + + def pause_training(self): + self.worker.pause_training() + + def save(self): + self.worker.save() + + def export(self): + self.worker.export() + + def get_state(self): + return self.worker.get_state() + + def shutdown(self): + if self._worker is None: + return Shutdown() + self.worker.shutdown() + return Shutdown() + def _run_server(api: RPCInterface, conn: Connection, log_queue: Optional[_mp.Queue] = None): try: @@ -125,6 +174,10 @@ def _run_server(api: RPCInterface, conn: Connection, log_queue: Optional[_mp.Que T = TypeVar("T", bound=RPCInterface) +def start_trainer_process(log_queue: Optional[_mp.Queue] = None) -> Tuple[_mp.Process, TrainerSessionProcess]: + return start_process(interface_class=TrainerSessionProcess, log_queue=log_queue) + + def start_process(interface_class: Type[T], log_queue: Optional[_mp.Queue] = None) -> Tuple[_mp.Process, T]: client_conn, server_conn = _mp.Pipe() proc = _mp.Process( diff --git a/tiktorch/server/session/rpc_interface.py b/tiktorch/server/session/rpc_interface.py index 4efface8..db714cb5 100644 --- a/tiktorch/server/session/rpc_interface.py +++ b/tiktorch/server/session/rpc_interface.py @@ -2,7 +2,9 @@ from tiktorch.converters import Sample from tiktorch.rpc import RPCInterface, exposed +from tiktorch.rpc.exceptions import Shutdown from tiktorch.tiktypes import TikTensorBatch +from tiktorch.trainer import TrainerState from tiktorch.types import ModelState @@ -46,3 +48,41 @@ def create_dataset_description(self, mean, stddev) -> str: @exposed def forward(self, input_tensors: Sample): raise NotImplementedError + + +class IRPCTrainer(RPCInterface): + @exposed + def init(self, trainer_yaml_config: str): + raise NotImplementedError + + @exposed + def forward(self, input_tensors: Sample): + raise NotImplementedError + + @exposed + def resume_training(self) -> None: + raise NotImplementedError + + @exposed + def pause_training(self) -> None: + raise NotImplementedError + + @exposed + def start_training(self) -> None: + raise NotImplementedError + + @exposed + def shutdown(self) -> Shutdown: + raise NotImplementedError + + @exposed + def save(self): + raise NotImplementedError + + @exposed + def export(self): + raise NotImplementedError + + @exposed + def get_state(self) -> TrainerState: + raise NotImplementedError diff --git a/tiktorch/trainer.py b/tiktorch/trainer.py new file mode 100644 index 00000000..79b371f9 --- /dev/null +++ b/tiktorch/trainer.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Generic, List, TypeVar + +import torch +import yaml +from pytorch3dunet.datasets.utils import get_train_loaders +from pytorch3dunet.unet3d.losses import get_loss_criterion +from pytorch3dunet.unet3d.metrics import get_evaluation_metric +from pytorch3dunet.unet3d.model import get_model +from pytorch3dunet.unet3d.trainer import UNetTrainer +from pytorch3dunet.unet3d.utils import create_lr_scheduler, create_optimizer, get_tensorboard_formatter +from torch import nn + +T = TypeVar("T", bound=Callable) + +logger = logging.getLogger(__name__) + + +class Callbacks(ABC, Generic[T]): + def __init__(self): + self._callbacks: List[T] = [] + + def register(self, callback: T): + self._callbacks.append(callback) + + def unregister(self, callback: T): + self._callbacks.remove(callback) + + @abstractmethod + def __call__(self, *args, **kwargs) -> Any: + pass + + +class BaseCallbacks(Callbacks[T]): + def __call__(self, *args, **kwargs): + for callback in self._callbacks: + callback(*args, **kwargs) + + +class ShouldStopCallbacks(Callbacks[Callable[[], bool]]): + def __call__(self, *args, **kwargs): + for callback in self._callbacks: + if callback(): + return True + + +ErrorCallbacks = BaseCallbacks[Callable[[Exception], None]] + + +class ModelPhase(Enum): + Train = "train" + Eval = "val" + + +@dataclass(frozen=True) +class Logs: + mode: ModelPhase + loss: float + eval_score: float + iteration: int + epoch: int + max_epochs: int + iteration: int + max_iterations: int + + def __str__(self): + iterations = f"Iteration[{self.iteration}/{self.max_iterations}]" + epochs = f"Epochs[{self.epoch}/{self.max_epochs}]" + return f"{epochs}, {iterations}: mode={self.mode}, loss={self.loss}, eval_score={self.eval_score}" + + +LogsCallbacks = Callbacks[Callable[[Logs], None]] + + +class TrainerAction(Enum): + START = "start" + PAUSE = "pause" + RESUME = "resume" + SHUTDOWN = "shutdown" + + +class TrainerState(Enum): + IDLE = 0 + RUNNING = 1 + PAUSED = 2 + FAILED = 3 + FINISHED = 4 + + +class Trainer(UNetTrainer): + def __init__( + self, + model, + optimizer, + lr_scheduler, + loss_criterion, + eval_criterion, + loaders, + checkpoint_dir, + max_num_epochs, + max_num_iterations, + validate_after_iters=200, + log_after_iters=100, + validate_iters=None, + num_iterations=1, + num_epoch=0, + eval_score_higher_is_better=True, + tensorboard_formatter=None, + skip_train_validation=False, + resume=None, + pre_trained=None, + **kwargs, + ): + super().__init__( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + loss_criterion=loss_criterion, + eval_criterion=eval_criterion, + loaders=loaders, + checkpoint_dir=checkpoint_dir, + max_num_epochs=max_num_epochs, + max_num_iterations=max_num_iterations, + validate_after_iters=validate_after_iters, + log_after_iters=log_after_iters, + validate_iters=validate_iters, + num_iterations=num_iterations, + num_epoch=num_epoch, + eval_score_higher_is_better=eval_score_higher_is_better, + tensorboard_formatter=tensorboard_formatter, + skip_train_validation=skip_train_validation, + resume=resume, + pre_trained=pre_trained, + **kwargs, + ) + self.logs_callbacks: LogsCallbacks = BaseCallbacks() + self.should_stop_callbacks: Callbacks = ShouldStopCallbacks() + + def fit(self): + return super().fit() + + def train(self): + return super().train() + + def validate(self): + return super().validate() + + def forward(self, input_tensors): + self.model.eval() + with torch.no_grad(): + self.model(input_tensors) + + def should_stop(self) -> bool: + """ + Intervene on how to stop the training. + """ + return self.should_stop_callbacks() or self.should_stop_model_criteria() + + def should_stop_model_criteria(self) -> bool: + """ + Retain the logic designed by a custom model on how to stop the training + e.g. learning rate lower than a threshold. + """ + return super().should_stop() + + def _log_stats(self, phase, loss_avg, eval_score_avg): + logs = Logs( + mode=ModelPhase(phase), + loss=loss_avg, + eval_score=eval_score_avg, + iteration=self.num_iterations, + epoch=self.num_epochs, + max_epochs=self.max_num_epochs, + max_iterations=self.max_num_iterations, + ) + self.logs_callbacks(logs) + # todo: why the internal training logging isn't printed on the stdout, although it is set + logger.info(str(logs)) + return super()._log_stats(phase, loss_avg, eval_score_avg) + + +class TrainerYamlParser: + def __init__(self, yaml_string: str): + self._yaml_string = yaml_string + self._yaml_config = yaml.safe_load(self._yaml_string) + + def get_device(self): + return self._yaml_config["device"] + + def parse(self) -> Trainer: + """ + Source: pytorch 3d unet + """ + + config = self._yaml_config + + model = get_model(config["model"]) + + if torch.cuda.device_count() > 1 and not config["device"] == "cpu": + model = nn.DataParallel(model) + if torch.cuda.is_available() and not config["device"] == "cpu": + model = model.cuda() + + # Create loss criterion + loss_criterion = get_loss_criterion(config) + # Create evaluation metric + eval_criterion = get_evaluation_metric(config) + + # Create data loaders + loaders = get_train_loaders(config) + + # Create the optimizer + optimizer = create_optimizer(config["optimizer"], model) + + # Create learning rate adjustment strategy + lr_scheduler = create_lr_scheduler(config.get("lr_scheduler", None), optimizer) + + trainer_config = config["trainer"] + # Create tensorboard formatter + tensorboard_formatter = get_tensorboard_formatter(trainer_config.pop("tensorboard_formatter", None)) + # Create trainer + resume = trainer_config.pop("resume", None) + pre_trained = trainer_config.pop("pre_trained", None) + + return Trainer( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + loss_criterion=loss_criterion, + eval_criterion=eval_criterion, + loaders=loaders, + tensorboard_formatter=tensorboard_formatter, + resume=resume, + pre_trained=pre_trained, + **trainer_config, + )