From 51f8aee58da605ca2cb084199525fbefbc57df07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?= <1286304229@qq.com> Date: Thu, 9 Nov 2023 14:27:16 +0800 Subject: [PATCH] Support LVIS chunked evaluation and image chunked inference of GLIP (#11136) --- configs/glip/README.md | 23 +- ...win-l_fpn_dyhead_pretrain_zeroshot_lvis.py | 12 + ..._fpn_dyhead_pretrain_zeroshot_mini-lvis.py | 12 + ...n-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py | 24 ++ ..._fpn_dyhead_pretrain_zeroshot_mini-lvis.py | 25 ++ ...-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py | 3 + ..._fpn_dyhead_pretrain_zeroshot_mini-lvis.py | 3 + demo/image_demo.py | 37 ++- mmdet/evaluation/functional/class_names.py | 247 +++++++++++++++++- mmdet/evaluation/metrics/lvis_metric.py | 174 +++++++++++- mmdet/models/detectors/glip.py | 232 ++++++++++++---- 11 files changed, 730 insertions(+), 62 deletions(-) create mode 100644 configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_lvis.py create mode 100644 configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_mini-lvis.py create mode 100644 configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py create mode 100644 configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_mini-lvis.py create mode 100644 configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py create mode 100644 configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_mini-lvis.py diff --git a/configs/glip/README.md b/configs/glip/README.md index 1252d922ac8..c45cb6dbb6e 100644 --- a/configs/glip/README.md +++ b/configs/glip/README.md @@ -56,7 +56,7 @@ model.save_pretrained("your path/bert-base-uncased") tokenizer.save_pretrained("your path/bert-base-uncased") ``` -## Results and Models +## COCO Results and Models | Model | Zero-shot or Finetune | COCO mAP | Official COCO mAP | Pre-Train Data | Config | Download | | :--------: | :-------------------: | :------: | ----------------: | :------------------------: | :---------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | @@ -78,3 +78,24 @@ Note: 3. Taking the GLIP-T(A) model as an example, I trained it twice using the official code, and the fine-tuning mAP were 52.5 and 52.6. Therefore, the mAP we achieved in our reproduction is higher than the official results. The main reason is that we modified the `weight_decay` parameter. 4. Our experiments revealed that training for 24 epochs leads to overfitting. Therefore, we chose the best-performing model. If users want to train on a custom dataset, it is advisable to shorten the number of epochs and save the best-performing model. 5. Due to the official absence of fine-tuning hyperparameters for the GLIP-L model, we have not yet reproduced the official accuracy. I have found that overfitting can also occur, so it may be necessary to consider custom modifications to data augmentation and model enhancement. Given the high cost of training, we have not conducted any research on this matter at the moment. + +## LVIS Results + +| Model | Official | MiniVal APr | MiniVal APc | MiniVal APf | MiniVal AP | Val1.0 APr | Val1.0 APc | Val1.0 APf | Val1.0 AP | Pre-Train Data | Config | Download | +| :--------: | :------: | :---------: | :---------: | :---------: | :--------: | :--------: | :--------: | :--------: | :-------: | :------------------------: | :---------------------------------------------------------------------: | :------------------------------------------------------------------------------------------: | +| GLIP-T (A) | ✔ | | | | | | | | | O365 | [config](lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b3654169.pth) | +| GLIP-T (A) | | 12.1 | 15.5 | 25.8 | 20.2 | 6.2 | 10.9 | 22.8 | 14.7 | O365 | [config](lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b3654169.pth) | +| GLIP-T (B) | ✔ | | | | | | | | | O365 | [config](lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_b_mmdet-6dfbd102.pth) | +| GLIP-T (B) | | 8.6 | 13.9 | 26.0 | 19.3 | 4.6 | 9.8 | 22.6 | 13.9 | O365 | [config](lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_b_mmdet-6dfbd102.pth) | +| GLIP-T (C) | ✔ | 14.3 | 19.4 | 31.1 | 24.6 | | | | | O365,GoldG | [config](lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_c_mmdet-2fc427dd.pth) | +| GLIP-T (C) | | 14.4 | 19.8 | 31.9 | 25.2 | 8.3 | 13.2 | 28.1 | 18.2 | O365,GoldG | [config](lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_c_mmdet-2fc427dd.pth) | +| GLIP-T | ✔ | | | | | | | | | O365,GoldG,CC3M,SBU | [config](lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_mmdet-c24ce662.pth) | +| GLIP-T | | 18.1 | 21.2 | 33.1 | 26.7 | 10.8 | 14.7 | 29.0 | 19.6 | O365,GoldG,CC3M,SBU | [config](lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_mmdet-c24ce662.pth) | +| GLIP-L | ✔ | 29.2 | 34.9 | 42.1 | 37.9 | | | | | FourODs,GoldG,CC3M+12M,SBU | [config](lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_l_mmdet-abfe026b.pth) | +| GLIP-L | | 27.9 | 33.7 | 39.7 | 36.1 | | | | | FourODs,GoldG,CC3M+12M,SBU | [config](lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_l_mmdet-abfe026b.pth) | + +Note: + +1. The above are zero-shot evaluation results. +2. The evaluation metric we used is LVIS FixAP. For specific details, please refer to [Evaluating Large-Vocabulary Object Detectors: The Devil is in the Details](https://arxiv.org/pdf/2102.01066.pdf). +3. We found that the performance on small models is better than the official results, but it is lower on large models. This is mainly due to the incomplete alignment of the GLIP post-processing. diff --git a/configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_lvis.py b/configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_lvis.py new file mode 100644 index 00000000000..1f79e447d3f --- /dev/null +++ b/configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_lvis.py @@ -0,0 +1,12 @@ +_base_ = './glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py' + +model = dict( + backbone=dict( + embed_dims=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=12, + drop_path_rate=0.4, + ), + neck=dict(in_channels=[384, 768, 1536]), + bbox_head=dict(early_fuse=True, num_dyhead_blocks=8)) diff --git a/configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_mini-lvis.py b/configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_mini-lvis.py new file mode 100644 index 00000000000..13f1a69082b --- /dev/null +++ b/configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_mini-lvis.py @@ -0,0 +1,12 @@ +_base_ = './glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_mini-lvis.py' + +model = dict( + backbone=dict( + embed_dims=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=12, + drop_path_rate=0.4, + ), + neck=dict(in_channels=[384, 768, 1536]), + bbox_head=dict(early_fuse=True, num_dyhead_blocks=8)) diff --git a/configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py b/configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py new file mode 100644 index 00000000000..4d526d59008 --- /dev/null +++ b/configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py @@ -0,0 +1,24 @@ +_base_ = '../glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py' + +model = dict(test_cfg=dict( + max_per_img=300, + chunked_size=40, +)) + +dataset_type = 'LVISV1Dataset' +data_root = 'data/coco/' + +val_dataloader = dict( + dataset=dict( + data_root=data_root, + type=dataset_type, + ann_file='annotations/lvis_od_val.json', + data_prefix=dict(img=''))) +test_dataloader = val_dataloader + +# numpy < 1.24.0 +val_evaluator = dict( + _delete_=True, + type='LVISFixedAPMetric', + ann_file=data_root + 'annotations/lvis_od_val.json') +test_evaluator = val_evaluator diff --git a/configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_mini-lvis.py b/configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_mini-lvis.py new file mode 100644 index 00000000000..70a57a3f581 --- /dev/null +++ b/configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_mini-lvis.py @@ -0,0 +1,25 @@ +_base_ = '../glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py' + +model = dict(test_cfg=dict( + max_per_img=300, + chunked_size=40, +)) + +dataset_type = 'LVISV1Dataset' +data_root = 'data/coco/' + +val_dataloader = dict( + dataset=dict( + data_root=data_root, + type=dataset_type, + ann_file='annotations/lvis_v1_minival_inserted_image_name.json', + data_prefix=dict(img=''))) +test_dataloader = val_dataloader + +# numpy < 1.24.0 +val_evaluator = dict( + _delete_=True, + type='LVISFixedAPMetric', + ann_file=data_root + + 'annotations/lvis_v1_minival_inserted_image_name.json') +test_evaluator = val_evaluator diff --git a/configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py b/configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py new file mode 100644 index 00000000000..6dc712b3bcb --- /dev/null +++ b/configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py @@ -0,0 +1,3 @@ +_base_ = './glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py' + +model = dict(bbox_head=dict(early_fuse=True)) diff --git a/configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_mini-lvis.py b/configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_mini-lvis.py new file mode 100644 index 00000000000..3babb91101a --- /dev/null +++ b/configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_mini-lvis.py @@ -0,0 +1,3 @@ +_base_ = './glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_mini-lvis.py' + +model = dict(bbox_head=dict(early_fuse=True)) diff --git a/demo/image_demo.py b/demo/image_demo.py index 2e2c27adbf2..5a9c906cef0 100644 --- a/demo/image_demo.py +++ b/demo/image_demo.py @@ -28,6 +28,16 @@ glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365 \ --texts 'There are a lot of cars here.' + python demo/image_demo.py demo/demo.jpg \ + glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365 \ + --texts '$: coco' + + python demo/image_demo.py demo/demo.jpg \ + glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365 \ + --texts '$: lvis' --pred-score-thr 0.7 \ + --palette random --chunked-size 80 + + Visualize prediction results:: python demo/image_demo.py demo/demo.jpg rtmdet-ins-s --show @@ -41,6 +51,7 @@ from mmengine.logging import print_log from mmdet.apis import DetInferencer +from mmdet.evaluation import get_classes def parse_args(): @@ -60,7 +71,12 @@ def parse_args(): type=str, default='outputs', help='Output directory of images or prediction results.') - parser.add_argument('--texts', help='text prompt') + # Once you input a format similar to $: xxx, it indicates that + # the prompt is based on the dataset class name. + # support $: coco, $: voc, $: cityscapes, $: lvis, $: imagenet_det. + # detail to `mmdet/evaluation/functional/class_names.py` + parser.add_argument( + '--texts', help='text prompt, such as "bench . car .", "$: coco"') parser.add_argument( '--device', default='cuda:0', help='Device used for inference') parser.add_argument( @@ -91,7 +107,7 @@ def parse_args(): default='none', choices=['coco', 'voc', 'citys', 'random', 'none'], help='Color palette used for visualization') - # only for GLIP + # only for GLIP and Grounding DINO parser.add_argument( '--custom-entities', '-c', @@ -99,6 +115,13 @@ def parse_args(): help='Whether to customize entity names? ' 'If so, the input text should be ' '"cls_name1 . cls_name2 . cls_name3 ." format') + parser.add_argument( + '--chunked-size', + '-s', + type=int, + default=-1, + help='If the number of categories is very large, ' + 'you can specify this parameter to truncate multiple predictions.') call_args = vars(parser.parse_args()) @@ -111,6 +134,12 @@ def parse_args(): call_args['weights'] = call_args['model'] call_args['model'] = None + if call_args['texts'] is not None: + if call_args['texts'].startswith('$:'): + dataset_name = call_args['texts'][3:].strip() + class_names = get_classes(dataset_name) + call_args['texts'] = [tuple(class_names)] + init_kws = ['model', 'weights', 'device', 'palette'] init_args = {} for init_kw in init_kws: @@ -125,6 +154,10 @@ def main(): # may consume too much memory if your input folder has a lot of images. # We will be optimized later. inferencer = DetInferencer(**init_args) + + chunked_size = call_args.pop('chunked_size') + inferencer.model.test_cfg.chunked_size = chunked_size + inferencer(**call_args) if call_args['out_dir'] != '' and not (call_args['no_save_vis'] diff --git a/mmdet/evaluation/functional/class_names.py b/mmdet/evaluation/functional/class_names.py index d0ea7094685..623a89cfdc0 100644 --- a/mmdet/evaluation/functional/class_names.py +++ b/mmdet/evaluation/functional/class_names.py @@ -485,6 +485,250 @@ def objects365v2_classes() -> list: ] +def lvis_classes() -> list: + """Class names of LVIS.""" + return [ + 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', 'alcohol', + 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', 'antenna', + 'apple', 'applesauce', 'apricot', 'apron', 'aquarium', + 'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor', + 'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer', + 'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy', + 'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel', + 'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon', + 'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo', + 'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow', + 'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap', + 'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)', + 'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)', + 'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie', + 'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper', + 'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt', + 'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor', + 'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath', + 'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card', + 'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket', + 'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry', + 'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg', + 'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase', + 'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle', + 'bottle_opener', 'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)', + 'bow-tie', 'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'box', + 'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere', + 'bread-bin', 'bread', 'breechcloth', 'bridal_gown', 'briefcase', + 'broccoli', 'broach', 'broom', 'brownie', 'brussels_sprouts', + 'bubble_gum', 'bucket', 'horse_buggy', 'bull', 'bulldog', 'bulldozer', + 'bullet_train', 'bulletin_board', 'bulletproof_vest', 'bullhorn', + 'bun', 'bunk_bed', 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', + 'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', + 'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf', + 'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)', + 'can', 'can_opener', 'candle', 'candle_holder', 'candy_bar', + 'candy_cane', 'walking_cane', 'canister', 'canoe', 'cantaloup', + 'canteen', 'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino', + 'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car', + 'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship', + 'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton', + 'cash_register', 'casserole', 'cassette', 'cast', 'cat', 'cauliflower', + 'cayenne_(spice)', 'CD_player', 'celery', 'cellular_telephone', + 'chain_mail', 'chair', 'chaise_longue', 'chalice', 'chandelier', + 'chap', 'checkbook', 'checkerboard', 'cherry', 'chessboard', + 'chicken_(animal)', 'chickpea', 'chili_(vegetable)', 'chime', + 'chinaware', 'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar', + 'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker', + 'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider', + 'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet', + 'clasp', 'cleansing_agent', 'cleat_(for_securing_rope)', 'clementine', + 'clip', 'clipboard', 'clippers_(for_plants)', 'cloak', 'clock', + 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', + 'coat', 'coat_hanger', 'coatrack', 'cock', 'cockroach', + 'cocoa_(beverage)', 'coconut', 'coffee_maker', 'coffee_table', + 'coffeepot', 'coil', 'coin', 'colander', 'coleslaw', + 'coloring_material', 'combination_lock', 'pacifier', 'comic_book', + 'compass', 'computer_keyboard', 'condiment', 'cone', 'control', + 'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie', + 'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)', + 'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet', + 'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall', + 'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker', + 'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib', + 'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown', + 'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch', + 'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup', + 'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain', + 'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard', + 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk', + 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux', + 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher', + 'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup', + 'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin', + 'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly', + 'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit', + 'dresser', 'drill', 'drone', 'dropper', 'drum_(musical_instrument)', + 'drumstick', 'duck', 'duckling', 'duct_tape', 'duffel_bag', 'dumbbell', + 'dumpster', 'dustpan', 'eagle', 'earphone', 'earplug', 'earring', + 'easel', 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater', + 'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk', + 'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan', + 'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)', + 'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm', + 'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace', + 'fireplug', 'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl', + 'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flap', + 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)', + 'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal', + 'folding_chair', 'food_processor', 'football_(American)', + 'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car', + 'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice', + 'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage', + 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic', + 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'generator', + 'giant_panda', 'gift_wrap', 'ginger', 'giraffe', 'cincture', + 'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles', + 'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose', + 'gorilla', 'gourd', 'grape', 'grater', 'gravestone', 'gravy_boat', + 'green_bean', 'green_onion', 'griddle', 'grill', 'grits', 'grizzly', + 'grocery_bag', 'guitar', 'gull', 'gun', 'hairbrush', 'hairnet', + 'hairpin', 'halter_top', 'ham', 'hamburger', 'hammer', 'hammock', + 'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel', + 'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw', + 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil', 'headband', + 'headboard', 'headlight', 'headscarf', 'headset', + 'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet', + 'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog', + 'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah', + 'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce', + 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear', + 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate', + 'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board', + 'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey', + 'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak', + 'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono', + 'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit', + 'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)', + 'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)', + 'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard', + 'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather', + 'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade', 'lettuce', + 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb', + 'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor', + 'lizard', 'log', 'lollipop', 'speaker_(stereo_equipment)', 'loveseat', + 'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)', + 'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange', 'manger', + 'manhole', 'map', 'marker', 'martini', 'mascot', 'mashed_potato', + 'masher', 'mask', 'mast', 'mat_(gym_equipment)', 'matchbox', + 'mattress', 'measuring_cup', 'measuring_stick', 'meatball', 'medicine', + 'melon', 'microphone', 'microscope', 'microwave_oven', 'milestone', + 'milk', 'milk_can', 'milkshake', 'minivan', 'mint_candy', 'mirror', + 'mitten', 'mixer_(kitchen_tool)', 'money', + 'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor', + 'motor_scooter', 'motor_vehicle', 'motorcycle', 'mound_(baseball)', + 'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom', + 'music_stool', 'musical_instrument', 'nailfile', 'napkin', + 'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newspaper', + 'newsstand', 'nightshirt', 'nosebag_(for_animals)', + 'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker', + 'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil', + 'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'ostrich', + 'ottoman', 'oven', 'overalls_(clothing)', 'owl', 'packet', 'inkpad', + 'pad', 'paddle', 'padlock', 'paintbrush', 'painting', 'pajamas', + 'palette', 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', + 'pantyhose', 'papaya', 'paper_plate', 'paper_towel', 'paperback_book', + 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)', 'parasol', + 'parchment', 'parka', 'parking_meter', 'parrot', + 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport', + 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter', + 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg', + 'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box', + 'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)', + 'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet', + 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano', + 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow', + 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball', + 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)', + 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat', + 'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)', + 'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)', + 'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)', + 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato', + 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'pretzel', + 'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune', + 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher', + 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit', + 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish', + 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat', + 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt', + 'recliner', 'record_player', 'reflector', 'remote_control', + 'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map', + 'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade', + 'rolling_pin', 'root_beer', 'router_(computer_equipment)', + 'rubber_band', 'runner_(carpet)', 'plastic_bag', + 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin', + 'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)', + 'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)', + 'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse', + 'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf', + 'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver', + 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane', + 'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark', + 'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl', + 'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt', + 'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass', + 'shoulder_bag', 'shovel', 'shower_head', 'shower_cap', + 'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink', + 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole', + 'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)', + 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman', + 'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball', + 'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon', + 'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)', + 'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish', + 'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)', + 'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish', + 'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel', + 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer', + 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove', 'strainer', + 'strap', 'straw_(for_drinking)', 'strawberry', 'street_sign', + 'streetlight', 'string_cheese', 'stylus', 'subwoofer', 'sugar_bowl', + 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower', 'sunglasses', + 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants', 'sweatband', + 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit', 'sword', + 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table', + 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight', + 'tambourine', 'army_tank', 'tank_(storage_vessel)', + 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure', + 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup', + 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth', + 'telephone_pole', 'telephoto_lens', 'television_camera', + 'television_set', 'tennis_ball', 'tennis_racket', 'tequila', + 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread', + 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil', + 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven', + 'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush', + 'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel', + 'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light', + 'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', + 'tray', 'trench_coat', 'triangle_(musical_instrument)', 'tricycle', + 'tripod', 'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat', + 'turban', 'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)', + 'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn', + 'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest', + 'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture', + 'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick', + 'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe', + 'washbasin', 'automatic_washer', 'watch', 'water_bottle', + 'water_cooler', 'water_faucet', 'water_heater', 'water_jug', + 'water_gun', 'water_scooter', 'water_ski', 'water_tower', + 'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake', + 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream', + 'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)', + 'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket', + 'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon', + 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt', + 'yoke_(animal_equipment)', 'zebra', 'zucchini' + ] + + dataset_aliases = { 'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'], 'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'], @@ -496,7 +740,8 @@ def objects365v2_classes() -> list: 'oid_challenge': ['oid_challenge', 'openimages_challenge'], 'oid_v6': ['oid_v6', 'openimages_v6'], 'objects365v1': ['objects365v1', 'obj365v1'], - 'objects365v2': ['objects365v2', 'obj365v2'] + 'objects365v2': ['objects365v2', 'obj365v2'], + 'lvis': ['lvis', 'lvis_v1'], } diff --git a/mmdet/evaluation/metrics/lvis_metric.py b/mmdet/evaluation/metrics/lvis_metric.py index e4dd6141c0e..a861c6ee7b4 100644 --- a/mmdet/evaluation/metrics/lvis_metric.py +++ b/mmdet/evaluation/metrics/lvis_metric.py @@ -1,14 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. import itertools +import logging import os.path as osp import tempfile import warnings -from collections import OrderedDict +from collections import OrderedDict, defaultdict from typing import Dict, List, Optional, Sequence, Union import numpy as np +import torch +from mmengine.dist import (all_gather_object, broadcast_object_list, + is_main_process) +from mmengine.evaluator import BaseMetric +from mmengine.evaluator.metric import _to_cpu from mmengine.fileio import get_local_path -from mmengine.logging import MMLogger +from mmengine.logging import MMLogger, print_log from terminaltables import AsciiTable from mmdet.registry import METRICS @@ -18,6 +24,7 @@ try: import lvis + if getattr(lvis, '__version__', '0') >= '10.5.3': warnings.warn( 'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501 @@ -362,3 +369,166 @@ def compute_metrics(self, results: list) -> Dict[str, float]: if tmp_dir is not None: tmp_dir.cleanup() return eval_results + + +def _merge_lists(listA, listB, maxN, key): + result = [] + indA, indB = 0, 0 + while (indA < len(listA) or indB < len(listB)) and len(result) < maxN: + if (indB < len(listB)) and (indA >= len(listA) + or key(listA[indA]) < key(listB[indB])): + result.append(listB[indB]) + indB += 1 + else: + result.append(listA[indA]) + indA += 1 + return result + + +@METRICS.register_module() +class LVISFixedAPMetric(BaseMetric): + default_prefix: Optional[str] = 'lvis_fixed_ap' + + def __init__(self, + ann_file: str, + topk: int = 10000, + format_only: bool = False, + outfile_prefix: Optional[str] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + backend_args: dict = None) -> None: + + if lvis is None: + raise RuntimeError( + 'Package lvis is not installed. Please run "pip install ' + 'git+https://github.com/lvis-dataset/lvis-api.git".') + super().__init__(collect_device=collect_device, prefix=prefix) + + self.format_only = format_only + if self.format_only: + assert outfile_prefix is not None, 'outfile_prefix must be not' + 'None when format_only is True, otherwise the result files will' + 'be saved to a temp directory which will be cleaned up at the end.' + + self.outfile_prefix = outfile_prefix + self.backend_args = backend_args + + with get_local_path( + ann_file, backend_args=self.backend_args) as local_path: + self._lvis_api = LVIS(local_path) + + self.cat_ids = self._lvis_api.get_cat_ids() + + self.results = {} + self.topk = topk + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + cur_results = [] + for data_sample in data_samples: + pred = data_sample['pred_instances'] + xmin, ymin, xmax, ymax = pred['bboxes'].cpu().unbind(1) + boxes = torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), + dim=1).tolist() + + scores = pred['scores'].cpu().numpy() + labels = pred['labels'].cpu().numpy() + + if len(boxes) == 0: + continue + + cur_results.extend([{ + 'image_id': data_sample['img_id'], + 'category_id': self.cat_ids[labels[k]], + 'bbox': box, + 'score': scores[k], + } for k, box in enumerate(boxes)]) + + by_cat = defaultdict(list) + for ann in cur_results: + by_cat[ann['category_id']].append(ann) + + for cat, cat_anns in by_cat.items(): + if cat not in self.results: + self.results[cat] = [] + + cur = sorted( + cat_anns, key=lambda x: x['score'], reverse=True)[:self.topk] + self.results[cat] = _merge_lists( + self.results[cat], cur, self.topk, key=lambda x: x['score']) + + def compute_metrics(self, results: dict) -> dict: + logger: MMLogger = MMLogger.get_current_instance() + + new_results = [] + + missing_dets_cats = set() + for cat, cat_anns in results.items(): + if len(cat_anns) < self.topk: + missing_dets_cats.add(cat) + new_results.extend( + sorted(cat_anns, key=lambda x: x['score'], + reverse=True)[:self.topk]) + + if missing_dets_cats: + logger.info( + f'\n===\n' + f'{len(missing_dets_cats)} classes had less than {self.topk} ' + f'detections!\n Outputting {self.topk} detections for each ' + f'class will improve AP further.\n ===') + + new_results = LVISResults(self._lvis_api, new_results, max_dets=-1) + lvis_eval = LVISEval(self._lvis_api, new_results, iou_type='bbox') + params = lvis_eval.params + params.max_dets = -1 # No limit on detections per image. + lvis_eval.run() + lvis_eval.print_results() + metrics = { + k: v + for k, v in lvis_eval.results.items() if k.startswith('AP') + } + logger.info(f'mAP_copypaste: {metrics}') + return metrics + + def evaluate(self, size: int) -> dict: + if len(self.results) == 0: + print_log( + f'{self.__class__.__name__} got empty `self.results`. Please ' + 'ensure that the processed results are properly added into ' + '`self.results` in `process` method.', + logger='current', + level=logging.WARNING) + + all_cats = all_gather_object(self.results) + results = defaultdict(list) + for cats in all_cats: + for cat, cat_anns in cats.items(): + results[cat].extend(cat_anns) + + if is_main_process(): + # cast all tensors in results list to cpu + results = _to_cpu(results) + _metrics = self.compute_metrics(results) # type: ignore + # Add prefix to metric names + if self.prefix: + _metrics = { + '/'.join((self.prefix, k)): v + for k, v in _metrics.items() + } + metrics = [_metrics] + else: + metrics = [None] # type: ignore + + broadcast_object_list(metrics) + + # reset the results + self.results = {} + return metrics[0] diff --git a/mmdet/models/detectors/glip.py b/mmdet/models/detectors/glip.py index 13cfea960a8..4011e73d09f 100644 --- a/mmdet/models/detectors/glip.py +++ b/mmdet/models/detectors/glip.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import re import warnings from typing import Optional, Tuple, Union @@ -166,6 +167,27 @@ def create_positive_map_label_to_token(positive_map: Tensor, return positive_map_label_to_token +def clean_label_name(name: str) -> str: + name = re.sub(r'\(.*\)', '', name) + name = re.sub(r'_', ' ', name) + name = re.sub(r' ', ' ', name) + return name + + +def chunks(lst: list, n: int) -> list: + """Yield successive n-sized chunks from lst.""" + all_ = [] + for i in range(0, len(lst), n): + data_index = lst[i:i + n] + all_.append(data_index) + counter = 0 + for i in all_: + counter += len(i) + assert (counter == len(lst)) + + return all_ + + @MODELS.register_module() class GLIP(SingleStageDetector): """Implementation of `GLIP `_ @@ -207,6 +229,46 @@ def __init__(self, self._special_tokens = '. ' + def to_enhance_text_prompts(self, original_caption, enhanced_text_prompts): + caption_string = '' + tokens_positive = [] + for idx, word in enumerate(original_caption): + if word in enhanced_text_prompts: + enhanced_text_dict = enhanced_text_prompts[word] + if 'prefix' in enhanced_text_dict: + caption_string += enhanced_text_dict['prefix'] + start_i = len(caption_string) + if 'name' in enhanced_text_dict: + caption_string += enhanced_text_dict['name'] + else: + caption_string += word + end_i = len(caption_string) + tokens_positive.append([[start_i, end_i]]) + + if 'suffix' in enhanced_text_dict: + caption_string += enhanced_text_dict['suffix'] + else: + tokens_positive.append( + [[len(caption_string), + len(caption_string) + len(word)]]) + caption_string += word + + if idx != len(original_caption) - 1: + caption_string += self._special_tokens + return caption_string, tokens_positive + + def to_plain_text_prompts(self, original_caption): + caption_string = '' + tokens_positive = [] + for idx, word in enumerate(original_caption): + tokens_positive.append( + [[len(caption_string), + len(caption_string) + len(word)]]) + caption_string += word + if idx != len(original_caption) - 1: + caption_string += self._special_tokens + return caption_string, tokens_positive + def get_tokens_and_prompts( self, original_caption: Union[str, list, tuple], @@ -221,44 +283,14 @@ def get_tokens_and_prompts( original_caption = list( filter(lambda x: len(x) > 0, original_caption)) + original_caption = [clean_label_name(i) for i in original_caption] + if custom_entities and enhanced_text_prompts is not None: - caption_string = '' - tokens_positive = [] - for idx, word in enumerate(original_caption): - if word in enhanced_text_prompts: - enhanced_text_dict = enhanced_text_prompts[word] - if 'prefix' in enhanced_text_dict: - caption_string += enhanced_text_dict['prefix'] - start_i = len(caption_string) - if 'name' in enhanced_text_dict: - caption_string += enhanced_text_dict['name'] - else: - caption_string += word - end_i = len(caption_string) - tokens_positive.append([[start_i, end_i]]) - - if 'suffix' in enhanced_text_dict: - caption_string += enhanced_text_dict['suffix'] - else: - tokens_positive.append([[ - len(caption_string), - len(caption_string) + len(word) - ]]) - caption_string += word - - if idx != len(original_caption) - 1: - caption_string += self._special_tokens + caption_string, tokens_positive = self.to_enhance_text_prompts( + original_caption, enhanced_text_prompts) else: - caption_string = '' - tokens_positive = [] - for idx, word in enumerate(original_caption): - tokens_positive.append([[ - len(caption_string), - len(caption_string) + len(word) - ]]) - caption_string += word - if idx != len(original_caption) - 1: - caption_string += self._special_tokens + caption_string, tokens_positive = self.to_plain_text_prompts( + original_caption) tokenized = self.language_model.tokenizer([caption_string], return_tensors='pt') @@ -285,14 +317,73 @@ def get_tokens_positive_and_prompts( custom_entities: bool = False, enhanced_text_prompt: Optional[ConfigType] = None ) -> Tuple[dict, str, Tensor, list]: - tokenized, caption_string, tokens_positive, entities = \ - self.get_tokens_and_prompts( - original_caption, custom_entities, enhanced_text_prompt) - positive_map_label_to_token, positive_map = self.get_positive_map( - tokenized, tokens_positive) + chunked_size = self.test_cfg.get('chunked_size', -1) + if not self.training and chunked_size > 0: + assert isinstance(original_caption, + (list, tuple)) or custom_entities is True + all_output = self.get_tokens_positive_and_prompts_chunked( + original_caption, enhanced_text_prompt) + positive_map_label_to_token, \ + caption_string, \ + positive_map, \ + entities = all_output + else: + tokenized, caption_string, tokens_positive, entities = \ + self.get_tokens_and_prompts( + original_caption, custom_entities, enhanced_text_prompt) + positive_map_label_to_token, positive_map = self.get_positive_map( + tokenized, tokens_positive) + if tokenized.input_ids.shape[1] > self.language_model.max_tokens: + warnings.warn('Inputting a text that is too long will result ' + 'in poor prediction performance. ' + 'Please reduce the text length.') return positive_map_label_to_token, caption_string, \ positive_map, entities + def get_tokens_positive_and_prompts_chunked( + self, + original_caption: Union[list, tuple], + enhanced_text_prompts: Optional[ConfigType] = None): + chunked_size = self.test_cfg.get('chunked_size', -1) + original_caption = [clean_label_name(i) for i in original_caption] + + original_caption_chunked = chunks(original_caption, chunked_size) + ids_chunked = chunks( + list(range(1, + len(original_caption) + 1)), chunked_size) + + positive_map_label_to_token_chunked = [] + caption_string_chunked = [] + positive_map_chunked = [] + entities_chunked = [] + + for i in range(len(ids_chunked)): + if enhanced_text_prompts is not None: + caption_string, tokens_positive = self.to_enhance_text_prompts( + original_caption_chunked[i], enhanced_text_prompts) + else: + caption_string, tokens_positive = self.to_plain_text_prompts( + original_caption_chunked[i]) + tokenized = self.language_model.tokenizer([caption_string], + return_tensors='pt') + if tokenized.input_ids.shape[1] > self.language_model.max_tokens: + warnings.warn('Inputting a text that is too long will result ' + 'in poor prediction performance. ' + 'Please reduce the --chunked-size.') + positive_map_label_to_token, positive_map = self.get_positive_map( + tokenized, tokens_positive) + + caption_string_chunked.append(caption_string) + positive_map_label_to_token_chunked.append( + positive_map_label_to_token) + positive_map_chunked.append(positive_map) + entities_chunked.append(original_caption_chunked[i]) + + return positive_map_label_to_token_chunked, \ + caption_string_chunked, \ + positive_map_chunked, \ + entities_chunked + def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> Union[dict, list]: # TODO: Only open vocabulary tasks are supported for training now. @@ -376,12 +467,14 @@ def predict(self, - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ - text_prompts = [ - data_samples.text for data_samples in batch_data_samples - ] - enhanced_text_prompts = [ - data_samples.caption_prompt for data_samples in batch_data_samples - ] + text_prompts = [] + enhanced_text_prompts = [] + for data_samples in batch_data_samples: + text_prompts.append(data_samples.text) + if 'caption_prompt' in data_samples: + enhanced_text_prompts.append(data_samples.caption_prompt) + else: + enhanced_text_prompts.append(None) if 'custom_entities' in batch_data_samples[0]: # Assuming that the `custom_entities` flag @@ -409,18 +502,45 @@ def predict(self, token_positive_maps, text_prompts, _, entities = zip( *_positive_maps_and_prompts) - language_dict_features = self.language_model(list(text_prompts)) + visual_features = self.extract_feat(batch_inputs) - for i, data_samples in enumerate(batch_data_samples): - data_samples.token_positive_map = token_positive_maps[i] + if isinstance(text_prompts[0], list): + # chunked text prompts, only bs=1 is supported + assert len(batch_inputs) == 1 + count = 0 + results_list = [] + + entities = [[item for lst in entities[0] for item in lst]] + + for b in range(len(text_prompts[0])): + text_prompts_once = [text_prompts[0][b]] + token_positive_maps_once = token_positive_maps[0][b] + language_dict_features = self.language_model(text_prompts_once) + batch_data_samples[ + 0].token_positive_map = token_positive_maps_once + + pred_instances = self.bbox_head.predict( + copy.deepcopy(visual_features), + language_dict_features, + batch_data_samples, + rescale=rescale)[0] + + if len(pred_instances) > 0: + pred_instances.labels += count + count += len(token_positive_maps_once) + results_list.append(pred_instances) + results_list = [results_list[0].cat(results_list)] + else: + language_dict_features = self.language_model(list(text_prompts)) - visual_features = self.extract_feat(batch_inputs) + for i, data_samples in enumerate(batch_data_samples): + data_samples.token_positive_map = token_positive_maps[i] - results_list = self.bbox_head.predict( - visual_features, - language_dict_features, - batch_data_samples, - rescale=rescale) + results_list = self.bbox_head.predict( + visual_features, + language_dict_features, + batch_data_samples, + rescale=rescale) for data_sample, pred_instances, entity in zip(batch_data_samples, results_list, entities):