Skip to content

Commit

Permalink
6894 update rank filter (#6895)
Browse files Browse the repository at this point in the history
Fixes #6894

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Aug 21, 2023
1 parent 59bcad4 commit 2daabf9
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
12 changes: 7 additions & 5 deletions monai/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,13 @@ def __init__(self, rank: int | None = None, filter_fn: Callable = lambda rank: r
if dist.is_available() and dist.is_initialized():
self.rank: int = rank if rank is not None else dist.get_rank()
else:
warnings.warn(
"The torch.distributed is either unavailable and uninitiated when RankFilter is instantiated. "
"If torch.distributed is used, please ensure that the RankFilter() is called "
"after torch.distributed.init_process_group() in the script."
)
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
warnings.warn(
"The torch.distributed is either unavailable and uninitiated when RankFilter is instantiated.\n"
"If torch.distributed is used, please ensure that the RankFilter() is called\n"
"after torch.distributed.init_process_group() in the script.\n"
)
self.rank = 0

def filter(self, *_args):
return self.filter_fn(self.rank)
26 changes: 25 additions & 1 deletion tests/test_rankfilter_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,35 @@ def test_rankfilter(self):
with open(log_filename) as file:
lines = [line.rstrip() for line in file]
log_message = " ".join(lines)
assert log_message.count("test_warnings") == 1
self.assertEqual(log_message.count("test_warnings"), 1)

def tearDown(self) -> None:
self.log_dir.cleanup()


class SingleRankFilterTest(unittest.TestCase):
def tearDown(self) -> None:
self.log_dir.cleanup()

def setUp(self):
self.log_dir = tempfile.TemporaryDirectory()

def test_rankfilter_single_proc(self):
logger = logging.getLogger(__name__)
log_filename = os.path.join(self.log_dir.name, "records_sp.log")
h1 = logging.FileHandler(filename=log_filename)
h1.setLevel(logging.WARNING)
logger.addHandler(h1)
logger.addFilter(RankFilter())
logger.warning("test_warnings")

with open(log_filename) as file:
lines = [line.rstrip() for line in file]
logger.removeHandler(h1)
h1.close()
log_message = " ".join(lines)
self.assertEqual(log_message.count("test_warnings"), 1)


if __name__ == "__main__":
unittest.main()

0 comments on commit 2daabf9

Please sign in to comment.