diff --git a/object_detection/detectron2_training-kfold.ipynb b/object_detection/detectron2_training-kfold.ipynb index 6dff110..b9d229b 100644 --- a/object_detection/detectron2_training-kfold.ipynb +++ b/object_detection/detectron2_training-kfold.ipynb @@ -25,13 +25,16 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "from detectron2.engine import DefaultTrainer\n", "from detectron2.config import get_cfg\n", "import pickle\n", "# import some common libraries\n", + "from detectron2.data import build_detection_test_loader, build_detection_train_loader\n", "import numpy as np\n", "import os, json, cv2, random\n", "from detectron2.data import build_detection_test_loader\n", @@ -46,7 +49,8 @@ "import glob\n", "from sklearn.model_selection import KFold\n", "import json\n", - "from collections import defaultdict" + "from collections import defaultdict\n", + "from detectron2.data.datasets import register_coco_instances" ] }, { @@ -156,7 +160,22 @@ " print(f\"Data info for training data fold {fold}:\")\n", " print_data_info(train_data, fold)\n", " print(f\"Data info for test data fold {fold}:\")\n", - " print_data_info(test_data, fold)\n" + " print_data_info(test_data, fold)\n", + " # Register COCO instances for training and validation. \n", + " # Note: The 'train2017' folder is retained as the base path for images.\n", + " register_coco_instances(train_file, {}, train_file, \"train2017\")\n", + " register_coco_instances(test_file, {}, test_file, \"train2017\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "TRAIN_PATH = os.path.join(os.getcwd(),\"train2017\")" ] }, { @@ -175,7 +194,7 @@ "The number of worker threads is set to 2 and the batch size is set to 2.\n", "The learning rate and maximum number of iterations are also specified. The model is initialized from the COCO-Detection model zoo and the output directory for the trained model is created. Finally, the configuration is passed to the DefaultTrainer class for training the object detection model.\n", "\n", - "Note: The number of iterations was set based on [early stopping.](https://en.wikipedia.org/wiki/Early_stopping#:~:text=In%20machine%20learning%2C%20early%20stopping,training%20data%20with%20each%20iteration.)" + "Note: he choice of the number of iterations is informed by the incorporation of [early stopping.](https://en.wikipedia.org/wiki/Early_stopping#:~:text=In%20machine%20learning%2C%20early%20stopping,training%20data%20with%20each%20iteration.) This technique monitors the validation loss throughout training, saving the model upon improvement and halting training if no progress is observed within a defined patience period. Early stopping aims to identify an optimal model iteration, mitigating the risk of overfitting." ] }, { @@ -186,25 +205,67 @@ }, "outputs": [], "source": [ - "def train_data(TRAIN,VALIDATION,folder):\n", + "class Early_stopping(DefaultTrainer):\n", + " def __init__(self, cfg, early_stop_patience=5, model_checkpoint_path=\"model_checkpoint.pth\"):\n", + " super().__init__(cfg)\n", + " self.early_stop_patience = early_stop_patience\n", + " self.model_checkpoint_path = model_checkpoint_path\n", + " self.best_validation_loss = float('inf')\n", + " self.current_patience = 0\n", + "\n", + " def build_train_loader(self, cfg):\n", + " return build_detection_train_loader(cfg)\n", + " \n", + " def data_loader_mapper(self, batch):\n", + " return batch\n", + "\n", + " def run_hooks(self):\n", + " val_loss = self.validation()\n", + " if val_loss < self.best_validation_loss:\n", + " self.best_validation_loss = val_loss\n", + " self.current_patience = 0\n", + " self.save_checkpoint()\n", + " else:\n", + " self.current_patience += 1\n", + " if self.current_patience >= self.early_stop_patience:\n", + " self._trainer.save_checkpoint()\n", + " self._trainer.has_finished = True\n", + "\n", + " def validation(self):\n", + " # val_loader = self.build_test_loader(cfg=self.cfg, dataset_name=self.cfg.DATASETS.TEST[0])\n", + " val_loader = build_detection_test_loader(self.cfg, self.cfg.DATASETS.TEST[0], evaluators=[evaluator])\n", + " evaluator = COCOEvaluator(self.cfg.DATASETS.TEST[0], self.cfg, True, output_dir=\"./output/\")\n", + " val_results = self._trainer.test(self.cfg, self.model, evaluators=[evaluator])[0]\n", + " val_loss = val_results[\"total_loss\"]\n", + " return val_loss\n", + "\n", + " def save_checkpoint(self):\n", + " checkpointer = DetectionCheckpointer(self.model)\n", + " checkpointer.save(self.model_checkpoint_path)\n", + " \n", + "\n", + "def train_model(TRAIN,VALIDATION,folder):\n", " cfg = get_cfg()\n", " MODEL = 'faster_rcnn_X_101_32x8d_FPN_3x.yaml'\n", " cfg.merge_from_file(model_zoo.get_config_file(\"COCO-Detection/\"+MODEL))\n", " cfg.DATASETS.TRAIN = (TRAIN,)\n", " cfg.DATASETS.TEST = (VALIDATION,)\n", " cfg.DATALOADER.NUM_WORKERS = 2\n", - " cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"COCO-Detection/\"+MODEL) # Let training initialize from model zoo\n", + " #Uncomment if you want to use pre-trained weights for finetuning, not recommended for K fold training\n", + " # cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"COCO-Detection/\"+MODEL) # Let training initialize from model zoo\n", + " \n", + " \n", " cfg.SOLVER.IMS_PER_BATCH = 2 # This is the real \"batch size\" commonly known to deep learning people\n", - " cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR\n", - " cfg.SOLVER.MAX_ITER = 6000 # \n", + " cfg.SOLVER.BASE_LR = 0.0004 # pick a good LR\n", " cfg.SOLVER.STEPS = [] # milestones where LR is reduced, in this case there's no decay\n", " cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # The \"RoIHead batch size\". \n", " cfg.MODEL.ROI_HEADS.NUM_CLASSES = 80 \n", " cfg.TEST.EVAL_PERIOD = 500\n", " os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)\n", - " trainer = DefaultTrainer(cfg) \n", + " trainer = Early_stopping(cfg, early_stop_patience=5, model_checkpoint_path=\"model_checkpoint.pth\")\n", " trainer.resume_or_load(resume=False)\n", - " trainer.train();\n" + " trainer.train();\n", + " return cfg\n" ] }, { @@ -274,19 +335,19 @@ " result_dict = {}\n", " train_data = pairs[k][0]\n", " val_data = pairs[k][1]\n", - " train_data(train_data,val_data,\"COCO_TRAIN_\"+str(k)+\"_FOLD\")\n", + " cfg = train_model(train_data,val_data,\"COCO_TRAIN_\"+str(k)+\"_FOLD\")\n", " evaluator = COCOEvaluator(val_data, output_dir=\"output\")\n", " val_loader = build_detection_test_loader(cfg, val_data)\n", " cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, \"model_final.pth\") # path to the model we just trained\n", " cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1 # set a custom testing threshold\n", " predictor = DefaultPredictor(cfg)\n", - " dataset = json.load(open(\"../\"+pairs[k][1]+'.json','rb'))\n", - " for image in dat['images']:\n", - " im_name = os.path.join(TRAIN_PATH, i['file_name'])\n", + " dataset = json.load(open(pairs[k][1],'rb'))\n", + " for image in dataset['images']:\n", + " im_name = os.path.join(TRAIN_PATH, image['file_name'])\n", " im = cv2.imread(im_name)\n", " outputs = predictor(im)\n", - " result_dict[im_name](format_detectron2_predictions(outputs[\"instances\"].to(\"cpu\"),cfg.MODEL.ROI_HEADS.NUM_CLASSES))\n", - " pickle.dump(result_dict,open(\"results_fold_\"+str(k)+\".pkl\",'wb'))" + " result_dict[im_name] = (format_detectron2_predictions(outputs[\"instances\"].to(\"cpu\"),cfg.MODEL.ROI_HEADS.NUM_CLASSES))\n", + " pickle.dump(result_dict,open(\"train_results_fold_\"+str(k)+\".pkl\",'wb'))" ] }, { @@ -319,9 +380,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "mmdet3d", "language": "python", - "name": "python3" + "name": "mmdet3d" }, "language_info": { "codemirror_mode": { @@ -333,7 +394,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.8.17" } }, "nbformat": 4,