From 64486e4e25d558fe3fcf3281c8d3b7d3fd5f8c53 Mon Sep 17 00:00:00 2001 From: kozistr Date: Tue, 3 Dec 2024 21:51:20 +0900 Subject: [PATCH] fix: get_parameters --- pytorch_optimizer/optimizer/muon.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_optimizer/optimizer/muon.py b/pytorch_optimizer/optimizer/muon.py index 7d26a6f6..6725fde2 100644 --- a/pytorch_optimizer/optimizer/muon.py +++ b/pytorch_optimizer/optimizer/muon.py @@ -96,8 +96,6 @@ def __init__( params = self.get_parameters(params) adamw_params = self.get_parameters(adamw_params) if adamw_params is not None else [] params.extend(adamw_params) - # print(params) - # print(adamw_params) self.world_size: int = os.environ.get('WORLD_SIZE', 1) self.rank: int = os.environ.get('RANK', 0) @@ -128,7 +126,7 @@ def get_parameters(params: PARAMETERS) -> PARAMETERS: new_params = [] for group in params: if isinstance(group, dict) and 'params' in group: - new_params.extend(group['params']) + new_params.extend(list(group['params'])) else: new_params.append(group)