From 14f81afde1f5ecd5ee1267b945374f36e13d4007 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Mon, 16 Dec 2024 10:36:43 +0100 Subject: [PATCH] Return an incremental id instead of just pinging The response of the best model stream will return an id. The id is increased by one, each time we have a new model. A client can identify if an action has been performed by an outdated model based on the id. If the current is greater, then a new best model exists. --- proto/training.proto | 2 +- .../test_grpc/test_training_servicer.py | 8 ++++---- tiktorch/proto/training_pb2.py | 4 ++-- tiktorch/proto/training_pb2_grpc.py | 20 +++++++++---------- tiktorch/server/grpc/training_servicer.py | 4 ++-- tiktorch/trainer.py | 1 + 6 files changed, 20 insertions(+), 19 deletions(-) diff --git a/proto/training.proto b/proto/training.proto index 2ae5eb82..5e4eeb52 100644 --- a/proto/training.proto +++ b/proto/training.proto @@ -21,7 +21,7 @@ service Training { rpc GetLogs(ModelSession) returns (GetLogsResponse) {} - rpc IsBestModel(ModelSession) returns (stream Empty) {} + rpc GetBestModelIdx(ModelSession) returns (stream GetBestModelIdxResponse) {} rpc Save(SaveRequest) returns (Empty) {} diff --git a/tests/test_server/test_grpc/test_training_servicer.py b/tests/test_server/test_grpc/test_training_servicer.py index 3270159f..000a580d 100644 --- a/tests/test_server/test_grpc/test_training_servicer.py +++ b/tests/test_server/test_grpc/test_training_servicer.py @@ -642,13 +642,13 @@ def test_best_model_ping(self, grpc_stub): grpc_stub.Start(training_session_id) - responses = grpc_stub.IsBestModel(training_session_id) + responses = grpc_stub.GetBestModelIdx(training_session_id) received_updates = 0 for response in responses: - assert isinstance(response, utils_pb2.Empty) + assert isinstance(response, training_pb2.GetBestModelIdxResponse) + assert response.id is not None received_updates += 1 - - if received_updates >= 3: + if received_updates >= 2: break def test_close_session(self, grpc_stub): diff --git a/tiktorch/proto/training_pb2.py b/tiktorch/proto/training_pb2.py index f5e47f0c..696528bf 100644 --- a/tiktorch/proto/training_pb2.py +++ b/tiktorch/proto/training_pb2.py @@ -14,7 +14,7 @@ from . import utils_pb2 as utils__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0etraining.proto\x12\x08training\x1a\x0butils.proto\"%\n\x17GetBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\t\"\x87\x01\n\x04Logs\x12\'\n\x04mode\x18\x01 \x01(\x0e\x32\x19.training.Logs.ModelPhase\x12\x12\n\neval_score\x18\x02 \x01(\x01\x12\x0c\n\x04loss\x18\x03 \x01(\x01\x12\x11\n\titeration\x18\x04 \x01(\r\"!\n\nModelPhase\x12\t\n\x05Train\x10\x00\x12\x08\n\x04\x45val\x10\x01\"L\n\x14StreamUpdateResponse\x12\x16\n\x0e\x62\x65st_model_idx\x18\x01 \x01(\r\x12\x1c\n\x04logs\x18\x02 \x01(\x0b\x32\x0e.training.Logs\"/\n\x0fGetLogsResponse\x12\x1c\n\x04logs\x18\x01 \x03(\x0b\x32\x0e.training.Logs\"F\n\x0bSaveRequest\x12%\n\x0emodelSessionId\x18\x01 \x01(\x0b\x32\r.ModelSession\x12\x10\n\x08\x66ilePath\x18\x02 \x01(\t\"H\n\rExportRequest\x12%\n\x0emodelSessionId\x18\x01 \x01(\x0b\x32\r.ModelSession\x12\x10\n\x08\x66ilePath\x18\x02 \x01(\t\"6\n\x12ValidationResponse\x12 \n\x18validation_score_average\x18\x01 \x01(\x01\"\x8b\x01\n\x11GetStatusResponse\x12\x30\n\x05state\x18\x01 \x01(\x0e\x32!.training.GetStatusResponse.State\"D\n\x05State\x12\x08\n\x04Idle\x10\x00\x12\x0b\n\x07Running\x10\x01\x12\n\n\x06Paused\x10\x02\x12\n\n\x06\x46\x61iled\x10\x03\x12\x0c\n\x08\x46inished\x10\x04\",\n\x1eGetCurrentBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\r\"&\n\x0eTrainingConfig\x12\x14\n\x0cyaml_content\x18\x01 \x01(\t2\xdd\x04\n\x08Training\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12\x31\n\x04Init\x12\x18.training.TrainingConfig\x1a\r.ModelSession\"\x00\x12 \n\x05Start\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12!\n\x06Resume\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12 \n\x05Pause\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12\x42\n\rStreamUpdates\x12\r.ModelSession\x1a\x1e.training.StreamUpdateResponse\"\x00\x30\x01\x12\x35\n\x07GetLogs\x12\r.ModelSession\x1a\x19.training.GetLogsResponse\"\x00\x12(\n\x0bIsBestModel\x12\r.ModelSession\x1a\x06.Empty\"\x00\x30\x01\x12\'\n\x04Save\x12\x15.training.SaveRequest\x1a\x06.Empty\"\x00\x12+\n\x06\x45xport\x12\x17.training.ExportRequest\x1a\x06.Empty\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x12\x39\n\tGetStatus\x12\r.ModelSession\x1a\x1b.training.GetStatusResponse\"\x00\x12.\n\x13\x43loseTrainerSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0etraining.proto\x12\x08training\x1a\x0butils.proto\"%\n\x17GetBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\t\"\x87\x01\n\x04Logs\x12\'\n\x04mode\x18\x01 \x01(\x0e\x32\x19.training.Logs.ModelPhase\x12\x12\n\neval_score\x18\x02 \x01(\x01\x12\x0c\n\x04loss\x18\x03 \x01(\x01\x12\x11\n\titeration\x18\x04 \x01(\r\"!\n\nModelPhase\x12\t\n\x05Train\x10\x00\x12\x08\n\x04\x45val\x10\x01\"L\n\x14StreamUpdateResponse\x12\x16\n\x0e\x62\x65st_model_idx\x18\x01 \x01(\r\x12\x1c\n\x04logs\x18\x02 \x01(\x0b\x32\x0e.training.Logs\"/\n\x0fGetLogsResponse\x12\x1c\n\x04logs\x18\x01 \x03(\x0b\x32\x0e.training.Logs\"F\n\x0bSaveRequest\x12%\n\x0emodelSessionId\x18\x01 \x01(\x0b\x32\r.ModelSession\x12\x10\n\x08\x66ilePath\x18\x02 \x01(\t\"H\n\rExportRequest\x12%\n\x0emodelSessionId\x18\x01 \x01(\x0b\x32\r.ModelSession\x12\x10\n\x08\x66ilePath\x18\x02 \x01(\t\"6\n\x12ValidationResponse\x12 \n\x18validation_score_average\x18\x01 \x01(\x01\"\x8b\x01\n\x11GetStatusResponse\x12\x30\n\x05state\x18\x01 \x01(\x0e\x32!.training.GetStatusResponse.State\"D\n\x05State\x12\x08\n\x04Idle\x10\x00\x12\x0b\n\x07Running\x10\x01\x12\n\n\x06Paused\x10\x02\x12\n\n\x06\x46\x61iled\x10\x03\x12\x0c\n\x08\x46inished\x10\x04\",\n\x1eGetCurrentBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\r\"&\n\x0eTrainingConfig\x12\x14\n\x0cyaml_content\x18\x01 \x01(\t2\xfc\x04\n\x08Training\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12\x31\n\x04Init\x12\x18.training.TrainingConfig\x1a\r.ModelSession\"\x00\x12 \n\x05Start\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12!\n\x06Resume\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12 \n\x05Pause\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12\x42\n\rStreamUpdates\x12\r.ModelSession\x1a\x1e.training.StreamUpdateResponse\"\x00\x30\x01\x12\x35\n\x07GetLogs\x12\r.ModelSession\x1a\x19.training.GetLogsResponse\"\x00\x12G\n\x0fGetBestModelIdx\x12\r.ModelSession\x1a!.training.GetBestModelIdxResponse\"\x00\x30\x01\x12\'\n\x04Save\x12\x15.training.SaveRequest\x1a\x06.Empty\"\x00\x12+\n\x06\x45xport\x12\x17.training.ExportRequest\x1a\x06.Empty\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x12\x39\n\tGetStatus\x12\r.ModelSession\x1a\x1b.training.GetStatusResponse\"\x00\x12.\n\x13\x43loseTrainerSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'training_pb2', globals()) @@ -46,5 +46,5 @@ _TRAININGCONFIG._serialized_start=735 _TRAININGCONFIG._serialized_end=773 _TRAINING._serialized_start=776 - _TRAINING._serialized_end=1381 + _TRAINING._serialized_end=1412 # @@protoc_insertion_point(module_scope) diff --git a/tiktorch/proto/training_pb2_grpc.py b/tiktorch/proto/training_pb2_grpc.py index a63ad6bc..b51ae922 100644 --- a/tiktorch/proto/training_pb2_grpc.py +++ b/tiktorch/proto/training_pb2_grpc.py @@ -50,10 +50,10 @@ def __init__(self, channel): request_serializer=utils__pb2.ModelSession.SerializeToString, response_deserializer=training__pb2.GetLogsResponse.FromString, ) - self.IsBestModel = channel.unary_stream( - '/training.Training/IsBestModel', + self.GetBestModelIdx = channel.unary_stream( + '/training.Training/GetBestModelIdx', request_serializer=utils__pb2.ModelSession.SerializeToString, - response_deserializer=utils__pb2.Empty.FromString, + response_deserializer=training__pb2.GetBestModelIdxResponse.FromString, ) self.Save = channel.unary_unary( '/training.Training/Save', @@ -127,7 +127,7 @@ def GetLogs(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def IsBestModel(self, request, context): + def GetBestModelIdx(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') @@ -201,10 +201,10 @@ def add_TrainingServicer_to_server(servicer, server): request_deserializer=utils__pb2.ModelSession.FromString, response_serializer=training__pb2.GetLogsResponse.SerializeToString, ), - 'IsBestModel': grpc.unary_stream_rpc_method_handler( - servicer.IsBestModel, + 'GetBestModelIdx': grpc.unary_stream_rpc_method_handler( + servicer.GetBestModelIdx, request_deserializer=utils__pb2.ModelSession.FromString, - response_serializer=utils__pb2.Empty.SerializeToString, + response_serializer=training__pb2.GetBestModelIdxResponse.SerializeToString, ), 'Save': grpc.unary_unary_rpc_method_handler( servicer.Save, @@ -361,7 +361,7 @@ def GetLogs(request, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def IsBestModel(request, + def GetBestModelIdx(request, target, options=(), channel_credentials=None, @@ -371,9 +371,9 @@ def IsBestModel(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_stream(request, target, '/training.Training/IsBestModel', + return grpc.experimental.unary_stream(request, target, '/training.Training/GetBestModelIdx', utils__pb2.ModelSession.SerializeToString, - utils__pb2.Empty.FromString, + training__pb2.GetBestModelIdxResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/tiktorch/server/grpc/training_servicer.py b/tiktorch/server/grpc/training_servicer.py index 8c53c6d6..aa8966ab 100644 --- a/tiktorch/server/grpc/training_servicer.py +++ b/tiktorch/server/grpc/training_servicer.py @@ -98,14 +98,14 @@ def StreamUpdates(self, request: utils_pb2.ModelSession, context): def GetLogs(self, request: utils_pb2.ModelSession, context): raise NotImplementedError - def IsBestModel(self, request, context): + def GetBestModelIdx(self, request, context): session = self._getTrainerSession(context, request) prev_best_model_idx = None while context.is_active(): current_best_model_idx = session.client.get_best_model_idx() if current_best_model_idx != prev_best_model_idx: prev_best_model_idx = current_best_model_idx - yield utils_pb2.Empty() + yield training_pb2.GetBestModelIdxResponse(id=str(current_best_model_idx)) time.sleep(1) logger.info("Client disconnected. Stopping stream.") diff --git a/tiktorch/trainer.py b/tiktorch/trainer.py index b41a0dac..94b6173e 100644 --- a/tiktorch/trainer.py +++ b/tiktorch/trainer.py @@ -315,6 +315,7 @@ def export(self, file_to_save: Path): architecture=ArchitectureFromLibraryDescr( import_from=f"{self.get_model_import_file_path()}", callable=Identifier(f"{self.model.__class__.__name__}"), + kwargs={"in_channels": self._in_channels, "out_channels": self._out_channels}, ), pytorch_version=Version("1.1.1"), )