From e88b672742f79613fee9230eb0c9926c8587d922 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 12 Dec 2024 00:05:57 +0000 Subject: [PATCH] add broadcast to collective.py --- msccl/language/collectives.py | 65 ++++++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/msccl/language/collectives.py b/msccl/language/collectives.py index 6d9a108..28fdaae 100755 --- a/msccl/language/collectives.py +++ b/msccl/language/collectives.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, field from msccl.language.ir import Buffer from msccl.language import * -#test class Collective: @@ -71,6 +70,70 @@ def check(self, prog): correct = False return correct +class Broadcast(Collective): + def __init__(self, num_ranks, root, chunk_factor, inplace, create_all_chunks=False): + Collective.__init__(self, num_ranks, root, chunk_factor, inplace) + self.name = "broadcast" + # This flag is a temporary solution, which initialize all the chuncks only for inputbuffer + # In this future we need to remove this flag and always initialize all the chunks + self.create_all_chunks = create_all_chunks + + # Initializes input buffer for an broadcast + def init_buffers(self): + rank_buffers = [] + if self.inplace: + # Inplace broadcast only uses the output buffer + for r in range(self.num_ranks): + input_buffer = [None] * (self.chunk_factor) + #if not self.create_all_chunks: + # for ch in range(self.chunk_factor): + # output_buffer[ch] = Chunk(r, ch, -1, ch) + #else: + for ch in range(self.chunk_factor): + input_buffer[ch] = Chunk(root, ch, -1, ch) + buffers = { + Buffer.input: input_buffer, #this only needs to be set for the root + Buffer.output: input_buffer, + } + rank_buffers.append(buffers) + else: + for r in range(self.num_ranks): + input_buffer = [None] * self.chunk_factor + output_buffer = [None] * (self.chunk_factor) + if r==root: + for ch in range(self.chunk_factor): + input_buffer[ch] = Chunk(root, ch, -1, ch) + buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer} # add if statement + rank_buffers.append(buffers) + return rank_buffers + + # Expected output buffer for broadcast + def check(self, prog): + correct = True + buf = Buffer.output + for r in range(self.num_ranks): + output = prog.buffers[r][buf] + for i in range(self.num_ranks): + for ch in range(self.chunk_factor): + index = ch + chunk = output[index] + if chunk is None: + print(f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given None") + correct = False + elif chunk.origin_rank != i or chunk.origin_index != ch: + print( + f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given ({chunk.origin_rank}, {chunk.origin_index})" + ) + correct = False + return correct + + def get_buffer_index(self, rank, buffer, index): + # For inplace Broadcast, the input buffer points into the output buffer + return buffer, index + #if self.inplace and buffer == Buffer.input: + # return Buffer.output, index + #else: + # return buffer, index class AllGather(Collective): def __init__(self, num_ranks, chunk_factor, inplace, create_all_chunks=False):