Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

solov2和cascade mask rcnn对齐代码 #8995

Open
wants to merge 1 commit into
base: release/2.7
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions eval_solov2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
export FLAGS_npu_storage_format=0
export FLAGS_use_stride_kernel=0
export FLAGS_npu_jit_compile=0
export CUSTOM_DEVICE_BLACK_LIST=set_value,set_value_with_tensor
export ASCEND_GLOBAL_LOG_LEVEL=3

#指定npu
export ASCEND_RT_VISIBLE_DEVICES=14

# 启动测试
python tools/eval.py -c configs/solov2/solov2_r50_fpn_3x_coco.yml -o weights=output_cascade_rcnn/model_final.pdopt

142 changes: 141 additions & 1 deletion ppdet/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,126 @@

from ppdet.utils.logger import setup_logger
logger = setup_logger('ppdet.engine')
# import paddle.profiler as profiler

__all__ = ['Trainer']

MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack']

##############################
class _AllReduce(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, input):
input_list = [paddle.zeros_like(input) for k in range(dist.get_world_size())]
# Use allgather instead of allreduce since I don't trust in-place operations ..
dist.all_gather(input_list, input, sync_op=True)
inputs = paddle.stack(input_list, axis=0)
return paddle.sum(inputs, axis=0)

@staticmethod
def backward(ctx, grad_output):
dist.all_reduce(grad_output, sync_op=True)
return grad_output


def differentiable_all_reduce(input):
"""
Differentiable counterpart of `dist.all_reduce`.
"""
if (
not dist.is_available()
or not dist.is_initialized()
or dist.get_world_size() == 1
):
return input
return _AllReduce.apply(input)


class NaiveSyncBatchNorm(nn.BatchNorm2D):
def __init__(self, *args, stats_mode="", **kwargs):
super().__init__(*args, **kwargs)
assert stats_mode in ["", "N"]
self._stats_mode = stats_mode

def forward(self, input):
if dist.get_world_size() == 1 or not self.training:
return super().forward(input)

B, C = input.shape[0], input.shape[1]

mean = paddle.mean(input, axis=[0, 2, 3])
meansqr = paddle.mean(input * input, axis=[0, 2, 3])

if self._stats_mode == "":
assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.'
vec = paddle.concat([mean, meansqr], axis=0)
vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size())
mean, meansqr = paddle.split(vec, [C, C])
momentum = 1 - self._momentum # NOTE: paddle has reverse momentum defination
else:
if B == 0:
vec = paddle.zeros([2 * C + 1], dtype=mean.dtype)
vec = vec + input.sum() # make sure there is gradient w.r.t input
else:
vec = paddle.concat(
[
mean,
meansqr,
paddle.ones([1], dtype=mean.dtype),
],
axis=0,
)
vec = differentiable_all_reduce(vec * B)

total_batch = vec[-1].detach()
momentum = total_batch.clip(max=1) * (1 - self._momentum) # no update if total_batch is 0
mean, meansqr, _ = paddle.split(vec / total_batch.clip(min=1), [C, C, int(vec.shape[0] - 2*C)]) # avoid div-by-zero

var = meansqr - mean * mean
invstd = paddle.rsqrt(var + self._epsilon)
scale = self.weight * invstd
bias = self.bias - mean * scale
scale = scale.reshape([1, -1, 1, 1])
bias = bias.reshape([1, -1, 1, 1])

tmp_mean = self._mean + momentum * (mean.detach() - self._mean)
self._mean.set_value(tmp_mean)
tmp_variance = self._variance + (momentum * (var.detach() - self._variance))
self._variance.set_value(tmp_variance)
ret = input * scale + bias
return ret

@classmethod
def convert_sync_batchnorm(cls, layer):
layer_output = layer
if isinstance(layer, nn.BatchNorm2D):

layer_output = NaiveSyncBatchNorm(
layer._num_features,
layer._momentum,
layer._epsilon,
layer._weight_attr,
layer._bias_attr,
layer._data_format,
layer._name)

if (
layer._weight_attr is not False
and layer._bias_attr is not False
):
with paddle.no_grad():
layer_output.weight = layer.weight
layer_output.bias = layer.bias
layer_output._mean = layer._mean
layer_output._variance = layer._variance

for name, sublayer in layer.named_children():
layer_output.add_sublayer(
name, cls.convert_sync_batchnorm(sublayer)
)
del layer
return layer_output
##############################

class Trainer(object):
def __init__(self, cfg, mode='train'):
Expand Down Expand Up @@ -471,6 +586,11 @@ def train(self, validate=False):
if sync_bn:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)

# paddle.save(model.state_dict(), "pretrained/init.pdparams")
# model = NaiveSyncBatchNorm.convert_sync_batchnorm(model)
# model.set_state_dict(paddle.load("pretrained/init.pdparams"))
print(model)

# enabel auto mixed precision mode
if self.use_amp:
scaler = paddle.amp.GradScaler(
Expand Down Expand Up @@ -517,9 +637,19 @@ def train(self, validate=False):
model.train()
iter_tic = time.time()
for step_id, data in enumerate(self.loader):
# paddle.save(data, "pretrained/data.pdparams")
# data = paddle.load("pretrained/data.pdparams")
# if step_id == 5:
# pf = paddle.profiler.Profiler(targets=[paddle.profiler.ProfilerTarget.CUSTOM_DEVICE], custom_device_types=['npu'],
# record_shapes=True)
# pf.start()
# if step_id == 6:
# pf.stop()
# time.sleep(15)
# exit()
self.status['data_time'].update(time.time() - iter_tic)
self.status['step_id'] = step_id
profiler.add_profiler_step(profiler_options)
# profiler.add_profiler_step(profiler_options)
self._compose_callback.on_step_begin(self.status)
data['epoch_id'] = epoch_id
if self.cfg.get('to_static',
Expand Down Expand Up @@ -578,6 +708,16 @@ def train(self, validate=False):
loss = outputs['loss']
# model backward
loss.backward()
# for n,p in model.named_parameters():
# if p.grad is None:
# mean, min_v, max_v = None, None, None
# else:
# grad = p.grad.numpy()
# mean, min_v, max_v = np.abs(grad).mean().item(), np.abs(grad).min().item(), np.abs(grad).max().item()
# print(n, mean, min_v, max_v)
# for k,v in outputs.items():
# print(k, v.item())
# exit()
self.optimizer.step()
curr_lr = self.optimizer.get_lr()
self.lr.step()
Expand Down
3 changes: 1 addition & 2 deletions ppdet/modeling/heads/roi_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self,
self.canonical_size = canonical_size
self.start_level = start_level
self.end_level = end_level
self.aligned = aligned
self.aligned = False

@classmethod
def from_config(cls, cfg, input_shape):
Expand Down Expand Up @@ -114,5 +114,4 @@ def forward(self, feats, roi, rois_num):
rois_feat_list.append(roi_feat)
rois_feat_shuffle = paddle.concat(rois_feat_list)
rois_feat = paddle.gather(rois_feat_shuffle, restore_index)

return rois_feat
2 changes: 2 additions & 0 deletions ppdet/modeling/heads/solov2_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,9 @@ def get_seg_single(self, cate_preds, seg_preds, kernel_preds, featmap_size,
seg_preds, cate_scores, cate_labels = self.mask_nms(
seg_preds, seg_masks, cate_labels, cate_scores, sum_masks=sum_masks)
ori_shape = im_shape[:2] / scale_factor + 0.5

ori_shape = paddle.cast(ori_shape, 'int32')

seg_preds = F.interpolate(
paddle.unsqueeze(seg_preds, 0),
size=upsampled_size_out,
Expand Down
38 changes: 38 additions & 0 deletions train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
export FLAGS_npu_storage_format=0
export FLAGS_use_stride_kernel=0
export FLAGS_npu_jit_compile=0
export CUSTOM_DEVICE_BLACK_LIST=set_value,set_value_with_tensor

#指定npu
# export ASCEND_RT_VISIBLE_DEVICES=11,12,13,14


# python -m paddle.distributed.launch --devices 0,1,2,3,4,5,6,7 --master=127.0.0.1:12345 tools/train.py -c configs/solov2/solov2_r50_fpn_3x_coco.yml -r output/23.pdparams

# python -m paddle.distributed.launch --devices 8,9,10,11,12,13,14 --master=127.0.0.1:12347 --enable_ce tools/train.py -c configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.yml

ps -x | grep configs/cascade_rcnn | awk '{print $1}' | xargs kill -9
python -m paddle.distributed.launch --devices 8,9,10,11,12,13,14,15 tools/train.py -c configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.yml


# python tools/train.py -c configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.yml



# # 模型算子获取
# # 1) 执行模型训练之前,需要先输出以下环境变量
# export GLOG_v=6 # 新动态图执行器算子输出配置

# # 2) 执行模型训练,并 grep 日志文件,完成1-2次迭代之后即可停止
# # 启动单卡训练
# python tools/train.py -c configs/solov2/solov2_r50_fpn_1x_coco.yml > solov2.log 2>&1
# grep 日志的 loss 信息,确认训练已经完成 1-2 次迭代之后即可停止
# grep Loss solov2.log
# # 期望输出如下
# # [2022/07/05 10:38:20] ppcls INFO: [Train][Epoch 1/120][Iter: 0/5005]lr(PiecewiseDecay): 0.10000000, top1: 0.00000, top5: 0.00391, CELoss: 7.02506, loss: 7.02506, batch_cost: 6.29051s, reader_cost: 3.57217, ips: 40.69624 samples/s, eta: 43 days, 17:27:58
# # [2022/07/05 10:38:28] ppcls INFO: [Train][Epoch 1/120][Iter: 10/5005]lr(PiecewiseDecay): 0.10000000, top1: 0.00142, top5: 0.00604, CELoss: 8.36165, loss: 8.36165, batch_cost: 0.79642s, reader_cost: 0.02113, ips: 321.43838 samples/s, eta: 5 days, 12:52:01

# # 3) 执行如下命令从日志中筛选算子列表
# cat solov2.log | grep -a "API kernel key" | awk '{print $5}' > solov2_temp.log
# cat solov2.log | grep -a "API kernel key" | cut -d " " -f6 > solov2_temp.log
# sort -u solov2_temp.log > rsolov2_oplist.txt