diff --git a/pytorch_optimizer/optimizer/utils.py b/pytorch_optimizer/optimizer/utils.py index 7c943a74d..50cdb09ef 100644 --- a/pytorch_optimizer/optimizer/utils.py +++ b/pytorch_optimizer/optimizer/utils.py @@ -282,7 +282,7 @@ def reduce_max_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor: def reg_noise( network1: nn.Module, network2: nn.Module, num_data: int, lr: float, eta: float = 8e-3, temperature: float = 1e-4 -) -> torch.Tensor | float: +) -> Union[torch.Tensor, float]: r"""Entropy-MCMC: Sampling from flat basins with ease. usage: https://github.com/lblaoke/EMCMC/blob/master/exp/cifar10_emcmc.py