From 6cb624fa184903868a620c665c93e69b5aa1a12f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20P=C3=B6schel?= Date: Thu, 8 Aug 2024 18:10:29 +0200 Subject: [PATCH] Fix forgotten ranks issue --- src/binding/python/openpmd_api/pipe/__main__.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/binding/python/openpmd_api/pipe/__main__.py b/src/binding/python/openpmd_api/pipe/__main__.py index f942b60948..b186d8be76 100644 --- a/src/binding/python/openpmd_api/pipe/__main__.py +++ b/src/binding/python/openpmd_api/pipe/__main__.py @@ -154,12 +154,14 @@ def __init__( self, granularity_in, granularity_out, - inner_distribution=io.ByHostname(io.RoundRobin()), + inner_distribution, + local_rank, ): super().__init__() self.inner_distribution = inner_distribution self.granularity_in = granularity_in self.granularity_out = granularity_out + self.local_rank = local_rank def assign(self, assignment, in_ranks, out_ranks): if "in_ranks_inner" in dir(self): @@ -209,6 +211,14 @@ def inner_rank_assignment(outer_assignment, hostname_to_hostgroup): out_ranks, out_hostname_to_hostgroup ) + # we only care about the local host + local_host = self.out_ranks_inner[self.local_rank] + self.out_ranks_inner = { + rank: host + for rank, host in self.out_ranks_inner.items() + if host == local_host + } + return self.inner_distribution.assign( assignment, self.in_ranks_inner, self.out_ranks_inner ) @@ -293,7 +303,7 @@ def distribution_strategy(dataset_extent, return IncreaseGranularity( granularity, 1, io.FromPartialStrategy(io.ByHostname(io.RoundRobin()), - io.DiscardingStrategy())) + io.DiscardingStrategy()), mpi_rank) elif strategy_identifier == 'all': return io.FromPartialStrategy(IncreaseGranularity(5), LoadAll(mpi_rank)) elif strategy_identifier == 'roundrobin': @@ -327,7 +337,7 @@ def __init__(self, infile, outfile, inconfig, outconfig, comm): hostinfo = io.HostInfo.MPI_PROCESSOR_NAME self.outranks = hostinfo.get_collective(self.comm) my_hostname = self.outranks[self.comm.rank] - self.outranks = {i: rank for i, rank in self.outranks.items() if rank == my_hostname} + self.outranks = {i: rank for i, rank in self.outranks.items()} else: self.outranks = {i: str(i) for i in range(self.comm.size)}