Skip to content

Commit

Permalink
Return an incremental id instead of just pinging
Browse files Browse the repository at this point in the history
The response of the best model stream will return an id. The id is
increased by one, each time we have a new model. A client can identify
if an action has been performed by an outdated model based on the id. If
    the current is greater, then a new best model exists.
  • Loading branch information
thodkatz committed Dec 21, 2024
1 parent c9663d2 commit 14f81af
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 19 deletions.
2 changes: 1 addition & 1 deletion proto/training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ service Training {

rpc GetLogs(ModelSession) returns (GetLogsResponse) {}

rpc IsBestModel(ModelSession) returns (stream Empty) {}
rpc GetBestModelIdx(ModelSession) returns (stream GetBestModelIdxResponse) {}

rpc Save(SaveRequest) returns (Empty) {}

Expand Down
8 changes: 4 additions & 4 deletions tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,13 +642,13 @@ def test_best_model_ping(self, grpc_stub):

grpc_stub.Start(training_session_id)

responses = grpc_stub.IsBestModel(training_session_id)
responses = grpc_stub.GetBestModelIdx(training_session_id)
received_updates = 0
for response in responses:
assert isinstance(response, utils_pb2.Empty)
assert isinstance(response, training_pb2.GetBestModelIdxResponse)
assert response.id is not None
received_updates += 1

if received_updates >= 3:
if received_updates >= 2:
break

def test_close_session(self, grpc_stub):
Expand Down
4 changes: 2 additions & 2 deletions tiktorch/proto/training_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 10 additions & 10 deletions tiktorch/proto/training_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def __init__(self, channel):
request_serializer=utils__pb2.ModelSession.SerializeToString,
response_deserializer=training__pb2.GetLogsResponse.FromString,
)
self.IsBestModel = channel.unary_stream(
'/training.Training/IsBestModel',
self.GetBestModelIdx = channel.unary_stream(
'/training.Training/GetBestModelIdx',
request_serializer=utils__pb2.ModelSession.SerializeToString,
response_deserializer=utils__pb2.Empty.FromString,
response_deserializer=training__pb2.GetBestModelIdxResponse.FromString,
)
self.Save = channel.unary_unary(
'/training.Training/Save',
Expand Down Expand Up @@ -127,7 +127,7 @@ def GetLogs(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

Check warning on line 128 in tiktorch/proto/training_pb2_grpc.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/proto/training_pb2_grpc.py#L126-L128

Added lines #L126 - L128 were not covered by tests

def IsBestModel(self, request, context):
def GetBestModelIdx(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
Expand Down Expand Up @@ -201,10 +201,10 @@ def add_TrainingServicer_to_server(servicer, server):
request_deserializer=utils__pb2.ModelSession.FromString,
response_serializer=training__pb2.GetLogsResponse.SerializeToString,
),
'IsBestModel': grpc.unary_stream_rpc_method_handler(
servicer.IsBestModel,
'GetBestModelIdx': grpc.unary_stream_rpc_method_handler(
servicer.GetBestModelIdx,
request_deserializer=utils__pb2.ModelSession.FromString,
response_serializer=utils__pb2.Empty.SerializeToString,
response_serializer=training__pb2.GetBestModelIdxResponse.SerializeToString,
),
'Save': grpc.unary_unary_rpc_method_handler(
servicer.Save,
Expand Down Expand Up @@ -361,7 +361,7 @@ def GetLogs(request,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def IsBestModel(request,
def GetBestModelIdx(request,
target,
options=(),
channel_credentials=None,
Expand All @@ -371,9 +371,9 @@ def IsBestModel(request,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(request, target, '/training.Training/IsBestModel',
return grpc.experimental.unary_stream(request, target, '/training.Training/GetBestModelIdx',

Check warning on line 374 in tiktorch/proto/training_pb2_grpc.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/proto/training_pb2_grpc.py#L374

Added line #L374 was not covered by tests
utils__pb2.ModelSession.SerializeToString,
utils__pb2.Empty.FromString,
training__pb2.GetBestModelIdxResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

Expand Down
4 changes: 2 additions & 2 deletions tiktorch/server/grpc/training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ def StreamUpdates(self, request: utils_pb2.ModelSession, context):
def GetLogs(self, request: utils_pb2.ModelSession, context):
raise NotImplementedError

Check warning on line 99 in tiktorch/server/grpc/training_servicer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/server/grpc/training_servicer.py#L99

Added line #L99 was not covered by tests

def IsBestModel(self, request, context):
def GetBestModelIdx(self, request, context):
session = self._getTrainerSession(context, request)
prev_best_model_idx = None
while context.is_active():
current_best_model_idx = session.client.get_best_model_idx()
if current_best_model_idx != prev_best_model_idx:
prev_best_model_idx = current_best_model_idx
yield utils_pb2.Empty()
yield training_pb2.GetBestModelIdxResponse(id=str(current_best_model_idx))
time.sleep(1)
logger.info("Client disconnected. Stopping stream.")

Expand Down
1 change: 1 addition & 0 deletions tiktorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def export(self, file_to_save: Path):
architecture=ArchitectureFromLibraryDescr(
import_from=f"{self.get_model_import_file_path()}",
callable=Identifier(f"{self.model.__class__.__name__}"),
kwargs={"in_channels": self._in_channels, "out_channels": self._out_channels},
),
pytorch_version=Version("1.1.1"),
)
Expand Down

0 comments on commit 14f81af

Please sign in to comment.