Skip to content

Commit

Permalink
Merge pull request #45 from HazyResearch/centered-tikhonov
Browse files Browse the repository at this point in the history
Centered Tikhonov regularization
  • Loading branch information
ajratner authored Sep 8, 2018
2 parents efa4c38 + 0e24b07 commit d6bcc1a
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 50 deletions.
54 changes: 46 additions & 8 deletions metal/label_model/label_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,36 @@ def _init_params(self):
- Z is the inverse form version of \mu.
"""
train_config = self.config["train_config"]

# Initialize mu so as to break basic reflective symmetry
# Note that we are given either a single or per-LF initial precision
# value, prec_i = P(Y=y|\lf=y), and use:
# mu_init = P(\lf=y|Y=y) = P(\lf=y) * prec_i / P(Y=y)

# Handle single or per-LF values
if isinstance(train_config["prec_init"], (int, float)):
prec_init = train_config["prec_init"] * torch.ones(self.m)
else:
prec_init = torch.from_numpy(train_config["prec_init"])
if prec_init.shape[0] != self.m:
raise ValueError(f"prec_init must have shape {self.m}.")

# Get the per-value labeling propensities
# Note that self.O must have been computed already!
lps = torch.diag(self.O).numpy()

# TODO: Update for higher-order cliques!
self.mu_init = torch.zeros(self.d, self.k)
for i in range(self.m):
for y in range(self.k):
self.mu_init[i * self.k + y, y] += (
train_config["mu_init"] * np.random.random()
)
self.mu = nn.Parameter(self.mu_init.clone()).float()
idx = i * self.k + y
mu_init = torch.clamp(lps[idx] * prec_init[i] / self.p[y], 0, 1)
self.mu_init[idx, y] += mu_init

# Initialize randomly based on self.mu_init
self.mu = nn.Parameter(
self.mu_init.clone() * np.random.random()
).float()

if self.inv_form:
self.Z = nn.Parameter(torch.randn(self.d, self.k)).float()
Expand Down Expand Up @@ -280,6 +301,25 @@ def get_Q(self):
# (for better or worse). The unused *args make these compatible with the
# Classifer._train() method which expect loss functions to accept an input.

def loss_l2(self, l2=0):
"""L2 loss centered around mu_init, scaled optionally per-source.
In other words, diagonal Tikhonov regularization,
||D(\mu-\mu_{init})||_2^2
where D is diagonal.
Args:
- l2: A float or np.array representing the per-source regularization
strengths to use
"""
if isinstance(l2, (int, float)):
D = l2 * torch.eye(self.d)
else:
D = torch.diag(torch.from_numpy(l2))

# Note that mu is a matrix and this is the *Frobenius norm*
return torch.norm(D @ (self.mu - self.mu_init)) ** 2

def loss_inv_Z(self, *args):
return torch.norm((self.O_inv + self.Z @ self.Z.t())[self.mask]) ** 2

Expand All @@ -288,8 +328,7 @@ def loss_inv_mu(self, *args, l2=0):
loss_2 = (
torch.norm(torch.sum(self.mu @ self.P, 1) - torch.diag(self.O)) ** 2
)
loss_l2 = torch.norm(self.mu - self.mu_init) ** 2
return loss_1 + loss_2 + l2 * loss_l2
return loss_1 + loss_2 + self.loss_l2(l2=l2)

def loss_mu(self, *args, l2=0):
loss_1 = (
Expand All @@ -299,8 +338,7 @@ def loss_mu(self, *args, l2=0):
loss_2 = (
torch.norm(torch.sum(self.mu @ self.P, 1) - torch.diag(self.O)) ** 2
)
loss_l2 = torch.norm(self.mu - self.mu_init) ** 2
return loss_1 + loss_2 + l2 * loss_l2
return loss_1 + loss_2 + self.loss_l2(l2=l2)

def _set_class_balance(self, class_balance, Y_dev):
"""Set a prior for the class balance
Expand Down
8 changes: 3 additions & 5 deletions metal/label_model/lm_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
# Classifier
# Class balance (if learn_class_balance=False, fix to class_balance)
"learn_class_balance": False,
# Class balance initialization / prior
"class_balance_init": None, # (array) If None, assume uniform
# Model params initialization / priors
"mu_init": 0.5,
# Centered L2 regularization
# LF precision initializations / priors (float or np.array)
"prec_init": 0.7,
# Centered L2 regularization strength (int, float, or np.array)
"l2": 0.0,
# Optimizer
"optimizer_config": {
Expand Down
Loading

0 comments on commit d6bcc1a

Please sign in to comment.