From 4a4c25129f533e275764ee41a0303d5c5dec5b63 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 14 Oct 2024 22:05:32 -0700 Subject: [PATCH] Removed CPU randn() from schedulers (#8145) Fixes performance issued due to extra CPU/GPU sync: https://nvbugswb.nvidia.com/NvBugs5/SWBug.aspx?bugid=4904446&cmtNo= --------- Signed-off-by: Boris Fomitchev Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/schedulers/ddim.py | 2 +- monai/networks/schedulers/ddpm.py | 8 ++++++-- requirements-dev.txt | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index 2a0121d063..50a680336d 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -220,7 +220,7 @@ def step( if eta > 0: # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu") - noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator, device=device) variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise pred_prev_sample = pred_prev_sample + variance diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py index 93ad833031..d64e11d379 100644 --- a/monai/networks/schedulers/ddpm.py +++ b/monai/networks/schedulers/ddpm.py @@ -241,8 +241,12 @@ def step( variance = 0 if timestep > 0: noise = torch.randn( - model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator - ).to(model_output.device) + model_output.size(), + dtype=model_output.dtype, + layout=model_output.layout, + generator=generator, + device=model_output.device, + ) variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise pred_prev_sample = pred_prev_sample + variance diff --git a/requirements-dev.txt b/requirements-dev.txt index 6d0ccd378a..72654d3534 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -22,7 +22,7 @@ isort>=5.1 ruff pytype>=2020.6.1; platform_system != "Windows" types-setuptools -mypy>=1.5.0 +mypy>=1.5.0, <1.12.0 ninja torchvision psutil