forked from 152334H/DL-Art-School
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsandbox.py
56 lines (45 loc) · 1.68 KB
/
sandbox.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torchvision
from PIL import Image
from pytorch_wavelets import DWTForward, DWTInverse
import torch.nn.functional as F
def load_img(path):
im = Image.open(path).convert(mode="RGB")
return torchvision.transforms.ToTensor()(im)
def save_img(t, path):
torchvision.utils.save_image(t, path)
img = load_img("pu.jpg")
img = img.unsqueeze(0)
# Reshape image to be multiple of 32
w, h = img.shape[2:]
w = (w // 32) * 32
h = (h // 32) * 32
img = F.interpolate(img, size=(w, h))
print("Input shape:", img.shape)
J_spec = 5
Yl, Yh = DWTForward(J=J_spec, mode='periodization', wave='db3')(img)
print(Yl.shape, [h.shape for h in Yh])
imgLR = F.interpolate(img, scale_factor=.5)
LQYl, LQYh = DWTForward(J=J_spec-1, mode='periodization', wave='db3')(imgLR)
print(LQYl.shape, [h.shape for h in LQYh])
for i in range(J_spec):
smd = torch.sum(Yh[i], dim=2).cpu()
save_img(smd, "high_%i.png" % (i,))
save_img(Yl, "lo.png")
'''
Following code reconstructs the image with different high passes cancelled out.
'''
for i in range(J_spec):
corrupted_im = [y for y in Yh]
corrupted_im[i] = torch.zeros_like(corrupted_im[i])
im = DWTInverse(mode='periodization', wave='db3')((Yl, corrupted_im))
save_img(im, "corrupt_%i.png" % (i,))
im = DWTInverse(mode='periodization', wave='db3')((torch.full_like(Yl, fill_value=torch.mean(Yl)), Yh))
save_img(im, "corrupt_im.png")
'''
Following code reconstructs a hybrid image with the first high pass from the HR and the rest of the data from the LR.
highpass = [Yh[0]] + LQYh
im = DWTInverse(mode='periodization', wave='db3')((LQYl, highpass))
save_img(im, "hybrid_lrhr.png")
save_img(F.interpolate(imgLR, scale_factor=2), "upscaled.png")
'''