diff --git a/pytorch_optimizer/optimizer/utils.py b/pytorch_optimizer/optimizer/utils.py index 2539f99c9..ce172a0d4 100644 --- a/pytorch_optimizer/optimizer/utils.py +++ b/pytorch_optimizer/optimizer/utils.py @@ -278,3 +278,19 @@ def reduce_max_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor: if d != dim: x = x.max(dim=d, keepdim=True).values return x + + +def reg_noise( + network1: nn.Module, network2: nn.Module, num_data: int, lr: float, eta: float = 8e-3, temperature: float = 1e-4 +) -> torch.Tensor | float: + reg_coef: float = 0.5 / (eta * num_data) + noise_coef: float = math.sqrt(2.0 / lr / num_data * temperature) + + loss = 0 + for param1, param2 in zip(network1.parameters(), network2.parameters(), strict=True): + reg = torch.sub(param1, param2).pow_(2) * reg_coef + noise1 = param1 * torch.randn_like(param1) * noise_coef + noise2 = param2 * torch.randn_like(param2) * noise_coef + loss += torch.sum(reg - noise1 - noise2) + + return loss