Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SlidingWindowInferer runtime increase if sw_batch_size is too big #6628

Open
matt3o opened this issue Jun 19, 2023 · 7 comments
Open

SlidingWindowInferer runtime increase if sw_batch_size is too big #6628

matt3o opened this issue Jun 19, 2023 · 7 comments
Labels

Comments

@matt3o
Copy link
Contributor

matt3o commented Jun 19, 2023

Describe the bug
I am currently using the SlidingWindowInferer for some modified DeepEdit Code. I discovered that for small sw_roi_sizes like (32,32,32) I have to set a higher sw_batch_size to make it run faster. See the data for that below.
However when the sw_batch_size becomes too big, the performance takes a dramatic hit which does not make any sense to me. Initial inputs volume shape is (1,3,344,344,284) and the inferer is created with eval_inferer = SlidingWindowInferer(roi_size=args.sw_roi_size, sw_batch_size=args.sw_batch_size, mode="gaussian")
Results of my test runs:

138 seconds for (32,32,32) on sw_batch_size 1
13.38 seconds for (32,32,32) on sw_batch_size 200 (12 iterations)
11 seconds for (32,32,32) on sw_batch_size 500 (8 iterations)
11 seconds for (32,32,32) on sw_batch_size 1000 (3 iterations)
93 seconds for (32,32,32) on sw_batch_size 2000 (2 iterations)
191 seconds for (32,32,32) on sw_batch_size 2400 (1 iteration)

I tried to debug that but I am not sure why this crazy increase in terms of time is happening. Of course I can always calculate the best sw_batch_size beforehand (1/4 of the actual amount of slices I guess from above but I have to know the size of the maximum volume beforehand), but an actual solution would be nice. Or maybe it is an issue with my code I am not aware of, would be good to know anyways.

To Reproduce
Use the SlidingWindowInferer, set the sw_batch_size so that is it is higher that the actual amount of slices and then the performance will deteriorate heavily.

Environment

Tried it on Monai 1.1 and also on the nightly, no change.

================================
Printing MONAI config...
================================
MONAI version: 1.1.0
Numpy version: 1.24.3
Pytorch version: 2.0.0+cu117
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: a2ec3752f54bfc3b40e7952234fbeb5452ed63e3
MONAI __file__: /homes/mhadlich/.conda/envs/monai/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.12
Nibabel version: 5.1.0
scikit-image version: 0.20.0
Pillow version: 9.5.0
Tensorboard version: 2.13.0
gdown version: 4.7.1
TorchVision version: 0.15.1+cu117
tqdm version: 4.65.0
lmdb version: 1.4.1
psutil version: 5.9.5
pandas version: 2.0.1
einops version: 0.6.1
transformers version: 4.21.3
mlflow version: 2.3.1
pynrrd version: 1.0.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies


================================
Printing system config...
================================
System: Linux
Linux version: Ubuntu 22.04.2 LTS
Platform: Linux-5.15.0-73-generic-x86_64-with-glibc2.35
Processor: x86_64
Machine: x86_64
Python version: 3.10.10
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: [popenfile(path='/projects/mhadlich_segmentation/sliding-window-based-interactive-segmentation-of-volumetric-medical-images_main/tmp.txt', fd=1, position=1040, mode='w', flags=32769)]
Num physical CPUs: 48
Num logical CPUs: 48
Num usable CPUs: 1
CPU usage (%): [100.0, 100.0, 59.9, 1.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.4, 0.2, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 0.0, 0.4, 0.0, 0.0, 2.7, 0.0, 0.0, 0.0, 0.0]
CPU freq. (MHz): 1724
Load avg. in last 1, 5, 15 mins (%): [5.1, 5.0, 5.1]
Disk usage (%): 66.3
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 1007.8
Available memory (GB): 980.8
Used memory (GB): 20.0

================================
Printing GPU config...
================================
Num GPUs: 1
Has CUDA: True
CUDA version: 11.7
cuDNN enabled: True
cuDNN version: 8500
Current device: 0
Library compiled for CUDA architectures: ['sm_37', 'sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86']
GPU 0 Name: NVIDIA RTX A6000
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 84
GPU 0 Total memory (GB): 47.5
GPU 0 CUDA capability (maj.min): 8.6

@KumoLiu
Copy link
Contributor

KumoLiu commented Jun 20, 2023

Hi @matt3o, I can't reproduce the issue, could you please share more information such as which model did you use, infer on GPU or CPU?
Thanks!

@matt3o
Copy link
Contributor Author

matt3o commented Jun 20, 2023

Hey @KumoLiu, thanks for the quick response again!
This code is using the DynUNet, however not the default one from monai.networks.nets but a separate file. I just switched to the official one from monai.networks.nets.dynunet which did not change anything, same behaviour as reported above.
Apart from that all the calculations are done on the GPU, all the input is already there. To be sure I recently nailed the device and the sw_device to the GPU, no change there.

@matt3o
Copy link
Contributor Author

matt3o commented Jun 20, 2023

Btw I am not sure if 32**3 is actually allowed by the DynUnet (I have an padding to 64**3 in the code, so I don't think it makes any sense). This problem however exists independent of the sw_roi_size, e.g.

30 seconds for (256,256,256) on sw_batch_size 1
400 seconds for (256,256,256) on sw_batch_size 8

@KumoLiu
Copy link
Contributor

KumoLiu commented Jun 20, 2023

Hi @matt3o, I can not even run with sw_batch_size=1 using the DynUNet with 24GB. I used the same setting in the deepedit.
https://github.com/Project-MONAI/tutorials/blob/bbc4e180f5130859396a98e35523bc73fa694595/deepedit/ignite/train.py#L84
And I can try with UNet but didn't find the same issue.
Thanks!

@matt3o
Copy link
Contributor Author

matt3o commented Jun 20, 2023

@KumoLiu, then we will have to debug this as soon as I publish my code. I am using exactly the network config you just mentioned. I would guess your problem now is related to #6626, in theory SlidingWindowInferer on DynUNet can work just fine on 24 Gb and I got it to run on smaller crops even on 11 Gb.

@KumoLiu
Copy link
Contributor

KumoLiu commented Jun 21, 2023

Hi @matt3o, I investigate a little bit more using UNet.
Here is the result:

sw_batch_size=10, time=0.561
sw_batch_size=200, time=0.441
sw_batch_size=300, time=0.313
sw_batch_size=700, time=0.344
sw_batch_size=1000, time=0.342
sw_batch_size=2000, time=0.344
sw_batch_size=2400, time=0.339

And I found that L252-L284 will cause more time when batch size increase. Such as when sw_batch_size=1000, the time caused by L252-L284 will be ~5x than sw_batch_size=500 then I think it makes sense when sw_batch_size increase, the total inference time didn't decrease as much.

for ss in range(len(sw_device_buffer)):
b_shape = sw_device_buffer[ss].shape
seg_chns, seg_shape = b_shape[1], b_shape[2:]
z_scale = None
if not buffered and seg_shape != roi_size:
z_scale = [out_w_i / float(in_w_i) for out_w_i, in_w_i in zip(seg_shape, roi_size)]
w_t = F.interpolate(w_t, seg_shape, mode=_nearest_mode)
if len(output_image_list) <= ss:
output_shape = [batch_size, seg_chns]
output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size)
# allocate memory to store the full output and the count for overlapping parts
new_tensor: Callable = torch.empty if non_blocking else torch.zeros # type: ignore
output_image_list.append(new_tensor(output_shape, dtype=compute_dtype, device=device))
count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
w_t_ = w_t.to(device)
for __s in slices:
if z_scale is not None:
__s = tuple(slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale))
count_map_list[-1][(slice(None), slice(None), *__s)] += w_t_
if buffered:
o_slice = [slice(None)] * len(inputs.shape)
o_slice[buffer_dim + 2] = slice(c_start, c_end)
img_b = b_s // n_per_batch # image batch index
o_slice[0] = slice(img_b, img_b + 1)
if non_blocking:
output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking)
else:
output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device)
else:
sw_device_buffer[ss] *= w_t
sw_device_buffer[ss] = sw_device_buffer[ss].to(device)
_compute_coords(unravel_slice, z_scale, output_image_list[ss], sw_device_buffer[ss])
sw_device_buffer = []

But I didn't see the time increase issue. Could you please try this simple demo in your local and see if you get similar results with me?

device = "cuda"
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm="batch",
).to(device=device)

out = mt.Compose([
    mt.LoadImaged(keys="image", image_only=True, ensure_channel_first=True),
    mt.Resized(keys="image", spatial_size=(344, 344, 284)),
    mt.ToDeviced(keys="image", device=device)
])(data[0])

sw_roi_size = (32, 32, 32)
sw_batch_size = 1000

start = time.time()
eval_inferer = SlidingWindowInferer(roi_size=sw_roi_size, sw_batch_size=sw_batch_size, mode="gaussian", progress=True)
ret = eval_inferer(out["image"].unsqueeze(0), model)

print(f'{sw_batch_size=}, time={(time.time()-start):.3f}')

Thanks!

@matt3o
Copy link
Contributor Author

matt3o commented Jun 21, 2023

@KumoLiu I get the similar results when using the UNet, but not when I am using the DynUNet. I will append the modified code and the runtime results. I also added amp as in the real code and ran the examples on the 50Gb GPU server.

UNet: sw_batch_size=1, time=11.335
UNet: sw_batch_size=10, time=3.371
UNet: sw_batch_size=100, time=2.656
UNet: sw_batch_size=1000, time=2.586
UNet: sw_batch_size=10000, time=2.560
UNet: sw_batch_size=20000, time=2.535
DynUNet: sw_batch_size=1, time=12.767
DynUNet: sw_batch_size=10, time=3.573
DynUNet: sw_batch_size=100, time=2.743
DynUNet: sw_batch_size=1000, time=3.185
DynUNet: sw_batch_size=10000, time=22.952
DynUNet: sw_batch_size=20000, time=23.085

import time
import os
import glob

import argparse

import torch
from monai.networks.nets.dynunet import DynUNet
from monai.networks.nets import UNet
import monai.transforms as mt

from monai.data.dataloader import DataLoader
from monai.data.dataset import Dataset

from monai.inferers import SimpleInferer, SlidingWindowInferer

location = "/projects/mhadlich_segmentation/AutoPET/AutoPET"
all_images = sorted(glob.glob(os.path.join(location, "imagesTr", "*.nii.gz")))
all_labels = sorted(glob.glob(os.path.join(location, "labelsTr", "*.nii.gz")))
datalist = [{"image": image_name, "label": label_name} for image_name, label_name in
            zip(all_images, all_labels)] #if image_name not in bad_images]

datalist = datalist[0:1]
device = "cuda"

transform = mt.Compose([
    mt.LoadImaged(keys="image", image_only=True, ensure_channel_first=True),
    mt.Resized(keys="image", spatial_size=(344, 344, 284)),
    mt.ToDeviced(keys="image", device=device)
])

train_ds =  Dataset(
        datalist, transform
)

train_loader = DataLoader(
        train_ds, shuffle=True#, num_workers=args.num_workers, batch_size=1, multiprocessing_context='spawn', persistent_workers=True,
    )

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm="batch",
).to(device=device)

model2 = DynUNet(
            spatial_dims=3,
            # 1 dim for the image, the other ones for the signal per label with is the size of image
            in_channels=1,
            out_channels=1,
            kernel_size=[3, 3, 3, 3, 3 ,3],
            strides=[1, 2, 2, 2, 2, [2, 2, 1]],
            upsample_kernel_size=[2, 2, 2, 2, [2, 2, 1]],
            norm_name="instance",
            deep_supervision=False,
            res_block=True,
            # conv1d=args.conv1d,
            # conv1s=args.conv1s,
).to(device=device)


sw_roi_size = (32, 32, 32)
sw_batch_size = 100000

chosen_model = "UNet"
if chosen_model == "UNet":
    model = model
elif chosen_model == "DynUNet":
    model = model2


for item in train_loader:
    for sw_batch_size in [1,10, 100,1000, 10000, 20000]:
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                start = time.time()
                eval_inferer = SlidingWindowInferer(roi_size=sw_roi_size, sw_batch_size=sw_batch_size, mode="gaussian", progress=True)
                ret = eval_inferer(item["image"], model)
                print(f'{chosen_model}: {sw_batch_size=}, time={(time.time()-start):.3f}')


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants