diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 3e9f25852..236b55d25 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -245,7 +245,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER: if optimizer.startswith('bnb'): if HAS_BNB and torch.cuda.is_available(): - return load_bnb_optimizer(optimizer) + return load_bnb_optimizer(optimizer) # pragma: no cover raise ImportError(f'[-] bitsandbytes and CUDA required for bnb optimizers : {optimizer}') if optimizer not in OPTIMIZERS: raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')