From 44419a4ae6b0f1fb18d84632104c5811d9b77ec0 Mon Sep 17 00:00:00 2001 From: 514flowey <1114811901@qq.com> Date: Sun, 20 Feb 2022 16:27:11 +0800 Subject: [PATCH 01/10] add HRSC2016Dataset --- .../preprocess/hrsc2016_preprocess_config.py | 26 +++ configs/s2anet/s2anet_r50_fpn_1x_hrsc2016.py | 182 ++++++++++++++++++ python/jdet/config/constant.py | 3 + python/jdet/data/HRSC2016.py | 41 ++++ python/jdet/data/__init__.py | 1 + .../jdet/data/devkits/conver_hrsc_to_mmdet.py | 85 ++++++++ python/jdet/data/devkits/hrsc_to_dota.py | 43 +++++ python/jdet/data/dota.py | 5 +- tools/preprocess.py | 27 +++ 9 files changed, 411 insertions(+), 2 deletions(-) create mode 100644 configs/preprocess/hrsc2016_preprocess_config.py create mode 100644 configs/s2anet/s2anet_r50_fpn_1x_hrsc2016.py create mode 100644 python/jdet/data/HRSC2016.py create mode 100644 python/jdet/data/devkits/conver_hrsc_to_mmdet.py create mode 100644 python/jdet/data/devkits/hrsc_to_dota.py diff --git a/configs/preprocess/hrsc2016_preprocess_config.py b/configs/preprocess/hrsc2016_preprocess_config.py new file mode 100644 index 00000000..470eed10 --- /dev/null +++ b/configs/preprocess/hrsc2016_preprocess_config.py @@ -0,0 +1,26 @@ +from numpy import source + + +type='HRSC2016' +source_dataset_path='/mnt/disk/flowey/dataset/HRSC2016' + +tasks=[ + dict( + label='train', + config=dict( + images_path=source_dataset_path+'/Train/AllImages', + xml_path=source_dataset_path+'/Train/Annotations', + imageset_file=source_dataset_path+'/Train/train.txt', + out_annotation_file=source_dataset_path+'/Train/labels.pkl', + ) + ), + dict( + label='test', + config=dict( + images_path=source_dataset_path+'/Test/AllImages', + xml_path=source_dataset_path+'/Test/Annotations', + imageset_file=source_dataset_path+'/Test/test.txt', + out_annotation_file=source_dataset_path+'/Test/labels.pkl', + ) + ) +] diff --git a/configs/s2anet/s2anet_r50_fpn_1x_hrsc2016.py b/configs/s2anet/s2anet_r50_fpn_1x_hrsc2016.py new file mode 100644 index 00000000..dc91b156 --- /dev/null +++ b/configs/s2anet/s2anet_r50_fpn_1x_hrsc2016.py @@ -0,0 +1,182 @@ +# model settings +model = dict( + type='S2ANet', + backbone=dict( + type='Resnet50', + frozen_stages=1, + return_stages=["layer1","layer2","layer3","layer4"], + pretrained= True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs="on_input", + num_outs=5), + bbox_head=dict( + type='S2ANetHead', + num_classes=2, + in_channels=256, + feat_channels=256, + stacked_convs=2, + with_orconv=True, + anchor_ratios=[1.0], + anchor_strides=[8, 16, 32, 64, 128], + anchor_scales=[4], + target_means=[.0, .0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0, 1.0], + loss_fam_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_fam_bbox=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), + loss_odm_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_odm_bbox=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms_rotated', iou_thr=0.1), + max_per_img=2000), + train_cfg=dict( + fam_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1, + iou_calculator=dict(type='BboxOverlaps2D_rotated')), + bbox_coder=dict(type='DeltaXYWHABBoxCoder', + target_means=(0., 0., 0., 0., 0.), + target_stds=(1., 1., 1., 1., 1.), + clip_border=True), + allowed_border=-1, + pos_weight=-1, + debug=False), + odm_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1, + iou_calculator=dict(type='BboxOverlaps2D_rotated')), + bbox_coder=dict(type='DeltaXYWHABBoxCoder', + target_means=(0., 0., 0., 0., 0.), + target_stds=(1., 1., 1., 1., 1.), + clip_border=True), + allowed_border=-1, + pos_weight=-1, + debug=False)) + ) + ) +base_data_root = '/mnt/disk/flowey/dataset/HRSC2016' +dataset = dict( + train=dict( + type="HRSC2016Dataset", + images_dir = base_data_root + '/Train/AllImages', + annotations_file = base_data_root + '/Train/labels.pkl', + transforms=[ + dict( + type="RotatedResize", + min_size=512, + max_size=800 + ), + dict(type='RotatedRandomFlip', prob=0.5), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False,) + + ], + batch_size=2, + num_workers=4, + shuffle=True, + filter_empty_gt=False + ), + val=dict( + type="HRSC2016Dataset", + images_dir = base_data_root + '/Test/AllImages', + annotations_file = base_data_root + '/Train/labels.pkl', + transforms=[ + dict( + type="RotatedResize", + min_size=512, + max_size=800 + ), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False), + ], + batch_size=2, + num_workers=4, + shuffle=False + ), + test=dict( + type="HRSC2016Dataset", + images_dir = base_data_root + '/Test/AllImages', + annotations_file = base_data_root + '/Train/labels.pkl', + transforms=[ + dict( + type="RotatedResize", + min_size=512, + max_size=800 + ), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False,), + ], + num_workers=4, + batch_size=1, + ) +) + +optimizer = dict( + type='SGD', + lr=0.01/4., #0.0,#0.01*(1/8.), + momentum=0.9, + weight_decay=0.0001, + grad_clip=dict( + max_norm=35, + norm_type=2)) + +scheduler = dict( + type='StepLR', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + milestones=[24, 33]) + + +logger = dict( + type="RunLogger") + +# when we the trained model from cshuan, image is rgb +max_epoch = 12 +eval_interval = 1 +checkpoint_interval = 1 +log_interval = 50 \ No newline at end of file diff --git a/python/jdet/config/constant.py b/python/jdet/config/constant.py index 8eec9c8b..3231c318 100644 --- a/python/jdet/config/constant.py +++ b/python/jdet/config/constant.py @@ -200,6 +200,8 @@ SSDD_CLASSES = ['ship'] +HRSC2016_CLASSES = ['ship'] + def get_classes_by_name(name): res = { 'VOC': VOC_CLASSES, @@ -213,6 +215,7 @@ def get_classes_by_name(name): 'FAIR': FAIR_CLASSES_, 'SSDD': SSDD_CLASSES, 'SSDD+': SSDD_CLASSES, + 'HRSC2016': HRSC2016_CLASSES, } assert(name in res) return res[name] \ No newline at end of file diff --git a/python/jdet/data/HRSC2016.py b/python/jdet/data/HRSC2016.py new file mode 100644 index 00000000..c6ba6b66 --- /dev/null +++ b/python/jdet/data/HRSC2016.py @@ -0,0 +1,41 @@ +from jdet.data.dota import DOTADataset +from jdet.utils.registry import DATASETS +from jdet.config.constant import HRSC2016_CLASSES, get_classes_by_name +from jdet.utils.general import check_dir +from tqdm import tqdm +from PIL import Image +import os.path as osp +import xml.etree.ElementTree as ET +import numpy as np + +def list_from_file(filename, prefix='', offset=0, max_num=0): + """Load a text file and parse the content as a list of strings. + + Args: + filename (str): Filename. + prefix (str): The prefix to be inserted to the begining of each item. + offset (int): The offset of lines. + max_num (int): The maximum number of lines to be read, + zeros and negatives mean no limitation. + + Returns: + list[str]: A list of strings. + """ + cnt = 0 + item_list = [] + with open(filename, 'r') as f: + for _ in range(offset): + f.readline() + for line in f: + if max_num > 0 and cnt >= max_num: + break + item_list.append(prefix + line.rstrip('\n')) + cnt += 1 + return item_list + + +@DATASETS.register_module() +class HRSC2016Dataset(DOTADataset): + def __init__(self,*arg,**kwargs): + super().__init__(*arg,**kwargs) + self.CLASSES = HRSC2016_CLASSES diff --git a/python/jdet/data/__init__.py b/python/jdet/data/__init__.py index 7aa3b52c..29134bd4 100644 --- a/python/jdet/data/__init__.py +++ b/python/jdet/data/__init__.py @@ -5,3 +5,4 @@ from .dota import DOTADataset from .fair import FAIRDataset from .ssdd_plus import SSDDDataset +from .HRSC2016 import HRSC2016Dataset diff --git a/python/jdet/data/devkits/conver_hrsc_to_mmdet.py b/python/jdet/data/devkits/conver_hrsc_to_mmdet.py new file mode 100644 index 00000000..2188e509 --- /dev/null +++ b/python/jdet/data/devkits/conver_hrsc_to_mmdet.py @@ -0,0 +1,85 @@ +import os +import os.path as osp +import xml.etree.ElementTree as ET + +import pickle +import numpy as np +from PIL import Image +from jdet.config.constant import get_classes_by_name +from tqdm import tqdm + +def list_from_file(filename, prefix='', offset=0, max_num=0): + """Load a text file and parse the content as a list of strings. + + Args: + filename (str): Filename. + prefix (str): The prefix to be inserted to the begining of each item. + offset (int): The offset of lines. + max_num (int): The maximum number of lines to be read, + zeros and negatives mean no limitation. + + Returns: + list[str]: A list of strings. + """ + cnt = 0 + item_list = [] + with open(filename, 'r') as f: + for _ in range(offset): + f.readline() + for line in f: + if max_num > 0 and cnt >= max_num: + break + item_list.append(prefix + line.rstrip('\n')) + cnt += 1 + return item_list + +def convert_hrsc_to_mmdet(img_path, xml_path, ann_file, out_path, convert_labels=True, filter_empty_gt=True, ext='.bmp', type="HRSC2016"): + """Generate .pkl format annotation that is consistent with mmdet. + Args: + image_path: path of all images + xml_path: path for annotations in xml format + ann_file: imageset file + out_path: output pkl file path + trainval: trainval or test + """ + label_ids = {name: i + 1 for i, name in enumerate(get_classes_by_name(type))} + img_ids = list_from_file(ann_file) + data_dict = [] + for img_id in tqdm(img_ids): + img = Image.open(osp.join(img_path, f'{img_id}{ext}')) + img_info = {} + img_info['filename'] = f'{img_id}{ext}' + img_info['height'] = img.height + img_info['width'] = img.width + if convert_labels: + xml_file = osp.join(xml_path, f'{img_id}.xml') + if not osp.exists(xml_file): + print(f'Annotation: {xml_file} Not Exist') + continue + tree = ET.parse(xml_file) + root = tree.getroot() + bboxes, bboxes_ignore, labels, labels_ignore = [], [], [], [] + for obj in root.findall('HRSC_Objects')[0].findall('HRSC_Object'): + label = label_ids['ship'] + bbox = [] + for key in ['mbox_cx', 'mbox_cy', 'mbox_w', 'mbox_h', 'mbox_ang']: + bbox.append(obj.find(key).text) + difficult = int(obj.find('difficult').text) + if difficult: + bboxes_ignore.append(bbox) + labels_ignore.append(label) + else: + bboxes.append(bbox) + labels.append(label) + if filter_empty_gt and (len(labels)+len(labels_ignore) == 0): + continue + ann = {} + ann['bboxes'] = np.array(bboxes, dtype=np.float32) + ann['labels'] = np.array(labels, dtype=np.int64) + ann['bboxes_ignore'] = np.array(bboxes_ignore, dtype=np.float32) + ann['labels_ignore'] = np.array(labels_ignore, dtype=np.int64) + img_info['ann'] = ann + data_dict.append(img_info) + print("left images:", len(data_dict)) + pickle.dump(data_dict, open(out_path, "wb")) + diff --git a/python/jdet/data/devkits/hrsc_to_dota.py b/python/jdet/data/devkits/hrsc_to_dota.py new file mode 100644 index 00000000..1490685c --- /dev/null +++ b/python/jdet/data/devkits/hrsc_to_dota.py @@ -0,0 +1,43 @@ +import os +import os.path as osp +import xml.etree.cElementTree as ET +from tqdm import tqdm +from jdet.data.devkits.conver_hrsc_to_mmdet import list_from_file +from jdet.models.boxes.box_ops import rotated_box_to_poly_single +import cv2 + +def xml2txt(xml_file, txt_file): + tree = ET.parse(xml_file) + root = tree.getroot() + out_lines = [] + for obj in root.findall('HRSC_Objects')[0].findall('HRSC_Object'): + label = 'ship' + bbox = [] + for key in ['mbox_cx', 'mbox_cy', 'mbox_w', 'mbox_h', 'mbox_ang']: + bbox.append(obj.find(key).text) + poly = rotated_box_to_poly_single(bbox) + difficult = int(obj.find('difficult').text) + temp_txt = '{} {} {} {} {} {} {} {} {} {}\n'.format( + poly[0], poly[1], poly[2], poly[3], poly[4], poly[5], poly[6], poly[7], + label, difficult + ) + out_lines.append(temp_txt) + + f = open(txt_file, "w") + f.writelines(out_lines) + f.close() + +def hrsc_to_dota(img_path, xml_path, ann_file, out_path, convert_label=True, ext='.bmp'): + out_img_path = osp.join(out_path, "images") + out_anno_path = osp.join(out_path, "labelTxt") + os.makedirs(out_img_path, exist_ok=True) + os.makedirs(out_anno_path, exist_ok=True) + img_ids = list_from_file(ann_file) + for img_id in tqdm(img_ids): + # TODO: add process, or replace with copy + img = cv2.imread(osp.join(img_path, f'{img_id}{ext}')) + cv2.imwrite(osp.join(out_img_path, f'{img_id}.png'), img) + if (convert_label): + xml_file = osp.join(xml_path, f'{img_id}.xml') + txt_file = osp.join(out_anno_path, f'{img_id}.txt') + xml2txt(xml_file, txt_file) \ No newline at end of file diff --git a/python/jdet/data/dota.py b/python/jdet/data/dota.py index 06545267..27c2801f 100644 --- a/python/jdet/data/dota.py +++ b/python/jdet/data/dota.py @@ -22,10 +22,11 @@ def s2anet_post(result): @DATASETS.register_module() class DOTADataset(CustomDataset): - def __init__(self,*arg,balance_category=False,version='1',**kwargs): + def __init__(self,*arg,balance_category=False,version='1',use_07_metric=False,**kwargs): assert version in ['1', '1_5', '2'] self.CLASSES = get_classes_by_name('DOTA'+version) super().__init__(*arg,**kwargs) + self.use_07_metric=use_07_metric if balance_category: self.img_infos = self._balance_categories() self.total_len = len(self.img_infos) @@ -132,7 +133,7 @@ def evaluate(self,results,work_dir,epoch,logger=None,save=True): diffculty = diffculty.astype(bool) g = np.concatenate([g,dg]) classname_gts[idx] = {"box":g.copy(),"det":[False for i in range(len(g))],'difficult':diffculty.copy()} - rec, prec, ap = voc_eval_dota(c_dets,classname_gts,iou_func=iou_poly) + rec, prec, ap = voc_eval_dota(c_dets,classname_gts,iou_func=iou_poly,use_07_metric=self.use_07_metric) aps["eval/"+str(i+1)+"_"+classname+"_AP"]=ap map = sum(list(aps.values()))/len(aps) aps["eval/0_meanAP"]=map diff --git a/tools/preprocess.py b/tools/preprocess.py index 0b4036ec..54d72fce 100644 --- a/tools/preprocess.py +++ b/tools/preprocess.py @@ -1,3 +1,4 @@ +from lib2to3.pytree import convert import cv2 import argparse import os @@ -5,6 +6,7 @@ from jdet.config import init_cfg, get_cfg from jdet.data.devkits.ImgSplit_multi_process import process from jdet.data.devkits.convert_data_to_mmdet import convert_data_to_mmdet +from jdet.data.devkits.conver_hrsc_to_mmdet import convert_hrsc_to_mmdet from jdet.data.devkits.fair_to_dota import fair_to_dota from jdet.utils.general import is_win @@ -20,6 +22,31 @@ def clear(cfg): os.system(f"rm -rf {os.path.join(cfg.target_dataset_path)}") def run(cfg): + if cfg.type=='HRSC2016': + for task in cfg.tasks: + print('==============') + cfg_ = task.config + label = task.label + # TODO: support convert hrsc2016 to dota + convert_mmdet = True if cfg_.convert_mmdet is None else cfg_.convert_mmdet + if convert_mmdet: + print("convert to mmdet:", label) + images_path = cfg_.images_path + xml_path = cfg_.xml_path + imageset_file = cfg_.imageset_file + out_file = cfg_.out_annotation_file + assert(images_path is not None) + assert(xml_path is not None) + assert(imageset_file is not None) + assert(out_file is not None) + convert_labels = True if cfg_.convert_labels is None else cfg_.convert_labels + filter_empty_gt = label=='train' if cfg_.filter_empty_gt is None else cfg_.filter_empty_gt + convert_hrsc_to_mmdet(images_path, xml_path, imageset_file, out_file, + convert_labels=convert_labels, + filter_empty_gt=filter_empty_gt, + type=cfg.type) + return + if cfg.type=='SSDD+' or cfg.type=='SSDD': for task in cfg.convert_tasks: print('==============') From c2d6f27e2e7f4979bfaa005828358c345c7d111c Mon Sep 17 00:00:00 2001 From: 514flowey <1114811901@qq.com> Date: Sun, 20 Feb 2022 16:34:06 +0800 Subject: [PATCH 02/10] modify dataset path --- configs/preprocess/hrsc2016_preprocess_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/preprocess/hrsc2016_preprocess_config.py b/configs/preprocess/hrsc2016_preprocess_config.py index 470eed10..632cb156 100644 --- a/configs/preprocess/hrsc2016_preprocess_config.py +++ b/configs/preprocess/hrsc2016_preprocess_config.py @@ -2,7 +2,7 @@ type='HRSC2016' -source_dataset_path='/mnt/disk/flowey/dataset/HRSC2016' +source_dataset_path='/home/flowey/dataset/HRSC2016' tasks=[ dict( From f5c42372a4a69a9d696262dfe0ef2900e93bc0b9 Mon Sep 17 00:00:00 2001 From: xzk Date: Sun, 27 Feb 2022 16:53:13 +0800 Subject: [PATCH 03/10] add hrsc2016 docs --- README.md | 2 + .../preprocess/hrsc2016_preprocess_config.py | 3 - ...c2016.py => s2anet_r50_fpn_3x_hrsc2016.py} | 13 ++-- docs/hrsc2016.md | 71 +++++++++++++++++++ python/jdet/data/devkits/data_merge.py | 8 +++ 5 files changed, 88 insertions(+), 9 deletions(-) rename configs/s2anet/{s2anet_r50_fpn_1x_hrsc2016.py => s2anet_r50_fpn_3x_hrsc2016.py} (95%) create mode 100644 docs/hrsc2016.md diff --git a/README.md b/README.md index 20fdc551..82ee1b09 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,8 @@ FAIR Dataset: [fair.md](docs/fair.md) SSDD/SSDD+: [ssdd.md](docs/ssdd.md) +HRSC2016: [hrsc2016.md](docs/hrsc2016.md) + You can also build your own dataset by convert your datas to DOTA format. ### Config JDet defines the used model, dataset and training/testing method by `config-file`, please check the [config.md](docs/config.md) to learn how it works. diff --git a/configs/preprocess/hrsc2016_preprocess_config.py b/configs/preprocess/hrsc2016_preprocess_config.py index 632cb156..c99fba76 100644 --- a/configs/preprocess/hrsc2016_preprocess_config.py +++ b/configs/preprocess/hrsc2016_preprocess_config.py @@ -1,6 +1,3 @@ -from numpy import source - - type='HRSC2016' source_dataset_path='/home/flowey/dataset/HRSC2016' diff --git a/configs/s2anet/s2anet_r50_fpn_1x_hrsc2016.py b/configs/s2anet/s2anet_r50_fpn_3x_hrsc2016.py similarity index 95% rename from configs/s2anet/s2anet_r50_fpn_1x_hrsc2016.py rename to configs/s2anet/s2anet_r50_fpn_3x_hrsc2016.py index dc91b156..604d3620 100644 --- a/configs/s2anet/s2anet_r50_fpn_1x_hrsc2016.py +++ b/configs/s2anet/s2anet_r50_fpn_3x_hrsc2016.py @@ -80,7 +80,7 @@ debug=False)) ) ) -base_data_root = '/mnt/disk/flowey/dataset/HRSC2016' +base_data_root = '/home/flowey/dataset/HRSC2016' dataset = dict( train=dict( type="HRSC2016Dataset", @@ -111,7 +111,8 @@ val=dict( type="HRSC2016Dataset", images_dir = base_data_root + '/Test/AllImages', - annotations_file = base_data_root + '/Train/labels.pkl', + annotations_file = base_data_root + '/Test/labels.pkl', + use_07_metric=False, transforms=[ dict( type="RotatedResize", @@ -132,9 +133,9 @@ shuffle=False ), test=dict( - type="HRSC2016Dataset", + type="ImageDataset", + dataset_type="HRSC2016", images_dir = base_data_root + '/Test/AllImages', - annotations_file = base_data_root + '/Train/labels.pkl', transforms=[ dict( type="RotatedResize", @@ -151,7 +152,7 @@ to_bgr=False,), ], num_workers=4, - batch_size=1, + batch_size=2, ) ) @@ -176,7 +177,7 @@ type="RunLogger") # when we the trained model from cshuan, image is rgb -max_epoch = 12 +max_epoch = 36 eval_interval = 1 checkpoint_interval = 1 log_interval = 50 \ No newline at end of file diff --git a/docs/hrsc2016.md b/docs/hrsc2016.md new file mode 100644 index 00000000..4a054031 --- /dev/null +++ b/docs/hrsc2016.md @@ -0,0 +1,71 @@ +# Using JDet with HRSC2016 +Using JDet with Ship Detection Dataset (HRSC2016). +## Data Preparing +save to `$HRSC_PATH$` as: +``` +$HRSC_PATH$ +├── Train +| ├──... +| ├──train.txt +| ├──AllImages +| | ├──100000001.bmp +| | ├──100000002.bmp +| | └──... +| └──Annotations +| ├──100000001.xml +| ├──100000002.xml +| └──... +└──Test + ├──... + ├──test.txt + ├──AllImages + | ├──100000003.bmp + | ├──100000005.bmp + | └──... + └──Annotations + ├──100000003.xml + ├──100000005.xml + └──... +``` +## Data Preprocessing +We need prepare labels into pkl annotation file before training and testing. +``` +cd $JDet_PATH$ +``` +We can set how the HRSC2016 is preprocessed by editing the `configs/preprocess/hrsc2016_preprocess_config.py`: +```python +type='HRSC2016' +source_dataset_path='/home/flowey/dataset/HRSC2016' + +tasks=[ + dict( + label='train', + config=dict( + images_path=source_dataset_path+'/Train/AllImages', + xml_path=source_dataset_path+'/Train/Annotations', + imageset_file=source_dataset_path+'/Train/train.txt', + out_annotation_file=source_dataset_path+'/Train/labels.pkl', + ) + ), + dict( + label='test', + config=dict( + images_path=source_dataset_path+'/Test/AllImages', + xml_path=source_dataset_path+'/Test/Annotations', + imageset_file=source_dataset_path+'/Test/test.txt', + out_annotation_file=source_dataset_path+'/Test/labels.pkl', + ) + ) +]``` +We need to set `out_annotation_file` for output pkl annotation file. +Finally, run the following script for preprocessing: +``` +python tools/preprocess.py --config-file configs/preprocess/hrsc2016_preprocess_config.py +``` +For the way of configuring the processed HRSC2016 dataset in the model config file, please refer to `$JDet_PATH$/configs/s2anet/s2anet_r50_fpn_3x_hrsc2016.py` + +## Data Postprocessing +Task 'test' will generate detection results at `$JDet_PATH$/work_dirs/s2anet_r50_fpn_1x_hrsc2016/test/submit_36/ship.txt` for example, but AP is not calculated. Only task 'val' calcuates AP. For example, we can judge the model by +``` +python tools/run_net.py --config-file configs/s2anet/s2anet_r50_fpn_3x_hrsc2016.py --task=val +``` diff --git a/python/jdet/data/devkits/data_merge.py b/python/jdet/data/devkits/data_merge.py index eabb7fac..1da716fe 100644 --- a/python/jdet/data/devkits/data_merge.py +++ b/python/jdet/data/devkits/data_merge.py @@ -1,3 +1,4 @@ +from multiprocessing.spawn import prepare import shutil import jittor as jt from jdet.config.constant import get_classes_by_name @@ -54,6 +55,13 @@ def data_merge(result_pkl, save_path, final_path,dataset_type): mergebypoly(save_path,final_path) def data_merge_result(result_pkl,work_dir,epoch,name,dataset_type,images_dir=""): + if dataset_type in ["HRSC2016"]: + save_path = os.path.join(work_dir, f"test/submit_{epoch}") + if (os.path.exists(save_path)): + shutil.rmtree(save_path) + classes = get_classes_by_name(dataset_type) + prepare_data(result_pkl, save_path, classes) + return assert dataset_type in ["FAIR", "DOTA", "DOTA1_5", "DOTA2"], "need to set dataset.test.dataset_type in the config file. FAIR, DOTA, DOTA1_5 and DOTA2 are supported" print("Merge results...") save_path = os.path.join(work_dir, f"test/submit_{epoch}/before_nms") From caf34da6c6008fa86f5e5197f5ecc47fd5eb3ac4 Mon Sep 17 00:00:00 2001 From: 514flowey <1114811901@qq.com> Date: Thu, 10 Mar 2022 16:37:34 +0800 Subject: [PATCH 04/10] init SAN --- configs/san10_pairwise.py | 95 +++ python/jdet/data/ILSVRC2012.py | 89 +++ python/jdet/data/__init__.py | 1 + python/jdet/data/transforms.py | 1 + python/jdet/models/losses/__init__.py | 3 +- python/jdet/models/losses/san_loss.py | 60 ++ python/jdet/models/networks/__init__.py | 1 + python/jdet/models/networks/san.py | 198 ++++++ python/jdet/ops/san_aggregations.py | 466 ++++++++++++++ python/jdet/ops/san_subtractions.py | 780 ++++++++++++++++++++++++ 10 files changed, 1693 insertions(+), 1 deletion(-) create mode 100644 configs/san10_pairwise.py create mode 100644 python/jdet/data/ILSVRC2012.py create mode 100644 python/jdet/models/losses/san_loss.py create mode 100644 python/jdet/models/networks/san.py create mode 100644 python/jdet/ops/san_aggregations.py create mode 100644 python/jdet/ops/san_subtractions.py diff --git a/configs/san10_pairwise.py b/configs/san10_pairwise.py new file mode 100644 index 00000000..1ccd98a4 --- /dev/null +++ b/configs/san10_pairwise.py @@ -0,0 +1,95 @@ +model = dict( + type='SAN', + sa_type=0, + layers=[2, 1, 2, 4, 1], + kernels=[3, 7, 7, 7, 7], + num_classes=1000, + loss=dict( + type='SAMSmoothLoss', + eps=0.1 + ), + loss_prepare=False +) + +# dataset settings +dataset_type = 'ILSVRCDataset' +dataset = dict( + imgs_per_gpu=2, + workers_per_gpu=4, + train=dict( + type=dataset_type, + images_dir='/home/flowey/dataset/ILSVRC2012/train/', + transforms=[ + dict(type = "Resize", ## TODO: implement RandomRotatedCrop + min_size = 224, + max_size = 224, + ), + dict( + type = "RotatedRandomFlip", + prob = 0.5, + direction="horizontal", + ), + dict( + type = "Normalize", ## unknown normalize + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=True), + ], + batch_size=2, + ), + val=dict( + type=dataset_type, + batch_size=128, + images_dir='/home/flowey/dataset/ILSVRC2012/train/', + transforms=[ + dict(type = "Resize", + min_size = 224, + max_size = 224, + ), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=True), + ], + ), + test=dict( + type="ImageDataset", + images_dir='/mnt/disk/lxl/dataset/DOTA_1024/test_split/images/', + transforms=[ + dict(type = "Resize", + min_size = 224, + max_size = 224, + ), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=True), + ], + ) +) +# optimizer +optimizer = dict( + type='SGD', + lr=0.01, + momentum=0.9, + weight_decay=0.0001, + grad_clip=dict( ##grad_clip: not sure + max_norm=35, + norm_type=2)) + +# learning policy +scheduler = dict( + type='CosineAnnealingLR', + warmup='linear', ##warmup: not sure + warmup_iters=500, + warmup_ratio=1.0 / 3, + max_steps=100) + +logger = dict( + type="RunLogger") +max_epoch = 100 +eval_interval = 25 +checkpoint_interval = 10 +log_interval = 20 \ No newline at end of file diff --git a/python/jdet/data/ILSVRC2012.py b/python/jdet/data/ILSVRC2012.py new file mode 100644 index 00000000..b7d0d131 --- /dev/null +++ b/python/jdet/data/ILSVRC2012.py @@ -0,0 +1,89 @@ +from PIL import Image +import numpy as np +import os + +from jdet.utils.registry import DATASETS +from .transforms import Compose + +import jittor as jt +import os +from jittor.dataset import Dataset +import jdet + +@DATASETS.register_module() +class ILSVRCDataset(Dataset): + """ ILSVRCDataset + Load image for ILSVRC2012. + prepare data as format below: + + images_dir/label1/img1.png + images_dir/label1/img2.png + ... + images_dir/label2/img1.png + images_dir/label2/img2.png + """ + def __init__(self,images_dir=None,transforms=None,batch_size=1,num_workers=0,shuffle=False,drop_last=False): + super(ILSVRCDataset,self).__init__(batch_size=batch_size,num_workers=num_workers,shuffle=shuffle,drop_last=drop_last) + self.classes, self.class_to_idx = self._load_labels(images_dir=images_dir) + self.images, self.labels = self._load_images(images_dir=images_dir) + self.total_len = len(self.labels) + + if isinstance(transforms,list): + transforms = Compose(transforms) + if transforms is not None and not callable(transforms): + raise TypeError("transforms must be list or callable") + self.transforms = transforms + + def _load_labels(self, images_dir): + classes = sorted([d.name for d in os.scandir(images_dir) if d.is_dir()]) + class_to_idx = {v:k for k,v in enumerate(classes)} + return classes, class_to_idx + + def _load_images(self, images_dir): + images, labels = [], [] + for label in os.listdir(images_dir): + label_dir = os.path.join(images_dir, label) + if os.path.isdir(label_dir): + if label not in self.class_to_idx.keys(): + raise ValueError("unknow class {}".format(label)) + for name in os.listdir(label_dir): + if (jdet.utils.general.is_img(name)): + images.append(os.path.join(images_dir, label, name)) + labels.append(self.class_to_idx[label]) + return images, labels + + def collate_batch(self,batch): + imgs = [] + anns = [] + max_width = 0 + max_height = 0 + for image,ann in batch: + height,width = image.shape[-2],image.shape[-1] + max_width = max(max_width,width) + max_height = max(max_height,height) + imgs.append(image) + anns.append(ann) + N = len(imgs) + batch_imgs = np.zeros((N,3,max_height,max_width),dtype=np.float32) + for i,image in enumerate(imgs): + batch_imgs[i,:,:image.shape[-2],:image.shape[-1]] = image + + return batch_imgs,anns + + + def __getitem__(self,index): + if "BATCH_IDX" in os.environ: + index = int(os.environ['BATCH_IDX']) + + img = Image.open(self.images[index]).convert("RGB") + targets = dict( + ori_img_size=img.size, + img_size=img.size, + scale_factor=1., + img_file = self.images[index], + img_label = self.labels[index] + ) + + if self.transforms: + img,targets = self.transforms(img,targets) + return img,targets diff --git a/python/jdet/data/__init__.py b/python/jdet/data/__init__.py index 29134bd4..22afb70a 100644 --- a/python/jdet/data/__init__.py +++ b/python/jdet/data/__init__.py @@ -6,3 +6,4 @@ from .fair import FAIRDataset from .ssdd_plus import SSDDDataset from .HRSC2016 import HRSC2016Dataset +from .ILSVRC2012 import ILSVRCDataset diff --git a/python/jdet/data/transforms.py b/python/jdet/data/transforms.py index 508c8442..967e089c 100644 --- a/python/jdet/data/transforms.py +++ b/python/jdet/data/transforms.py @@ -1,4 +1,5 @@ import random +from tokenize import Number import jittor as jt import cv2 import numpy as np diff --git a/python/jdet/models/losses/__init__.py b/python/jdet/models/losses/__init__.py index 7b27b2c5..df819fff 100644 --- a/python/jdet/models/losses/__init__.py +++ b/python/jdet/models/losses/__init__.py @@ -2,4 +2,5 @@ from .focal_loss import FocalLoss from .cross_entropy_loss import CrossEntropyLoss from .l1_loss import L1Loss -from .poly_iou_loss import PolyIoULoss \ No newline at end of file +from .poly_iou_loss import PolyIoULoss +from .san_loss import SANMixUpLoss, SAMSmoothLoss \ No newline at end of file diff --git a/python/jdet/models/losses/san_loss.py b/python/jdet/models/losses/san_loss.py new file mode 100644 index 00000000..ddeb17d0 --- /dev/null +++ b/python/jdet/models/losses/san_loss.py @@ -0,0 +1,60 @@ +import jittor as jt +import numpy as np +from jittor import nn +from jdet.utils.registry import LOSSES + +def mixup_data(x, y, alpha=0.2): + '''Returns mixed inputs, pairs of targets, and lambda''' + if alpha > 0: + lam = np.random.beta(alpha, alpha) + else: + lam = 1 + index = jt.randperm(x.shape[0]) + x = lam * x + (1 - lam) * x[index, :] + y_a, y_b = y, y[index] + return x, y_a, y_b, lam + + +def mixup_loss(output, target_a, target_b, lam=1.0, eps=0.0): + w = jt.zeros_like(output).scatter_(1, target_a.unsqueeze(1), jt.array(1)) + w = w * (1 - eps) + (1 - w) * eps / (output.shape[1] - 1) + log_prob = nn.log_softmax(output, dim=1) + loss_a = (-w * log_prob).sum(dim=1).mean() + + w = jt.zeros_like(output).scatter_(1, target_b.unsqueeze(1), jt.array(1)) + w = w * (1 - eps) + (1 - w) * eps / (output.shape[1] - 1) + log_prob = nn.log_softmax(output, dim=1) + loss_b = (-w * log_prob).sum(dim=1).mean() + return lam * loss_a + (1 - lam) * loss_b + + +def smooth_loss(output, target, eps=0.1): + w = jt.zeros_like(output).scatter_(1, target.unsqueeze(1), jt.array(1)) + w = w * (1 - eps) + (1 - w) * eps / (output.shape[1] - 1) + log_prob = nn.log_softmax(output, dim=1) + loss = (-w * log_prob).sum(dim=1).mean() + return loss + +@LOSSES.register_module() +class SANMixUpLoss(nn.Module): + def __init__(self, alpha=0.2, eps=0.0): + super(SANMixUpLoss, self).__init__() + self.alpha = alpha + self.eps = eps + self.target_a, self.target_b, self.lam = None, None, None + + def prepare(self, input, target): + input, self.target_a, self.target_b, self.lam = mixup_data(input, target, self.alpha) + return input + + def execute(self, output, target=None): + return mixup_loss(output, self.target_a, self.target_b, self.lam, self.eps) + +@LOSSES.register_module() +class SAMSmoothLoss(nn.Module): + def __init__(self, eps=0.1): + super(SAMSmoothLoss).__init__() + self.eps = eps + + def execute(self, output, target): + return smooth_loss(output, target, self.eps) diff --git a/python/jdet/models/networks/__init__.py b/python/jdet/models/networks/__init__.py index c52343af..16f1a67a 100644 --- a/python/jdet/models/networks/__init__.py +++ b/python/jdet/models/networks/__init__.py @@ -7,4 +7,5 @@ from .faster_rcnn_obb import FasterRCNNOBB from .roi_transformer import RoITransformer from .fcos import FCOS +from .san import SAN __all__ = [] \ No newline at end of file diff --git a/python/jdet/models/networks/san.py b/python/jdet/models/networks/san.py new file mode 100644 index 00000000..89bad303 --- /dev/null +++ b/python/jdet/models/networks/san.py @@ -0,0 +1,198 @@ +import jittor as jt +import jittor.nn as nn + +from jittor.nn import _pair +from jdet.ops.san_aggregations import aggregation +from jdet.ops.san_subtractions import subtraction, subtraction2 +from jdet.utils.registry import MODELS, LOSSES, build_from_cfg +from jdet.models.losses import SANMixUpLoss + + +def conv1x1(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def position(H, W): + loc_w = jt.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1) + loc_h = jt.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W) + loc = jt.concat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0) + return loc + +class Subtraction(nn.Module): + def __init__(self, kernel_size, stride, padding, dilation, pad_mode): + super(Subtraction, self).__init__() + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.pad_mode = pad_mode + + def execute(self, input): + return subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) + +class Subtraction2(nn.Module): + def __init__(self, kernel_size, stride, padding, dilation, pad_mode): + super(Subtraction2, self).__init__() + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.pad_mode = pad_mode + + def execute(self, input1, input2): + return subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) + +class Aggregation(nn.Module): + def __init__(self, kernel_size, stride, padding, dilation, pad_mode): + super(Aggregation, self).__init__() + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.pad_mode = pad_mode + + def execute(self, input, weight): + return aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) + +class Unfold(nn.Module): + def __init__(self, kernel_size, dilation=1, padding=0, stride=1): + super(Unfold, self).__init__() + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + + def execute(self, x): + return nn.unfold(x, kernel_size=self.kernel_size, dilation=self.dilation, padding=self.padding, stride=self.stride) + +class SAM(nn.Module): + def __init__(self, sa_type, in_planes, rel_planes, out_planes, share_planes, kernel_size=3, stride=1, dilation=1): + super(SAM, self).__init__() + self.sa_type, self.kernel_size, self.stride = sa_type, kernel_size, stride + self.conv1 = nn.Conv2d(in_planes, rel_planes, kernel_size=1) + self.conv2 = nn.Conv2d(in_planes, rel_planes, kernel_size=1) + self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1) + if sa_type == 0: + self.conv_w = nn.Sequential(nn.BatchNorm2d(rel_planes + 2), nn.ReLU(), + nn.Conv2d(rel_planes + 2, rel_planes, kernel_size=1, bias=False), + nn.BatchNorm2d(rel_planes), nn.ReLU(), + nn.Conv2d(rel_planes, out_planes // share_planes, kernel_size=1)) + self.conv_p = nn.Conv2d(2, 2, kernel_size=1) + self.subtraction = Subtraction(kernel_size, stride, (dilation * (kernel_size - 1) + 1) // 2, dilation, pad_mode=1) + self.subtraction2 = Subtraction2(kernel_size, stride, (dilation * (kernel_size - 1) + 1) // 2, dilation, pad_mode=1) + self.softmax = nn.Softmax(dim=-2) + else: + self.conv_w = nn.Sequential(nn.BatchNorm2d(rel_planes * (pow(kernel_size, 2) + 1)), nn.ReLU(), + nn.Conv2d(rel_planes * (pow(kernel_size, 2) + 1), out_planes // share_planes, kernel_size=1, bias=False), + nn.BatchNorm2d(out_planes // share_planes), nn.ReLU(), + nn.Conv2d(out_planes // share_planes, pow(kernel_size, 2) * out_planes // share_planes, kernel_size=1)) + self.unfold_i = Unfold(kernel_size=1, dilation=dilation, padding=0, stride=stride) + self.unfold_j = Unfold(kernel_size=kernel_size, dilation=dilation, padding=0, stride=stride) + self.pad = nn.ReflectionPad2d(kernel_size // 2) + self.aggregation = Aggregation(kernel_size, stride, (dilation * (kernel_size - 1) + 1) // 2, dilation, pad_mode=1) + + def execute(self, x): + x1, x2, x3 = self.conv1(x), self.conv2(x), self.conv3(x) + if self.sa_type == 0: # pairwise + p = self.conv_p(position(x.shape[2], x.shape[3])) + w = self.softmax(self.conv_w(jt.concat([self.subtraction2(x1, x2), self.subtraction(p).repeat(x.shape[0], 1, 1, 1)], 1))) + else: # patchwise + if self.stride != 1: + x1 = self.unfold_i(x1) + x1 = x1.reshape((x.shape[0], -1, 1, x.shape[2]*x.shape[3])) + x2 = self.unfold_j(self.pad(x2)).reshape((x.shape[0], -1, 1, x1.shape[-1])) + w = self.conv_w(jt.concat([x1, x2], 1)).reshape((x.shape[0], -1, pow(self.kernel_size, 2), x1.shape[-1])) + x = self.aggregation(x3, w) + return x + +class Bottleneck(nn.Module): + def __init__(self, sa_type, in_planes, rel_planes, mid_planes, out_planes, share_planes=8, kernel_size=7, stride=1): + super(Bottleneck, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.sam = SAM(sa_type, in_planes, rel_planes, mid_planes, share_planes, kernel_size, stride) + self.bn2 = nn.BatchNorm2d(mid_planes) + self.conv = nn.Conv2d(mid_planes, out_planes, kernel_size=1) + self.relu = nn.ReLU() + self.stride = stride + + def execute(self, x): + identity = x + out = self.relu(self.bn1(x)) + out = self.relu(self.bn2(self.sam(out))) + out = self.conv(out) + out += identity + return out + +@MODELS.register_module() +class SAN(nn.Module): + def __init__(self, sa_type, layers, kernels, num_classes, block=Bottleneck, loss=None, loss_prepare=False): + super(SAN, self).__init__() + self.loss = build_from_cfg(loss, LOSSES) + self.loss_prepare = loss_prepare + c = 64 + self.conv_in, self.bn_in = conv1x1(3, c), nn.BatchNorm2d(c) + self.conv0, self.bn0 = conv1x1(c, c), nn.BatchNorm2d(c) + self.layer0 = self._make_layer(sa_type, block, c, layers[0], kernels[0]) + + c *= 4 + self.conv1, self.bn1 = conv1x1(c // 4, c), nn.BatchNorm2d(c) + self.layer1 = self._make_layer(sa_type, block, c, layers[1], kernels[1]) + + c *= 2 + self.conv2, self.bn2 = conv1x1(c // 2, c), nn.BatchNorm2d(c) + self.layer2 = self._make_layer(sa_type, block, c, layers[2], kernels[2]) + + c *= 2 + self.conv3, self.bn3 = conv1x1(c // 2, c), nn.BatchNorm2d(c) + self.layer3 = self._make_layer(sa_type, block, c, layers[3], kernels[3]) + + c *= 2 + self.conv4, self.bn4 = conv1x1(c // 2, c), nn.BatchNorm2d(c) + self.layer4 = self._make_layer(sa_type, block, c, layers[4], kernels[4]) + + self.relu = nn.ReLU() + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(c, num_classes) + + def _make_layer(self, sa_type, block, planes, blocks, kernel_size=7, stride=1): + layers = [] + for _ in range(0, blocks): + layers.append(block(sa_type, planes, planes // 16, planes // 4, planes, 8, kernel_size, stride)) + return nn.Sequential(*layers) + + def execute(self, x, targets=None): + targets = jt.array([t['img_label'] for t in targets]) + if self.is_training() and self.loss_prepare: + x = self.loss.prepare(x, targets) + + x = self.relu(self.bn_in(self.conv_in(x))) + x = self.relu(self.bn0(self.layer0(self.conv0(self.pool(x))))) + x = self.relu(self.bn1(self.layer1(self.conv1(self.pool(x))))) + x = self.relu(self.bn2(self.layer2(self.conv2(self.pool(x))))) + x = self.relu(self.bn3(self.layer3(self.conv3(self.pool(x))))) + x = self.relu(self.bn4(self.layer4(self.conv4(self.pool(x))))) + + x = self.avgpool(x) + x = x.view(x.shape[0], -1) + x = self.fc(x) + + if self.is_training(): + return dict(loss=(self.loss(x, targets))) + return x + + +def san(sa_type, layers, kernels, num_classes): + model = SAN(sa_type, layers, kernels, num_classes, block=Bottleneck, loss=SANMixUpLoss(), loss_prepare=True) + return model + + +if __name__ == '__main__': + jt.flags.use_cuda=1 + net = san(sa_type=0, layers=(3, 4, 6, 8, 3), kernels=[3, 7, 7, 7, 7], num_classes=1000) + # print(net) + targets = [dict(img_label=1),dict(img_label=2),dict(img_label=3),dict(img_label=4)] + y = net(jt.randn(4, 3, 224, 224), targets) + print(y) + net.eval() + y = net(jt.randn(4, 3, 224, 224), targets) + print(y.size()) diff --git a/python/jdet/ops/san_aggregations.py b/python/jdet/ops/san_aggregations.py new file mode 100644 index 00000000..d4675314 --- /dev/null +++ b/python/jdet/ops/san_aggregations.py @@ -0,0 +1,466 @@ +import jittor as jt +import os +from jittor.nn import _pair + + +_kernel_loop_head = ''' +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) +#define THREADS_PER_BLOCK 1024 +inline int GET_BLOCKS(const int N) { + return (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; +} +''' + +_aggregation_zeropad_forward_header = _kernel_loop_head + r''' +template +__global__ void aggregation_zeropad_forward_kernel( + const int nthreads, const T* bottom_data, const T* weight_data, T* top_data, + const int input_channels, const int weight_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / input_channels / top_height / top_width; + const int c = (index / top_height / top_width) % input_channels; + const int h = (index / top_width) % top_height; + const int w = index % top_width; + T value = 0; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int h_in = -pad_h + h * stride_h + kh * dilation_h; + const int w_in = -pad_w + w * stride_w + kw * dilation_w; + if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) { + const int offset_bottom = ((n * input_channels + c) * bottom_height + h_in) * bottom_width + w_in; + const int offset_weight = ((n * weight_channels + c % weight_channels) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h * top_width + w; + value += weight_data[offset_weight] * bottom_data[offset_bottom]; + } + } + } + top_data[index] = value; + } +} +''' + +_aggregation_zeropad_input_backward_header = _kernel_loop_head + r''' +template +__global__ void aggregation_zeropad_input_backward_kernel( + const int nthreads, const T* const top_diff, const T* const weight_data, T* bottom_diff, + const int input_channels, const int weight_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / input_channels / bottom_height / bottom_width; + const int c = (index / bottom_height / bottom_width) % input_channels; + const int h = (index / bottom_width) % bottom_height; + const int w = index % bottom_width; + T value = 0; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int h_out_s = h + pad_h - kh * dilation_h; + const int w_out_s = w + pad_w - kw * dilation_w; + if (((h_out_s % stride_h) == 0) && ((w_out_s % stride_w) == 0)) { + const int h_out = h_out_s / stride_h; + const int w_out = w_out_s / stride_w; + if ((h_out >= 0) && (h_out < top_height) && (w_out >= 0) && (w_out < top_width)) { + const int offset_top = ((n * input_channels + c) * top_height + h_out) * top_width + w_out; + const int offset_weight = ((n * weight_channels + c % weight_channels) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h_out * top_width + w_out; + value += weight_data[offset_weight] * top_diff[offset_top]; + } + } + } + } + bottom_diff[index] = value; + } +} +''' + +_aggregation_zeropad_weight_backward_header = _kernel_loop_head + r''' +template +__global__ void aggregation_zeropad_weight_backward_kernel( + const int nthreads, const T* const top_diff, const T* const bottom_data, T* weight_diff, + const int input_channels, const int weight_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / weight_channels / top_height / top_width; + const int c = (index / top_height / top_width) % weight_channels; + const int h = (index / top_width) % top_height; + const int w = index % top_width; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int h_in = -pad_h + h * stride_h + kh * dilation_h; + const int w_in = -pad_w + w * stride_w + kw * dilation_w; + const int offset_weight = ((n * weight_channels + c) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h * top_width + w; + T value = 0; + if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) { + for (int cc = c; cc < input_channels; cc += weight_channels) { + const int offset_bottom = ((n * input_channels + cc) * bottom_height + h_in) * bottom_width + w_in; + const int offset_top = ((n * input_channels + cc) * top_height + h) * top_width + w; + value += bottom_data[offset_bottom] * top_diff[offset_top]; + } + } + weight_diff[offset_weight] = value; + } + } + } +} +''' + +_aggregation_refpad_forward_header = _kernel_loop_head + r''' +template +__global__ void aggregation_refpad_forward_kernel( + const int nthreads, const T* bottom_data, const T* weight_data, T* top_data, + const int input_channels, const int weight_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / input_channels / top_height / top_width; + const int c = (index / top_height / top_width) % input_channels; + const int h = (index / top_width) % top_height; + const int w = index % top_width; + T value = 0; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int h_in = -pad_h + h * stride_h + kh * dilation_h; + int w_in = -pad_w + w * stride_w + kw * dilation_w; + const int offset_weight = ((n * weight_channels + c % weight_channels) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h * top_width + w; + int offset_bottom; + if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) { + offset_bottom = ((n * input_channels + c) * bottom_height + h_in) * bottom_width + w_in; + } + else { + if (h_in < 0) h_in = -h_in; + if (h_in >= bottom_height) h_in = 2 * (bottom_height - 1) - h_in; + if (w_in < 0) w_in = -w_in; + if (w_in >= bottom_width) w_in = 2 * (bottom_width - 1) - w_in; + offset_bottom = ((n * input_channels + c) * bottom_height + h_in) * bottom_width + w_in; + } + value += weight_data[offset_weight] * bottom_data[offset_bottom]; + } + } + top_data[index] = value; + } +} +''' + +_aggregation_refpad_input_backward_header = _kernel_loop_head + r''' +template +__global__ void aggregation_refpad_input_backward_kernel( + const int nthreads, const T* const top_diff, const T* const weight_data, T* bottom_diff, + const int input_channels, const int weight_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / input_channels / (bottom_height + 2 * pad_h) / (bottom_width + 2 * pad_w); + const int c = (index / (bottom_height + 2 * pad_h) / (bottom_width + 2 * pad_w)) % input_channels; + const int h = (index / (bottom_width + 2 * pad_w)) % (bottom_height + 2 * pad_h); + const int w = index % (bottom_width + 2 * pad_w); + T value = 0; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int h_out_s = h - kh * dilation_h; + const int w_out_s = w - kw * dilation_w; + if ((h_out_s % stride_h == 0) && (w_out_s % stride_w == 0)) { + const int h_out = h_out_s / stride_h; + const int w_out = w_out_s / stride_w; + if ((h_out >= 0) && (h_out < top_height) && (w_out >= 0) && (w_out < top_width)) { + const int offset_top = ((n * input_channels + c) * top_height + h_out) * top_width + w_out; + const int offset_weight = ((n * weight_channels + c % weight_channels) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h_out * top_width + w_out; + value += weight_data[offset_weight] * top_diff[offset_top]; + } + } + } + } + bottom_diff[index] = value; + } +} +''' + +_aggregation_refpad_weight_backward_header = _kernel_loop_head + r''' +template +__global__ void aggregation_refpad_weight_backward_kernel( + const int nthreads, const T* const top_diff, const T* const bottom_data, T* weight_diff, + const int input_channels, const int weight_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / weight_channels / top_height / top_width; + const int c = (index / top_height / top_width) % weight_channels; + const int h = (index / top_width) % top_height; + const int w = index % top_width; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int h_in = -pad_h + h * stride_h + kh * dilation_h; + int w_in = -pad_w + w * stride_w + kw * dilation_w; + const int offset_weight = ((n * weight_channels + c) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h * top_width + w; + T value = 0; + for (int cc = c; cc < input_channels; cc += weight_channels) { + const int offset_top = ((n * input_channels + cc) * top_height + h) * top_width + w; + int offset_bottom; + if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) { + offset_bottom = ((n * input_channels + cc) * bottom_height + h_in) * bottom_width + w_in; + } + else { + if (h_in < 0) h_in = -h_in; + if (h_in >= bottom_height) h_in = 2 * (bottom_height - 1) - h_in; + if (w_in < 0) w_in = -w_in; + if (w_in >= bottom_width) w_in = 2 * (bottom_width - 1) - w_in; + offset_bottom = ((n * input_channels + cc) * bottom_height + h_in) * bottom_width + w_in; + } + value += bottom_data[offset_bottom] * top_diff[offset_top]; + } + weight_diff[offset_weight] = value; + } + } + } +} +''' + +def _tuple_numel(shape): + return shape[0] * shape[1] * shape[2] * shape[3] + +class AggregationZeropad(jt.Function): + def execute(self, input, weight, kernel_size, stride, padding, dilation): + kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) + self.kernel_size, self.stride, self.padding, self.dilation = kernel_size, stride, padding, dilation + self.input_, self.weight_ = input, weight + assert len(input.shape) == 4 and jt.flags.use_cuda + batch_size, input_channels, input_height, input_width = input.size() + _, weight_channels, weight_height, weight_width = weight.size() + output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) + output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) + assert output_height * output_width == weight_width + output_shape = (batch_size, input_channels, output_height, output_width) + nthreads = _tuple_numel(output_shape) + aggregation_zeropad_src = f''' + @alias(input,in0); + @alias(weight,in1); + @alias(output,out0); + aggregation_zeropad_forward_kernel<<>>( + {nthreads}, input_p, weight_p, output_p, + {input_channels}, {weight_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + return jt.code(output_shape, input.dtype, [input, weight], cuda_header=_aggregation_zeropad_forward_header, cuda_src=aggregation_zeropad_src) + + def grad(self, grad_output): + kernel_size, stride, padding, dilation = self.kernel_size, self.stride, self.padding, self.dilation + input, weight = self.input_, self.weight_ + assert jt.flags.use_cuda + batch_size, input_channels, input_height, input_width = input.size() + _, weight_channels, weight_height, weight_width = weight.size() + output_height, output_width = grad_output.size()[2:] + nthreads_input = input.numel() + nthreads_weight = weight.numel() // weight.shape[2] + aggregation_zeropad_input_backward_src = f''' + @alias(diff,in0); + @alias(weight,in1); + @alias(output,out0); + aggregation_zeropad_input_backward_kernel<<>>( + {nthreads_input}, diff_p, weight_p, output_p, + {input_channels}, {weight_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + aggregation_zeropad_weight_backward_src = f''' + @alias(diff,in0); + @alias(input,in1); + @alias(output,out0); + aggregation_zeropad_weight_backward_kernel<<>>( + {nthreads_weight}, diff_p, input_p, output_p, + {input_channels}, {weight_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + grad_input = jt.code(input.size(), grad_output.dtype, [grad_output, weight], cuda_header=_aggregation_zeropad_input_backward_header, cuda_src=aggregation_zeropad_input_backward_src) + grad_weight = jt.code(weight.size(), grad_output.dtype, [grad_output, input], cuda_header=_aggregation_zeropad_weight_backward_header, cuda_src=aggregation_zeropad_weight_backward_src) + return grad_input, grad_weight, None, None, None, None + +class AggregationRefpad(jt.Function): + def execute(self, input, weight, kernel_size, stride, padding, dilation): + kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) + self.kernel_size, self.stride, self.padding, self.dilation = kernel_size, stride, padding, dilation + self.input_, self.weight_ = input, weight + assert len(input.shape) == 4 and jt.flags.use_cuda + batch_size, input_channels, input_height, input_width = input.size() + _, weight_channels, weight_height, weight_width = weight.size() + output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) + output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) + assert output_height * output_width == weight_width + output_shape = (batch_size, input_channels, output_height, output_width) + nthreads = _tuple_numel(output_shape) + aggregation_refpad_src = f''' + @alias(input,in0); + @alias(weight,in1); + @alias(output,out0); + aggregation_refpad_forward_kernel<<>>( + {nthreads}, input_p, weight_p, output_p, + {input_channels}, {weight_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + return jt.code(output_shape, input.dtype, [input, weight], cuda_header=_aggregation_refpad_forward_header, cuda_src=aggregation_refpad_src) + + def grad(self, grad_output): + kernel_size, stride, padding, dilation = self.kernel_size, self.stride, self.padding, self.dilation + input, weight = self.input_, self.weight_ + batch_size, input_channels, input_height, input_width = input.size() + _, weight_channels, weight_height, weight_width = weight.size() + output_height, output_width = grad_output.shape[2:] + grad_input_shape = (batch_size, input_channels, input_height + 2 * padding[0], input_width + 2 * padding[1]) + nthreads_input = _tuple_numel(grad_input_shape) + nthreads_weight = weight.numel() // weight.shape[2] + aggregation_refpad_input_backward_src = f''' + @alias(diff,in0); + @alias(weight,in1); + @alias(output,out0); + aggregation_refpad_input_backward_kernel<<>>( + {nthreads_input}, diff_p, weight_p, output_p, + {input_channels}, {weight_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + aggregation_refpad_weight_backward_src = f''' + @alias(diff,in0); + @alias(input,in1); + @alias(output,out0); + aggregation_refpad_weight_backward_kernel<<>>( + {nthreads_weight}, diff_p, input_p, output_p, + {input_channels}, {weight_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + grad_input = jt.code(grad_input_shape, grad_output.dtype, [grad_output, weight], cuda_header=_aggregation_refpad_input_backward_header, cuda_src=aggregation_refpad_input_backward_src) + grad_weight = jt.code(weight.size(), grad_output.dtype, [grad_output, input], cuda_header=_aggregation_refpad_weight_backward_header, cuda_src=aggregation_refpad_weight_backward_src) + grad_input[:, :, padding[0] + 1:2 * padding[0] + 1, :] += jt.flip(grad_input[:, :, :padding[0], :], dim=2) + grad_input[:, :, input_height - 1:input_height + padding[0] - 1, :] += jt.flip(grad_input[:, :, input_height + padding[0]:, :], dim=2) + grad_input[:, :, :, padding[1] + 1:2 * padding[1] + 1] += jt.flip(grad_input[:, :, :, :padding[1]], dim=3) + grad_input[:, :, :, input_width - 1:input_width + padding[1] - 1] += jt.flip(grad_input[:, :, :, input_width + padding[1]:], dim=3) + grad_input = grad_input[:, :, padding[0]:padding[0]+input_height, padding[1]:padding[1]+input_width] + return grad_input, grad_weight, None, None, None, None + + +def aggregation_zeropad(input, weight, kernel_size=3, stride=1, padding=0, dilation=1): + assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) + if jt.flags.use_cuda == 1: + out = AggregationZeropad.apply(input, weight, kernel_size, stride, padding, dilation) + else: + raise NotImplementedError + return out + +def aggregation_refpad(input, weight, kernel_size=3, stride=1, padding=0, dilation=1): + assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) + if jt.flags.use_cuda == 1: + out = AggregationRefpad.apply(input, weight, kernel_size, stride, padding, dilation) + else: + raise NotImplementedError + return out + +def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): + assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1] + if jt.flags.use_cuda == 1: + if pad_mode == 0: + out = aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation) + elif pad_mode == 1: + out = aggregation_refpad(input, weight, kernel_size, stride, padding, dilation) + else: + raise NotImplementedError + return out + +def test_aggregation_zeropad(): + kernel_size, stride, dilation = 5, 4, 2 + padding = (dilation * (kernel_size - 1) + 1) // 2 + n, c_x, c_w, in_height, in_width = 2, 8, 4, 9, 9 + out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) + out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) + x = jt.randn(n, c_x, in_height, in_width, requires_grad=True, dtype=jt.float64) + w = jt.randn(n, c_w, pow(kernel_size, 2), out_height * out_width, requires_grad=True, dtype=jt.float64) + + y1 = aggregation_zeropad(x, w, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) + unfold_j = jt.nn.unfold(x, kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) + x2 = unfold_j.reshape((n, c_x // c_w, c_w, pow(kernel_size, 2), out_height * out_width)) + y2 = (jt.unsqueeze(w, 1) * x2).sum(dim=-2).reshape((n, c_x, out_height, out_width)) + + assert (y1 - y2).abs().max() < 1e-9 + gx1 = jt.grad(y1.mean(), x)[0] + gx2 = jt.grad(y2.mean(), x)[0] + assert (gx1 - gx2).abs().max() < 1e-9 + + gw1 = jt.grad(y1.mean(), w)[0] + gw2 = jt.grad(y2.mean(), w)[0] + assert (gw1 - gw2).abs().max() < 1e-9 + + print('aggregation_zeropad passed') + +def test_aggregation_refpad(): + kernel_size, stride, dilation = 5, 4, 2 + padding = (dilation * (kernel_size - 1) + 1) // 2 + n, c_x, c_w, in_height, in_width = 2, 8, 4, 5, 5 + out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) + out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) + x = jt.randn(n, c_x, in_height, in_width, requires_grad=True, dtype=jt.float64) + w = jt.randn(n, c_w, pow(kernel_size, 2), out_height * out_width, requires_grad=True, dtype=jt.float64) + + y1 = aggregation_refpad(x, w, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) + pad = jt.nn.ReflectionPad2d(padding) + unfold_j = jt.nn.unfold(pad(x), kernel_size=kernel_size, dilation=dilation, padding=0, stride=stride) + x2 = unfold_j.reshape((n, c_x // c_w, c_w, pow(kernel_size, 2), out_height * out_width)) + y2 = (jt.unsqueeze(w, 1) * x2).sum(-2).reshape((n, c_x, out_height, out_width)) + assert (y1 - y2).abs().max() < 1e-9 + + gx1 = jt.grad(y1.mean(), x)[0] + gx2 = jt.grad(y2.mean(), x)[0] + assert (gx1 - gx2).abs().max() < 1e-9 + + gw1 = jt.grad(y1.mean(), w)[0] + gw2 = jt.grad(y2.mean(), w)[0] + assert (gw1 - gw2).abs().max() < 1e-9 + + print('aggregation_refpad passed') + +if __name__ == '__main__': + os.environ["CUDA_VISIBLE_DEVICES"] = '0' + jt.flags.use_cuda = 1 + print("start...") + test_aggregation_zeropad() + test_aggregation_refpad() + print("done.") \ No newline at end of file diff --git a/python/jdet/ops/san_subtractions.py b/python/jdet/ops/san_subtractions.py new file mode 100644 index 00000000..811e3441 --- /dev/null +++ b/python/jdet/ops/san_subtractions.py @@ -0,0 +1,780 @@ +import jittor as jt +import os + +from jittor.nn import _pair + +# cuda codes +# warning: nthreads > MAX_INT? + +_kernel_loop_head = ''' +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) +#define THREADS_PER_BLOCK 1024 +inline int GET_BLOCKS(const int N) { + return (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; +} +''' + +_subtraction_zeropad_forward_header = _kernel_loop_head + r''' +template +__global__ void subtraction_zeropad_forward_kernel( + const int nthreads, const T* bottom_data, T* top_data, const int input_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / input_channels / top_height / top_width; + const int c = (index / top_height / top_width) % input_channels; + const int h = (index / top_width) % top_height; + const int w = index % top_width; + const int h_in_center = -pad_h + h * stride_h + (kernel_h - 1) / 2 * dilation_h; + const int w_in_center = -pad_w + w * stride_w + (kernel_w - 1) / 2 * dilation_w; + const int offset_center = ((n * input_channels + c) * bottom_height + h_in_center) * bottom_width + w_in_center; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int h_in = -pad_h + h * stride_h + kh * dilation_h; + const int w_in = -pad_w + w * stride_w + kw * dilation_w; + const int offset_top = ((n * input_channels + c) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h * top_width + w; + if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) { + const int offset_bottom = ((n * input_channels + c) * bottom_height + h_in) * bottom_width + w_in; + top_data[offset_top] = bottom_data[offset_center] - bottom_data[offset_bottom]; + } + else + top_data[offset_top] = bottom_data[offset_center]; + } + } + } +} +''' + +_subtraction_zeropad_backward_header = _kernel_loop_head + r''' +template +__global__ void subtraction_zeropad_input_backward_kernel( + const int nthreads, const T* const top_diff, T* bottom_diff, const int input_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / input_channels / bottom_height / bottom_width; + const int c = (index / bottom_height / bottom_width) % input_channels; + const int h = (index / bottom_width) % bottom_height; + const int w = index % bottom_width; + T value = 0; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int h_out_s = h + pad_h - kh * dilation_h; + const int w_out_s = w + pad_w - kw * dilation_w; + if (((h_out_s % stride_h) == 0) && ((w_out_s % stride_w) == 0)) { + const int h_out = h_out_s / stride_h; + const int w_out = w_out_s / stride_w; + if ((h_out >= 0) && (h_out < top_height) && (w_out >= 0) && (w_out < top_width)) { + const int offset_top = ((n * input_channels + c) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h_out * top_width + w_out; + value += -top_diff[offset_top]; + } + } + } + } + if (((h % stride_h) == 0) && ((w % stride_w) == 0)) { + const int h_out = h / stride_h; + const int w_out = w / stride_w; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int offset_top = ((n * input_channels + c) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h_out * top_width + w_out; + value += top_diff[offset_top]; + } + } + } + bottom_diff[index] = value; + } +} +''' + +_subtraction_refpad_forward_header = _kernel_loop_head + r''' +template +__global__ void subtraction_refpad_forward_kernel( + const int nthreads, const T* bottom_data, T* top_data, const int input_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / input_channels / top_height / top_width; + const int c = (index / top_height / top_width) % input_channels; + const int h = (index / top_width) % top_height; + const int w = index % top_width; + const int h_in_center = -pad_h + h * stride_h + (kernel_h - 1) / 2 * dilation_h; + const int w_in_center = -pad_w + w * stride_w + (kernel_w - 1) / 2 * dilation_w; + const int offset_center = ((n * input_channels + c) * bottom_height + h_in_center) * bottom_width + w_in_center; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int h_in = -pad_h + h * stride_h + kh * dilation_h; + int w_in = -pad_w + w * stride_w + kw * dilation_w; + const int offset_top = ((n * input_channels + c) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h * top_width + w; + int offset_bottom; + if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) { + offset_bottom = ((n * input_channels + c) * bottom_height + h_in) * bottom_width + w_in; + } + else { + if (h_in < 0) h_in = -h_in; + if (h_in >= bottom_height) h_in = 2 * (bottom_height - 1) - h_in; + if (w_in < 0) w_in = -w_in; + if (w_in >= bottom_width) w_in = 2 * (bottom_width - 1) - w_in; + offset_bottom = ((n * input_channels + c) * bottom_height + h_in) * bottom_width + w_in; + } + top_data[offset_top] = bottom_data[offset_center] - bottom_data[offset_bottom]; + } + } + } +} +''' + +_subtraction_refpad_backward_header = _kernel_loop_head + r''' +template +__global__ void subtraction_refpad_input_backward_kernel( + const int nthreads, const T* const top_diff, T* bottom_diff, const int input_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / input_channels / (bottom_height + 2 * pad_h) / (bottom_width + 2 * pad_w); + const int c = (index / (bottom_height + 2 * pad_h) / (bottom_width + 2 * pad_w)) % input_channels; + const int h = (index / (bottom_width + 2 * pad_w)) % (bottom_height + 2 * pad_h); + const int w = index % (bottom_width + 2 * pad_w); + T value = 0; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int h_out_s = h - kh * dilation_h; + const int w_out_s = w - kw * dilation_w; + if (((h_out_s % stride_h) == 0) && ((w_out_s % stride_w) == 0)) { + const int h_out = h_out_s / stride_h; + const int w_out = w_out_s / stride_w; + if ((h_out >= 0) && (h_out < top_height) && (w_out >= 0) && (w_out < top_width)) { + const int offset_top = ((n * input_channels + c) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h_out * top_width + w_out; + value += -top_diff[offset_top]; + } + } + } + } + const int hh = h - pad_h; + const int ww = w - pad_w; + if ((hh >= 0) && (hh < bottom_height) && (ww >= 0) && (ww < bottom_width)) { + if (((hh % stride_h) == 0) && ((ww % stride_w) == 0)) { + const int h_out = hh / stride_h; + const int w_out = ww / stride_w; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int offset_top = ((n * input_channels + c) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h_out * top_width + w_out; + value += top_diff[offset_top]; + } + } + } + } + bottom_diff[index] = value; + } +} +''' + +_subtraction2_zeropad_forward_header = _kernel_loop_head + r''' +template +__global__ void subtraction2_zeropad_forward_kernel( + int nthreads, const T* bottom1_data, const T* bottom2_data, T* top_data, const int input_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / input_channels / top_height / top_width; + const int c = (index / top_height / top_width) % input_channels; + const int h = (index / top_width) % top_height; + const int w = index % top_width; + const int h_in_center = -pad_h + h * stride_h + (kernel_h - 1) / 2 * dilation_h; + const int w_in_center = -pad_w + w * stride_w + (kernel_w - 1) / 2 * dilation_w; + const int offset_center = ((n * input_channels + c) * bottom_height + h_in_center) * bottom_width + w_in_center; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int h_in = -pad_h + h * stride_h + kh * dilation_h; + const int w_in = -pad_w + w * stride_w + kw * dilation_w; + const int offset_top = ((n * input_channels + c) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h * top_width + w; + if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) { + const int offset_bottom = ((n * input_channels + c) * bottom_height + h_in) * bottom_width + w_in; + top_data[offset_top] = bottom1_data[offset_center] - bottom2_data[offset_bottom]; + } + else + top_data[offset_top] = bottom1_data[offset_center]; + } + } + } +} +''' + +_subtraction2_zeropad_input1_backward_header = _kernel_loop_head + r''' +template +__global__ void subtraction2_zeropad_input1_backward_kernel( + const int nthreads, const T* const top_diff, T* bottom_diff, const int input_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / input_channels / bottom_height / bottom_width; + const int c = (index / bottom_height / bottom_width) % input_channels; + const int h = (index / bottom_width) % bottom_height; + const int w = index % bottom_width; + T value = 0; + if (((h % stride_h) == 0) && ((w % stride_w) == 0)) { + const int h_out = h / stride_h; + const int w_out = w / stride_w; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int offset_top = ((n * input_channels + c) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h_out * top_width + w_out; + value += top_diff[offset_top]; + } + } + } + bottom_diff[index] = value; + } +} +''' + +_subtraction2_zeropad_input2_backward_header = _kernel_loop_head + r''' +template +__global__ void subtraction2_zeropad_input2_backward_kernel( + const int nthreads, const T* const top_diff, T* bottom_diff, const int input_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / input_channels / bottom_height / bottom_width; + const int c = (index / bottom_height / bottom_width) % input_channels; + const int h = (index / bottom_width) % bottom_height; + const int w = index % bottom_width; + T value = 0; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int h_out_s = h + pad_h - kh * dilation_h; + const int w_out_s = w + pad_w - kw * dilation_w; + if (((h_out_s % stride_h) == 0) && ((w_out_s % stride_w) == 0)) { + const int h_out = h_out_s / stride_h; + const int w_out = w_out_s / stride_w; + if ((h_out >= 0) && (h_out < top_height) && (w_out >= 0) && (w_out < top_width)) { + const int offset_top = ((n * input_channels + c) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h_out * top_width + w_out; + value += -top_diff[offset_top]; + } + } + } + } + bottom_diff[index] = value; + } +} +''' + +_subtraction2_refpad_forward_header = _kernel_loop_head + r''' +template +__global__ void subtraction2_refpad_forward_kernel( + int nthreads, const T* bottom1_data, const T* bottom2_data, T* top_data, const int input_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / input_channels / top_height / top_width; + const int c = (index / top_height / top_width) % input_channels; + const int h = (index / top_width) % top_height; + const int w = index % top_width; + const int h_in_center = -pad_h + h * stride_h + (kernel_h - 1) / 2 * dilation_h; + const int w_in_center = -pad_w + w * stride_w + (kernel_w - 1) / 2 * dilation_w; + const int offset_center = ((n * input_channels + c) * bottom_height + h_in_center) * bottom_width + w_in_center; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int h_in = -pad_h + h * stride_h + kh * dilation_h; + int w_in = -pad_w + w * stride_w + kw * dilation_w; + const int offset_top = ((n * input_channels + c) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h * top_width + w; + int offset_bottom; + if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) { + offset_bottom = ((n * input_channels + c) * bottom_height + h_in) * bottom_width + w_in; + } + else { + if (h_in < 0) h_in = -h_in; + if (h_in >= bottom_height) h_in = 2 * (bottom_height - 1) - h_in; + if (w_in < 0) w_in = -w_in; + if (w_in >= bottom_width) w_in = 2 * (bottom_width - 1) - w_in; + offset_bottom = ((n * input_channels + c) * bottom_height + h_in) * bottom_width + w_in; + } + top_data[offset_top] = bottom1_data[offset_center] - bottom2_data[offset_bottom]; + } + } + } +} +''' + +_subtraction2_refpad_input1_backward_header = _kernel_loop_head + r''' +template +__global__ void subtraction2_refpad_input1_backward_kernel( + const int nthreads, const T* const top_diff, T* bottom_diff, const int input_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / input_channels / bottom_height / bottom_width; + const int c = (index / bottom_height / bottom_width) % input_channels; + const int h = (index / bottom_width) % bottom_height; + const int w = index % bottom_width; + T value = 0; + if (((h % stride_h) == 0) && ((w % stride_w) == 0)) { + const int h_out = h / stride_h; + const int w_out = w / stride_w; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int offset_top = ((n * input_channels + c) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h_out * top_width + w_out; + value += top_diff[offset_top]; + } + } + } + bottom_diff[index] = value; + } +} +''' + +_subtraction2_refpad_input2_backward_header = _kernel_loop_head + r''' +template +__global__ void subtraction2_refpad_input2_backward_kernel( + const int nthreads, const T* const top_diff, T* bottom_diff, const int input_channels, + const int bottom_height, const int bottom_width, + const int top_height, const int top_width, + const int pad_h, const int stride_h, const int kernel_h, const int dilation_h, + const int pad_w, const int stride_w, const int kernel_w, const int dilation_w + ) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / input_channels / (bottom_height + 2 * pad_h) / (bottom_width + 2 * pad_w); + const int c = (index / (bottom_height + 2 * pad_h) / (bottom_width + 2 * pad_w)) % input_channels; + const int h = (index / (bottom_width + 2 * pad_w)) % (bottom_height + 2 * pad_h); + const int w = index % (bottom_width + 2 * pad_w); + T value = 0; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int h_out_s = h - kh * dilation_h; + const int w_out_s = w - kw * dilation_w; + if (((h_out_s % stride_h) == 0) && ((w_out_s % stride_w) == 0)) { + const int h_out = h_out_s / stride_h; + const int w_out = w_out_s / stride_w; + if ((h_out >= 0) && (h_out < top_height) && (w_out >= 0) && (w_out < top_width)) { + const int offset_top = ((n * input_channels + c) * kernel_h * kernel_w + (kh * kernel_w + kw)) * top_height * top_width + h_out * top_width + w_out; + value += -top_diff[offset_top]; + } + } + } + } + bottom_diff[index] = value; + } +} +''' + + +# classes +# TODO: inplace codes like f'''...''' into r'''...''' +# TODO: check backward. +def _tuple_numel(shape): + return shape[0] * shape[1] * shape[2] * shape[3] + +class SubtractionZeropad(jt.Function): + def execute(self, input, kernel_size, stride, padding, dilation): + kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) + self.kernel_size, self.stride, self.padding, self.dilation = kernel_size, stride, padding, dilation + self.input = input + assert len(input.shape) == 4 and jt.flags.use_cuda + batch_size, input_channels, input_height, input_width = input.size() + output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) + output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) + output_shape = (batch_size, input_channels, kernel_size[0] * kernel_size[1], output_height * output_width) + nthreads = batch_size * input_channels * output_height * output_width + subtraction_zeropad_src = f''' + @alias(input,in0); + @alias(output,out0); + subtraction_zeropad_forward_kernel<<>>( + {nthreads}, input_p, output_p, {input_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + return jt.code(output_shape, input.dtype, [input], cuda_header=_subtraction_zeropad_forward_header, cuda_src=subtraction_zeropad_src) + + def grad(self, grad_output): + kernel_size, stride, padding, dilation = self.kernel_size, self.stride, self.padding, self.dilation + input = self.input + assert jt.flags.use_cuda + batch_size, input_channels, input_height, input_width = input.size() + output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) + output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) + nthreads = input.numel() + subtraction_zeropad_backward_src = f''' + @alias(input,in0); + @alias(output,out0); + subtraction_zeropad_input_backward_kernel<<>>( + {nthreads}, input_p, output_p, {input_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + grad_input = jt.code(input.size(), grad_output.dtype, [grad_output], cuda_header=_subtraction_zeropad_backward_header, cuda_src=subtraction_zeropad_backward_src) + return grad_input, None, None, None, None + +class SubtractionRefpad(jt.Function): + def execute(self, input, kernel_size, stride, padding, dilation): + kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) + self.kernel_size, self.stride, self.padding, self.dilation = kernel_size, stride, padding, dilation + self.input = input + assert len(input.shape) == 4 and jt.flags.use_cuda + batch_size, input_channels, input_height, input_width = input.size() + output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) + output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) + output_shape = (batch_size, input_channels, kernel_size[0] * kernel_size[1], output_height * output_width) + nthreads = batch_size * input_channels * output_height * output_width + subtraction_refpad_src = f''' + @alias(input,in0); + @alias(output,out0); + subtraction_refpad_forward_kernel<<>>( + {nthreads}, input_p, output_p, {input_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + return jt.code(output_shape, input.dtype, [input], cuda_header=_subtraction_refpad_forward_header, cuda_src=subtraction_refpad_src) + + def grad(self, grad_output): + kernel_size, stride, padding, dilation = self.kernel_size, self.stride, self.padding, self.dilation + input = self.input + assert jt.flags.use_cuda + batch_size, input_channels, input_height, input_width = input.size() + output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) + output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) + grad_shape = (batch_size, input_channels, input_height + 2 * padding[0], input_width + 2 * padding[1]) + nthreads = _tuple_numel(grad_shape) + subtraction_refpad_backward_src = f''' + @alias(input,in0); + @alias(output,out0); + subtraction_refpad_input_backward_kernel<<>>( + {nthreads}, input_p, output_p, {input_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + grad_input = jt.code(grad_shape, grad_output.dtype, [grad_output], cuda_header=_subtraction_refpad_backward_header, cuda_src=subtraction_refpad_backward_src) + grad_input[:, :, padding[0] + 1:2 * padding[0] + 1, :] += jt.flip(grad_input[:, :, :padding[0], :], dim=2) + grad_input[:, :, input_height - 1:input_height + padding[0] - 1, :] += jt.flip(grad_input[:, :, input_height + padding[0]:, :], dim=2) + grad_input[:, :, :, padding[1] + 1:2 * padding[1] + 1] += jt.flip(grad_input[:, :, :, :padding[1]], dim=3) + grad_input[:, :, :, input_width - 1:input_width + padding[1] - 1] += jt.flip(grad_input[:, :, :, input_width + padding[1]:], dim=3) + grad_input = grad_input[:, :, padding[0]:padding[0] + input_height, padding[1]:padding[1] + input_width] + return grad_input, None, None, None, None + + + +class Subtraction2Zeropad(jt.Function): + def execute(self, input1, input2, kernel_size, stride, padding, dilation): + kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) + self.kernel_size, self.stride, self.padding, self.dilation = kernel_size, stride, padding, dilation + self.input1, self.input2 = input1, input2 + assert len(input1.shape) == 4 and jt.flags.use_cuda + assert input1.size() == input2.size() + batch_size, input_channels, input_height, input_width = input1.size() + output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) + output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) + output_shape = (batch_size, input_channels, kernel_size[0] * kernel_size[1], output_height * output_width) + nthreads = batch_size * input_channels * output_height * output_width + subtraction2_zeropad_src = f''' + @alias(input1,in0); + @alias(input2,in1); + @alias(output,out0); + subtraction2_zeropad_forward_kernel<<>>( + {nthreads}, input1_p, input2_p, output_p, {input_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + return jt.code(output_shape, input1.dtype, [input1, input2], cuda_header=_subtraction2_zeropad_forward_header, cuda_src=subtraction2_zeropad_src) + + def grad(self, grad_output): + kernel_size, stride, padding, dilation = self.kernel_size, self.stride, self.padding, self.dilation + input1, input2 = self.input1, self.input2 + assert jt.flags.use_cuda + batch_size, input_channels, input_height, input_width = input1.size() + output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) + output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) + nthreads1, nthreads2 = input1.numel(), input2.numel() + subtraction2_zeropad_input1_backward_src = f''' + @alias(input,in0); + @alias(output,out0); + subtraction2_zeropad_input1_backward_kernel<<>>( + {nthreads1}, input_p, output_p, {input_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + subtraction2_zeropad_input2_backward_src = f''' + @alias(input,in0); + @alias(output,out0); + subtraction2_zeropad_input2_backward_kernel<<>>( + {nthreads2}, input_p, output_p, {input_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + grad_input1 = jt.code(input1.size(), grad_output.dtype, [grad_output], cuda_header=_subtraction2_zeropad_input1_backward_header, cuda_src=subtraction2_zeropad_input1_backward_src) + grad_input2 = jt.code(input2.size(), grad_output.dtype, [grad_output], cuda_header=_subtraction2_zeropad_input2_backward_header, cuda_src=subtraction2_zeropad_input2_backward_src) + return grad_input1, grad_input2, None, None, None, None + +class Subtraction2Refpad(jt.Function): + def execute(self, input1, input2, kernel_size, stride, padding, dilation): + kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) + self.kernel_size, self.stride, self.padding, self.dilation = kernel_size, stride, padding, dilation + self.input1, self.input2 = input1, input2 + assert len(input1.shape) == 4 and jt.flags.use_cuda + assert input1.size() == input2.size() + batch_size, input_channels, input_height, input_width = input1.size() + output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) + output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) + output_shape = (batch_size, input_channels, kernel_size[0] * kernel_size[1], output_height * output_width) + nthreads = batch_size * input_channels * output_height * output_width + subtraction2_refpad_src = f''' + @alias(input1,in0); + @alias(input2,in1); + @alias(output,out0); + subtraction2_refpad_forward_kernel<<>>( + {nthreads}, input1_p, input2_p, output_p, {input_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + return jt.code(output_shape, input1.dtype, [input1, input2], cuda_header=_subtraction2_refpad_forward_header, cuda_src=subtraction2_refpad_src) + + def grad(self, grad_output): + kernel_size, stride, padding, dilation = self.kernel_size, self.stride, self.padding, self.dilation + input1, input2 = self.input1, self.input2 + assert jt.flags.use_cuda + batch_size, input_channels, input_height, input_width = input1.size() + output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) + output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) + grad_shape2 = (batch_size, input_channels, input_height + 2 * padding[0], input_width + 2 * padding[1]) + nthreads2 = _tuple_numel(grad_shape2) + nthreads1 = input1.numel() + subtraction2_refpad_input1_backward_src = f''' + @alias(input,in0); + @alias(output,out0); + subtraction2_refpad_input1_backward_kernel<<>>( + {nthreads1}, input_p, output_p, {input_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + subtraction2_refpad_input2_backward_src = f''' + @alias(input,in0); + @alias(output,out0); + subtraction2_refpad_input2_backward_kernel<<>>( + {nthreads2}, input_p, output_p, {input_channels}, + {input_height}, {input_width}, + {output_height}, {output_width}, + {padding[0]}, {stride[0]}, {kernel_size[0]}, {dilation[0]}, + {padding[1]}, {stride[1]}, {kernel_size[1]}, {dilation[1]} + ); + ''' + grad_input1 = jt.code(input1.size(), grad_output.dtype, [grad_output], cuda_header=_subtraction2_refpad_input1_backward_header, cuda_src=subtraction2_refpad_input1_backward_src) + grad_input2 = jt.code(grad_shape2, grad_output.dtype, [grad_output], cuda_header=_subtraction2_refpad_input2_backward_header, cuda_src=subtraction2_refpad_input2_backward_src) + grad_input2[:, :, padding[0] + 1:2 * padding[0] + 1, :] += jt.flip(grad_input2[:, :, :padding[0], :], dim=2) + grad_input2[:, :, input_height - 1:input_height + padding[0] - 1, :] += jt.flip(grad_input2[:, :, input_height + padding[0]:, :], dim=2) + grad_input2[:, :, :, padding[1] + 1:2 * padding[1] + 1] += jt.flip(grad_input2[:, :, :, :padding[1]], dim=3) + grad_input2[:, :, :, input_width - 1:input_width + padding[1] - 1] += jt.flip(grad_input2[:, :, :, input_width + padding[1]:], dim=3) + grad_input2 = grad_input2[:, :, padding[0]:padding[0] + input_height, padding[1]:padding[1] + input_width] + return grad_input1, grad_input2, None, None, None, None + + +# functions +def subtraction_zeropad(input, kernel_size=3, stride=1, padding=0, dilation=1): + assert len(input.size()) == 4 + if jt.flags.use_cuda == 1: + out = SubtractionZeropad.apply(input, kernel_size, stride, padding, dilation) + else: + raise NotImplementedError + return out + +def subtraction_refpad(input, kernel_size=3, stride=1, padding=0, dilation=1): + assert len(input.size()) == 4 + if jt.flags.use_cuda == 1: + out = SubtractionRefpad.apply(input, kernel_size, stride, padding, dilation) + else: + raise NotImplementedError + return out + +def subtraction2_zeropad(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1): + assert len(input1.size()) == 4 + if jt.flags.use_cuda == 1: + out = Subtraction2Zeropad.apply(input1, input2, kernel_size, stride, padding, dilation) + else: + raise NotImplementedError + return out + +def subtraction2_refpad(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1): + assert len(input1.size()) == 4 + if jt.flags.use_cuda == 1: + out = Subtraction2Refpad.apply(input1, input2, kernel_size, stride, padding, dilation) + else: + raise NotImplementedError + return out + +def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): + assert len(input.size()) == 4 and pad_mode in [0, 1] + if jt.flags.use_cuda == 1: + if pad_mode == 0: + out = subtraction_zeropad(input, kernel_size, stride, padding, dilation) + elif pad_mode == 1: + out = subtraction_refpad(input, kernel_size, stride, padding, dilation) + else: + raise NotImplementedError + return out + + +def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): + assert len(input1.size()) == 4 and len(input2.size()) == 4 and pad_mode in [0, 1] + if jt.flags.use_cuda == 1: + if pad_mode == 0: + out = subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation) + elif pad_mode == 1: + out = subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation) + else: + raise NotImplementedError + return out + +# unit tests +def test_subtraction_zeropad(): + kernel_size, stride, dilation = 5, 4, 2 + padding = (dilation * (kernel_size - 1) + 1) // 2 + n, c, in_height, in_width = 2, 8, 5, 5 + out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) + out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) + x = jt.randn(n, c, in_height, in_width, requires_grad=True, dtype=jt.float64) + + y1 = subtraction_zeropad(x, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) + unfold_i = jt.nn.unfold(x, kernel_size=1, dilation=dilation, padding=0, stride=stride) + unfold_j = jt.nn.unfold(x, kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) + y2 = unfold_i.reshape((n, c, 1, out_height * out_width)) - unfold_j.reshape((n, c, pow(kernel_size, 2), out_height * out_width)) + assert (y1 - y2).abs().max() < 1e-9 + + gx1 = jt.grad(y1.mean(), x)[0] + gx2 = jt.grad(y2.mean(), x)[0] + assert (gx1 - gx2).abs().max() < 1e-9 + + print('subtraction_zeropad passed') + +def test_subtraction_refpad(): + kernel_size, stride, dilation = 5, 4, 2 + padding = (dilation * (kernel_size - 1) + 1) // 2 + n, c, in_height, in_width = 2, 8, 5, 5 + out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) + out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) + x = jt.randn(n, c, in_height, in_width, requires_grad=True, dtype=jt.float64) + + y1 = subtraction_refpad(x, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) + pad = jt.nn.ReflectionPad2d(padding) + unfold_i = jt.nn.unfold(x, kernel_size=1, dilation=dilation, padding=0, stride=stride) + unfold_j = jt.nn.unfold(pad(x), kernel_size=kernel_size, dilation=dilation, padding=0, stride=stride) + y2 = unfold_i.reshape((n, c, 1, out_height * out_width)) - unfold_j.reshape((n, c, pow(kernel_size, 2), out_height * out_width)) + assert (y1 - y2).abs().max() < 1e-9 + + gx1 = jt.grad(y1.mean(), x)[0] + gx2 = jt.grad(y2.mean(), x)[0] + assert (gx1 - gx2).abs().max() < 1e-9 + + print('subtraction_refpad passed') + +def test_subtraction2_zeropad(): + kernel_size, stride, dilation = 5, 4, 2 + padding = (dilation * (kernel_size - 1) + 1) // 2 + n, c, in_height, in_width = 2, 8, 9, 9 + out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) + out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) + x1 = jt.randn(n, c, in_height, in_width, requires_grad=True, dtype=jt.float64) + x2 = jt.randn(n, c, in_height, in_width, requires_grad=True, dtype=jt.float64) + + y1 = subtraction2_zeropad(x1, x2, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) + unfold_i = jt.nn.unfold(x1, kernel_size=1, dilation=dilation, padding=0, stride=stride) + unfold_j = jt.nn.unfold(x2, kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) + y2 = unfold_i.reshape((n, c, 1, out_height * out_width)) - unfold_j.reshape((n, c, pow(kernel_size, 2), out_height * out_width)) + assert (y1 - y2).abs().max() < 1e-9 + + gx11 = jt.grad(y1.mean(), x1)[0] + gx12 = jt.grad(y1.mean(), x2)[0] + gx21 = jt.grad(y2.mean(), x1)[0] + gx22 = jt.grad(y2.mean(), x2)[0] + assert (gx11 - gx21).abs().max() < 1e-9 + assert (gx12 - gx22).abs().max() < 1e-9 + + print('subtraction2_zeropad passed') + +def test_subtraction2_refpad(): + kernel_size, stride, dilation = 5, 4, 2 # 3, 1, 1 + padding = (dilation * (kernel_size - 1) + 1) // 2 + n, c, in_height, in_width = 2, 8, 9, 9 + out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) + out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) + x1 = jt.randn(n, c, in_height, in_width, requires_grad=True, dtype=jt.float64) + x2 = jt.randn(n, c, in_height, in_width, requires_grad=True, dtype=jt.float64) + + y1 = subtraction2_refpad(x1, x2, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) + pad = jt.nn.ReflectionPad2d(padding) + unfold_i = jt.nn.unfold(x1, kernel_size=1, dilation=dilation, padding=0, stride=stride) + unfold_j = jt.nn.unfold(pad(x2), kernel_size=kernel_size, dilation=dilation, padding=0, stride=stride) + y2 = unfold_i.reshape((n, c, 1, out_height * out_width)) - unfold_j.reshape((n, c, pow(kernel_size, 2), out_height * out_width)) + assert (y1 - y2).abs().max() < 1e-9 + + gx11 = jt.grad(y1.mean(), x1)[0] + gx12 = jt.grad(y1.mean(), x2)[0] + gx21 = jt.grad(y2.mean(), x1)[0] + gx22 = jt.grad(y2.mean(), x2)[0] + assert (gx11 - gx21).abs().max() < 1e-9 + assert (gx12 - gx22).abs().max() < 1e-9 + + print('subtraction2_refpad passed') + + +if __name__ == '__main__': + os.environ["CUDA_VISIBLE_DEVICES"] = '0' + jt.flags.use_cuda = 1 + print("start...") + test_subtraction_zeropad() + test_subtraction_refpad() + test_subtraction2_zeropad() + test_subtraction2_refpad() + print("done.") \ No newline at end of file From 42af857b3e8626ab92dd4e710b8ed40fc565a47e Mon Sep 17 00:00:00 2001 From: 514flowey <1114811901@qq.com> Date: Thu, 10 Mar 2022 19:40:25 +0800 Subject: [PATCH 05/10] add forward test --- projects/san/san_forward_test.py | 96 ++++++++++++++++++++++++++++++ python/jdet/models/networks/san.py | 3 +- 2 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 projects/san/san_forward_test.py diff --git a/projects/san/san_forward_test.py b/projects/san/san_forward_test.py new file mode 100644 index 00000000..90761561 --- /dev/null +++ b/projects/san/san_forward_test.py @@ -0,0 +1,96 @@ +import jittor as jt +jt.set_global_seed(0) +from jdet.config import init_cfg, get_cfg +from jdet.utils.general import parse_losses +from jdet.utils.registry import build_from_cfg,MODELS,DATASETS,OPTIMS +import numpy as np +import random +import jdet +import argparse +import os +import pickle as pk + +import numpy as np + +def main(): + parser = argparse.ArgumentParser(description="Jittor Object Detection Training") + parser.add_argument( + "--set_data", + action='store_true' + ) + args = parser.parse_args() + + jt.flags.use_cuda=1 + jt.set_global_seed(0) + np.random.seed(0) + random.seed(0) + init_cfg("/home/flowey/remote2/JDet/configs/san10_pairwise.py") + cfg = get_cfg() + + model = build_from_cfg(cfg.model,MODELS) + numpy_save_dir = '/home/flowey/remote2/JDet/projects/san/test_datas_san/models_numpy.pkl' + numpy_dict = pk.load(open(numpy_save_dir, 'rb')) + jittor_dict = dict() + for k, v in numpy_dict.items(): + jittor_dict[k] = jt.array(v) + model.load_state_dict(jittor_dict) + optimizer = build_from_cfg(cfg.optimizer,OPTIMS,params=model.parameters()) + + model.eval() + if (args.set_data): + + imagess = [] + targetss = [] + correct_loss = [] + + train_dataset = build_from_cfg(cfg.dataset.train,DATASETS) + for batch_idx,(images,targets) in enumerate(train_dataset): + print(batch_idx) + if (batch_idx > 10): + break + imagess.append(jdet.utils.general.sync(images)) + targetss.append(jdet.utils.general.sync(targets)) + losses = model(images,targets) + #all_loss, losses = parse_losses(losses) + #jt.sync_all(True) + #correct_loss.append(all_loss.item()) + correct_loss.append(losses.numpy()) + #optimizer.step(all_loss) + data = { + "imagess": imagess, + "targetss": targetss, + "correct_loss": correct_loss, + } + pk.dump(data, open("test_datas_san/test_data_jittor.pkl", "wb")) + print(correct_loss) + correct_loss = [jdet.utils.general.sync(i) for i in correct_loss] + data_numpy = { + "imagess": imagess, + "targetss": targetss, + "correct_loss": correct_loss, + } + pk.dump(data_numpy, open("test_datas_san/test_data_numpy.pkl", "wb")) + else: + data = pk.load(open("test_datas_faster_rcnn/test_data.pk", "rb")) + imagess = jdet.utils.general.to_jt_var(data["imagess"]) + targetss = jdet.utils.general.to_jt_var(data["targetss"]) + correct_loss = data["correct_loss"] + thr = 0.5 #TODO: fix thr + for batch_idx in range(len(imagess)): + images = imagess[batch_idx] + targets = targetss[batch_idx] + + losses = model(images,targets) + all_loss, losses = parse_losses(losses) + jt.sync_all(True) + l = all_loss.item() + optimizer.step(all_loss) + c_l = correct_loss[batch_idx] + err_rate = float(abs(c_l - l)/np.minimum(c_l, l)) + print(f"correct loss is {float(c_l):.4f}, runtime loss is {float(l):.4f}, err rate is {err_rate*100:.2f}%") + assert err_rate Date: Wed, 16 Mar 2022 17:22:46 +0800 Subject: [PATCH 06/10] update forward --- projects/san/san_backward_test.py | 127 +++++++++++++++++++++++++++ projects/san/san_forward_test.py | 16 ++-- projects/san/san_tot_forward_test.py | 71 +++++++++++++++ 3 files changed, 206 insertions(+), 8 deletions(-) create mode 100644 projects/san/san_backward_test.py create mode 100644 projects/san/san_tot_forward_test.py diff --git a/projects/san/san_backward_test.py b/projects/san/san_backward_test.py new file mode 100644 index 00000000..85083228 --- /dev/null +++ b/projects/san/san_backward_test.py @@ -0,0 +1,127 @@ +import jittor as jt +jt.set_global_seed(0) +from jdet.config import init_cfg, get_cfg +from jdet.utils.general import parse_losses +from jdet.utils.registry import build_from_cfg,MODELS,DATASETS,OPTIMS +import numpy as np +import random +import jdet +import argparse +import os +import pickle as pk + +import numpy as np + +def get_deep_attr(obj, name): + sls = name.split('.') + m = obj + for s in sls: + if s in ['0', '1', '2', '3', '4', '5']: + m = m[int(s)] + elif hasattr(m, s): + m = getattr(m, s) + else: + m = None + break + return m + + +def main(): + parser = argparse.ArgumentParser(description="Jittor Object Detection Training") + parser.add_argument( + "--set_data", + action='store_true' + ) + args = parser.parse_args() + + jt.flags.use_cuda=1 + jt.set_global_seed(0) + np.random.seed(0) + random.seed(0) + init_cfg("/home/flowey/remote2/JDet/configs/san10_pairwise.py") + cfg = get_cfg() + + model = build_from_cfg(cfg.model,MODELS) + numpy_save_dir = '/home/flowey/remote2/JDet/projects/san/test_datas_san/models_numpy.pkl' + numpy_dict = pk.load(open(numpy_save_dir, 'rb')) + jittor_dict = dict() + for k, v in numpy_dict.items(): + jittor_dict[k] = jt.array(v) + model.load_state_dict(jittor_dict) + optimizer = build_from_cfg(cfg.optimizer,OPTIMS,params=model.parameters()) + + model.train() + if (args.set_data): + + imagess = [] + targetss = [] + correct_loss = [] + + train_dataset = build_from_cfg(cfg.dataset.train,DATASETS) + for batch_idx,(images,targets) in enumerate(train_dataset): + print(batch_idx) + if (batch_idx > 10): + break + imagess.append(jdet.utils.general.sync(images)) + targetss.append(jdet.utils.general.sync(targets)) + losses = model(images,targets) + all_loss, losses = parse_losses(losses) + jt.sync_all(True) + correct_loss.append(all_loss.item()) + optimizer.step(all_loss) + data = { + "imagess": imagess, + "targetss": targetss, + "correct_loss": correct_loss, + } + pk.dump(data, open("test_datas_san/test_data_jittor.pkl", "wb")) + print(correct_loss) + correct_loss = [jdet.utils.general.sync(i) for i in correct_loss] + data_numpy = { + "imagess": imagess, + "targetss": targetss, + "correct_loss": correct_loss, + } + pk.dump(data_numpy, open("test_datas_san/test_data_numpy.pkl", "wb")) + else: + data = pk.load(open("test_datas_san/test_data_numpy.pkl", "rb")) + imagess = jdet.utils.general.to_jt_var(data["imagess"]) + targetss = jdet.utils.general.to_jt_var(data["targetss"]) + correct_loss = data["correct_loss"] + thr = 0.5 #TODO: fix thr + for batch_idx in range(len(imagess)): + images = imagess[batch_idx] + targets = targetss[batch_idx] + + losses = model(images,targets) + loss = losses['loss'] + grads = pk.load(open('test_datas_san/data_grad.pkl', 'rb')) + loss_value = grads.pop('loss_value') + print(loss.numpy() - loss_value) + max_diff, max_name = .0, 'None' + for name, value in grads.items(): + w = get_deep_attr(model, name) + g = jt.grad(loss, w).numpy() + d = np.max(np.abs((g - value)))/np.max(np.abs(value)) + print(name, d) + if d > max_diff: + max_diff = d + max_name = name + print(max_name, max_diff) + + return + all_loss, losses = parse_losses(losses) + jt.sync_all(True) + l = all_loss.item() + optimizer.step(all_loss) + state_dict2_path = '/home/flowey/remote2/JDet/projects/san/test_datas_san/models_numpy2.pth' + torch_dict = pk.load(open(state_dict2_path, 'rb')) + for k, v in model.state_dict().items(): + v = v.numpy() + print(k, np.max(np.abs(v - torch_dict[k]))) + return + #print(f"Loss is correct with err_rate<{thr}") + print("success!") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/projects/san/san_forward_test.py b/projects/san/san_forward_test.py index 90761561..7a20ecb1 100644 --- a/projects/san/san_forward_test.py +++ b/projects/san/san_forward_test.py @@ -36,7 +36,7 @@ def main(): model.load_state_dict(jittor_dict) optimizer = build_from_cfg(cfg.optimizer,OPTIMS,params=model.parameters()) - model.eval() + model.train() if (args.set_data): imagess = [] @@ -51,11 +51,10 @@ def main(): imagess.append(jdet.utils.general.sync(images)) targetss.append(jdet.utils.general.sync(targets)) losses = model(images,targets) - #all_loss, losses = parse_losses(losses) - #jt.sync_all(True) - #correct_loss.append(all_loss.item()) - correct_loss.append(losses.numpy()) - #optimizer.step(all_loss) + all_loss, losses = parse_losses(losses) + jt.sync_all(True) + correct_loss.append(all_loss.item()) + optimizer.step(all_loss) data = { "imagess": imagess, "targetss": targetss, @@ -71,7 +70,7 @@ def main(): } pk.dump(data_numpy, open("test_datas_san/test_data_numpy.pkl", "wb")) else: - data = pk.load(open("test_datas_faster_rcnn/test_data.pk", "rb")) + data = pk.load(open("test_datas_san/test_data_numpy.pkl", "rb")) imagess = jdet.utils.general.to_jt_var(data["imagess"]) targetss = jdet.utils.general.to_jt_var(data["targetss"]) correct_loss = data["correct_loss"] @@ -84,12 +83,13 @@ def main(): all_loss, losses = parse_losses(losses) jt.sync_all(True) l = all_loss.item() + print(l) optimizer.step(all_loss) c_l = correct_loss[batch_idx] err_rate = float(abs(c_l - l)/np.minimum(c_l, l)) print(f"correct loss is {float(c_l):.4f}, runtime loss is {float(l):.4f}, err rate is {err_rate*100:.2f}%") assert err_rate Date: Wed, 16 Mar 2022 18:07:32 +0800 Subject: [PATCH 07/10] keep update --- configs/san10_pairwise.py | 17 ++++++----- python/jdet/data/ILSVRC2012.py | 55 ++++++++++++++++++++++++++++++++++ python/jdet/data/transforms.py | 15 ++++++++++ 3 files changed, 80 insertions(+), 7 deletions(-) diff --git a/configs/san10_pairwise.py b/configs/san10_pairwise.py index 1ccd98a4..e17031ad 100644 --- a/configs/san10_pairwise.py +++ b/configs/san10_pairwise.py @@ -33,24 +33,27 @@ type = "Normalize", ## unknown normalize mean = [123.675, 116.28, 103.53], std = [58.395, 57.12, 57.375], - to_bgr=True), + to_bgr=False), ], batch_size=2, ), val=dict( type=dataset_type, - batch_size=128, - images_dir='/home/flowey/dataset/ILSVRC2012/train/', + batch_size=100, + images_dir='/home/flowey/dataset/ILSVRC2012/val/', transforms=[ dict(type = "Resize", - min_size = 224, - max_size = 224, + min_size = 256, + max_size = None, + ), + dict(type = "CenterCropJt", + size = 224, ), dict( type = "Normalize", mean = [123.675, 116.28, 103.53], std = [58.395, 57.12, 57.375], - to_bgr=True), + to_bgr=False), ], ), test=dict( @@ -65,7 +68,7 @@ type = "Normalize", mean = [123.675, 116.28, 103.53], std = [58.395, 57.12, 57.375], - to_bgr=True), + to_bgr=False), ], ) ) diff --git a/python/jdet/data/ILSVRC2012.py b/python/jdet/data/ILSVRC2012.py index b7d0d131..678fe3e7 100644 --- a/python/jdet/data/ILSVRC2012.py +++ b/python/jdet/data/ILSVRC2012.py @@ -3,13 +3,34 @@ import os from jdet.utils.registry import DATASETS +from jdet.utils.general import check_dir, to_jt_var from .transforms import Compose +from tqdm import tqdm import jittor as jt import os from jittor.dataset import Dataset import jdet +def cal_topk_accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k + + Parameters: + output: jt.Var [N, K] + target: jt.Var [K] + """ + with jt.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = jt.equal(pred, target.view(1, -1).expand_as(pred)) + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdims=True) + res.append(correct_k.sum().item() * (100.0 / batch_size)) + return res + @DATASETS.register_module() class ILSVRCDataset(Dataset): """ ILSVRCDataset @@ -87,3 +108,37 @@ def __getitem__(self,index): if self.transforms: img,targets = self.transforms(img,targets) return img,targets + + def evaluate(self,results,work_dir,epoch,logger=None, save=True): + print("Calculating mAP......") + if save: + save_path = os.path.join(work_dir,f"detections/val_{epoch}") + check_dir(save_path) + jt.save(results,save_path+"/val.pkl") + sum_top1, sum_top5, count = 0, 0, 0 + num_classes = len(self.classes) + sum_intersection, sum_output, sum_target = jt.zeros((num_classes)), jt.zeros((num_classes)), jt.zeros((num_classes)) + for img_idx,(result,target) in tqdm(enumerate(results)): + target = jt.array([target['img_label']]) + result = to_jt_var(result).unsqueeze(0) + top1, top5 = cal_topk_accuracy(result, target, topk=(1,5)) + + output, confidence = jt.argmax(result, dim=1, keepdims=False) + intersection = output[output == target] + sum_intersection = jt.scatter(sum_intersection, 0, intersection, jt.array([1]), reduce='add') + sum_output = jt.scatter(sum_output, 0, output, jt.array([1]), reduce='add') + sum_target = jt.scatter(sum_target, 0, target, jt.array([1]), reduce='add') + + sum_top1 = sum_top1 + top1 + sum_top5 = sum_top5 + top5 + count = count + 1 + iou_classes = sum_intersection / (sum_output + sum_target - sum_intersection + 1e-10) + accuracy_class = sum_intersection / (sum_target + 1e-10) + aps = dict( + mIoU = jt.mean(iou_classes, 0).item(), + mAcc = jt.mean(accuracy_class, 0).item(), + allAcc = sum_intersection.sum() / (sum_target.sum() + 1e-10), + mtop1 = sum_top1 / count, + mtop5 = sum_top5 / count, + ) + return aps diff --git a/python/jdet/data/transforms.py b/python/jdet/data/transforms.py index 967e089c..3e748949 100644 --- a/python/jdet/data/transforms.py +++ b/python/jdet/data/transforms.py @@ -9,6 +9,7 @@ from jdet.models.boxes.box_ops import rotated_box_to_poly_np,poly_to_rotated_box_np,norm_angle from jdet.models.boxes.iou_calculator import bbox_overlaps_np from numpy import random as nprandom +from jittor.transform import CenterCrop @TRANSFORMS.register_module() class Compose: @@ -152,6 +153,20 @@ def __call__(self, image, target=None): target["keep_ratio"] = self.keep_ratio return image, target +# Warning: DO NOT USE THIS +@TRANSFORMS.register_module() +class CenterCropJt: + def __init__(self, size): + if not isinstance(size, (list, tuple)): + size = (size,) + self.transformer = CenterCrop(size) + + def __call__(self, image, target=None): + image = self.transformer(image) + if target is not None: + target["img_size"] = image.size + return image, target + @TRANSFORMS.register_module() class MinIoURandomCrop: def __init__(self, From fdacdd47fd7842b64359ca9dd71f510b9aaeaa6d Mon Sep 17 00:00:00 2001 From: 514flowey <1114811901@qq.com> Date: Sat, 19 Mar 2022 21:43:56 +0800 Subject: [PATCH 08/10] tot forward + backward --- projects/san/configs/san10_pairwise.py | 94 ++++++++++++++++++++++++++ projects/san/san_backward_test.py | 84 ++++++++++++++++------- projects/san/san_forward_test.py | 10 +-- projects/san/san_tot_forward_test.py | 46 ++++++++++--- 4 files changed, 198 insertions(+), 36 deletions(-) create mode 100644 projects/san/configs/san10_pairwise.py diff --git a/projects/san/configs/san10_pairwise.py b/projects/san/configs/san10_pairwise.py new file mode 100644 index 00000000..def585db --- /dev/null +++ b/projects/san/configs/san10_pairwise.py @@ -0,0 +1,94 @@ +model = dict( + type='SAN', + sa_type=0, + layers=[2, 1, 2, 4, 1], + kernels=[3, 7, 7, 7, 7], + num_classes=1000, + loss=dict( + type='SAMSmoothLoss', + eps=0.1 + ), + loss_prepare=False +) + +# dataset settings +dataset_type = 'ILSVRCDataset' +dataset = dict( + imgs_per_gpu=2, + workers_per_gpu=4, + train=dict( + type=dataset_type, + images_dir='/home/flowey/dataset/ILSVRC2012/val/', + transforms=[ + dict(type = "Resize", ## TODO: implement RandomRotatedCrop + min_size = 224, + max_size = 224, + ), + dict( + type = "RotatedRandomFlip", + prob = 0.5, + direction="horizontal", + ), + dict( + type = "Normalize", ## unknown normalize + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False), + ], + batch_size=2, + ), + val=dict( + type=dataset_type, + batch_size=100, + images_dir='/home/flowey/dataset/ILSVRC2012/val/', + transforms=[ + dict(type = "Resize", + min_size = 256, + max_size = None, + clip_min_size = False + ), + dict(type = "CenterCropJt", + size = 224, + ), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False), + ], + ), + test=dict( + type="ImageDataset", + images_dir='/mnt/disk/lxl/dataset/DOTA_1024/test_split/images/', + transforms=[ + dict(type = "Resize", + min_size = 224, + max_size = 224, + ), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False), + ], + ) +) +# optimizer +optimizer = dict( + type='SGD', + lr=0.1, + momentum=0.9, + weight_decay=0.0001, + ) + +# learning policy +scheduler = dict( + type='CosineAnnealingLR', + max_steps=100) + +logger = dict( + type="RunLogger") +max_epoch = 100 +eval_interval = 25 +checkpoint_interval = 10 +log_interval = 20 diff --git a/projects/san/san_backward_test.py b/projects/san/san_backward_test.py index 85083228..bed8ceff 100644 --- a/projects/san/san_backward_test.py +++ b/projects/san/san_backward_test.py @@ -25,6 +25,11 @@ def get_deep_attr(obj, name): break return m +def calc_relative(x, y): + return np.max(np.abs(x - y)/(np.abs(x)+1e-10)) + +def calc_absolute(x, y): + return np.max(np.abs(x - y)) def main(): parser = argparse.ArgumentParser(description="Jittor Object Detection Training") @@ -38,11 +43,11 @@ def main(): jt.set_global_seed(0) np.random.seed(0) random.seed(0) - init_cfg("/home/flowey/remote2/JDet/configs/san10_pairwise.py") + init_cfg("configs/san10_pairwise.py") cfg = get_cfg() model = build_from_cfg(cfg.model,MODELS) - numpy_save_dir = '/home/flowey/remote2/JDet/projects/san/test_datas_san/models_numpy.pkl' + numpy_save_dir = '/home/flowey/remote/JDet/projects/san/test_datas_san/models_numpy.pkl' numpy_dict = pk.load(open(numpy_save_dir, 'rb')) jittor_dict = dict() for k, v in numpy_dict.items(): @@ -84,6 +89,33 @@ def main(): } pk.dump(data_numpy, open("test_datas_san/test_data_numpy.pkl", "wb")) else: + if False: + save_path_j = '/home/flowey/remote/JDet/projects/san/test_datas_san/conv4_jittor.pkl' + save_path_t = '/home/flowey/remote/JDet/projects/san/test_datas_san/conv4_torch.pkl' + jittor_dic = pk.load(open(save_path_j, 'rb')) + torch_dic = pk.load(open(save_path_t, 'rb')) + dic = dict() + save_path = '/home/flowey/remote/JDet/projects/san/test_datas_san/conv4_jittor.pkl' + dic['conv_in'] = jittor_dic['conv_in'] + x = model.layer4[0].conv(jt.array(torch_dic['conv_in'])) + dic['conv_out'] = x.numpy() + dic['conv_weight'] = model.layer4[0].conv.weight.numpy() + dic['conv_bias'] = model.layer4[0].conv.bias.numpy() + # x = x.sum() + x = model.bn4(x)[0].sum() + dic['grad_weight'] = jt.grad(x, model.layer4[0].conv.weight).numpy() + dic['grad_bias'] = jt.grad(x, model.layer4[0].conv.bias).numpy() + dic['x'] = x.numpy() + print(calc_relative(dic['conv_in'], torch_dic['conv_in'])) + print(calc_relative(dic['conv_out'], torch_dic['conv_out'])) + print(calc_relative(dic['grad_weight'], torch_dic['grad_weight'])) + print(calc_relative(dic['grad_bias'], torch_dic['grad_bias'])) + print(calc_relative(dic['x'], torch_dic['x'])) + print(model.bn4) + print(model.layer4[0].conv) + #pk.dump(dic, open(save_path, 'wb')) + return + data = pk.load(open("test_datas_san/test_data_numpy.pkl", "rb")) imagess = jdet.utils.general.to_jt_var(data["imagess"]) targetss = jdet.utils.general.to_jt_var(data["targetss"]) @@ -92,35 +124,41 @@ def main(): for batch_idx in range(len(imagess)): images = imagess[batch_idx] targets = targetss[batch_idx] - losses = model(images,targets) - loss = losses['loss'] - grads = pk.load(open('test_datas_san/data_grad.pkl', 'rb')) - loss_value = grads.pop('loss_value') - print(loss.numpy() - loss_value) - max_diff, max_name = .0, 'None' - for name, value in grads.items(): - w = get_deep_attr(model, name) - g = jt.grad(loss, w).numpy() - d = np.max(np.abs((g - value)))/np.max(np.abs(value)) - print(name, d) - if d > max_diff: - max_diff = d - max_name = name - print(max_name, max_diff) - - return + # loss = losses['loss'] + # grads = pk.load(open('test_datas_san/data_grad.pkl', 'rb')) + # loss_value = grads.pop('loss_value') + # print(loss.numpy() - loss_value) + # max_diff, max_name = .0, 'None' + # for name, value in grads.items(): + # w = get_deep_attr(model, name) + # g = jt.grad(loss, w).numpy() + # # d = np.mean(np.abs((g - value)/(np.abs(value)+1e-10))) + # d = np.max(np.abs((g - value))) + # print(name, d) + # if d > max_diff: + # max_diff = d + # max_name = name + # print(max_name, max_diff) + # return all_loss, losses = parse_losses(losses) jt.sync_all(True) - l = all_loss.item() + fc_bias_grad = jt.grad(all_loss, model.fc.bias).numpy() optimizer.step(all_loss) - state_dict2_path = '/home/flowey/remote2/JDet/projects/san/test_datas_san/models_numpy2.pth' + state_dict2_path = 'test_datas_san/backward_torch.pkl' torch_dict = pk.load(open(state_dict2_path, 'rb')) + max_diff, max_name = .0, 'None' for k, v in model.state_dict().items(): v = v.numpy() - print(k, np.max(np.abs(v - torch_dict[k]))) + d = np.max(np.abs(v - torch_dict[k])) + if d > max_diff: + max_diff = d + max_name = k + print(k, d) + print(max_name, max_diff) + print(np.max(np.abs(fc_bias_grad - torch_dict['fc.bias.grad']))) return - #print(f"Loss is correct with err_rate<{thr}") + print(f"Loss is correct with err_rate<{thr}") print("success!") if __name__ == "__main__": diff --git a/projects/san/san_forward_test.py b/projects/san/san_forward_test.py index 7a20ecb1..009af3e7 100644 --- a/projects/san/san_forward_test.py +++ b/projects/san/san_forward_test.py @@ -24,11 +24,11 @@ def main(): jt.set_global_seed(0) np.random.seed(0) random.seed(0) - init_cfg("/home/flowey/remote2/JDet/configs/san10_pairwise.py") + init_cfg("configs/san10_pairwise.py") cfg = get_cfg() model = build_from_cfg(cfg.model,MODELS) - numpy_save_dir = '/home/flowey/remote2/JDet/projects/san/test_datas_san/models_numpy.pkl' + numpy_save_dir = 'test_datas_san/models_numpy.pkl' numpy_dict = pk.load(open(numpy_save_dir, 'rb')) jittor_dict = dict() for k, v in numpy_dict.items(): @@ -36,7 +36,7 @@ def main(): model.load_state_dict(jittor_dict) optimizer = build_from_cfg(cfg.optimizer,OPTIMS,params=model.parameters()) - model.train() + model.eval() if (args.set_data): imagess = [] @@ -83,8 +83,8 @@ def main(): all_loss, losses = parse_losses(losses) jt.sync_all(True) l = all_loss.item() - print(l) - optimizer.step(all_loss) + # print(l) + # optimizer.step(all_loss) c_l = correct_loss[batch_idx] err_rate = float(abs(c_l - l)/np.minimum(c_l, l)) print(f"correct loss is {float(c_l):.4f}, runtime loss is {float(l):.4f}, err rate is {err_rate*100:.2f}%") diff --git a/projects/san/san_tot_forward_test.py b/projects/san/san_tot_forward_test.py index c7e79ac6..f481c57e 100644 --- a/projects/san/san_tot_forward_test.py +++ b/projects/san/san_tot_forward_test.py @@ -1,9 +1,9 @@ import jittor as jt import copy -jt.set_global_seed(0) +#jt.set_global_seed(0) from jdet.config import init_cfg, get_cfg -from jdet.utils.general import parse_losses +from jdet.utils.general import parse_losses, sync from jdet.utils.registry import build_from_cfg,MODELS,DATASETS,OPTIMS import numpy as np import random @@ -13,7 +13,8 @@ import pickle as pk from jdet.runner import Runner from PIL import Image -from jittor.transform import Resize, CenterCrop, ImageNormalize, to_tensor +from tqdm import tqdm +import copy import numpy as np @@ -29,11 +30,11 @@ def main(): jt.set_global_seed(0) np.random.seed(0) random.seed(0) - init_cfg("/home/flowey/remote2/JDet/configs/san10_pairwise.py") + init_cfg("configs/san10_pairwise.py") runner = Runner() - # image_file = '/home/flowey/dataset/ILSVRC2012/img_test/n01440764/ILSVRC2012_val_00009379.JPEG' - # image_torch_file = '/home/flowey/remote2/JDet/projects/san/test_datas_san/single_image.pkl' + # image_file = '/home/flowey/dataset/ILSVRC2012/val/n02106166/ILSVRC2012_val_00001900.JPEG' + # image_torch_file = '/home/flowey/remote/JDet/projects/san/test_datas_san/single_image.pkl' # val_trans = runner.val_dataset.transforms.transforms # image = Image.open(image_file).convert('RGB') # images = [] @@ -58,13 +59,42 @@ def main(): # print(i,j,k) # print(np.max(np.abs((images_torch[4] - images[4])/(images[4]+1e-10)))) - numpy_save_dir = '/home/flowey/remote2/JDet/projects/san/test_datas_san/models_numpy.pth' + numpy_save_dir = 'test_datas_san/models_numpy.pth' numpy_dict = pk.load(open(numpy_save_dir, 'rb')) jittor_dict = dict() for k, v in numpy_dict.items(): jittor_dict[k] = jt.array(v) + # print(len(jittor_dict.keys())) + # print(len(runner.model.state_dict().keys())) + # for k in runner.model.state_dict().keys(): + # if k not in numpy_dict.keys(): + # print("fjkdsalfjklsda") + # return runner.model.load_state_dict(jittor_dict) - runner.val() + + runner.logger.print_log("Validating....") + # TODO: need move eval into this function + runner.model.eval() + # if runner.model.is_training(): + # runner.model.eval() + results = [] + for batch_idx,(images,targets) in tqdm(enumerate(runner.val_dataset),total=len(runner.val_dataset)): + # if batch_idx == 0: + # l, result = runner.model(images, targets, show=True) + # pk.dump(sync(l), open('/home/flowey/remote/JDet/projects/san/test_datas_san/layer_j.pkl', 'wb')) + # else: + # continue + result = runner.model(images,targets) + # iimages.append(sync(images)) + # itarget.append(sync(targets)) + # results.append(sync(result)) + # layerss.append(sync(layers)) + results.extend([(r,t) for r,t in zip(sync(result),sync(targets))]) + # save_path = 'test_datas_san/eval_results.pkl' + # pk.dump(saved, open(save_path, 'wb')) + + eval_results = runner.val_dataset.evaluate(results,runner.work_dir,runner.epoch,logger=runner.logger) + runner.logger.log(eval_results,iter=runner.iter) if __name__ == "__main__": From 91db63993f4e08604157d262449fb1c23c6ab44f Mon Sep 17 00:00:00 2001 From: 514flowey <1114811901@qq.com> Date: Sat, 19 Mar 2022 21:44:21 +0800 Subject: [PATCH 09/10] check backward --- python/jdet/data/ILSVRC2012.py | 28 +++++++++++++++++---------- python/jdet/data/transforms.py | 12 +++++++----- python/jdet/models/losses/san_loss.py | 13 +++++++++---- python/jdet/models/networks/san.py | 2 +- 4 files changed, 35 insertions(+), 20 deletions(-) diff --git a/python/jdet/data/ILSVRC2012.py b/python/jdet/data/ILSVRC2012.py index 678fe3e7..8b19c714 100644 --- a/python/jdet/data/ILSVRC2012.py +++ b/python/jdet/data/ILSVRC2012.py @@ -9,7 +9,7 @@ import jittor as jt import os -from jittor.dataset import Dataset +from jittor.dataset import Dataset, ImageFolder import jdet def cal_topk_accuracy(output, target, topk=(1,)): @@ -62,15 +62,23 @@ def _load_labels(self, images_dir): def _load_images(self, images_dir): images, labels = [], [] - for label in os.listdir(images_dir): - label_dir = os.path.join(images_dir, label) - if os.path.isdir(label_dir): - if label not in self.class_to_idx.keys(): - raise ValueError("unknow class {}".format(label)) - for name in os.listdir(label_dir): - if (jdet.utils.general.is_img(name)): - images.append(os.path.join(images_dir, label, name)) - labels.append(self.class_to_idx[label]) + for i, class_name in enumerate(self.classes): + class_dir = os.path.join(images_dir, class_name) + for dname, _, fnames in sorted(os.walk(class_dir, followlinks=True)): + for fname in sorted(fnames): + if (jdet.utils.general.is_img(fname)): + images.append(os.path.join(class_dir, fname)) + labels.append(self.class_to_idx[class_name]) + + # for label in os.listdir(images_dir): + # label_dir = os.path.join(images_dir, label) + # if os.path.isdir(label_dir): + # if label not in self.class_to_idx.keys(): + # raise ValueError("unknow class {}".format(label)) + # for name in os.listdir(label_dir): + # if (jdet.utils.general.is_img(name)): + # images.append(os.path.join(images_dir, label, name)) + # labels.append(self.class_to_idx[label]) return images, labels def collate_batch(self,batch): diff --git a/python/jdet/data/transforms.py b/python/jdet/data/transforms.py index 3e748949..c2c44347 100644 --- a/python/jdet/data/transforms.py +++ b/python/jdet/data/transforms.py @@ -80,12 +80,13 @@ def __call__( self, image, target=None ): @TRANSFORMS.register_module() class Resize: - def __init__(self, min_size, max_size, keep_ratio=True): + def __init__(self, min_size, max_size, keep_ratio=True, clip_min_size=True): if not isinstance(min_size, (list, tuple)): min_size = (min_size,) self.min_size = min_size self.max_size = max_size self.keep_ratio = keep_ratio + self.clip_min_size = clip_min_size # modified from torchvision to add support for max size def get_size(self, image_size): @@ -95,10 +96,11 @@ def get_size(self, image_size): if self.keep_ratio: # NOTE Mingtao - if w <= h: - size = np.clip( size, int(w / 1.5), int(w * 1.5) ) - else: - size = np.clip( size, int(h / 1.5), int(h * 1.5) ) + if self.clip_min_size: + if w <= h: + size = np.clip( size, int(w / 1.5), int(w * 1.5) ) + else: + size = np.clip( size, int(h / 1.5), int(h * 1.5) ) if max_size is not None: min_original_size = float(min((w, h))) diff --git a/python/jdet/models/losses/san_loss.py b/python/jdet/models/losses/san_loss.py index ddeb17d0..dfc3d5d0 100644 --- a/python/jdet/models/losses/san_loss.py +++ b/python/jdet/models/losses/san_loss.py @@ -3,6 +3,11 @@ from jittor import nn from jdet.utils.registry import LOSSES +def log_softmax(x, dim): + v_max = jt.expand(jt.max(x, dim, keepdims=True), x.shape) + log_bias = jt.expand(jt.sum(jt.exp(x - v_max), dim, keepdims=True), x.shape) + return x - v_max - jt.log(log_bias) + def mixup_data(x, y, alpha=0.2): '''Returns mixed inputs, pairs of targets, and lambda''' if alpha > 0: @@ -18,20 +23,20 @@ def mixup_data(x, y, alpha=0.2): def mixup_loss(output, target_a, target_b, lam=1.0, eps=0.0): w = jt.zeros_like(output).scatter_(1, target_a.unsqueeze(1), jt.array(1)) w = w * (1 - eps) + (1 - w) * eps / (output.shape[1] - 1) - log_prob = nn.log_softmax(output, dim=1) + log_prob = log_softmax(output, dim=1) loss_a = (-w * log_prob).sum(dim=1).mean() w = jt.zeros_like(output).scatter_(1, target_b.unsqueeze(1), jt.array(1)) w = w * (1 - eps) + (1 - w) * eps / (output.shape[1] - 1) - log_prob = nn.log_softmax(output, dim=1) + log_prob = log_softmax(output, dim=1) loss_b = (-w * log_prob).sum(dim=1).mean() return lam * loss_a + (1 - lam) * loss_b def smooth_loss(output, target, eps=0.1): - w = jt.zeros_like(output).scatter_(1, target.unsqueeze(1), jt.array(1)) + w = jt.zeros_like(output).scatter_(1, target.unsqueeze(1), jt.array(1, dtype=output.dtype)) w = w * (1 - eps) + (1 - w) * eps / (output.shape[1] - 1) - log_prob = nn.log_softmax(output, dim=1) + log_prob = log_softmax(output, dim=1) loss = (-w * log_prob).sum(dim=1).mean() return loss diff --git a/python/jdet/models/networks/san.py b/python/jdet/models/networks/san.py index 161d60e1..a041aa4d 100644 --- a/python/jdet/models/networks/san.py +++ b/python/jdet/models/networks/san.py @@ -161,7 +161,7 @@ def _make_layer(self, sa_type, block, planes, blocks, kernel_size=7, stride=1): return nn.Sequential(*layers) def execute(self, x, targets=None): - targets = jt.array([t['img_label'] for t in targets]) + targets = jt.array([t['img_label'] for t in targets], dtype = jt.int32) if self.is_training() and self.loss_prepare: x = self.loss.prepare(x, targets) From 8c25328dad93961b79e86b054bbb79173d2a7fe8 Mon Sep 17 00:00:00 2001 From: 514flowey Date: Tue, 17 May 2022 21:59:02 -0400 Subject: [PATCH 10/10] fix bugs --- configs/faster_rcnn_RoITrans_r50_fpn_1x_dota.py | 7 ++++--- .../configs/faster_rcnn_RoITrans_r50_fpn_1x_dota.py | 1 + .../configs/faster_rcnn_RoITrans_r50_fpn_1x_dota_test.py | 1 + python/jdet/models/boxes/assigner.py | 2 ++ python/jdet/models/roi_heads/anchor_target.py | 2 +- 5 files changed, 9 insertions(+), 4 deletions(-) diff --git a/configs/faster_rcnn_RoITrans_r50_fpn_1x_dota.py b/configs/faster_rcnn_RoITrans_r50_fpn_1x_dota.py index 379d08f6..c4e80e0c 100644 --- a/configs/faster_rcnn_RoITrans_r50_fpn_1x_dota.py +++ b/configs/faster_rcnn_RoITrans_r50_fpn_1x_dota.py @@ -144,7 +144,7 @@ workers_per_gpu=2, train=dict( type=dataset_type, - dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0', + dataset_dir='/home/flowey/dataset/processed_DOTA/trainval_1024_200_1.0/', version='1', filter_min_size=32, transforms=[ @@ -163,10 +163,11 @@ to_bgr=True), ], batch_size=2, + shuffle=True, ), val=dict( type=dataset_type, - dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0', + dataset_dir='/home/flowey/dataset/processed_DOTA/trainval_1024_200_1.0/', version='1', filter_min_size=32, transforms=[ @@ -182,7 +183,7 @@ ), test=dict( type="ImageDataset", - images_dir='/mnt/disk/lxl/dataset/DOTA_1024/test_split/images/', + images_dir='/home/flowey/dataset/processed_DOTA/test_1024_200_1.0/images', transforms=[ dict( type = "Pad", diff --git a/projects/roi_transformer/configs/faster_rcnn_RoITrans_r50_fpn_1x_dota.py b/projects/roi_transformer/configs/faster_rcnn_RoITrans_r50_fpn_1x_dota.py index 379d08f6..9395c21e 100644 --- a/projects/roi_transformer/configs/faster_rcnn_RoITrans_r50_fpn_1x_dota.py +++ b/projects/roi_transformer/configs/faster_rcnn_RoITrans_r50_fpn_1x_dota.py @@ -163,6 +163,7 @@ to_bgr=True), ], batch_size=2, + shuffle=True, ), val=dict( type=dataset_type, diff --git a/projects/roi_transformer/configs/faster_rcnn_RoITrans_r50_fpn_1x_dota_test.py b/projects/roi_transformer/configs/faster_rcnn_RoITrans_r50_fpn_1x_dota_test.py index b633183f..d1c67b8f 100644 --- a/projects/roi_transformer/configs/faster_rcnn_RoITrans_r50_fpn_1x_dota_test.py +++ b/projects/roi_transformer/configs/faster_rcnn_RoITrans_r50_fpn_1x_dota_test.py @@ -164,6 +164,7 @@ to_bgr=True), ], batch_size=2, + shuffle=True, ), val=dict( type=dataset_type, diff --git a/python/jdet/models/boxes/assigner.py b/python/jdet/models/boxes/assigner.py index 81e75de0..d6187f97 100644 --- a/python/jdet/models/boxes/assigner.py +++ b/python/jdet/models/boxes/assigner.py @@ -156,6 +156,8 @@ def assign_wrt_overlaps(self, overlaps, gt_labels=None): assigned_gt_inds[max_iou_inds] = i + 1 else: assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1 + if i % 100 == 99: + jt.sync_all() if gt_labels is not None: assigned_labels = jt.full((num_bboxes, ), self.assigned_labels_filled, dtype=assigned_gt_inds.dtype) diff --git a/python/jdet/models/roi_heads/anchor_target.py b/python/jdet/models/roi_heads/anchor_target.py index 2f19340d..b734b267 100644 --- a/python/jdet/models/roi_heads/anchor_target.py +++ b/python/jdet/models/roi_heads/anchor_target.py @@ -192,7 +192,7 @@ def anchor_inside_flags(flat_anchors, valid_flags, img_shape, (flat_anchors[:, 3] < img_h + allowed_border) else: inside_flags = valid_flags - return inside_flags + return inside_flags.bool() def unmap(data, count, inds, fill=0):