Skip to content

Commit

Permalink
Streamlined Rearrange in SpatialAttentionBlock (#8130)
Browse files Browse the repository at this point in the history
The Rearrange code failed dynamo export in 24.09 container: 
pytorch/pytorch#137629
While we can't still use dynamo export with TRT in 23.09, I also noticed
that my workaround improved runtime by about 1 second end-to-end for 100
seconds run.

### Description

Replaced einops Rearrange with reshape/transpose 

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).

---------

Signed-off-by: Boris Fomitchev <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
borisfom and pre-commit-ci[bot] authored Oct 22, 2024
1 parent 35b3894 commit 052dbb4
Showing 1 changed file with 3 additions and 20 deletions.
23 changes: 3 additions & 20 deletions monai/networks/blocks/spatialattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
import torch.nn as nn

from monai.networks.blocks import SABlock
from monai.utils import optional_import

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")


class SpatialAttentionBlock(nn.Module):
Expand Down Expand Up @@ -74,24 +71,10 @@ def __init__(

def forward(self, x: torch.Tensor):
residual = x

if self.spatial_dims == 1:
h = x.shape[2]
rearrange_input = Rearrange("b c h -> b h c")
rearrange_output = Rearrange("b h c -> b c h", h=h)
if self.spatial_dims == 2:
h, w = x.shape[2], x.shape[3]
rearrange_input = Rearrange("b c h w -> b (h w) c")
rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w)
else:
h, w, d = x.shape[2], x.shape[3], x.shape[4]
rearrange_input = Rearrange("b c h w d -> b (h w d) c")
rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d)

shape = x.shape
x = self.norm(x)
x = rearrange_input(x) # B x (x_dim * y_dim [ * z_dim]) x C

x = x.reshape(*shape[:2], -1).transpose(1, 2) # "b c h w d -> b (h w d) c"
x = self.attn(x)
x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim]
x = x.transpose(1, 2).reshape(shape) # "b (h w d) c -> b c h w d"
x = x + residual
return x

0 comments on commit 052dbb4

Please sign in to comment.