diff --git a/examples/tensorflow/object_detection/main.py b/examples/tensorflow/object_detection/main.py index 091a67d50ca..8b538268c76 100644 --- a/examples/tensorflow/object_detection/main.py +++ b/examples/tensorflow/object_detection/main.py @@ -323,8 +323,7 @@ def run(config): # Training parameters epochs = config.epochs - steps_per_epoch = train_builder.steps_per_epoch - num_test_batches = test_builder.steps_per_epoch + steps_per_epoch, num_test_batches = train_builder.steps_per_epoch, test_builder.steps_per_epoch # Create model builder model_builder = get_model_builder(config) @@ -335,10 +334,9 @@ def run(config): nncf_config=config.nncf_config, data_loader=train_dataset, batch_size=train_builder.global_batch_size ) - resume_training = config.ckpt_path is not None - compression_state = None - if resume_training: + if config.ckpt_path is not None: + # Resume training compression_state = load_compression_state(config.ckpt_path) with TFModelManager(model_builder.build_model, config.nncf_config, weights=config.get("weights", None)) as model: