Skip to content

Commit

Permalink
Invoke assertRaises* in non-context manager fashsion
Browse files Browse the repository at this point in the history
Summary: It was discovered that internal and external Ruff formater formats differently on codes with multiple context managers and this will convert the `assertRaises*` usages into non-context manager fashion. [This was caused some OSS CI complains before](#65), and this diff should reduce the chance of discrepency between diff Ruff formater setting on multiple context managers.

Differential Revision: D67681063
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Dec 27, 2024
1 parent 3efeacf commit b2fa69a
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 211 deletions.
151 changes: 69 additions & 82 deletions distributed_shampoo/tests/distributed_shampoo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,39 +40,33 @@ def setUp(self) -> None:
)

def test_invalid_preconditioner_config(self) -> None:
with (
mock.patch.object(
distributed_shampoo,
"type",
side_effect=lambda object: {
ShampooPreconditionerConfig: PreconditionerConfig
}.get(type(object), type(object)),
),
with mock.patch.object(
distributed_shampoo,
"type",
side_effect=lambda object: {
ShampooPreconditionerConfig: PreconditionerConfig
}.get(type(object), type(object)),
):
self.assertRaisesRegex(
NotImplementedError,
re.escape("group[PRECONDITIONER_CONFIG]=ShampooPreconditionerConfig"),
),
):
DistributedShampoo(
DistributedShampoo,
self._model.parameters(),
preconditioner_config=DefaultShampooConfig,
)

def test_invalid_grafting_config(self) -> None:
with (
mock.patch.object(
distributed_shampoo,
"type",
side_effect=lambda object: {SGDGraftingConfig: GraftingConfig}.get(
type(object), type(object)
),
with mock.patch.object(
distributed_shampoo,
"type",
side_effect=lambda object: {SGDGraftingConfig: GraftingConfig}.get(
type(object), type(object)
),
):
self.assertRaisesRegex(
NotImplementedError,
re.escape("group[GRAFTING_CONFIG]=SGDGraftingConfig"),
),
):
DistributedShampoo(
DistributedShampoo,
self._model.parameters(),
grafting_config=SGDGraftingConfig(), # type: ignore[abstract]
)
Expand Down Expand Up @@ -143,29 +137,26 @@ def test_invalid_with_incorrect_hyperparameter_setting(self) -> None:
incorrect_hyperparameter_setting,
expected_error_msg,
) in incorrect_hyperparameter_setting_and_expected_error_msg:
with (
self.subTest(
incorrect_hyperparameter_setting=incorrect_hyperparameter_setting,
expected_error_msg=expected_error_msg,
),
self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)),
with self.subTest(
incorrect_hyperparameter_setting=incorrect_hyperparameter_setting,
expected_error_msg=expected_error_msg,
):
DistributedShampoo(
self.assertRaisesRegex(
ValueError,
re.escape(expected_error_msg),
DistributedShampoo,
self._model.parameters(),
**incorrect_hyperparameter_setting,
)

def test_invalid_cuda_pytorch_compile_setting(self) -> None:
with (
mock.patch.object(torch.cuda, "is_available", return_value=False),
with mock.patch.object(torch.cuda, "is_available", return_value=False):
self.assertRaisesRegex(
ValueError,
re.escape(
"Backend does NOT support Pytorch 2.0 compile. Switch to shampoo_pt2_compile_config=None."
),
),
):
DistributedShampoo(
DistributedShampoo,
self._model.parameters(),
shampoo_pt2_compile_config=ShampooPT2CompileConfig(),
)
Expand All @@ -187,21 +178,18 @@ def test_nesterov_and_zero_momentum(self) -> None:
)

def test_invalid_distributed_config(self) -> None:
with (
with mock.patch.object(
distributed_shampoo,
"type",
side_effect=lambda object: DistributedConfig,
):
self.assertRaisesRegex(
NotImplementedError,
re.escape(
"distributed_config=DDPShampooConfig(communication_dtype=<CommunicationDType.DEFAULT: 0>, "
"num_trainers_per_group=-1, communicate_params=False) not supported!"
),
),
mock.patch.object(
distributed_shampoo,
"type",
side_effect=lambda object: DistributedConfig,
),
):
DistributedShampoo(
DistributedShampoo,
params=self._model.parameters(),
distributed_config=DDPShampooConfig(),
)
Expand Down Expand Up @@ -450,22 +438,23 @@ def setUp(self) -> None:
}

def test_state_dict(self) -> None:
with self.assertRaisesRegex(
self.assertRaisesRegex(
NotImplementedError,
re.escape(
"Distributed Shampoo does not support the standard state_dict() method for checkpointing!"
),
):
self._optimizer.state_dict()
self._optimizer.state_dict,
)

def test_load_state_dict(self) -> None:
with self.assertRaisesRegex(
self.assertRaisesRegex(
NotImplementedError,
re.escape(
"Distributed Shampoo does not support the standard load_state_dict() method for checkpointing!"
),
):
self._optimizer.load_state_dict(state_dict={})
self._optimizer.load_state_dict,
state_dict={},
)

def test_distributed_state_dict(self) -> None:
state_dict_with_param_groups = self._optimizer.distributed_state_dict(
Expand Down Expand Up @@ -523,41 +512,39 @@ def test_load_distributed_state_dict_with_mismatch_param_groups(self) -> None:
# but param_groups only needs one (i.e., "0.weight").
self._distributed_state_dict["param_groups"]["1.weight"] = {}

with self.assertRaisesRegex(
ValueError, re.escape("Different param_groups count: 1 vs 2")
):
self._optimizer.load_distributed_state_dict(
state_dict=self._distributed_state_dict,
key_to_param=self._model.named_parameters(),
save_param_groups=True,
)
self.assertRaisesRegex(
ValueError,
re.escape("Different param_groups count: 1 vs 2"),
self._optimizer.load_distributed_state_dict,
state_dict=self._distributed_state_dict,
key_to_param=self._model.named_parameters(),
save_param_groups=True,
)

# Remove "0.weight" so param_groups_to_load has "1.weight" only but param_groups needs "0.weight".
del self._distributed_state_dict["param_groups"]["0.weight"]

with self.assertRaisesRegex(
self.assertRaisesRegex(
ValueError,
re.escape("Param group 0.weight not found in param_groups_to_load!"),
):
self._optimizer.load_distributed_state_dict(
state_dict=self._distributed_state_dict,
key_to_param=self._model.named_parameters(),
save_param_groups=True,
)
self._optimizer.load_distributed_state_dict,
state_dict=self._distributed_state_dict,
key_to_param=self._model.named_parameters(),
save_param_groups=True,
)

def test_load_distributed_state_dict_with_missing_param_key(self) -> None:
with self.assertRaisesRegex(
self.assertRaisesRegex(
KeyError,
re.escape("Parameter key 0.weight not found in key_to_param mapping!"),
):
self._optimizer.load_distributed_state_dict(
state_dict=self._distributed_state_dict,
# Instead of providing self._model.named_parameters(), we provide an empty list
# to trigger the missing key check error.
key_to_param=iter([]),
save_param_groups=False,
enable_missing_key_check=True,
)
self._optimizer.load_distributed_state_dict,
state_dict=self._distributed_state_dict,
# Instead of providing self._model.named_parameters(), we provide an empty list
# to trigger the missing key check error.
key_to_param=iter([]),
save_param_groups=False,
enable_missing_key_check=True,
)

with self.assertLogs(
level="WARNING",
Expand All @@ -584,15 +571,15 @@ def test_load_distributed_state_dict_with_missing_param(self) -> None:
key_to_param_copy = chain(
self._model.named_parameters(), iter([("1.weight", torch.tensor(1))])
)
with self.assertRaisesRegex(
KeyError, re.escape("Parameter 1 not found in state!")
):
self._optimizer.load_distributed_state_dict(
state_dict=state_dict_to_load_copy,
key_to_param=key_to_param_copy,
save_param_groups=False,
enable_missing_key_check=True,
)
self.assertRaisesRegex(
KeyError,
re.escape("Parameter 1 not found in state!"),
self._optimizer.load_distributed_state_dict,
state_dict=state_dict_to_load_copy,
key_to_param=key_to_param_copy,
save_param_groups=False,
enable_missing_key_check=True,
)

# Re-populate key_to_param_copy because it is an iterator that was consumed by the previous call.
key_to_param_copy = chain(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,19 +291,18 @@ def test_number_of_trainers_per_group_out_of_range(self) -> None:
num_trainers_per_group=3,
)

with self.assertRaisesRegex(
self.assertRaisesRegex(
ValueError,
re.escape(
"Invalid number of trainers per group: 3. Must be between [1, 2] or set to -1."
),
):
ShampooHSDPDistributorTest._train_model(
optim_factory=ShampooHSDPDistributorTest._shampoo_optim_factory(
distributed_config=hsdp_config,
),
model_factory=ShampooHSDPDistributorTest._model_factory(hsdp_config),
device=torch.device("cuda"),
)
ShampooHSDPDistributorTest._train_model,
optim_factory=ShampooHSDPDistributorTest._shampoo_optim_factory(
distributed_config=hsdp_config,
),
model_factory=ShampooHSDPDistributorTest._model_factory(hsdp_config),
device=torch.device("cuda"),
)

@skip_if_lt_x_gpu(4)
def test_dist_is_initialized(self) -> None:
Expand All @@ -313,13 +312,11 @@ def test_dist_is_initialized(self) -> None:
device_mesh=mesh_2d,
)

with mock.patch.object(
torch.distributed, "is_initialized", return_value=False
), self.assertRaisesRegex(
RuntimeError,
re.escape("HSDPDistributor needs torch.distributed to be initialized!"),
):
ShampooHSDPDistributorTest._train_model(
with mock.patch.object(torch.distributed, "is_initialized", return_value=False):
self.assertRaisesRegex(
RuntimeError,
re.escape("HSDPDistributor needs torch.distributed to be initialized!"),
ShampooHSDPDistributorTest._train_model,
optim_factory=ShampooHSDPDistributorTest._shampoo_optim_factory(
distributed_config=hsdp_config,
),
Expand All @@ -341,13 +338,13 @@ def test_incompatible_replicated_group_size_and_num_trainers_per_group(
# Hijack the DeviceMesh.size() method to return 4 instead of 2 to bypass the check of num_trainers_per_group.
with mock.patch.object(
torch.distributed.device_mesh.DeviceMesh, "size", return_value=4
), self.assertRaisesRegex(
ValueError,
re.escape(
"distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!"
),
):
ShampooHSDPDistributorTest._train_model(
self.assertRaisesRegex(
ValueError,
re.escape(
"distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!"
),
ShampooHSDPDistributorTest._train_model,
optim_factory=ShampooHSDPDistributorTest._shampoo_optim_factory(
distributed_config=hsdp_config,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,21 +330,20 @@ def test_number_of_trainers_per_group_out_of_range(self) -> None:
num_trainers_per_group=3,
)

with self.assertRaisesRegex(
self.assertRaisesRegex(
ValueError,
re.escape(
"Invalid number of trainers per group: 3. Must be between [1, 2] or set to -1."
),
):
ShampooHybridShardDistributorTest._train_model(
optim_factory=ShampooHybridShardDistributorTest._shampoo_optim_factory(
distributed_config=hybrid_shard_config,
),
model_factory=ShampooHybridShardDistributorTest._model_factory(
hybrid_shard_config
),
device=torch.device("cuda"),
)
ShampooHybridShardDistributorTest._train_model,
optim_factory=ShampooHybridShardDistributorTest._shampoo_optim_factory(
distributed_config=hybrid_shard_config,
),
model_factory=ShampooHybridShardDistributorTest._model_factory(
hybrid_shard_config
),
device=torch.device("cuda"),
)

@with_comms
@skip_if_lt_x_gpu(4)
Expand All @@ -356,15 +355,13 @@ def test_dist_is_initialized(self) -> None:
device_mesh=mesh_2d,
)

with mock.patch.object(
torch.distributed, "is_initialized", return_value=False
), self.assertRaisesRegex(
RuntimeError,
re.escape(
"HybridShardDistributor needs torch.distributed to be initialized!"
),
):
ShampooHybridShardDistributorTest._train_model(
with mock.patch.object(torch.distributed, "is_initialized", return_value=False):
self.assertRaisesRegex(
RuntimeError,
re.escape(
"HybridShardDistributor needs torch.distributed to be initialized!"
),
ShampooHybridShardDistributorTest._train_model,
optim_factory=ShampooHybridShardDistributorTest._shampoo_optim_factory(
distributed_config=hybrid_shard_config,
),
Expand All @@ -390,13 +387,13 @@ def test_incompatible_replicated_group_size_and_num_trainers_per_group(
# Hijack the DeviceMesh.size() method to return 4 instead of 2 to bypass the check of num_trainers_per_group.
with mock.patch.object(
torch.distributed.device_mesh.DeviceMesh, "size", return_value=4
), self.assertRaisesRegex(
ValueError,
re.escape(
"distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!"
),
):
ShampooHybridShardDistributorTest._train_model(
self.assertRaisesRegex(
ValueError,
re.escape(
"distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!"
),
ShampooHybridShardDistributorTest._train_model,
optim_factory=ShampooHybridShardDistributorTest._shampoo_optim_factory(
distributed_config=hybrid_shard_config,
),
Expand Down
Loading

0 comments on commit b2fa69a

Please sign in to comment.