diff --git a/object_detection/detectron2_training-kfold.ipynb b/object_detection/detectron2_training-kfold.ipynb
index 6dff110..b9d229b 100644
--- a/object_detection/detectron2_training-kfold.ipynb
+++ b/object_detection/detectron2_training-kfold.ipynb
@@ -25,13 +25,16 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {},
+ "metadata": {
+ "tags": []
+ },
"outputs": [],
"source": [
"from detectron2.engine import DefaultTrainer\n",
"from detectron2.config import get_cfg\n",
"import pickle\n",
"# import some common libraries\n",
+ "from detectron2.data import build_detection_test_loader, build_detection_train_loader\n",
"import numpy as np\n",
"import os, json, cv2, random\n",
"from detectron2.data import build_detection_test_loader\n",
@@ -46,7 +49,8 @@
"import glob\n",
"from sklearn.model_selection import KFold\n",
"import json\n",
- "from collections import defaultdict"
+ "from collections import defaultdict\n",
+ "from detectron2.data.datasets import register_coco_instances"
]
},
{
@@ -156,7 +160,22 @@
" print(f\"Data info for training data fold {fold}:\")\n",
" print_data_info(train_data, fold)\n",
" print(f\"Data info for test data fold {fold}:\")\n",
- " print_data_info(test_data, fold)\n"
+ " print_data_info(test_data, fold)\n",
+ " # Register COCO instances for training and validation. \n",
+ " # Note: The 'train2017' folder is retained as the base path for images.\n",
+ " register_coco_instances(train_file, {}, train_file, \"train2017\")\n",
+ " register_coco_instances(test_file, {}, test_file, \"train2017\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "TRAIN_PATH = os.path.join(os.getcwd(),\"train2017\")"
]
},
{
@@ -175,7 +194,7 @@
"The number of worker threads is set to 2 and the batch size is set to 2.\n",
"The learning rate and maximum number of iterations are also specified. The model is initialized from the COCO-Detection model zoo and the output directory for the trained model is created. Finally, the configuration is passed to the DefaultTrainer class for training the object detection model.\n",
"\n",
- "Note: The number of iterations was set based on [early stopping.](https://en.wikipedia.org/wiki/Early_stopping#:~:text=In%20machine%20learning%2C%20early%20stopping,training%20data%20with%20each%20iteration.)"
+ "Note: he choice of the number of iterations is informed by the incorporation of [early stopping.](https://en.wikipedia.org/wiki/Early_stopping#:~:text=In%20machine%20learning%2C%20early%20stopping,training%20data%20with%20each%20iteration.) This technique monitors the validation loss throughout training, saving the model upon improvement and halting training if no progress is observed within a defined patience period. Early stopping aims to identify an optimal model iteration, mitigating the risk of overfitting."
]
},
{
@@ -186,25 +205,67 @@
},
"outputs": [],
"source": [
- "def train_data(TRAIN,VALIDATION,folder):\n",
+ "class Early_stopping(DefaultTrainer):\n",
+ " def __init__(self, cfg, early_stop_patience=5, model_checkpoint_path=\"model_checkpoint.pth\"):\n",
+ " super().__init__(cfg)\n",
+ " self.early_stop_patience = early_stop_patience\n",
+ " self.model_checkpoint_path = model_checkpoint_path\n",
+ " self.best_validation_loss = float('inf')\n",
+ " self.current_patience = 0\n",
+ "\n",
+ " def build_train_loader(self, cfg):\n",
+ " return build_detection_train_loader(cfg)\n",
+ " \n",
+ " def data_loader_mapper(self, batch):\n",
+ " return batch\n",
+ "\n",
+ " def run_hooks(self):\n",
+ " val_loss = self.validation()\n",
+ " if val_loss < self.best_validation_loss:\n",
+ " self.best_validation_loss = val_loss\n",
+ " self.current_patience = 0\n",
+ " self.save_checkpoint()\n",
+ " else:\n",
+ " self.current_patience += 1\n",
+ " if self.current_patience >= self.early_stop_patience:\n",
+ " self._trainer.save_checkpoint()\n",
+ " self._trainer.has_finished = True\n",
+ "\n",
+ " def validation(self):\n",
+ " # val_loader = self.build_test_loader(cfg=self.cfg, dataset_name=self.cfg.DATASETS.TEST[0])\n",
+ " val_loader = build_detection_test_loader(self.cfg, self.cfg.DATASETS.TEST[0], evaluators=[evaluator])\n",
+ " evaluator = COCOEvaluator(self.cfg.DATASETS.TEST[0], self.cfg, True, output_dir=\"./output/\")\n",
+ " val_results = self._trainer.test(self.cfg, self.model, evaluators=[evaluator])[0]\n",
+ " val_loss = val_results[\"total_loss\"]\n",
+ " return val_loss\n",
+ "\n",
+ " def save_checkpoint(self):\n",
+ " checkpointer = DetectionCheckpointer(self.model)\n",
+ " checkpointer.save(self.model_checkpoint_path)\n",
+ " \n",
+ "\n",
+ "def train_model(TRAIN,VALIDATION,folder):\n",
" cfg = get_cfg()\n",
" MODEL = 'faster_rcnn_X_101_32x8d_FPN_3x.yaml'\n",
" cfg.merge_from_file(model_zoo.get_config_file(\"COCO-Detection/\"+MODEL))\n",
" cfg.DATASETS.TRAIN = (TRAIN,)\n",
" cfg.DATASETS.TEST = (VALIDATION,)\n",
" cfg.DATALOADER.NUM_WORKERS = 2\n",
- " cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"COCO-Detection/\"+MODEL) # Let training initialize from model zoo\n",
+ " #Uncomment if you want to use pre-trained weights for finetuning, not recommended for K fold training\n",
+ " # cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"COCO-Detection/\"+MODEL) # Let training initialize from model zoo\n",
+ " \n",
+ " \n",
" cfg.SOLVER.IMS_PER_BATCH = 2 # This is the real \"batch size\" commonly known to deep learning people\n",
- " cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR\n",
- " cfg.SOLVER.MAX_ITER = 6000 # \n",
+ " cfg.SOLVER.BASE_LR = 0.0004 # pick a good LR\n",
" cfg.SOLVER.STEPS = [] # milestones where LR is reduced, in this case there's no decay\n",
" cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # The \"RoIHead batch size\". \n",
" cfg.MODEL.ROI_HEADS.NUM_CLASSES = 80 \n",
" cfg.TEST.EVAL_PERIOD = 500\n",
" os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)\n",
- " trainer = DefaultTrainer(cfg) \n",
+ " trainer = Early_stopping(cfg, early_stop_patience=5, model_checkpoint_path=\"model_checkpoint.pth\")\n",
" trainer.resume_or_load(resume=False)\n",
- " trainer.train();\n"
+ " trainer.train();\n",
+ " return cfg\n"
]
},
{
@@ -274,19 +335,19 @@
" result_dict = {}\n",
" train_data = pairs[k][0]\n",
" val_data = pairs[k][1]\n",
- " train_data(train_data,val_data,\"COCO_TRAIN_\"+str(k)+\"_FOLD\")\n",
+ " cfg = train_model(train_data,val_data,\"COCO_TRAIN_\"+str(k)+\"_FOLD\")\n",
" evaluator = COCOEvaluator(val_data, output_dir=\"output\")\n",
" val_loader = build_detection_test_loader(cfg, val_data)\n",
" cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, \"model_final.pth\") # path to the model we just trained\n",
" cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1 # set a custom testing threshold\n",
" predictor = DefaultPredictor(cfg)\n",
- " dataset = json.load(open(\"../\"+pairs[k][1]+'.json','rb'))\n",
- " for image in dat['images']:\n",
- " im_name = os.path.join(TRAIN_PATH, i['file_name'])\n",
+ " dataset = json.load(open(pairs[k][1],'rb'))\n",
+ " for image in dataset['images']:\n",
+ " im_name = os.path.join(TRAIN_PATH, image['file_name'])\n",
" im = cv2.imread(im_name)\n",
" outputs = predictor(im)\n",
- " result_dict[im_name](format_detectron2_predictions(outputs[\"instances\"].to(\"cpu\"),cfg.MODEL.ROI_HEADS.NUM_CLASSES))\n",
- " pickle.dump(result_dict,open(\"results_fold_\"+str(k)+\".pkl\",'wb'))"
+ " result_dict[im_name] = (format_detectron2_predictions(outputs[\"instances\"].to(\"cpu\"),cfg.MODEL.ROI_HEADS.NUM_CLASSES))\n",
+ " pickle.dump(result_dict,open(\"train_results_fold_\"+str(k)+\".pkl\",'wb'))"
]
},
{
@@ -319,9 +380,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": "mmdet3d",
"language": "python",
- "name": "python3"
+ "name": "mmdet3d"
},
"language_info": {
"codemirror_mode": {
@@ -333,7 +394,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.12"
+ "version": "3.8.17"
}
},
"nbformat": 4,