-
Notifications
You must be signed in to change notification settings - Fork 9
/
utils.py
101 lines (78 loc) · 3.46 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import kaolin as kal
import kaolin.ops.mesh
import clip
import numpy as np
from torchvision import transforms
from pathlib import Path
from collections import Counter
from Normalization import MeshNormalizer
if torch.cuda.is_available():
device = torch.device("cuda:0")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
def get_camera_from_view2(elev, azim, r=3.0):
x = r * torch.cos(elev) * torch.cos(azim)
y = r * torch.sin(elev)
z = r * torch.cos(elev) * torch.sin(azim)
# print(elev,azim,x,y,z)
pos = torch.tensor([x, y, z]).unsqueeze(0)
look_at = -pos
direction = torch.tensor([0.0, 1.0, 0.0]).unsqueeze(0)
camera_proj = kal.render.camera.generate_transformation_matrix(pos, look_at, direction)
return camera_proj
def get_texture_map_from_color(mesh, color, H=224, W=224):
num_faces = mesh.faces.shape[0]
texture_map = torch.zeros(1, H, W, 3).to(device)
texture_map[:, :, :] = color
return texture_map.permute(0, 3, 1, 2)
def get_face_attributes_from_color(mesh, color):
num_faces = mesh.faces.shape[0]
face_attributes = torch.zeros(1, num_faces, 3, 3).to(device)
face_attributes[:, :, :] = color
return face_attributes
# ================== POSITIONAL ENCODERS =============================
class FourierFeatureTransform(torch.nn.Module):
"""
An implementation of Gaussian Fourier feature mapping.
"Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
https://arxiv.org/abs/2006.10739
https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
Given an input of size [batches, num_input_channels, width, height],
returns a tensor of size [batches, mapping_size*2, width, height].
"""
def __init__(self, num_input_channels, mapping_size=256, scale=10):
super().__init__()
self._num_input_channels = num_input_channels
self._mapping_size = mapping_size
B = torch.randn((num_input_channels, mapping_size)) * scale
B_sort = sorted(B, key=lambda x: torch.norm(x, p=2))
self._B = torch.stack(B_sort) # for sape
def forward(self, x):
# assert x.dim() == 4, 'Expected 4D input (got {}D input)'.format(x.dim())
batches, channels = x.shape
assert channels == self._num_input_channels, \
"Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels)
# Make shape compatible for matmul with _B.
# From [B, C, W, H] to [(B*W*H), C].
# x = x.permute(0, 2, 3, 1).reshape(batches * width * height, channels)
res = x @ self._B.to(x.device)
# From [(B*W*H), C] to [B, W, H, C]
# x = x.view(batches, width, height, self._mapping_size)
# From [B, W, H, C] to [B, C, W, H]
# x = x.permute(0, 3, 1, 2)
res = 2 * np.pi * res
return torch.cat([x, torch.sin(res), torch.cos(res)], dim=1)
# mesh coloring helpers
def color_mesh(pred_class, sampled_mesh, colors):
pred_rgb = segment2rgb(pred_class, colors)
sampled_mesh.face_attributes = kaolin.ops.mesh.index_vertices_by_faces(
pred_rgb.unsqueeze(0),
sampled_mesh.faces)
MeshNormalizer(sampled_mesh)()
def segment2rgb(pred_class, colors):
pred_rgb = torch.zeros(pred_class.shape[0], 3).to(device)
for class_idx, color in enumerate(colors):
pred_rgb += torch.matmul(pred_class[:,class_idx].unsqueeze(1), color.unsqueeze(0))
return pred_rgb