diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 67e5195f6d..048fd912e4 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -87,26 +87,23 @@ def __init__( with h5py.File(systems) as file: systems = [os.path.join(systems, item) for item in file.keys()] - self.systems: list[DeepmdDataSetForLoader] = [] - if len(systems) >= 100: - log.info(f"Constructing DataLoaders from {len(systems)} systems") - def construct_dataset(system): return DeepmdDataSetForLoader( system=system, type_map=type_map, ) - with Pool( - os.cpu_count() - // ( - int(os.environ["LOCAL_WORLD_SIZE"]) - if dist.is_available() and dist.is_initialized() - else 1 - ) - ) as pool: - self.systems = pool.map(construct_dataset, systems) - + self.systems: list[DeepmdDataSetForLoader] = [] + global_rank = dist.get_rank() if dist.is_initialized() else 0 + if global_rank == 0: + log.info(f"Constructing DataLoaders from {len(systems)} systems") + with Pool(max(1, env.NUM_WORKERS)) as pool: + self.systems = pool.map(construct_dataset, systems) + else: + self.systems = [None] * len(systems) # type: ignore + if dist.is_initialized(): + dist.broadcast_object_list(self.systems) + assert self.systems[-1] is not None self.sampler_list: list[DistributedSampler] = [] self.index = [] self.total_batch = 0