-
Notifications
You must be signed in to change notification settings - Fork 0
/
meta_modules.py
79 lines (63 loc) · 2.74 KB
/
meta_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
'''Define Hypernet
'''
import torch
from torch import nn
from collections import OrderedDict
import modules
'''Adapted from the SIREN repository https://github.com/vsitzmann/siren
'''
class HyperNetwork(nn.Module):
def __init__(self, hyper_in_features, hyper_hidden_layers, hyper_hidden_features, hypo_module):
'''
Args:
hyper_in_features: In features of hypernetwork
hyper_hidden_layers: Number of hidden layers in hypernetwork
hyper_hidden_features: Number of hidden units in hypernetwork
hypo_module: MetaModule. The module whose parameters are predicted.
'''
super().__init__()
hypo_parameters = hypo_module.meta_named_parameters()
self.names = []
self.nets = nn.ModuleList()
self.param_shapes = []
for name, param in hypo_parameters:
self.names.append(name)
self.param_shapes.append(param.size())
hn = modules.FCBlock(in_features=hyper_in_features, out_features=int(torch.prod(torch.tensor(param.size()))),
num_hidden_layers=hyper_hidden_layers, hidden_features=hyper_hidden_features,
outermost_linear=True, nonlinearity='relu')
self.nets.append(hn)
if 'weight' in name:
self.nets[-1].net[-1].apply(lambda m: hyper_weight_init(m, param.size()[-1]))
elif 'bias' in name:
self.nets[-1].net[-1].apply(lambda m: hyper_bias_init(m))
def forward(self, z):
'''
Args:
z: Embedding. Input to hypernetwork. Could be output of "Autodecoder"
Returns:
params: OrderedDict. Can be directly passed as the "params" parameter of a MetaModule.
'''
params = OrderedDict()
for name, net, param_shape in zip(self.names, self.nets, self.param_shapes):
batch_param_shape = (-1,) + param_shape
params[name] = net(z).reshape(batch_param_shape)
# print(params)
return params
'''Initialization schemes
'''
def hyper_weight_init(m, in_features_main_net):
if hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
m.weight.data = m.weight.data / 1.e2
if hasattr(m, 'bias'):
with torch.no_grad():
m.bias.uniform_(-1/in_features_main_net, 1/in_features_main_net)
def hyper_bias_init(m):
if hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
m.weight.data = m.weight.data / 1.e2
if hasattr(m, 'bias'):
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
with torch.no_grad():
m.bias.uniform_(-1/fan_in, 1/fan_in)