Skip to content

Commit

Permalink
feature: implement reg_noise
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed May 5, 2024
1 parent 416d91b commit b585827
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions pytorch_optimizer/optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,19 @@ def reduce_max_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor:
if d != dim:
x = x.max(dim=d, keepdim=True).values
return x


def reg_noise(
network1: nn.Module, network2: nn.Module, num_data: int, lr: float, eta: float = 8e-3, temperature: float = 1e-4
) -> torch.Tensor | float:
reg_coef: float = 0.5 / (eta * num_data)
noise_coef: float = math.sqrt(2.0 / lr / num_data * temperature)

loss = 0
for param1, param2 in zip(network1.parameters(), network2.parameters(), strict=True):
reg = torch.sub(param1, param2).pow_(2) * reg_coef
noise1 = param1 * torch.randn_like(param1) * noise_coef
noise2 = param2 * torch.randn_like(param2) * noise_coef
loss += torch.sum(reg - noise1 - noise2)

return loss

0 comments on commit b585827

Please sign in to comment.