Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training service #225

Merged
merged 13 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/grpc_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import grpc

from tiktorch.proto import inference_pb2, inference_pb2_grpc
from tiktorch.proto import inference_pb2_grpc, utils_pb2


def run():
with grpc.insecure_channel("127.0.0.1:5567") as channel:
stub = inference_pb2_grpc.InferenceStub(channel)
response = stub.ListDevices(inference_pb2.Empty())
response = stub.ListDevices(utils_pb2.Empty())
print(response)


Expand Down
35 changes: 5 additions & 30 deletions proto/inference.proto
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
syntax = "proto3";

package inference;

import "utils.proto";


service Inference {
rpc CreateModelSession(CreateModelSessionRequest) returns (ModelSession) {}

Expand All @@ -14,15 +19,6 @@ service Inference {
rpc Predict(PredictRequest) returns (PredictResponse) {}
}

message Device {
enum Status {
AVAILABLE = 0;
IN_USE = 1;
}

string id = 1;
Status status = 2;
}

message CreateDatasetDescriptionRequest {
string modelSessionId = 1;
Expand Down Expand Up @@ -76,26 +72,6 @@ message LogEntry {
string content = 3;
}

message Devices {
repeated Device devices = 1;
}

message NamedInt {
uint32 size = 1;
string name = 2;
}

message NamedFloat {
float size = 1;
string name = 2;
}

message Tensor {
bytes buffer = 1;
string dtype = 2;
string tensorId = 3;
repeated NamedInt shape = 4;
}

message PredictRequest {
string modelSessionId = 1;
Expand All @@ -107,7 +83,6 @@ message PredictResponse {
repeated Tensor tensors = 1;
}

message Empty {}

service FlightControl {
rpc Ping(Empty) returns (Empty) {}
Expand Down
95 changes: 95 additions & 0 deletions proto/training.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
syntax = "proto3";

package training;

import "utils.proto";



service Training {
rpc ListDevices(Empty) returns (Devices) {}

rpc Init(TrainingConfig) returns (TrainingSessionId) {}

rpc Start(TrainingSessionId) returns (Empty) {}

rpc Resume(TrainingSessionId) returns (Empty) {}

rpc Pause(TrainingSessionId) returns (Empty) {}

rpc StreamUpdates(TrainingSessionId) returns (stream StreamUpdateResponse) {}

rpc GetLogs(TrainingSessionId) returns (GetLogsResponse) {}

rpc Save(TrainingSessionId) returns (Empty) {}

rpc Export(TrainingSessionId) returns (Empty) {}

rpc Predict(PredictRequest) returns (PredictResponse) {}

rpc GetStatus(TrainingSessionId) returns (GetStatusResponse) {}

rpc CloseTrainerSession(TrainingSessionId) returns (Empty) {}
}

message TrainingSessionId {
string id = 1;
}

message Logs {
enum ModelPhase {
Train = 0;
Eval = 1;
}
ModelPhase mode = 1;
double eval_score = 2;
double loss = 3;
uint32 iteration = 4;
}


message StreamUpdateResponse {
uint32 best_model_idx = 1;
Logs logs = 2;
}


message GetLogsResponse {
repeated Logs logs = 1;
}



message PredictRequest {
repeated Tensor tensors = 1;
TrainingSessionId sessionId = 2;
}


message PredictResponse {
repeated Tensor tensors = 1;
}

message ValidationResponse {
double validation_score_average = 1;
}

message GetStatusResponse {
enum State {
Idle = 0;
Running = 1;
Paused = 2;
Failed = 3;
Finished = 4;
}
State state = 1;
}


message GetCurrentBestModelIdxResponse {
uint32 id = 1;
}

message TrainingConfig {
string yaml_content = 1;
}
34 changes: 34 additions & 0 deletions proto/utils.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
syntax = "proto3";

message Empty {}

message NamedInt {
uint32 size = 1;
string name = 2;
}

message NamedFloat {
float size = 1;
string name = 2;
}

message Tensor {
bytes buffer = 1;
string dtype = 2;
string tensorId = 3;
repeated NamedInt shape = 4;
}

message Device {
enum Status {
AVAILABLE = 0;
IN_USE = 1;
}

string id = 1;
Status status = 2;
}

message Devices {
repeated Device devices = 1;
}
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[pytest]
python_files = test_*.py
addopts =
--timeout 10
--timeout 60
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference in creating threads, processes using a start method "spawn" instead of "fork" is quite significant, that led me to bump it, so the tests can pass for macos and windows platforms.

-v
-s
--color=yes
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ max-line-length = 120

[flake8]
max-line-length = 120
ignore=E203
ignore=E203,W503
exclude = tiktorch/proto/*,vendor
34 changes: 17 additions & 17 deletions tests/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
xarray_to_pb_tensor,
xr_tensors_to_sample,
)
from tiktorch.proto import inference_pb2
from tiktorch.proto import utils_pb2


def _numpy_to_pb_tensor(arr, tensor_id: str = "dummy_tensor_name"):
"""
Makes sure that tensor was serialized/deserialized
"""
tensor = numpy_to_pb_tensor(tensor_id, arr)
parsed = inference_pb2.Tensor()
parsed = utils_pb2.Tensor()
parsed.ParseFromString(tensor.SerializeToString())
return parsed

Expand All @@ -31,7 +31,7 @@ def to_pb_tensor(tensor_id: str, arr: xr.DataArray):
Makes sure that tensor was serialized/deserialized
"""
tensor = xarray_to_pb_tensor(tensor_id, arr)
parsed = inference_pb2.Tensor()
parsed = utils_pb2.Tensor()
parsed.ParseFromString(tensor.SerializeToString())
return parsed

Expand All @@ -40,7 +40,7 @@ class TestNumpyToPBTensor:
def test_should_serialize_to_tensor_type(self):
arr = np.arange(9)
tensor = _numpy_to_pb_tensor(arr)
assert isinstance(tensor, inference_pb2.Tensor)
assert isinstance(tensor, utils_pb2.Tensor)

@pytest.mark.parametrize("np_dtype,dtype_str", [(np.int64, "int64"), (np.uint8, "uint8"), (np.float32, "float32")])
def test_should_have_dtype_as_str(self, np_dtype, dtype_str):
Expand All @@ -65,12 +65,12 @@ def test_should_have_serialized_bytes(self):

class TestPBTensorToNumpy:
def test_should_raise_on_empty_dtype(self):
tensor = inference_pb2.Tensor(dtype="", shape=[inference_pb2.NamedInt(size=1), inference_pb2.NamedInt(size=2)])
tensor = utils_pb2.Tensor(dtype="", shape=[utils_pb2.NamedInt(size=1), utils_pb2.NamedInt(size=2)])
with pytest.raises(ValueError):
pb_tensor_to_numpy(tensor)

def test_should_raise_on_empty_shape(self):
tensor = inference_pb2.Tensor(dtype="int64", shape=[])
tensor = utils_pb2.Tensor(dtype="int64", shape=[])
with pytest.raises(ValueError):
pb_tensor_to_numpy(tensor)

Expand Down Expand Up @@ -109,7 +109,7 @@ class TestXarrayToPBTensor:
def test_should_serialize_to_tensor_type(self):
xarr = xr.DataArray(np.arange(8).reshape((2, 4)), dims=("x", "y"))
pb_tensor = to_pb_tensor("input0", xarr)
assert isinstance(pb_tensor, inference_pb2.Tensor)
assert isinstance(pb_tensor, utils_pb2.Tensor)
assert len(pb_tensor.shape) == 2
dim1 = pb_tensor.shape[0]
dim2 = pb_tensor.shape[1]
Expand Down Expand Up @@ -137,12 +137,12 @@ def test_should_have_serialized_bytes(self):

class TestPBTensorToXarray:
def test_should_raise_on_empty_dtype(self):
tensor = inference_pb2.Tensor(dtype="", shape=[inference_pb2.NamedInt(size=1), inference_pb2.NamedInt(size=2)])
tensor = utils_pb2.Tensor(dtype="", shape=[utils_pb2.NamedInt(size=1), utils_pb2.NamedInt(size=2)])
with pytest.raises(ValueError):
pb_tensor_to_xarray(tensor)

def test_should_raise_on_empty_shape(self):
tensor = inference_pb2.Tensor(dtype="int64", shape=[])
tensor = utils_pb2.Tensor(dtype="int64", shape=[])
with pytest.raises(ValueError):
pb_tensor_to_xarray(tensor)

Expand Down Expand Up @@ -178,19 +178,19 @@ def test_should_same_data(self, shape):
class TestSample:
def test_pb_tensors_to_sample(self):
arr_1 = np.arange(32 * 32, dtype=np.int64).reshape(32, 32)
tensor_1 = inference_pb2.Tensor(
tensor_1 = utils_pb2.Tensor(
dtype="int64",
tensorId="input1",
buffer=bytes(arr_1),
shape=[inference_pb2.NamedInt(name="x", size=32), inference_pb2.NamedInt(name="y", size=32)],
shape=[utils_pb2.NamedInt(name="x", size=32), utils_pb2.NamedInt(name="y", size=32)],
)

arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64)
tensor_2 = inference_pb2.Tensor(
tensor_2 = utils_pb2.Tensor(
dtype="int64",
tensorId="input2",
buffer=bytes(arr_2),
shape=[inference_pb2.NamedInt(name="x", size=64), inference_pb2.NamedInt(name="y", size=64)],
shape=[utils_pb2.NamedInt(name="x", size=64), utils_pb2.NamedInt(name="y", size=64)],
)

sample = pb_tensors_to_sample([tensor_1, tensor_2])
Expand Down Expand Up @@ -218,17 +218,17 @@ def test_sample_to_pb_tensors(self):
tensors_ids = ["input1", "input2"]
sample = xr_tensors_to_sample(tensors_ids, [tensor_1, tensor_2])

pb_tensor_1 = inference_pb2.Tensor(
pb_tensor_1 = utils_pb2.Tensor(
dtype="int64",
tensorId="input1",
buffer=bytes(arr_1),
shape=[inference_pb2.NamedInt(name="x", size=32), inference_pb2.NamedInt(name="y", size=32)],
shape=[utils_pb2.NamedInt(name="x", size=32), utils_pb2.NamedInt(name="y", size=32)],
)
pb_tensor_2 = inference_pb2.Tensor(
pb_tensor_2 = utils_pb2.Tensor(
dtype="int64",
tensorId="input2",
buffer=bytes(arr_2),
shape=[inference_pb2.NamedInt(name="x", size=64), inference_pb2.NamedInt(name="y", size=64)],
shape=[utils_pb2.NamedInt(name="x", size=64), utils_pb2.NamedInt(name="y", size=64)],
)
expected_tensors = [pb_tensor_1, pb_tensor_2]

Expand Down
4 changes: 2 additions & 2 deletions tests/test_server/test_grpc/test_fligh_control_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def _pinger():
pinger_thread.start()

assert not evt.is_set()
assert not evt.wait(timeout=0.2)
assert not evt.wait(timeout=1)

stop_pinger.set()
assert evt.wait(timeout=0.2)
assert evt.wait(timeout=1)


def test_shutdown_timeout_0_means_no_watchdog():
Expand Down
Loading
Loading