-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
65 lines (53 loc) · 1.84 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import torch
import torch
import torch.nn as nn
class PEncoding(nn.Module):
def __init__(
self,
num_freq,
max_freq_log2,
log_sampling=True,
include_input=True,
input_dim=1,
):
super().__init__()
self.num_freq = num_freq
self.max_freq_log2 = max_freq_log2
self.log_sampling = log_sampling
self.include_input = include_input
self.out_dim = 0
if include_input:
self.out_dim += input_dim
if self.log_sampling:
self.bands = 2.0 ** torch.linspace(0.0, max_freq_log2, steps=num_freq)
else:
self.bands = torch.linspace(1, 2.0**max_freq_log2, steps=num_freq)
# The out_dim is really just input_dim + num_freq * input_dim * 2 (for sin and cos)
self.out_dim += self.bands.shape[0] * input_dim * 2
self.bands = nn.Parameter(self.bands).requires_grad_(False)
def forward(self, coords):
"""Embeds the coordinates.
Args:
coords (torch.FloatTensor): Coordinates of shape [N, input_dim]
Returns:
(torch.FloatTensor): Embeddings of shape [N, input_dim + out_dim] or [N, out_dim].
"""
N = coords.shape[0]
winded = (coords[:, None] * self.bands[None, :, None]).reshape(
N, coords.shape[1] * self.num_freq
)
encoded = torch.cat([torch.sin(winded), torch.cos(winded)], dim=-1)
if self.include_input:
encoded = torch.cat([coords, encoded], dim=-1)
return encoded
def to_numpy(x):
return x.detach().numpy()
def place_objects(min_x, max_x):
objects = []
for i in range(1):
# x = np.random.uniform(min_x, max_x)
# y = np.random.uniform(x, max_x)
# objects.append((x, y))
objects.append((3, 5))
return objects