Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-OOM error raised by PyTorch during find_batch_size() #286

Open
OriKovacsiKatz opened this issue Jul 30, 2024 · 4 comments
Open

Non-OOM error raised by PyTorch during find_batch_size() #286

OriKovacsiKatz opened this issue Jul 30, 2024 · 4 comments
Labels
bug Something isn't working Legacy Legacy GUI related

Comments

@OriKovacsiKatz
Copy link

OriKovacsiKatz commented Jul 30, 2024

running plantseg example col-0_20161116
getting crash:

(plant-seg-dev) user-name@kitchen_computer:/home/user-name/plant-seg# python -m run_plantseg --config=/home/user-name/data/plantseg_config_2016.yaml
You are using the latest version of PlantSeg: 1.8.1.
2024-07-29 12:00:02,666 [MainThread] INFO PlantSeg - Running the pipeline on: ['/home/user-name/data/col-0_20161116/20161116.tif']
2024-07-29 12:00:02,666 [MainThread] INFO PlantSeg - Executing pipeline, see terminal for verbose logs.
2024-07-29 12:00:05,357 [MainThread] INFO PlantSeg - Executing pipeline step: 'preprocessing'. Parameters: '{'state': False, 'save_directory': 'PreProcessing', 'factor': [1.0, 1.0, 1.0], 'order': 2, 'crop_volume': '[:, :, :]', 'filter': {'state': False, 'type': 'gaussian', 'filter_param': 1.0}}'. Files ['/home/user-name/data/col-0_20161116/20161116.tif'].
2024-07-29 12:00:05,357 [MainThread] INFO PlantSeg - Skipping 'DataPreProcessing3D'. Disabled by the user.
2024-07-29 12:00:05,357 [MainThread] INFO PlantSeg - Executing pipeline step: 'cnn_prediction'. Parameters: '{'state': False, 'model_name': 'generic_confocal_3D_unet', 'device': 'cuda', 'patch': [80, 160, 160], 'stride_ratio': 0.75, 'patch_halo': [4, 8, 8], 'model_update': True, 'num_workers': 8}'. Files ['/home/user-name/data/col-0_20161116/20161116.tif'].
2024-07-29 12:00:05,389 [MainThread] INFO PlantSeg - File config_train.yml already exists. Skipping download.
2024-07-29 12:00:05,389 [MainThread] INFO PlantSeg - File best_checkpoint.pytorch already exists. Skipping download.
2024-07-29 12:00:05,389 [MainThread] INFO PlantSeg Zoo - Loaded model from PlantSeg zoo: generic_confocal_3D_unet
2024-07-29 12:00:05,496 [MainThread] INFO PlantSeg Zoo - Loaded model from user specified weights: /user-name/.plantseg_models/generic_confocal_3D_unet/best_checkpoint.pytorch
/home/user-name/plant-seg/plantseg/predictions/predict.py:80: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state = torch.load(model_path, map_location='cpu')
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/user-name/plant-seg/run_plantseg.py", line 4, in <module>
    main()
  File "/home/user-name/plant-seg/plantseg/run_plantseg.py", line 96, in main
    process_config(args.config)
  File "/home/user-name/plant-seg/plantseg/run_plantseg.py", line 77, in process_config
    raw2seg(config)
  File "/home/user-name/plant-seg/plantseg/pipeline/raw2seg.py", line 148, in raw2seg
    pipeline_step = pipeline_step_setup(input_paths, config[pipeline_step_name])
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/pipeline/raw2seg.py", line 55, in configure_cnn_step
    return UnetPredictions(
           ^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/predictions/predict.py", line 93, in __init__
    self.predictor = ArrayPredictor(
                     ^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/predictions/functional/array_predictor.py", line 181, in __init__
    self.batch_size = find_batch_size(model, in_channels, patch, patch_halo, device)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/predictions/functional/array_predictor.py", line 51, in find_batch_size
    _ = model(x)
        ^^^^^^^^
  File "/opt/conda/envs/plant-seg-dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/plant-seg-dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/training/model.py", line 500, in forward
    x = decoder(encoder_features, x)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/plant-seg-dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/plant-seg-dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/training/model.py", line 294, in forward
    x = self.upsampling(encoder_features=encoder_features, x=x)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/plant-seg-dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/plant-seg-dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/training/model.py", line 377, in forward
    return self.upsample(x, output_size)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/training/model.py", line 394, in _interpolate
    return F.interpolate(x, size=size, mode=mode)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/plant-seg-dev/lib/python3.11/site-packages/torch/nn/functional.py", line 4052, in interpolate
    return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected output.numel() <= std::numeric_limits<int32_t>::max() to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
(plant-seg-dev) user-name@kitchen_computer:/home/user-name/plant-seg# 

modified code to print details:

  File "/home/user-name/plant-seg/plantseg/predictions/functional/array_predictor.py", line 181, in __init__
    self.batch_size = find_batch_size(model, in_channels, patch, patch_halo, device)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/predictions/functional/array_predictor.py", line 51, in find_batch_size

added print debugging details 
    with torch.no_grad():
        for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
            try:
                # ====================================================================
                print("# [Ori Kovacsi-Katz]: attempt setting batch_size ",sep='=')
                print(batch_size, sep=',')
                print("in_channels", sep='=')
                print(in_channels, sep=',')
                print("actual_patch_shape", sep='=')
                print(actual_patch_shape, sep=',')
                print("device", sep='=')
                print(device)
                # ====================================================================            
                x = torch.randn((batch_size, in_channels) + actual_patch_shape).to(device)
                _ = model(x)
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print("# [Ori Kovacsi-Katz]: out of memory exception while attempt setting batch_size ")
                    batch_size //= 2
                    break
                else:
                    print("# [Ori Kovacsi-Katz]: Other then out of memory exception while attempt setting batch_size ")                
                    print(e)
                    raise
            finally:
                del x
                torch.cuda.empty_cache()

it crashed at batch_size=16

(plant-seg-dev) root@n311:/home/lahavt/plant-seg# python -m run_plantseg --config=/home/lahavt/data/plantseg_config_2016.yaml
You are using the latest version of PlantSeg: 1.8.1.
2024-07-30 10:26:52,559 [MainThread] INFO PlantSeg - Running the pipeline on: ['/home/lahavt/data/col-0_20161116/20161116.tif']
2024-07-30 10:26:52,559 [MainThread] INFO PlantSeg - Executing pipeline, see terminal for verbose logs.
2024-07-30 10:26:55,362 [MainThread] INFO PlantSeg - Executing pipeline step: 'preprocessing'. Parameters: '{'state': True, 'save_directory': 'PreProcessing', 'factor': [1.0, 1.0, 1.0], 'order': 2, 'crop_volume': '[:, :, :]', 'filter': {'state': False, 'type': 'gaussian', 'filter_param': 1.0}}'. Files ['/home/lahavt/data/col-0_20161116/20161116.tif'].
# [Ori Kovacsi-Katz DEBUGGING]
Executing pipeline step: 'preprocessing'. Parameters: '{'state': True, 'save_directory': 'PreProcessing', 'factor': [1.0, 1.0, 1.0], 'order': 2, 'crop_volume': '[:, :, :]', 'filter': {'state': False, 'type': 'gaussian', 'filter_param': 1.0}}'. Files ['/home/lahavt/data/col-0_20161116/20161116.tif'].
2024-07-30 10:26:55,363 [MainThread] INFO PlantSeg - Loading stack from /home/lahavt/data/col-0_20161116/20161116.tif
2024-07-30 10:27:04,145 [MainThread] INFO PlantSeg - Preprocessing files...
2024-07-30 10:27:04,146 [MainThread] INFO PlantSeg - Cropping input image to: [:, :, :]
2024-07-30 10:27:04,147 [MainThread] INFO PlantSeg - Saving results in /home/lahavt/data/col-0_20161116/PreProcessing/20161116.h5
2024-07-30 10:27:33,122 [MainThread] INFO PlantSeg - Executing pipeline step: 'cnn_prediction'. Parameters: '{'state': False, 'model_name': 'generic_confocal_3D_unet', 'device': 'cuda', 'patch': [80, 160, 160], 'stride_ratio': 0.75, 'patch_halo': [4, 8, 8], 'model_update': True, 'num_workers': 8}'. Files ['/home/lahavt/data/col-0_20161116/PreProcessing/20161116.h5'].
# [Ori Kovacsi-Katz DEBUGGING]
Executing pipeline step: 'cnn_prediction'. Parameters: '{'state': False, 'model_name': 'generic_confocal_3D_unet', 'device': 'cuda', 'patch': [80, 160, 160], 'stride_ratio': 0.75, 'patch_halo': [4, 8, 8], 'model_update': True, 'num_workers': 8}'. Files ['/home/lahavt/data/col-0_20161116/PreProcessing/20161116.h5'].
2024-07-30 10:27:33,153 [MainThread] INFO PlantSeg - File config_train.yml already exists. Skipping download.
2024-07-30 10:27:33,154 [MainThread] INFO PlantSeg - File best_checkpoint.pytorch already exists. Skipping download.
2024-07-30 10:27:33,154 [MainThread] INFO PlantSeg Zoo - Loaded model from PlantSeg zoo: generic_confocal_3D_unet
2024-07-30 10:27:33,239 [MainThread] INFO PlantSeg Zoo - Loaded model from user specified weights: /home/lahavt/tmp/.plantseg_models/generic_confocal_3D_unet/best_checkpoint.pytorch
# [Ori Kovacsi-Katz]: modified to be weights_only=True
# [Ori Kovacsi-Katz]: before setting self.batch_size... using find_batch_size(...)  
# [Ori Kovacsi-Katz]: attempt setting batch_size 
1
in_channels
1
actual_patch_shape
(88, 176, 176)
device
cuda
# [Ori Kovacsi-Katz]: attempt setting batch_size 
2
in_channels
1
actual_patch_shape
(88, 176, 176)
device
cuda
# [Ori Kovacsi-Katz]: attempt setting batch_size 
4
in_channels
1
actual_patch_shape
(88, 176, 176)
device
cuda
# [Ori Kovacsi-Katz]: attempt setting batch_size 
8
in_channels
1
actual_patch_shape
(88, 176, 176)
device
cuda
# [Ori Kovacsi-Katz]: attempt setting batch_size 
16
in_channels
1
actual_patch_shape
(88, 176, 176)
device
cuda
# [Ori Kovacsi-Katz]: Other then out of memory exception while attempt setting batch_size 
Expected output.numel() <= std::numeric_limits<int32_t>::max() to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/lahavt/plant-seg/run_plantseg.py", line 4, in <module>
    main()
  File "/home/lahavt/plant-seg/plantseg/run_plantseg.py", line 96, in main
    process_config(args.config)
  File "/home/lahavt/plant-seg/plantseg/run_plantseg.py", line 77, in process_config
    raw2seg(config)
  File "/home/lahavt/plant-seg/plantseg/pipeline/raw2seg.py", line 150, in raw2seg
    pipeline_step = pipeline_step_setup(input_paths, config[pipeline_step_name])
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

changed the sizes to maximal number 8 and it didn't crash

    with torch.no_grad():
        for batch_size in [1, 2, 4, 8]:

how can I fix the plant-seg/plantseg/predictions/functional/array_predictor.py line 51
so it will not crash the plantseg execution with all batch_sizes :

       for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: ?

thanks
Ori

@qin-yu
Copy link
Collaborator

qin-yu commented Jul 31, 2024

I see. Thanks for reporting this @OriKovacsiKatz

The reason why it failed is because PyTorch didn't raise OOM but raised RuntimeError: Expected output.numel() <= std::numeric_limits<int32_t>::max() to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) while I decided to only allow OOM errors:

for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
try:
x = torch.randn((batch_size, in_channels) + actual_patch_shape).to(device)
_ = model(x)
except RuntimeError as e:
if 'out of memory' in str(e):
batch_size //= 2
break
else:
raise

I do not understand why it raised something else. For now, yes, just change the list to have max 8 would work. But removing line 53, 54, 56, 57 would be the real solution if no OOM happened and the error message was correct. If what really happened was OOM, but PyTorch reports this, then we report an enhancement request to PyTorch) as suggested.

I'll keep this issue open until we figure this out.

@qin-yu qin-yu changed the title pre-processing col-0_20161116/PreProcessing/20161116.h5 example with plantseg failed on batch_size=16 Non-OOM error raised by PyTorch during find_batch_size() Jul 31, 2024
@qin-yu qin-yu added bug Something isn't working Legacy Legacy GUI related labels Jul 31, 2024
@qin-yu
Copy link
Collaborator

qin-yu commented Jul 31, 2024

"Legacy" tag because Napari GUI has workaround (single patch mode).

@qin-yu
Copy link
Collaborator

qin-yu commented Jul 31, 2024

Just formatted the issue for readability.

@qin-yu
Copy link
Collaborator

qin-yu commented Jul 31, 2024

Just in case I didn't sound encouraging, @OriKovacsiKatz you are very welcomed to check if OOM really happens in your device and then make a PR for PlantSeg and/or an issue for PyTorch. The easy way is just to stare at the terminal of PlantSeg and watch nvidia-smi together

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Legacy Legacy GUI related
Projects
None yet
Development

No branches or pull requests

2 participants