This repository has been archived by the owner on Jan 12, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 438
/
Copy pathbev_pool.py
97 lines (75 loc) · 2.58 KB
/
bev_pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from . import bev_pool_ext
__all__ = ["bev_pool"]
class QuickCumsum(torch.autograd.Function):
@staticmethod
def forward(ctx, x, geom_feats, ranks):
x = x.cumsum(0)
kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
kept[:-1] = ranks[1:] != ranks[:-1]
x, geom_feats = x[kept], geom_feats[kept]
x = torch.cat((x[:1], x[1:] - x[:-1]))
# save kept for backward
ctx.save_for_backward(kept)
# no gradient for geom_feats
ctx.mark_non_differentiable(geom_feats)
return x, geom_feats
@staticmethod
def backward(ctx, gradx, gradgeom):
(kept,) = ctx.saved_tensors
back = torch.cumsum(kept, 0)
back[kept] -= 1
val = gradx[back]
return val, None, None
class QuickCumsumCuda(torch.autograd.Function):
@staticmethod
def forward(ctx, x, geom_feats, ranks, B, D, H, W):
kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
kept[1:] = ranks[1:] != ranks[:-1]
interval_starts = torch.where(kept)[0].int()
interval_lengths = torch.zeros_like(interval_starts)
interval_lengths[:-1] = interval_starts[1:] - interval_starts[:-1]
interval_lengths[-1] = x.shape[0] - interval_starts[-1]
geom_feats = geom_feats.int()
out = bev_pool_ext.bev_pool_forward(
x,
geom_feats,
interval_lengths,
interval_starts,
B,
D,
H,
W,
)
ctx.save_for_backward(interval_starts, interval_lengths, geom_feats)
ctx.saved_shapes = B, D, H, W
return out
@staticmethod
def backward(ctx, out_grad):
interval_starts, interval_lengths, geom_feats = ctx.saved_tensors
B, D, H, W = ctx.saved_shapes
out_grad = out_grad.contiguous()
x_grad = bev_pool_ext.bev_pool_backward(
out_grad,
geom_feats,
interval_lengths,
interval_starts,
B,
D,
H,
W,
)
return x_grad, None, None, None, None, None, None
def bev_pool(feats, coords, B, D, H, W):
assert feats.shape[0] == coords.shape[0]
ranks = (
coords[:, 0] * (W * D * B)
+ coords[:, 1] * (D * B)
+ coords[:, 2] * B
+ coords[:, 3]
)
indices = ranks.argsort()
feats, coords, ranks = feats[indices], coords[indices], ranks[indices]
x = QuickCumsumCuda.apply(feats, coords, ranks, B, D, H, W)
x = x.permute(0, 4, 1, 2, 3).contiguous()
return x