Skip to content
This repository has been archived by the owner on Jul 25, 2024. It is now read-only.

Commit

Permalink
[Major] Update APIs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Haotian Tang committed Jun 18, 2023
1 parent f30f0da commit 08570cf
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 129 deletions.
6 changes: 3 additions & 3 deletions core/models/semantic_kitti/spvcnn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torchsparse
import torchsparse.nn as spnn
from torch import nn
from torchsparse import PointTensor
from torchsparse import SparseTensor

from core.models.utils import initial_voxelize, point_to_voxel, voxel_to_point
from core.models.utils import *

__all__ = ['SPVCNN']

Expand Down Expand Up @@ -187,7 +187,7 @@ def forward(self, x):
z = PointTensor(x.F, x.C.float())

x0 = initial_voxelize(z, self.pres, self.vres)

x0 = self.stem(x0)
z0 = voxel_to_point(x0, z, nearest=False)
z0.F = z0.F
Expand Down
9 changes: 5 additions & 4 deletions core/models/semantic_kitti/spvnas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch.nn as nn
import torchsparse
import torchsparse.nn as spnn
from torchsparse import PointTensor, SparseTensor
from torchsparse import SparseTensor

from core.models.utils import point_to_voxel, voxel_to_point
from core.models.utils import *
from core.modules.layers import (DynamicConvolutionBlock,
DynamicDeconvolutionBlock, DynamicLinear,
DynamicLinearBlock, DynamicResidualBlock,
Expand Down Expand Up @@ -281,14 +281,15 @@ def manual_select(self, sample):
def determinize(self, local_rank=0):
# Get the determinized SPVNAS network by running dummy inference.
self.eval()
device = next(self.parameters()).device

sample_feat = torch.randn(1000, 4)
sample_coord = torch.randn(1000, 4).random_(997)
sample_coord[:, -1] = 0
sample_coord[:, 0] = 0

if torch.cuda.is_available():
x = SparseTensor(sample_feat,
sample_coord.int()).to('cuda:%d' % local_rank)
sample_coord.int()).to(str(device))#'cuda:%d' % local_rank)
else:
x = SparseTensor(sample_feat, sample_coord.int())
with torch.no_grad():
Expand Down
169 changes: 98 additions & 71 deletions core/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,100 +1,127 @@
import numpy as np
import torch
import torchsparse.nn.functional as F
from torchsparse import PointTensor, SparseTensor
from torchsparse.nn.utils import get_kernel_offsets
import torchsparse
import torchsparse.nn.functional as spf
from torchsparse import SparseTensor
from torchsparse.nn.functional.devoxelize import calc_ti_weights
from torchsparse.nn.utils import *
from torchsparse.utils import *
from torchsparse.utils.tensor_cache import TensorCache
import torch_scatter
from typing import Union, Tuple

__all__ = ['initial_voxelize', 'point_to_voxel', 'voxel_to_point']
__all__ = ["initial_voxelize", "point_to_voxel", "voxel_to_point", "PointTensor"]


class PointTensor(SparseTensor):
def __init__(
self,
feats: torch.Tensor,
coords: torch.Tensor,
stride: Union[int, Tuple[int, ...]] = 1,
):
super().__init__(feats=feats, coords=coords, stride=stride)
self._caches.idx_query = dict()
self._caches.idx_query_devox = dict()
self._caches.weights_devox = dict()


def sphashquery(query, target, kernel_size=1):
hashmap_keys = torch.zeros(
2 * target.shape[0], dtype=torch.int64, device=target.device
)
hashmap_vals = torch.zeros(
2 * target.shape[0], dtype=torch.int32, device=target.device
)
hashmap = torchsparse.backend.GPUHashTable(hashmap_keys, hashmap_vals)
hashmap.insert_coords(target[:, [1, 2, 3, 0]])
kernel_size = make_ntuple(kernel_size, 3)
kernel_volume = np.prod(kernel_size)
kernel_size = make_tensor(kernel_size, device=target.device, dtype=torch.int32)
stride = make_tensor((1, 1, 1), device=target.device, dtype=torch.int32)
results = (
hashmap.lookup_coords(
query[:, [1, 2, 3, 0]], kernel_size, stride, kernel_volume
)
- 1
)[: query.shape[0]]
return results


# z: PointTensor
# return: SparseTensor
def initial_voxelize(z, init_res, after_res):
new_float_coord = torch.cat(
[(z.C[:, :3] * init_res) / after_res, z.C[:, -1].view(-1, 1)], 1)

pc_hash = F.sphash(torch.floor(new_float_coord).int())
sparse_hash = torch.unique(pc_hash)
idx_query = F.sphashquery(pc_hash, sparse_hash)
counts = F.spcount(idx_query.int(), len(sparse_hash))

inserted_coords = F.spvoxelize(torch.floor(new_float_coord), idx_query,
counts)
inserted_coords = torch.round(inserted_coords).int()
inserted_feat = F.spvoxelize(z.F, idx_query, counts)

new_tensor = SparseTensor(inserted_feat, inserted_coords, 1)
new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords)
z.additional_features['idx_query'][1] = idx_query
z.additional_features['counts'][1] = counts
z.C = new_float_coord
[z.C[:, 0].view(-1, 1), (z.C[:, 1:] * init_res) / after_res], 1
)
# optimization TBD: init_res = after_res
new_int_coord = torch.floor(new_float_coord).int()
sparse_coord = torch.unique(new_int_coord, dim=0)
idx_query = sphashquery(new_int_coord, sparse_coord).reshape(-1)

sparse_feat = torch_scatter.scatter_mean(z.F, idx_query.long(), dim=0)
new_tensor = SparseTensor(sparse_feat, sparse_coord, 1)
z._caches.idx_query[z.s] = idx_query
z.C = new_float_coord
return new_tensor


# x: SparseTensor, z: PointTensor
# return: SparseTensor
def point_to_voxel(x, z):
if z.additional_features is None or z.additional_features.get(
'idx_query') is None or z.additional_features['idx_query'].get(
x.s) is None:
pc_hash = F.sphash(
torch.cat([
torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0],
z.C[:, -1].int().view(-1, 1)
], 1))
sparse_hash = F.sphash(x.C)
idx_query = F.sphashquery(pc_hash, sparse_hash)
counts = F.spcount(idx_query.int(), x.C.shape[0])
z.additional_features['idx_query'][x.s] = idx_query
z.additional_features['counts'][x.s] = counts
if z._caches.idx_query.get(x.s) is None:
# Note: x.C has a smaller range after downsampling.
new_int_coord = torch.cat(
[
z.C[:, 0].int().view(-1, 1),
torch.floor(z.C[:, 1:] / x.s[0]).int(),
],
1,
)
idx_query = sphashquery(new_int_coord, x.C)
z._caches.idx_query[x.s] = idx_query
else:
idx_query = z.additional_features['idx_query'][x.s]
counts = z.additional_features['counts'][x.s]

inserted_feat = F.spvoxelize(z.F, idx_query, counts)
new_tensor = SparseTensor(inserted_feat, x.C, x.s)
new_tensor.cmaps = x.cmaps
new_tensor.kmaps = x.kmaps
idx_query = z._caches.idx_query[x.s]
# Haotian: This impl. is not elegant
idx_query = idx_query.clamp_(0)
sparse_feat = torch_scatter.scatter_mean(z.F, idx_query.long(), dim=0)
new_tensor = SparseTensor(sparse_feat, x.C, x.s)
new_tensor._caches = x._caches

return new_tensor


# x: SparseTensor, z: PointTensor
# return: PointTensor
def voxel_to_point(x, z, nearest=False):
if z.idx_query is None or z.weights is None or z.idx_query.get(
x.s) is None or z.weights.get(x.s) is None:
off = get_kernel_offsets(2, x.s, 1, device=z.F.device)
old_hash = F.sphash(
torch.cat([
torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0],
z.C[:, -1].int().view(-1, 1)
], 1), off)
pc_hash = F.sphash(x.C.to(z.F.device))
idx_query = F.sphashquery(old_hash, pc_hash)
weights = F.calc_ti_weights(z.C, idx_query,
scale=x.s[0]).transpose(0, 1).contiguous()
idx_query = idx_query.transpose(0, 1).contiguous()
if (
z._caches.idx_query_devox.get(x.s) is None
or z._caches.weights_devox.get(x.s) is None
):
point_coords_float = torch.cat(
[z.C[:, 0].int().view(-1, 1), z.C[:, 1:] / x.s[0]],
1,
)
point_coords_int = torch.floor(point_coords_float).int()
idx_query = sphashquery(point_coords_int, x.C, kernel_size=2)
weights = calc_ti_weights(point_coords_float[:, 1:], idx_query, scale=1)

if nearest:
weights[:, 1:] = 0.
weights[:, 1:] = 0.0
idx_query[:, 1:] = -1
new_feat = F.spdevoxelize(x.F, idx_query, weights)
new_tensor = PointTensor(new_feat,
z.C,
idx_query=z.idx_query,
weights=z.weights)
new_tensor.additional_features = z.additional_features
new_tensor.idx_query[x.s] = idx_query
new_tensor.weights[x.s] = weights
z.idx_query[x.s] = idx_query
z.weights[x.s] = weights
new_feat = spf.spdevoxelize(x.F, idx_query, weights)
new_tensor = PointTensor(new_feat, z.C)
new_tensor._caches = z._caches
new_tensor._caches.idx_query_devox[x.s] = idx_query
new_tensor._caches.weights_devox[x.s] = weights
z._caches.idx_query_devox[x.s] = idx_query
z._caches.weights_devox[x.s] = weights

else:
new_feat = F.spdevoxelize(x.F, z.idx_query.get(x.s), z.weights.get(x.s))
new_tensor = PointTensor(new_feat,
z.C,
idx_query=z.idx_query,
weights=z.weights)
new_tensor.additional_features = z.additional_features
new_feat = spf.spdevoxelize(
x.F, z._caches.idx_query_devox.get(x.s), z._caches.weights_devox.get(x.s)
)
new_tensor = PointTensor(new_feat, z.C)
new_tensor._caches = z._caches

return new_tensor
Loading

0 comments on commit 08570cf

Please sign in to comment.