This repository has been archived by the owner on Jul 25, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Haotian Tang
committed
Jun 18, 2023
1 parent
f30f0da
commit 08570cf
Showing
7 changed files
with
158 additions
and
129 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,100 +1,127 @@ | ||
import numpy as np | ||
import torch | ||
import torchsparse.nn.functional as F | ||
from torchsparse import PointTensor, SparseTensor | ||
from torchsparse.nn.utils import get_kernel_offsets | ||
import torchsparse | ||
import torchsparse.nn.functional as spf | ||
from torchsparse import SparseTensor | ||
from torchsparse.nn.functional.devoxelize import calc_ti_weights | ||
from torchsparse.nn.utils import * | ||
from torchsparse.utils import * | ||
from torchsparse.utils.tensor_cache import TensorCache | ||
import torch_scatter | ||
from typing import Union, Tuple | ||
|
||
__all__ = ['initial_voxelize', 'point_to_voxel', 'voxel_to_point'] | ||
__all__ = ["initial_voxelize", "point_to_voxel", "voxel_to_point", "PointTensor"] | ||
|
||
|
||
class PointTensor(SparseTensor): | ||
def __init__( | ||
self, | ||
feats: torch.Tensor, | ||
coords: torch.Tensor, | ||
stride: Union[int, Tuple[int, ...]] = 1, | ||
): | ||
super().__init__(feats=feats, coords=coords, stride=stride) | ||
self._caches.idx_query = dict() | ||
self._caches.idx_query_devox = dict() | ||
self._caches.weights_devox = dict() | ||
|
||
|
||
def sphashquery(query, target, kernel_size=1): | ||
hashmap_keys = torch.zeros( | ||
2 * target.shape[0], dtype=torch.int64, device=target.device | ||
) | ||
hashmap_vals = torch.zeros( | ||
2 * target.shape[0], dtype=torch.int32, device=target.device | ||
) | ||
hashmap = torchsparse.backend.GPUHashTable(hashmap_keys, hashmap_vals) | ||
hashmap.insert_coords(target[:, [1, 2, 3, 0]]) | ||
kernel_size = make_ntuple(kernel_size, 3) | ||
kernel_volume = np.prod(kernel_size) | ||
kernel_size = make_tensor(kernel_size, device=target.device, dtype=torch.int32) | ||
stride = make_tensor((1, 1, 1), device=target.device, dtype=torch.int32) | ||
results = ( | ||
hashmap.lookup_coords( | ||
query[:, [1, 2, 3, 0]], kernel_size, stride, kernel_volume | ||
) | ||
- 1 | ||
)[: query.shape[0]] | ||
return results | ||
|
||
|
||
# z: PointTensor | ||
# return: SparseTensor | ||
def initial_voxelize(z, init_res, after_res): | ||
new_float_coord = torch.cat( | ||
[(z.C[:, :3] * init_res) / after_res, z.C[:, -1].view(-1, 1)], 1) | ||
|
||
pc_hash = F.sphash(torch.floor(new_float_coord).int()) | ||
sparse_hash = torch.unique(pc_hash) | ||
idx_query = F.sphashquery(pc_hash, sparse_hash) | ||
counts = F.spcount(idx_query.int(), len(sparse_hash)) | ||
|
||
inserted_coords = F.spvoxelize(torch.floor(new_float_coord), idx_query, | ||
counts) | ||
inserted_coords = torch.round(inserted_coords).int() | ||
inserted_feat = F.spvoxelize(z.F, idx_query, counts) | ||
|
||
new_tensor = SparseTensor(inserted_feat, inserted_coords, 1) | ||
new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords) | ||
z.additional_features['idx_query'][1] = idx_query | ||
z.additional_features['counts'][1] = counts | ||
z.C = new_float_coord | ||
[z.C[:, 0].view(-1, 1), (z.C[:, 1:] * init_res) / after_res], 1 | ||
) | ||
# optimization TBD: init_res = after_res | ||
new_int_coord = torch.floor(new_float_coord).int() | ||
sparse_coord = torch.unique(new_int_coord, dim=0) | ||
idx_query = sphashquery(new_int_coord, sparse_coord).reshape(-1) | ||
|
||
sparse_feat = torch_scatter.scatter_mean(z.F, idx_query.long(), dim=0) | ||
new_tensor = SparseTensor(sparse_feat, sparse_coord, 1) | ||
z._caches.idx_query[z.s] = idx_query | ||
z.C = new_float_coord | ||
return new_tensor | ||
|
||
|
||
# x: SparseTensor, z: PointTensor | ||
# return: SparseTensor | ||
def point_to_voxel(x, z): | ||
if z.additional_features is None or z.additional_features.get( | ||
'idx_query') is None or z.additional_features['idx_query'].get( | ||
x.s) is None: | ||
pc_hash = F.sphash( | ||
torch.cat([ | ||
torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], | ||
z.C[:, -1].int().view(-1, 1) | ||
], 1)) | ||
sparse_hash = F.sphash(x.C) | ||
idx_query = F.sphashquery(pc_hash, sparse_hash) | ||
counts = F.spcount(idx_query.int(), x.C.shape[0]) | ||
z.additional_features['idx_query'][x.s] = idx_query | ||
z.additional_features['counts'][x.s] = counts | ||
if z._caches.idx_query.get(x.s) is None: | ||
# Note: x.C has a smaller range after downsampling. | ||
new_int_coord = torch.cat( | ||
[ | ||
z.C[:, 0].int().view(-1, 1), | ||
torch.floor(z.C[:, 1:] / x.s[0]).int(), | ||
], | ||
1, | ||
) | ||
idx_query = sphashquery(new_int_coord, x.C) | ||
z._caches.idx_query[x.s] = idx_query | ||
else: | ||
idx_query = z.additional_features['idx_query'][x.s] | ||
counts = z.additional_features['counts'][x.s] | ||
|
||
inserted_feat = F.spvoxelize(z.F, idx_query, counts) | ||
new_tensor = SparseTensor(inserted_feat, x.C, x.s) | ||
new_tensor.cmaps = x.cmaps | ||
new_tensor.kmaps = x.kmaps | ||
idx_query = z._caches.idx_query[x.s] | ||
# Haotian: This impl. is not elegant | ||
idx_query = idx_query.clamp_(0) | ||
sparse_feat = torch_scatter.scatter_mean(z.F, idx_query.long(), dim=0) | ||
new_tensor = SparseTensor(sparse_feat, x.C, x.s) | ||
new_tensor._caches = x._caches | ||
|
||
return new_tensor | ||
|
||
|
||
# x: SparseTensor, z: PointTensor | ||
# return: PointTensor | ||
def voxel_to_point(x, z, nearest=False): | ||
if z.idx_query is None or z.weights is None or z.idx_query.get( | ||
x.s) is None or z.weights.get(x.s) is None: | ||
off = get_kernel_offsets(2, x.s, 1, device=z.F.device) | ||
old_hash = F.sphash( | ||
torch.cat([ | ||
torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], | ||
z.C[:, -1].int().view(-1, 1) | ||
], 1), off) | ||
pc_hash = F.sphash(x.C.to(z.F.device)) | ||
idx_query = F.sphashquery(old_hash, pc_hash) | ||
weights = F.calc_ti_weights(z.C, idx_query, | ||
scale=x.s[0]).transpose(0, 1).contiguous() | ||
idx_query = idx_query.transpose(0, 1).contiguous() | ||
if ( | ||
z._caches.idx_query_devox.get(x.s) is None | ||
or z._caches.weights_devox.get(x.s) is None | ||
): | ||
point_coords_float = torch.cat( | ||
[z.C[:, 0].int().view(-1, 1), z.C[:, 1:] / x.s[0]], | ||
1, | ||
) | ||
point_coords_int = torch.floor(point_coords_float).int() | ||
idx_query = sphashquery(point_coords_int, x.C, kernel_size=2) | ||
weights = calc_ti_weights(point_coords_float[:, 1:], idx_query, scale=1) | ||
|
||
if nearest: | ||
weights[:, 1:] = 0. | ||
weights[:, 1:] = 0.0 | ||
idx_query[:, 1:] = -1 | ||
new_feat = F.spdevoxelize(x.F, idx_query, weights) | ||
new_tensor = PointTensor(new_feat, | ||
z.C, | ||
idx_query=z.idx_query, | ||
weights=z.weights) | ||
new_tensor.additional_features = z.additional_features | ||
new_tensor.idx_query[x.s] = idx_query | ||
new_tensor.weights[x.s] = weights | ||
z.idx_query[x.s] = idx_query | ||
z.weights[x.s] = weights | ||
new_feat = spf.spdevoxelize(x.F, idx_query, weights) | ||
new_tensor = PointTensor(new_feat, z.C) | ||
new_tensor._caches = z._caches | ||
new_tensor._caches.idx_query_devox[x.s] = idx_query | ||
new_tensor._caches.weights_devox[x.s] = weights | ||
z._caches.idx_query_devox[x.s] = idx_query | ||
z._caches.weights_devox[x.s] = weights | ||
|
||
else: | ||
new_feat = F.spdevoxelize(x.F, z.idx_query.get(x.s), z.weights.get(x.s)) | ||
new_tensor = PointTensor(new_feat, | ||
z.C, | ||
idx_query=z.idx_query, | ||
weights=z.weights) | ||
new_tensor.additional_features = z.additional_features | ||
new_feat = spf.spdevoxelize( | ||
x.F, z._caches.idx_query_devox.get(x.s), z._caches.weights_devox.get(x.s) | ||
) | ||
new_tensor = PointTensor(new_feat, z.C) | ||
new_tensor._caches = z._caches | ||
|
||
return new_tensor |
Oops, something went wrong.