Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update module.py #2

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Update module.py #2

wants to merge 1 commit into from

Conversation

johndpope
Copy link

The main improvements in this version are:

Instead of using F.pad and updating x_padded for each split, the code pre-allocates the output tensor output and directly updates it with the padded values from the previous split. This avoids the need to store the entire padded input tensor in memory.
The code uses a loop to process each split separately, which allows for more efficient memory usage. The padded input tensor x_padded is created only for the first split, and subsequent splits reuse the output from the previous split for padding.
The torch.zeros function is used to pre-allocate the output tensor output on the same device as the input tensor x. This avoids the need for additional memory allocation during the loop.

These optimizations further reduce the memory footprint of the PatchConv2d module by avoiding the storage of the entire padded input tensor and reusing the output from previous splits for padding.
Please note that these optimizations assume that the height dimension of the input tensor is divisible by the number of splits (self.splits). If this assumption is not met, the code will raise an assertion error.

@johndpope
Copy link
Author

dethwise pointwise convolution is where it's at - checkout my post here
https://www.reddit.com/r/StableDiffusion/comments/1bh970h/comment/kvclr41/

import torch
from torch import nn
from torch.nn import functional as F

class PatchPointwiseConv2d(nn.Module):
    def __init__(self, splits: int = 4, conv2d: nn.Conv2d = None, *args, **kwargs):
        super(PatchPointwiseConv2d, self).__init__()
        if conv2d is not None:
            self.conv2d = conv2d
            self.splits = splits
        else:
            self.conv2d = nn.Conv2d(*args, **kwargs, groups=kwargs['in_channels'])
            self.pointwise = nn.Conv2d(kwargs['in_channels'], kwargs['out_channels'], 1)
            self.splits = splits

    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        b, c, h, w = x.shape
        if c * h * w >= (1 << 30):
            assert h % self.splits == 0
            split_size = h // self.splits
            x_permuted = x.view(b, c, self.splits, split_size, w).permute(0, 2, 1, 3, 4)
            padding_bak = self.conv2d.padding
            self.conv2d.padding = (0, self.conv2d.padding[1])
            output = torch.zeros(b, self.splits, self.conv2d.out_channels, split_size + 2 * self.conv2d.padding[0], w, device=x.device)
            for i in range(self.splits):
                if i == 0:
                    x_padded = F.pad(
                        x_permuted[:, i],
                        (0, 0, self.conv2d.padding[0], self.conv2d.padding[0]),
                        mode="constant" if self.conv2d.padding_mode == "zeros" else self.conv2d.padding_mode,
                        value=0,
                    )
                else:
                    x_padded[:, :, : self.conv2d.padding[0]] = output[:, i - 1, :, -2 * self.conv2d.padding[0] : -self.conv2d.padding[0]]
                    x_padded[:, :, -self.conv2d.padding[0] :] = x_permuted[:, i, :, : self.conv2d.padding[0]]
                output[:, i] = self.pointwise(self.conv2d(x_padded, *args, **kwargs))
            self.conv2d.padding = padding_bak
            output = output.permute(0, 2, 1, 3, 4).reshape(b, self.conv2d.out_channels, -1, w)
            return output
        else:
            return self.pointwise(self.conv2d(x, *args, **kwargs))

Sure! Let's analyze the memory footprint and computational requirements for an input tensor with a resolution of 4048 pixels at both height and width.

Given:

  • Input tensor dimensions: (batch_size, in_channels=128, height=4048, width=4048)
  • Convolution parameters: kernel_size=3, out_channels=256

Regular Convolution:

  • Number of multiplications:
    • (3 * 3 * 128 * 256) * (4046 * 4046) ≈ 1,886 billion
    • Approximate computational requirement: 1,886 billion * 4 bytes (float32) ≈ 7.5 TB
  • Number of parameters:
    • 3 * 3 * 128 * 256 ≈ 294,912
    • Memory footprint of parameters: 294,912 * 4 bytes (float32) ≈ 1.1 MB
  • Memory footprint of input tensor:
    • batch_size * 128 * 4048 * 4048 * 4 bytes (float32) ≈ batch_size * 8 GB
  • Memory footprint of output tensor:
    • batch_size * 256 * 4046 * 4046 * 4 bytes (float32) ≈ batch_size * 16 GB

Pointwise Convolution:

  • Number of multiplications:
    • (3 * 3 * 128 + 128 * 256) * (4046 * 4046) ≈ 279 billion
    • Approximate computational requirement: 279 billion * 4 bytes (float32) ≈ 1.1 TB
  • Number of parameters:
    • 3 * 3 * 128 + 128 * 256 ≈ 33,792
    • Memory footprint of parameters: 33,792 * 4 bytes (float32) ≈ 0.13 MB
  • Memory footprint of input tensor:
    • batch_size * 128 * 4048 * 4048 * 4 bytes (float32) ≈ batch_size * 8 GB
  • Memory footprint of output tensor:
    • batch_size * 256 * 4046 * 4046 * 4 bytes (float32) ≈ batch_size * 16 GB

Comparison:

  • Computational Reduction:
    • Pointwise convolution reduces the number of multiplications by approximately 85% compared to regular convolution.
    • Computational requirement reduction: From 7.5 TB to 1.1 TB.
  • Memory Footprint Reduction:
    • Pointwise convolution reduces the number of parameters by approximately 88% compared to regular convolution.
    • Parameter memory footprint reduction: From 1.1 MB to 0.13 MB.
    • The memory footprint of the input and output tensors remains the same for both convolution types.

As you can see, at a high resolution of 4048 pixels, the computational requirements and memory footprint are significantly larger compared to the previous example. The pointwise convolution still provides a substantial reduction in the number of multiplications and parameters, but the memory footprint of the input and output tensors dominates the overall memory usage.

Please note that the memory footprint calculations assume a batch size of 1. If you have a larger batch size, the memory footprint of the input and output tensors will scale accordingly.

It's important to consider the available memory and computational resources when working with high-resolution images and adjust the batch size and model architecture accordingly to ensure that the memory and computational requirements are within the limits of your hardware.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant