Skip to content

Commit

Permalink
explain early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya1503 committed Jan 22, 2024
1 parent 36e2c7b commit 9f12174
Showing 1 changed file with 30 additions and 28 deletions.
58 changes: 30 additions & 28 deletions object_detection/detectron2_training-kfold.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!wget -nc \"http://images.cocodataset.org/annotations/annotations_trainval2017.zip\" && unzip -q -o annotations_trainval2017.zip\n",
Expand Down Expand Up @@ -96,7 +98,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import json\n",
Expand Down Expand Up @@ -150,8 +154,12 @@
"pairs = []\n",
"for fold, (train_indices, test_indices) in enumerate(kf.split(image_ids)):\n",
" train_data, test_data = split_data(train_indices, test_indices)\n",
" # Register COCO instances for training and validation. \n",
" # Note: The 'train2017' folder is retained as the base path for images.\n",
" train_file = f\"train_coco_{fold}_fold.json\"\n",
" test_file = f\"test_coco_{fold}_fold.json\"\n",
" register_coco_instances(train_file, {}, train_file, \"train2017\")\n",
" register_coco_instances(test_file, {}, test_file, \"train2017\")\n",
" pairs.append([train_file,test_file])\n",
" with open(train_file, 'w') as train_file:\n",
" json.dump(train_data, train_file)\n",
Expand All @@ -161,20 +169,7 @@
" print_data_info(train_data, fold)\n",
" print(f\"Data info for test data fold {fold}:\")\n",
" print_data_info(test_data, fold)\n",
" # Register COCO instances for training and validation. \n",
" # Note: The 'train2017' folder is retained as the base path for images.\n",
" register_coco_instances(train_file, {}, train_file, \"train2017\")\n",
" register_coco_instances(test_file, {}, test_file, \"train2017\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
" \n",
"TRAIN_PATH = os.path.join(os.getcwd(),\"train2017\")"
]
},
Expand All @@ -194,14 +189,15 @@
"The number of worker threads is set to 2 and the batch size is set to 2.\n",
"The learning rate and maximum number of iterations are also specified. The model is initialized from the COCO-Detection model zoo and the output directory for the trained model is created. Finally, the configuration is passed to the DefaultTrainer class for training the object detection model.\n",
"\n",
"<strong>Note:</strong> he choice of the number of iterations is informed by the incorporation of [early stopping.](https://en.wikipedia.org/wiki/Early_stopping#:~:text=In%20machine%20learning%2C%20early%20stopping,training%20data%20with%20each%20iteration.) This technique monitors the validation loss throughout training, saving the model upon improvement and halting training if no progress is observed within a defined patience period. Early stopping aims to identify an optimal model iteration, mitigating the risk of overfitting."
"<strong>Note:</strong> The choice of the number of iterations is informed by the incorporation of [early stopping.](https://en.wikipedia.org/wiki/Early_stopping#:~:text=In%20machine%20learning%2C%20early%20stopping,training%20data%20with%20each%20iteration.) This technique monitors the validation loss throughout training, saving the model upon improvement and halting training if no progress is observed within a defined patience period. Early stopping aims to identify an optimal model iteration, mitigating the risk of overfitting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -232,7 +228,6 @@
" self._trainer.has_finished = True\n",
"\n",
" def validation(self):\n",
" # val_loader = self.build_test_loader(cfg=self.cfg, dataset_name=self.cfg.DATASETS.TEST[0])\n",
" val_loader = build_detection_test_loader(self.cfg, self.cfg.DATASETS.TEST[0], evaluators=[evaluator])\n",
" evaluator = COCOEvaluator(self.cfg.DATASETS.TEST[0], self.cfg, True, output_dir=\"./output/\")\n",
" val_results = self._trainer.test(self.cfg, self.model, evaluators=[evaluator])[0]\n",
Expand All @@ -256,7 +251,7 @@
" \n",
" \n",
" cfg.SOLVER.IMS_PER_BATCH = 2 # This is the real \"batch size\" commonly known to deep learning people\n",
" cfg.SOLVER.BASE_LR = 0.0004 # pick a good LR\n",
" cfg.SOLVER.BASE_LR = 0.004 # pick a good LR\n",
" cfg.SOLVER.STEPS = [] # milestones where LR is reduced, in this case there's no decay\n",
" cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # The \"RoIHead batch size\". \n",
" cfg.MODEL.ROI_HEADS.NUM_CLASSES = 80 \n",
Expand Down Expand Up @@ -285,7 +280,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def format_detectron2_predictions(instances, num_classes):\n",
Expand Down Expand Up @@ -315,7 +312,7 @@
" formatted_results = []\n",
" for i in results:\n",
" if len(i) == 0:\n",
" formatted_array = np.array(i, dtype=np.float32).reshape((0, num_classes))\n",
" formatted_array = np.array(i, dtype=np.float32).reshape((0, 5))\n",
" else:\n",
" formatted_array = np.array(i, dtype=np.float32)\n",
" formatted_results.append(formatted_array)\n",
Expand All @@ -327,7 +324,8 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
Expand All @@ -353,7 +351,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"result_dict = {}\n",
Expand All @@ -366,7 +366,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"dataset = pickle.load(open(\"TRAIN_COCO_ALL_labels.pkl\",'rb'))\n",
Expand All @@ -380,9 +382,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "mmdet3d",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "mmdet3d"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -394,7 +396,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
"version": "3.9.12"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 9f12174

Please sign in to comment.