-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update module.py #2
base: main
Are you sure you want to change the base?
Conversation
dethwise pointwise convolution is where it's at - checkout my post here import torch
from torch import nn
from torch.nn import functional as F
class PatchPointwiseConv2d(nn.Module):
def __init__(self, splits: int = 4, conv2d: nn.Conv2d = None, *args, **kwargs):
super(PatchPointwiseConv2d, self).__init__()
if conv2d is not None:
self.conv2d = conv2d
self.splits = splits
else:
self.conv2d = nn.Conv2d(*args, **kwargs, groups=kwargs['in_channels'])
self.pointwise = nn.Conv2d(kwargs['in_channels'], kwargs['out_channels'], 1)
self.splits = splits
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
b, c, h, w = x.shape
if c * h * w >= (1 << 30):
assert h % self.splits == 0
split_size = h // self.splits
x_permuted = x.view(b, c, self.splits, split_size, w).permute(0, 2, 1, 3, 4)
padding_bak = self.conv2d.padding
self.conv2d.padding = (0, self.conv2d.padding[1])
output = torch.zeros(b, self.splits, self.conv2d.out_channels, split_size + 2 * self.conv2d.padding[0], w, device=x.device)
for i in range(self.splits):
if i == 0:
x_padded = F.pad(
x_permuted[:, i],
(0, 0, self.conv2d.padding[0], self.conv2d.padding[0]),
mode="constant" if self.conv2d.padding_mode == "zeros" else self.conv2d.padding_mode,
value=0,
)
else:
x_padded[:, :, : self.conv2d.padding[0]] = output[:, i - 1, :, -2 * self.conv2d.padding[0] : -self.conv2d.padding[0]]
x_padded[:, :, -self.conv2d.padding[0] :] = x_permuted[:, i, :, : self.conv2d.padding[0]]
output[:, i] = self.pointwise(self.conv2d(x_padded, *args, **kwargs))
self.conv2d.padding = padding_bak
output = output.permute(0, 2, 1, 3, 4).reshape(b, self.conv2d.out_channels, -1, w)
return output
else:
return self.pointwise(self.conv2d(x, *args, **kwargs)) Sure! Let's analyze the memory footprint and computational requirements for an input tensor with a resolution of 4048 pixels at both height and width. Given:
Regular Convolution:
Pointwise Convolution:
Comparison:
As you can see, at a high resolution of 4048 pixels, the computational requirements and memory footprint are significantly larger compared to the previous example. The pointwise convolution still provides a substantial reduction in the number of multiplications and parameters, but the memory footprint of the input and output tensors dominates the overall memory usage. Please note that the memory footprint calculations assume a batch size of 1. If you have a larger batch size, the memory footprint of the input and output tensors will scale accordingly. It's important to consider the available memory and computational resources when working with high-resolution images and adjust the batch size and model architecture accordingly to ensure that the memory and computational requirements are within the limits of your hardware. |
The main improvements in this version are:
Instead of using F.pad and updating x_padded for each split, the code pre-allocates the output tensor output and directly updates it with the padded values from the previous split. This avoids the need to store the entire padded input tensor in memory.
The code uses a loop to process each split separately, which allows for more efficient memory usage. The padded input tensor x_padded is created only for the first split, and subsequent splits reuse the output from the previous split for padding.
The torch.zeros function is used to pre-allocate the output tensor output on the same device as the input tensor x. This avoids the need for additional memory allocation during the loop.
These optimizations further reduce the memory footprint of the PatchConv2d module by avoiding the storage of the entire padded input tensor and reusing the output from previous splits for padding.
Please note that these optimizations assume that the height dimension of the input tensor is divisible by the number of splits (self.splits). If this assumption is not met, the code will raise an assertion error.