-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodels.py
143 lines (122 loc) · 5.01 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
from torch import nn
from torch.nn import functional as F
from mapping_network import MappingNetowrk, AdaptiveInstanceNorm, NoiseInjection
from helpers.imle_helpers import get_1x1, get_3x3, draw_gaussian_diag_samples, gaussian_analytical_kl
from collections import defaultdict
import numpy as np
import itertools
class Block(nn.Module):
def __init__(self, in_width, middle_width, out_width, down_rate=None, residual=False, use_3x3=True, zero_last=False):
super().__init__()
self.down_rate = down_rate
self.residual = residual
self.c1 = get_1x1(in_width, middle_width)
self.c2 = get_3x3(middle_width, middle_width) if use_3x3 else get_1x1(middle_width, middle_width)
self.c3 = get_3x3(middle_width, middle_width) if use_3x3 else get_1x1(middle_width, middle_width)
self.c4 = get_1x1(middle_width, out_width, zero_weights=zero_last)
def forward(self, x):
xhat = self.c1(F.gelu(x))
xhat = self.c2(F.gelu(xhat))
xhat = self.c3(F.gelu(xhat))
xhat = self.c4(F.gelu(xhat))
out = x + xhat if self.residual else xhat
if self.down_rate is not None:
out = F.avg_pool2d(out, kernel_size=self.down_rate, stride=self.down_rate)
return out
def parse_layer_string(s):
layers = []
for ss in s.split(','):
if 'x' in ss:
res, num = ss.split('x')
count = int(num)
layers += [(int(res), None) for _ in range(count)]
elif 'm' in ss:
res, mixin = [int(a) for a in ss.split('m')]
layers.append((res, mixin))
elif 'd' in ss:
res, down_rate = [int(a) for a in ss.split('d')]
layers.append((res, down_rate))
else:
res = int(ss)
layers.append((res, None))
return layers
def pad_channels(t, width):
d1, d2, d3, d4 = t.shape
empty = torch.zeros(d1, width, d3, d4, device=t.device)
empty[:, :d2, :, :] = t
return empty
def get_width_settings(width, s):
mapping = defaultdict(lambda: width)
if s:
s = s.split(',')
for ss in s:
k, v = ss.split(':')
mapping[int(k)] = int(v)
return mapping
class DecBlock(nn.Module):
def __init__(self, H, res, mixin, n_blocks):
super().__init__()
self.base = res
self.mixin = mixin
self.H = H
self.widths = get_width_settings(H.width, H.custom_width_str)
width = self.widths[res]
if res <= H.max_hierarchy:
self.noise = NoiseInjection(width)
self.adaIN = AdaptiveInstanceNorm(width, H.latent_dim)
use_3x3 = res > 2
cond_width = int(width * H.bottleneck_multiple)
self.resnet = Block(width, cond_width, width, residual=True, use_3x3=use_3x3)
self.resnet.c4.weight.data *= np.sqrt(1 / n_blocks)
def forward(self, x, w, spatial_noise):
if self.mixin is not None:
x = F.interpolate(x, scale_factor=self.base // self.mixin)
if self.base <= self.H.max_hierarchy:
x = self.noise(x, spatial_noise)
x = self.adaIN(x, w)
x = self.resnet(x)
return x
class Decoder(nn.Module):
def __init__(self, H):
super().__init__()
self.H = H
self.mapping_network = MappingNetowrk(code_dim=H.latent_dim, n_mlp=H.n_mpl)
resos = set()
cond_width = int(H.width * H.bottleneck_multiple)
dec_blocks = []
self.widths = get_width_settings(H.width, H.custom_width_str)
blocks = parse_layer_string(H.dec_blocks)
for idx, (res, mixin) in enumerate(blocks):
dec_blocks.append(DecBlock(H, res, mixin, n_blocks=len(blocks)))
resos.add(res)
self.resolutions = sorted(resos)
self.dec_blocks = nn.ModuleList(dec_blocks)
first_res = self.resolutions[0]
self.constant = nn.Parameter(torch.randn(1, self.widths[first_res], first_res, first_res))
self.resnet = get_1x1(H.width, H.image_channels)
self.gain = nn.Parameter(torch.ones(1, H.image_channels, 1, 1))
self.bias = nn.Parameter(torch.zeros(1, H.image_channels, 1, 1))
def forward(self, latent_code, spatial_noise, input_is_w=False):
if not input_is_w:
w = self.mapping_network(latent_code)[0]
else:
w = latent_code
x = self.constant.repeat(latent_code.shape[0], 1, 1, 1)
if spatial_noise:
res_to_noise = {x.shape[3]: x for x in spatial_noise}
for idx, block in enumerate(self.dec_blocks):
noise = None
if block.base <= self.H.max_hierarchy:
noise = res_to_noise[block.base]
x = block(x, w, noise)
x = self.resnet(x)
x = self.gain * x + self.bias
return x
class IMLE(nn.Module):
def __init__(self, H):
super().__init__()
self.dci_db = None
self.decoder = Decoder(H)
def forward(self, latents, spatial_noise=None, input_is_w=False):
return self.decoder.forward(latents, spatial_noise, input_is_w)