-
Notifications
You must be signed in to change notification settings - Fork 0
/
part1_code.py
151 lines (123 loc) · 5.44 KB
/
part1_code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import imageio.v2 as imageio
import time
def positional_encoding(x, num_frequencies=6, incl_input=True):
"""
Apply positional encoding to the input.
Args:
x (torch.Tensor): Input tensor to be positionally encoded.
The dimension of x is [N, D], where N is the number of input coordinates,
and D is the dimension of the input coordinate.
num_frequencies (optional, int): The number of frequencies used in
the positional encoding (default: 6).
incl_input (optional, bool): If True, concatenate the input with the
computed positional encoding (default: True).
Returns:
(torch.Tensor): Positional encoding of the input tensor.
"""
results = []
if incl_input:
results.append(x)
############################# TODO 1(a) BEGIN ############################
# encode input tensor and append the encoded tensor to the list of results.
for i in range(num_frequencies):
for f in [torch.sin, torch.cos]:
results.append(f((2.0 ** i) * np.pi * x))
############################# TODO 1(a) END ##############################
return torch.cat(results, dim=-1)
class model_2d(nn.Module):
"""
Define a 2D model comprising of three fully connected layers,
two relu activations and one sigmoid activation.
"""
def __init__(self, filter_size=128, num_frequencies=6):
super().__init__()
############################# TODO 1(b) BEGIN ############################
D = 2
L = num_frequencies
self.layer_1 = nn.Linear(2 + 2*D*L, filter_size)
self.layer_2 = nn.Linear(filter_size, filter_size)
self.layer_3 = nn.Linear(filter_size, 3)
############################# TODO 1(b) END ##############################
def forward(self, x):
############################# TODO 1(b) BEGIN ############################
x = F.relu(self.layer_1(x))
x = F.relu(self.layer_2(x))
x = torch.sigmoid(self.layer_3(x))
############################# TODO 1(b) END ##############################
return x
def train_2d_model(test_img, num_frequencies, device, model=model_2d, positional_encoding=positional_encoding, show=True):
# Optimizer parameters
lr = 5e-4
iterations = 10000
height, width = test_img.shape[:2]
# Number of iters after which stats are displayed
display = 2000
# Define the model and initialize its weights.
model2d = model(num_frequencies=num_frequencies)
model2d.to(device)
def weights_init(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
model2d.apply(weights_init)
############################# TODO 1(c) BEGIN ############################
# Define the optimizer
optimizer = torch.optim.Adam(model2d.parameters(), lr=lr)
############################# TODO 1(c) END ############################
# Seed RNG, for repeatability
seed = 5670
torch.manual_seed(seed)
np.random.seed(seed)
# Lists to log metrics etc.
psnrs = []
iternums = []
t = time.time()
t0 = time.time()
############################# TODO 1(c) BEGIN ############################
# Create the 2D normalized coordinates, and apply positional encoding to them:
u, v = torch.meshgrid(torch.linspace(-1, 1, width), torch.linspace(-1, 1, height))
img_uv = torch.stack([u, v], dim=-1).to(device)
img_uv = img_uv.view(-1, 2)
pos_enc_img_uv = positional_encoding(img_uv, num_frequencies=num_frequencies)
############################# TODO 1(c) END ############################
for i in range(iterations+1):
optimizer.zero_grad()
############################# TODO 1(c) BEGIN ############################
# Run one iteration
pred = model2d(pos_enc_img_uv).view(height, width, 3)
# Compute mean-squared error between the predicted and target images. Backprop!
loss = F.mse_loss(pred.view(-1, 3), test_img.view(-1, 3))
loss.backward()
optimizer.step()
############################# TODO 1(c) END ############################
# Display images/plots/stats
if i % display == 0 and show:
############################# TODO 1(c) BEGIN ############################
# Calculate psnr
psnr = -10 * torch.log10(loss)
############################# TODO 1(c) END ############################
print("Iteration %d " % i, "Loss: %.4f " % loss.item(), "PSNR: %.2f" % psnr.item(), \
"Time: %.2f secs per iter" % ((time.time() - t) / display), "%.2f secs in total" % (time.time() - t0))
t = time.time()
psnrs.append(psnr.item())
iternums.append(i)
plt.figure(figsize=(13, 4))
plt.subplot(131)
plt.imshow(pred.detach().cpu().numpy())
plt.title(f"Iteration {i}")
plt.subplot(132)
plt.imshow(test_img.cpu().numpy())
plt.title("Target image")
plt.subplot(133)
plt.plot(iternums, psnrs)
plt.title("PSNR")
plt.show()
#if i==iterations:
# np.save('result_'+str(num_frequencies)+'.npz',pred.detach().cpu().numpy())
print('Done!')
return pred.detach().cpu()