From 780ffed9f3736fedadf18b51266ecbf521e64cf6 Mon Sep 17 00:00:00 2001 From: RRaoyzee <162255573+RRaoyzee@users.noreply.github.com> Date: Wed, 13 Mar 2024 15:15:46 +0800 Subject: [PATCH] Add multi_scale_deform_attn_grad op adapter for NPU (#3042) --- .../csrc/pytorch/npu/ms_deform_attn_npu.cpp | 59 ++++++++++++++++- tests/test_ops/test_ms_deformable_attn.py | 64 +++++++++++++++++++ 2 files changed, 122 insertions(+), 1 deletion(-) diff --git a/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp index fa83f17547..da6f291048 100644 --- a/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp @@ -55,7 +55,7 @@ Tensor ms_deform_attn_forward_npu(const Tensor &value, c10::SmallVector output_size = { value.size(0), sampling_locations.size(1), value.size(2) * value.size(3)}; - at::Tensor output = at::empty(output_size, value_fp32.options()); + at::Tensor output = at::zeros(output_size, value_fp32.options()); OpCommand cmd; cmd.Name("MultiScaleDeformableAttnFunction") @@ -75,3 +75,60 @@ Tensor ms_deform_attn_forward_npu(const Tensor &value, } REGISTER_NPU_IMPL(ms_deform_attn_impl_forward, ms_deform_attn_forward_npu); + +void ms_deform_attn_impl_backward( + const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, const Tensor &sampling_loc, + const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value, + Tensor &grad_sampling_loc, Tensor &grad_attn_weight, + const int im2col_step); + +void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const Tensor &grad_output, Tensor &grad_value, + Tensor &grad_sampling_loc, + Tensor &grad_attn_weight, const int im2col_step) { + check_support(value, attn_weight); + at::Tensor value_fp32 = value; + at::Tensor spatial_shapes_int32 = spatial_shapes; + at::Tensor level_start_index_int32 = level_start_index; + at::Tensor sampling_loc_fp32 = sampling_loc.transpose(4, 5).contiguous(); + at::Tensor attn_weight_fp32 = attn_weight; + at::Tensor grad_output_fp32 = grad_output; + if (value.scalar_type() != at::kFloat) { + value_fp32 = value.to(at::kFloat); + } + if (spatial_shapes.scalar_type() != at::kInt) { + spatial_shapes_int32 = spatial_shapes.to(at::kInt); + } + if (level_start_index.scalar_type() != at::kInt) { + level_start_index_int32 = level_start_index.to(at::kInt); + } + if (sampling_loc.scalar_type() != at::kFloat) { + sampling_loc_fp32 = sampling_loc_fp32.to(at::kFloat); + } + if (attn_weight.scalar_type() != at::kFloat) { + attn_weight_fp32 = attn_weight.to(at::kFloat); + } + if (grad_output.scalar_type() != at::kFloat) { + grad_output_fp32 = grad_output.to(at::kFloat); + } + + OpCommand cmd; + cmd.Name("MultiScaleDeformableAttentionGrad") + .Input(value_fp32) + .Input(spatial_shapes_int32) + .Input(level_start_index_int32) + .Input(sampling_loc_fp32) + .Input(attn_weight_fp32) + .Input(grad_output_fp32) + .Output(grad_value) + .Output(grad_sampling_loc) + .Output(grad_attn_weight) + .Run(); + grad_sampling_loc = grad_sampling_loc.transpose(4, 5).contiguous(); +} + +REGISTER_NPU_IMPL(ms_deform_attn_impl_backward, ms_deform_attn_backward_npu); diff --git a/tests/test_ops/test_ms_deformable_attn.py b/tests/test_ops/test_ms_deformable_attn.py index eda4bb80bd..06859dfe43 100644 --- a/tests/test_ops/test_ms_deformable_attn.py +++ b/tests/test_ops/test_ms_deformable_attn.py @@ -337,3 +337,67 @@ def test_gradient_numerical(channels, im2col_step), eps=eps, atol=1e-2) + + +@pytest.mark.skipif(not IS_NPU_AVAILABLE, reason='requires NPU support') +def test_backward_equal_with_pytorch_npu(): + N, M, D = 6, 4, 8 + Lq, L, P = 10000, 4, 8 + shapes = torch.as_tensor([(60, 40), (30, 20), (16, 24), (53, 32)], + dtype=torch.int32) + level_start_index = torch.cat((shapes.new_zeros( + (1, )), shapes.prod(1).cumsum(0)[:-1])) + S = sum((H * W).item() for H, W in shapes) + + torch.manual_seed(3) + value = torch.rand(N, S, M, D) * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2) + attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5 + attention_weights /= attention_weights.sum( + -1, keepdim=True).sum( + -2, keepdim=True) + im2col_step = 2 + value.requires_grad = True + sampling_locations.requires_grad = True + attention_weights.requires_grad = True + output_pytorch = multi_scale_deformable_attn_pytorch( + value.float(), shapes, sampling_locations.float(), + attention_weights.float()) + grad_output_pytorch = torch.ones_like(output_pytorch) + output_pytorch.backward(grad_output_pytorch) + grad_value = value.grad.detach().cpu() + grad_location = sampling_locations.grad.detach().cpu() + grad_attn_weight = attention_weights.grad.detach().cpu() + + value_npu = value.npu() + shapes_npu = shapes.npu() + level_start_index_npu = level_start_index.npu() + sampling_locations_npu = sampling_locations.npu() + attention_weights_npu = attention_weights.npu() + output_npu = MultiScaleDeformableAttnFunction.apply( + value_npu.float(), shapes_npu, level_start_index_npu, + sampling_locations_npu.float(), attention_weights_npu.float(), + im2col_step) + grad_output_npu = torch.ones_like(output_npu) + output_npu.backward(grad_output_npu) + grad_value_npu = value_npu.grad.detach().cpu() + grad_location_npu = sampling_locations_npu.grad.detach().cpu() + grad_attn_weight_npu = attention_weights_npu.grad.detach().cpu() + assert torch.allclose(grad_value_npu, grad_value) + max_abs_err_1 = (grad_value_npu - grad_value).abs().max() + max_rel_err_1 = ((grad_value_npu - grad_value).abs() / + grad_value.abs()).max() + assert max_abs_err_1 < 1e-5 + assert max_rel_err_1 < 1e-4 + assert torch.allclose(grad_location_npu, grad_location) + max_abs_err_2 = (grad_location_npu - grad_location).abs().max() + max_rel_err_2 = ((grad_location_npu - grad_location).abs() / + grad_location.abs()).max() + assert max_abs_err_2 < 1e-5 + assert max_rel_err_2 < 1e-4 + assert torch.allclose(grad_attn_weight_npu, grad_attn_weight) + max_abs_err_3 = (grad_attn_weight_npu - grad_attn_weight).abs().max() + max_rel_err_3 = ((grad_attn_weight_npu - grad_attn_weight).abs() / + grad_attn_weight.abs()).max() + assert max_abs_err_3 < 1e-5 + assert max_rel_err_3 < 1e-4