Skip to content

Commit

Permalink
Fix forgotten ranks issue
Browse files Browse the repository at this point in the history
  • Loading branch information
franzpoeschel committed Dec 11, 2024
1 parent af94e5c commit 6cb624f
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions src/binding/python/openpmd_api/pipe/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,14 @@ def __init__(
self,
granularity_in,
granularity_out,
inner_distribution=io.ByHostname(io.RoundRobin()),
inner_distribution,
local_rank,
):
super().__init__()
self.inner_distribution = inner_distribution
self.granularity_in = granularity_in
self.granularity_out = granularity_out
self.local_rank = local_rank

def assign(self, assignment, in_ranks, out_ranks):
if "in_ranks_inner" in dir(self):
Expand Down Expand Up @@ -209,6 +211,14 @@ def inner_rank_assignment(outer_assignment, hostname_to_hostgroup):
out_ranks, out_hostname_to_hostgroup
)

# we only care about the local host
local_host = self.out_ranks_inner[self.local_rank]
self.out_ranks_inner = {
rank: host
for rank, host in self.out_ranks_inner.items()
if host == local_host
}

return self.inner_distribution.assign(
assignment, self.in_ranks_inner, self.out_ranks_inner
)
Expand Down Expand Up @@ -293,7 +303,7 @@ def distribution_strategy(dataset_extent,
return IncreaseGranularity(
granularity, 1,
io.FromPartialStrategy(io.ByHostname(io.RoundRobin()),
io.DiscardingStrategy()))
io.DiscardingStrategy()), mpi_rank)
elif strategy_identifier == 'all':
return io.FromPartialStrategy(IncreaseGranularity(5), LoadAll(mpi_rank))
elif strategy_identifier == 'roundrobin':
Expand Down Expand Up @@ -327,7 +337,7 @@ def __init__(self, infile, outfile, inconfig, outconfig, comm):
hostinfo = io.HostInfo.MPI_PROCESSOR_NAME
self.outranks = hostinfo.get_collective(self.comm)
my_hostname = self.outranks[self.comm.rank]
self.outranks = {i: rank for i, rank in self.outranks.items() if rank == my_hostname}
self.outranks = {i: rank for i, rank in self.outranks.items()}
else:
self.outranks = {i: str(i) for i in range(self.comm.size)}

Expand Down

0 comments on commit 6cb624f

Please sign in to comment.