diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 89f8129e63f..7fc5bf0173e 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1193,10 +1193,12 @@ def prepare(self, *args, device_placement=None): ) for obj in args: + # TODO: Look at enabling native TP training directly with a proper config if ( isinstance(obj, torch.nn.Module) and self.verify_device_map(obj) and self.distributed_type != DistributedType.NO + and os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true" ): raise ValueError( "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode." @@ -1328,7 +1330,12 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP self._models.append(model) - if self.verify_device_map(model) and self.distributed_type != DistributedType.NO: + # TODO: Look at enabling native TP training directly with a proper config + if ( + self.verify_device_map(model) + and self.distributed_type != DistributedType.NO + and os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true" + ): raise ValueError( "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode." " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`." @@ -1401,8 +1408,14 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e ): if any(p.requires_grad for p in model.parameters()): kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {} + # TODO: Look at enabling native TP training directly with a proper config + if os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true": + device_ids, output_device = [self.local_process_index], self.local_process_index + else: + device_ids, output_device = None, None + model = torch.nn.parallel.DistributedDataParallel( - model, device_ids=[self.local_process_index], output_device=self.local_process_index, **kwargs + model, device_ids=device_ids, output_device=output_device, **kwargs ) elif self.distributed_type == DistributedType.FSDP: from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP