Skip to content

Commit

Permalink
Merge pull request #87 from cleanlab/kfold_early
Browse files Browse the repository at this point in the history
 add early stopping support for object detection example
  • Loading branch information
aditya1503 authored Jan 29, 2024
2 parents f85155b + 704a473 commit 3643335
Showing 1 changed file with 103 additions and 26 deletions.
129 changes: 103 additions & 26 deletions object_detection/detectron2_training-kfold.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from detectron2.engine import DefaultTrainer\n",
"from detectron2.config import get_cfg\n",
"import pickle\n",
"# import some common libraries\n",
"from detectron2.data import build_detection_test_loader, build_detection_train_loader\n",
"import numpy as np\n",
"import os, json, cv2, random\n",
"from detectron2.data import build_detection_test_loader\n",
Expand All @@ -52,12 +55,15 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!wget -nc \"http://images.cocodataset.org/annotations/annotations_trainval2017.zip\" && unzip -q -o annotations_trainval2017.zip\n",
"!wget -nc \"http://images.cocodataset.org/zips/val2017.zip\" && unzip -q -o val2017.zip\n",
"!wget -nc \"http://images.cocodataset.org/zips/train2017.zip\" && unzip -q -o train2017.zip"
"!wget -nc \"http://images.cocodataset.org/zips/train2017.zip\" && unzip -q -o train2017.zip\n",
"!wget -nc \"https://cleanlab-public.s3.amazonaws.com/ObjectDetectionBenchmarking/tutorial/TRAIN_COCO_ALL_labels.pkl\""
]
},
{
Expand Down Expand Up @@ -92,7 +98,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import json\n",
Expand Down Expand Up @@ -141,13 +149,26 @@
" annotations_count = len(data_dict['annotations'])\n",
" print(f\"Number of images: {images_count}, Number of annotations: {annotations_count}\")\n",
"\n",
" \n",
"def unregister_coco_instances(name):\n",
" if name in DatasetCatalog.list():\n",
" DatasetCatalog.remove(name)\n",
" MetadataCatalog.remove(name)\n",
"\n",
"# Generate K-Fold cross-validation\n",
"kf = KFold(n_splits=NUM_FOLDS)\n",
"pairs = []\n",
"for fold, (train_indices, test_indices) in enumerate(kf.split(image_ids)):\n",
" train_data, test_data = split_data(train_indices, test_indices)\n",
" train_file = f\"train_coco_{fold}_fold.json\"\n",
" test_file = f\"test_coco_{fold}_fold.json\"\n",
" # Unregister instances with the same names only if they exist\n",
" unregister_coco_instances(train_file)\n",
" unregister_coco_instances(test_file)\n",
" # Register COCO instances for training and validation. \n",
" # Note: The 'train2017' folder is retained as the base path for images.\n",
" register_coco_instances(train_file, {}, train_file, \"train2017\")\n",
" register_coco_instances(test_file, {}, test_file, \"train2017\")\n",
" pairs.append([train_file,test_file])\n",
" with open(train_file, 'w') as train_file:\n",
" json.dump(train_data, train_file)\n",
Expand All @@ -156,7 +177,9 @@
" print(f\"Data info for training data fold {fold}:\")\n",
" print_data_info(train_data, fold)\n",
" print(f\"Data info for test data fold {fold}:\")\n",
" print_data_info(test_data, fold)\n"
" print_data_info(test_data, fold)\n",
" \n",
"TRAIN_PATH = os.path.join(os.getcwd(),\"train2017\")"
]
},
{
Expand All @@ -175,36 +198,83 @@
"The number of worker threads is set to 2 and the batch size is set to 2.\n",
"The learning rate and maximum number of iterations are also specified. The model is initialized from the COCO-Detection model zoo and the output directory for the trained model is created. Finally, the configuration is passed to the DefaultTrainer class for training the object detection model.\n",
"\n",
"<strong>Note:</strong> The number of iterations was set based on [early stopping.](https://en.wikipedia.org/wiki/Early_stopping#:~:text=In%20machine%20learning%2C%20early%20stopping,training%20data%20with%20each%20iteration.)"
"<strong>Note:</strong> The choice of the number of iterations is informed by the incorporation of [early stopping.](https://en.wikipedia.org/wiki/Early_stopping#:~:text=In%20machine%20learning%2C%20early%20stopping,training%20data%20with%20each%20iteration.) This technique monitors the validation loss throughout training, saving the model upon improvement and halting training if no progress is observed within a defined patience period. Early stopping aims to identify an optimal model iteration, mitigating the risk of overfitting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"def train_data(TRAIN,VALIDATION,folder):\n",
"class Early_stopping(DefaultTrainer):\n",
" def __init__(self, cfg, early_stop_patience=5, model_checkpoint_path=\"model_checkpoint.pth\"):\n",
" super().__init__(cfg)\n",
" self.early_stop_patience = early_stop_patience\n",
" self.model_checkpoint_path = model_checkpoint_path\n",
" self.best_validation_loss = float('inf')\n",
" self.current_patience = 0\n",
"\n",
" def build_train_loader(self, cfg):\n",
" return build_detection_train_loader(cfg)\n",
" \n",
" def data_loader_mapper(self, batch):\n",
" return batch\n",
"\n",
" def run_hooks(self):\n",
" val_loss = self.validation()\n",
" if val_loss < self.best_validation_loss:\n",
" self.best_validation_loss = val_loss\n",
" self.current_patience = 0\n",
" self.save_checkpoint()\n",
" else:\n",
" self.current_patience += 1\n",
" if self.current_patience >= self.early_stop_patience:\n",
" self._trainer.save_checkpoint()\n",
" self._trainer.has_finished = True\n",
"\n",
" def validation(self):\n",
" # Define evaluator here\n",
" evaluator = COCOEvaluator(self.cfg.DATASETS.TEST[0], self.cfg, True, output_dir=\"./output/\")\n",
" val_loader = build_detection_test_loader(self.cfg, self.cfg.DATASETS.TEST[0], evaluators=[evaluator])\n",
" val_results = self._trainer.test(self.cfg, self.model, evaluators=[evaluator])[0]\n",
" val_loss = val_results[\"total_loss\"]\n",
" return val_loss\n",
"\n",
" def save_checkpoint(self):\n",
" checkpointer = DetectionCheckpointer(self.model)\n",
" checkpointer.save(self.model_checkpoint_path)\n",
" \n",
"\n",
"def train_model(TRAIN,VALIDATION,folder):\n",
" cfg = get_cfg()\n",
" MODEL = 'faster_rcnn_X_101_32x8d_FPN_3x.yaml'\n",
" cfg.merge_from_file(model_zoo.get_config_file(\"COCO-Detection/\"+MODEL))\n",
" cfg.DATASETS.TRAIN = (TRAIN,)\n",
" cfg.DATASETS.TEST = (VALIDATION,)\n",
" cfg.DATALOADER.NUM_WORKERS = 2\n",
" cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"COCO-Detection/\"+MODEL) # Let training initialize from model zoo\n",
" #Uncomment if you want to use pre-trained weights for finetuning, not recommended for K fold training\n",
" # cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"COCO-Detection/\"+MODEL) # Let training initialize from model zoo\n",
" \n",
" \n",
" cfg.SOLVER.IMS_PER_BATCH = 2 # This is the real \"batch size\" commonly known to deep learning people\n",
" cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR\n",
" cfg.SOLVER.MAX_ITER = 6000 # \n",
" cfg.SOLVER.BASE_LR = 0.004 # pick a good LR\n",
" cfg.SOLVER.STEPS = [] # milestones where LR is reduced, in this case there's no decay\n",
" cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # The \"RoIHead batch size\". \n",
" cfg.MODEL.ROI_HEADS.NUM_CLASSES = 80 \n",
" cfg.TEST.EVAL_PERIOD = 500\n",
" cfg.TEST.EVAL_PERIOD = 15000\n",
" os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)\n",
" trainer = DefaultTrainer(cfg) \n",
" trainer = Early_stopping(cfg, early_stop_patience=5, model_checkpoint_path=\"model_checkpoint.pth\")\n",
" # Specify evaluators during testing\n",
" evaluator = COCOEvaluator(cfg.DATASETS.TEST[0], cfg, True, output_dir=\"./output/\")\n",
" trainer.resume_or_load(resume=False)\n",
" trainer.test(cfg, trainer.model, evaluators=[evaluator])\n",
" trainer.resume_or_load(resume=False)\n",
" trainer.train();\n"
" trainer.train();\n",
" return cfg\n"
]
},
{
Expand All @@ -224,7 +294,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def format_detectron2_predictions(instances, num_classes):\n",
Expand Down Expand Up @@ -254,7 +326,7 @@
" formatted_results = []\n",
" for i in results:\n",
" if len(i) == 0:\n",
" formatted_array = np.array(i, dtype=np.float32).reshape((0, num_classes))\n",
" formatted_array = np.array(i, dtype=np.float32).reshape((0, 5))\n",
" else:\n",
" formatted_array = np.array(i, dtype=np.float32)\n",
" formatted_results.append(formatted_array)\n",
Expand All @@ -266,46 +338,51 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"for k in range(0,NUM_FOLDS):\n",
" result_dict = {}\n",
" train_data = pairs[k][0]\n",
" val_data = pairs[k][1]\n",
" train_data(train_data,val_data,\"COCO_TRAIN_\"+str(k)+\"_FOLD\")\n",
" cfg = train_model(train_data,val_data,\"COCO_TRAIN_\"+str(k)+\"_FOLD\")\n",
" evaluator = COCOEvaluator(val_data, output_dir=\"output\")\n",
" val_loader = build_detection_test_loader(cfg, val_data)\n",
" cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, \"model_final.pth\") # path to the model we just trained\n",
" cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1 # set a custom testing threshold\n",
" predictor = DefaultPredictor(cfg)\n",
" dataset = json.load(open(\"../\"+pairs[k][1]+'.json','rb'))\n",
" for image in dat['images']:\n",
" im_name = os.path.join(TRAIN_PATH, i['file_name'])\n",
" dataset = json.load(open(pairs[k][1],'rb'))\n",
" for image in dataset['images']:\n",
" im_name = os.path.join(TRAIN_PATH, image['file_name'])\n",
" im = cv2.imread(im_name)\n",
" outputs = predictor(im)\n",
" result_dict[im_name](format_detectron2_predictions(outputs[\"instances\"].to(\"cpu\"),cfg.MODEL.ROI_HEADS.NUM_CLASSES))\n",
" result_dict[im_name] = (format_detectron2_predictions(outputs[\"instances\"].to(\"cpu\"),cfg.MODEL.ROI_HEADS.NUM_CLASSES))\n",
" pickle.dump(result_dict,open(\"results_fold_\"+str(k)+\".pkl\",'wb'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"result_dict = {}\n",
"for k in range(0,NUM_FOLDS):\n",
" res_d = pickle.load(open(\"results_fold_\"+str(k)+'.pkl','rb'))\n",
" for r in res_d:\n",
" result_dict[r] = res_d[i]"
" result_dict[r] = res_d[r]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"dataset = pickle.load(open(\"TRAIN_COCO_ALL_labels.pkl\",'rb'))\n",
Expand Down Expand Up @@ -333,7 +410,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 3643335

Please sign in to comment.