Skip to content

Commit

Permalink
Fix and refactor get_all_subclasses() util in shampoo_types_test.py (
Browse files Browse the repository at this point in the history
…#75)

Summary:
Pull Request resolved: #75

1. Fixes the bug that recursive function did not include that `cls` itself into the return result.
2. Adds a new argument `include_cls_self` to the function to control whether to include the class itself in the result. It also includes docstrings to explain how this function works.

Reviewed By: anana10c

Differential Revision: D67802249

fbshipit-source-id: cb44a9b51e620f404cec30a5e49313e4e8a11b9c
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Jan 3, 2025
1 parent fe7cd45 commit c51e4e6
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions distributed_shampoo/tests/shampoo_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,37 @@
SubclassesType = TypeVar("SubclassesType")


def get_all_subclasses(cls: SubclassesType) -> list[SubclassesType]:
def get_all_subclasses(
cls: SubclassesType, include_cls_self: bool = True
) -> list[SubclassesType]:
"""
Retrieves all subclasses of a given class, optionally including the class itself.
This function uses a helper function to recursively find all unique subclasses
of the specified class.
Args:
cls (SubclassesType): The class for which to find subclasses.
include_cls_self (bool): Whether to include the class itself in the result. (Default: True)
Returns:
list[SubclassesType]: A list of all unique subclasses of the given class.
"""

def get_all_unique_subclasses(cls: SubclassesType) -> set[SubclassesType]:
"""Gets all unique subclasses of a given class recursively."""
assert (
subclasses := getattr(cls, "__subclasses__", lambda: None)()
) is not None, f"{cls} does not have __subclasses__."
return reduce(or_, map(get_all_unique_subclasses, subclasses), set())
return reduce(or_, map(get_all_unique_subclasses, subclasses), {cls})

return list(get_all_unique_subclasses(cls))
return list(get_all_unique_subclasses(cls) - (set() if include_cls_self else {cls}))


class AdaGradGraftingConfigSubclassesTest(unittest.TestCase):
def test_illegal_epsilon(self) -> None:
epsilon = 0.0
for cls in [AdaGradGraftingConfig] + get_all_subclasses(AdaGradGraftingConfig):
for cls in get_all_subclasses(AdaGradGraftingConfig):
with self.subTest(cls=cls):
self.assertRaisesRegex(
ValueError,
Expand All @@ -53,7 +69,7 @@ def test_illegal_beta2(
self,
) -> None:
for cls, beta2 in itertools.product(
[RMSpropGraftingConfig] + get_all_subclasses(RMSpropGraftingConfig),
get_all_subclasses(RMSpropGraftingConfig),
(-1.0, 0.0, 1.3),
):
with self.subTest(cls=cls, beta2=beta2):
Expand All @@ -70,7 +86,8 @@ def test_illegal_beta2(
class PreconditionerConfigSubclassesTest(unittest.TestCase):
def test_illegal_num_tolerated_failed_amortized_computations(self) -> None:
num_tolerated_failed_amortized_computations = -1
for cls in get_all_subclasses(PreconditionerConfig):
# Not testing for the base class PreconditionerConfig because it is an abstract class.
for cls in get_all_subclasses(PreconditionerConfig, include_cls_self=False):
with self.subTest(cls=cls):
self.assertRaisesRegex(
ValueError,
Expand Down

0 comments on commit c51e4e6

Please sign in to comment.