Skip to content

Commit

Permalink
Fix model weight load bug with multigpu (#40)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

Fix the bug where model weights are not loaded when multigpu is used. 

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] In-line docstrings updated.

---------

Signed-off-by: heyufan1995 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
heyufan1995 and pre-commit-ci[bot] authored Sep 11, 2024
1 parent f6308df commit aeff47e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 17 deletions.
14 changes: 5 additions & 9 deletions vista3d/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from monai.bundle.scripts import _pop_args, _update_args
from monai.data import DataLoader, DistributedSampler, DistributedWeightedRandomSampler
from monai.metrics import compute_dice
from monai.networks.utils import copy_model_state
from monai.utils import optional_import, set_determinism
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.sampler import RandomSampler, WeightedRandomSampler
Expand Down Expand Up @@ -216,10 +215,6 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
optimizer = optimizer_part.instantiate(params=model.parameters())
lr_scheduler_part = parser.get_parsed_content("lr_scheduler", instantiate=False)
lr_scheduler = lr_scheduler_part.instantiate(optimizer=optimizer)
if world_size > 1:
model = DistributedDataParallel(
model, device_ids=[device], find_unused_parameters=True
)
if finetune["activate"] and os.path.isfile(finetune["pretrained_ckpt_name"]):
logger.debug(
"Fine-tuning pre-trained checkpoint {:s}".format(
Expand All @@ -229,13 +224,14 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
pretrained_ckpt = torch.load(
finetune["pretrained_ckpt_name"], map_location=device
)
copy_model_state(
model, pretrained_ckpt, exclude_vars=finetune.get("exclude_vars")
)
model.load_state_dict(pretrained_ckpt)
del pretrained_ckpt
else:
logger.debug("Training from scratch")

if world_size > 1:
model = DistributedDataParallel(
model, device_ids=[device], find_unused_parameters=True
)
# training hyperparameters - sample
num_images_per_batch = parser.get_parsed_content("num_images_per_batch")
num_patches_per_iter = parser.get_parsed_content("num_patches_per_iter")
Expand Down
13 changes: 5 additions & 8 deletions vista3d/scripts/train_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from monai.bundle.scripts import _pop_args, _update_args
from monai.data import DataLoader, DistributedSampler
from monai.metrics import compute_dice
from monai.networks.utils import copy_model_state
from monai.utils import set_determinism
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
Expand Down Expand Up @@ -149,10 +148,6 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
optimizer = optimizer_part.instantiate(params=model.parameters())
lr_scheduler_part = parser.get_parsed_content("lr_scheduler", instantiate=False)
lr_scheduler = lr_scheduler_part.instantiate(optimizer=optimizer)
if world_size > 1:
model = DistributedDataParallel(
model, device_ids=[device], find_unused_parameters=True
)
if finetune["activate"] and os.path.isfile(finetune["pretrained_ckpt_name"]):
logger.debug(
"Fine-tuning pre-trained checkpoint {:s}".format(
Expand All @@ -162,13 +157,15 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
pretrained_ckpt = torch.load(
finetune["pretrained_ckpt_name"], map_location=device
)
copy_model_state(
model, pretrained_ckpt, exclude_vars=finetune.get("exclude_vars")
)
model.load_state_dict(pretrained_ckpt)
del pretrained_ckpt
else:
logger.debug("Training from scratch")

if world_size > 1:
model = DistributedDataParallel(
model, device_ids=[device], find_unused_parameters=True
)
# training hyperparameters - sample
num_images_per_batch = parser.get_parsed_content("num_images_per_batch")
num_patches_per_iter = parser.get_parsed_content("num_patches_per_iter")
Expand Down

0 comments on commit aeff47e

Please sign in to comment.