diff --git a/src/anemoi/inference/runners/parallel.py b/src/anemoi/inference/runners/parallel.py index f77a19d..c259ca9 100644 --- a/src/anemoi/inference/runners/parallel.py +++ b/src/anemoi/inference/runners/parallel.py @@ -49,6 +49,16 @@ def __init__(self, context): model_comm_group = self.__init_parallel(self.device, self.global_rank, self.world_size) self.model_comm_group = model_comm_group + # Ensure each parallel model instance uses the same seed + if self.global_rank == 0: + seed = torch.initial_seed() + torch.distributed.broadcast_object_list([seed], src=0, group=model_comm_group) + else: + msg_buffer = np.array([1], dtype=np.uint64) + torch.distributed.broadcast_object_list(msg_buffer, src=0, group=model_comm_group) + seed = msg_buffer[0] + torch.manual_seed(seed) + def predict_step(self, model, input_tensor_torch, fcstep, **kwargs): if self.model_comm_group is None: return model.predict_step(input_tensor_torch)