Skip to content

Commit

Permalink
Ensure each model has the same seed
Browse files Browse the repository at this point in the history
  • Loading branch information
cathalobrien committed Jan 22, 2025
1 parent 5dd8a55 commit 861161d
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/anemoi/inference/runners/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ def __init__(self, context):
model_comm_group = self.__init_parallel(self.device, self.global_rank, self.world_size)
self.model_comm_group = model_comm_group

# Ensure each parallel model instance uses the same seed
if self.global_rank == 0:
seed = torch.initial_seed()
torch.distributed.broadcast_object_list([seed], src=0, group=model_comm_group)
else:
msg_buffer = np.array([1], dtype=np.uint64)
torch.distributed.broadcast_object_list(msg_buffer, src=0, group=model_comm_group)
seed = msg_buffer[0]
torch.manual_seed(seed)

def predict_step(self, model, input_tensor_torch, fcstep, **kwargs):
if self.model_comm_group is None:
return model.predict_step(input_tensor_torch)
Expand Down

0 comments on commit 861161d

Please sign in to comment.