Skip to content

Commit

Permalink
modify check for broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu authored and Ubuntu committed Dec 12, 2024
1 parent daf360d commit 414a45f
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions msccl/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,19 @@ def check(self, prog):
correct = True
buf = Buffer.output
for r in range(self.num_ranks):
output = prog.buffers[r][buf]
for i in range(self.num_ranks):
for ch in range(self.chunk_factor):
index = ch
chunk = output[index]
if chunk is None:
print(f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given None")
correct = False
elif chunk.origin_rank != i or chunk.origin_index != ch:
print(
f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given ({chunk.origin_rank}, {chunk.origin_index})"
)
correct = False
output = prog.buffers[0][buf]
for ch in range(self.chunk_factor):
index = ch
chunk = output[index]
if chunk is None:
print(f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given None")
correct = False
elif chunk.origin_rank != self.root or chunk.origin_index != ch:
print(f"Rank {r} chunk {index} is incorrect should be ({self.root}, {ch}) given ({chunk.origin_rank}, {chunk.origin_index})")
correct = False
return correct


def get_buffer_index(self, rank, buffer, index):
# For inplace Broadcast, the input buffer points into the output buffer
return buffer, index
Expand Down

0 comments on commit 414a45f

Please sign in to comment.