Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Sam integration #7722

Draft
wants to merge 28 commits into
base: gen-ai-dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7acfc69
causal self attention
vgrau98 Jan 24, 2024
568013c
causal selfattention tests
vgrau98 Jan 24, 2024
12dca86
integrate flash attention usage
vgrau98 Apr 27, 2024
71d7c9d
transformer block local window attention
vgrau98 Dec 30, 2023
a0325e4
fix: window partition input shapes
vgrau98 Jan 2, 2024
02c087c
fix: error handling
vgrau98 Jan 2, 2024
ab4a440
local window attention tests
vgrau98 Jan 2, 2024
b47d71a
feat: 3d local window attention
vgrau98 Jan 2, 2024
d4fec56
3d local attention window tests
vgrau98 Jan 2, 2024
a4ee3f4
clean
vgrau98 Jan 6, 2024
e417ffe
refacto
vgrau98 Jan 6, 2024
e91c289
feat: layer norm 2d
vgrau98 Dec 28, 2023
3726f57
fix: rel pos embedding with local attention
vgrau98 May 6, 2024
e4778d1
feature: 2D sam image encoder
vgrau98 May 6, 2024
bd5318e
Merge pull request #1 from vgrau98/sam-integration-img-encoder
vgrau98 May 6, 2024
ce3f32c
sam prompt encoder
vgrau98 May 8, 2024
47e0eb8
Merge pull request #2 from vgrau98/sam-integration-prompt-encoder
vgrau98 May 8, 2024
50752a0
refacto mlp block for multilayer suppoert
vgrau98 May 9, 2024
0250aef
mask decoder
vgrau98 May 9, 2024
c6cd2c0
attention block as defined in sam
vgrau98 May 10, 2024
0a6dadb
two way transformer block and two way attention block
vgrau98 May 10, 2024
6131995
fix: mlp block with multiple layers, fix input, output and hidden dim…
vgrau98 May 11, 2024
496bf63
fix transformer block
vgrau98 May 11, 2024
548c3cc
Merge pull request #3 from vgrau98/sam-integration-mask-decoder
vgrau98 May 11, 2024
3809202
sam network
vgrau98 May 11, 2024
f0cd576
Merge pull request #4 from vgrau98/sam-integration-sam-network
vgrau98 May 11, 2024
0c5d48f
sam weights mapping
vgrau98 May 12, 2024
47c6a40
Merge pull request #5 from vgrau98/sam-integration-weights-mapping
vgrau98 May 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions monai/networks/blocks/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
import torch.nn.functional as F
from torch import nn

from monai.utils import optional_import

rearrange, _ = optional_import("einops", name="rearrange")


def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -126,3 +130,162 @@ def add_decomposed_rel_pos(
).view(batch, q_h * q_w * q_d, k_h * k_w * k_d)

return attn


def window_partition(x: torch.Tensor, window_size: int, input_size: Tuple = ()) -> Tuple[torch.Tensor, Tuple]:
"""
Partition into non-overlapping windows with padding if needed. Support 2D and 3D.
Args:
x (tensor): input tokens with [B, s_dim_1 * ... * s_dim_n, C]. with n = 1...len(input_size)
input_size (Tuple): input spatial dimension: (H, W) or (H, W, D)
window_size (int): window size

Returns:
windows: windows after partition with [B * num_windows, window_size_1 * ... * window_size_n, C].
with n = 1...len(input_size) and window_size_i == window_size.
(S_DIM_1p, ...,S_DIM_np): padded spatial dimensions before partition with n = 1...len(input_size)
"""
if x.shape[1] != int(torch.prod(torch.tensor(input_size))):
raise ValueError(f"Input tensor spatial dimension {x.shape[1]} should be equal to {input_size} product")

if len(input_size) == 2:
x = rearrange(x, "b (h w) c -> b h w c", h=input_size[0], w=input_size[1])
x, pad_hw = window_partition_2d(x, window_size)
x = rearrange(x, "b h w c -> b (h w) c", h=window_size, w=window_size)
return x, pad_hw
elif len(input_size) == 3:
x = rearrange(x, "b (h w d) c -> b h w d c", h=input_size[0], w=input_size[1], d=input_size[2])
x, pad_hwd = window_partition_3d(x, window_size)
x = rearrange(x, "b h w d c -> b (h w d) c", h=window_size, w=window_size, d=window_size)
return x, pad_hwd
else:
raise ValueError(f"input_size cannot be length {len(input_size)}. It can be composed of 2 or 3 elements only. ")


def window_partition_2d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed. Support only 2D.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.

Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
batch, h, w, c = x.shape

pad_h = (window_size - h % window_size) % window_size
pad_w = (window_size - w % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
hp, wp = h + pad_h, w + pad_w

x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, c)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
return windows, (hp, wp)


def window_partition_3d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
"""
Partition into non-overlapping windows with padding if needed. 3d implementation.
Args:
x (tensor): input tokens with [B, H, W, D, C].
window_size (int): window size.

Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, window_size, C].
(Hp, Wp, Dp): padded height, width and depth before partition
"""
batch, h, w, d, c = x.shape

pad_h = (window_size - h % window_size) % window_size
pad_w = (window_size - w % window_size) % window_size
pad_d = (window_size - d % window_size) % window_size
if pad_h > 0 or pad_w > 0 or pad_d > 0:
x = F.pad(x, (0, 0, 0, pad_d, 0, pad_w, 0, pad_h))
hp, wp, dp = h + pad_h, w + pad_w, d + pad_d

x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, dp // window_size, window_size, c)
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, c)
return windows, (hp, wp, dp)


def window_unpartition(windows: torch.Tensor, window_size: int, pad: Tuple, spatial_dims: Tuple) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size_1, ..., window_size_n, C].
with n = 1...len(spatial_dims) and window_size == window_size_i
window_size (int): window size.
pad (Tuple): padded spatial dims (H, W) or (H, W, D)
spatial_dims (Tuple): original spatial dimensions - (H, W) or (H, W, D) - before padding.

Returns:
x: unpartitioned sequences with [B, s_dim_1, ..., s_dim_n, C].
"""
x: torch.Tensor
if len(spatial_dims) == 2:
x = rearrange(windows, "b (h w) c -> b h w c", h=window_size, w=window_size)
x = window_unpartition_2d(x, window_size, pad, spatial_dims)
x = rearrange(x, "b h w c -> b (h w) c", h=spatial_dims[0], w=spatial_dims[1])
return x
elif len(spatial_dims) == 3:
x = rearrange(windows, "b (h w d) c -> b h w d c", h=window_size, w=window_size, d=window_size)
x = window_unpartition_3d(x, window_size, pad, spatial_dims)
x = rearrange(x, "b h w d c -> b (h w d) c", h=spatial_dims[0], w=spatial_dims[1], d=spatial_dims[2])
return x
else:
raise ValueError()


def window_unpartition_2d(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (hp, wp).
hw (Tuple): original height and width (H, W) before padding.

Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
hp, wp = pad_hw
h, w = hw
batch = windows.shape[0] // (hp * wp // window_size // window_size)
x = windows.view(batch, hp // window_size, wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch, hp, wp, -1)

if hp > h or wp > w:
x = x[:, :h, :w, :].contiguous()
return x


def window_unpartition_3d(
windows: torch.Tensor, window_size: int, pad_hwd: Tuple[int, int, int], hwd: Tuple[int, int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding. 3d implementation.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, window_size, C].
window_size (int): window size.
pad_hwd (Tuple): padded height, width and depth (hp, wp, dp).
hwd (Tuple): original height, width and depth (H, W, D) before padding.

Returns:
x: unpartitioned sequences with [B, H, W, D, C].
"""
hp, wp, dp = pad_hwd
h, w, d = hwd
batch = windows.shape[0] // (hp * wp * dp // window_size // window_size // window_size)
x = windows.view(
batch, hp // window_size, wp // window_size, dp // window_size, window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(batch, hp, wp, dp, -1)

if hp > h or wp > w or dp > d:
x = x[:, :h, :w, :d, :].contiguous()
return x
109 changes: 91 additions & 18 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,29 @@
from __future__ import annotations

from typing import Optional, Tuple
import warnings

import torch
import torch.nn as nn

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.networks.blocks.attention_utils import window_partition, window_unpartition
from monai.utils import optional_import

xops, has_xformers = optional_import("xformers.ops")
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")


class SABlock(nn.Module):
"""
A self-attention block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
A self-attention block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
<<<<<<< HEAD
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
=======
and some additional features:
- local window attention
>>>>>>> f7aca872 (refacto)
"""

def __init__(
Expand All @@ -38,6 +46,10 @@ def __init__(
save_attn: bool = False,
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
causal: bool = False,
sequence_length: int | None = None,
use_flash_attention: bool = False,
window_size: int = 0,
) -> None:
"""
Args:
Expand All @@ -48,9 +60,13 @@ def __init__(
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
positional parameter size. Has to be set if local window attention is used
causal (bool): wether to use causal attention. If true `sequence_length` has to be set
sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.

window_size (int): Window size for local attention as used in Segment Anything https://arxiv.org/abs/2304.02643.
If 0, global attention used.
See https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py.
"""

super().__init__()
Expand All @@ -61,24 +77,54 @@ def __init__(
if hidden_size % num_heads != 0:
raise ValueError("hidden size should be divisible by num_heads.")

if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

if use_flash_attention and rel_pos_embedding is not None:
self.use_flash_attention = False
warnings.warn(
"flash attention set to `False`: flash attention can't be used with relative position embedding. Set `rel_pos_embedding` to `None` to use flash attention"
)
else:
self.use_flash_attention = use_flash_attention

if use_flash_attention and not has_xformers:
raise ValueError("use_flash_attention is True but xformers is not installed.")
if window_size > 0 and len(input_size) not in [2, 3]:
raise ValueError(
"If local window attention is used (window_size > 0), input_size should be specified: (h, w) or (h, w, d)"
)

self.num_heads = num_heads
self.out_proj = nn.Linear(hidden_size, hidden_size)
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.dropout_rate = dropout_rate
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.causal = causal
self.sequence_length = sequence_length
self.save_attn = save_attn
self.att_mat = torch.Tensor()
self.rel_positional_embedding = (
get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads)
if rel_pos_embedding is not None
else None
)
self.window_size = window_size
self.input_size = input_size

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer(
"causal_mask",
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
)
self.causal_mask: torch.Tensor

def forward(self, x: torch.Tensor):
"""
Args:
Expand All @@ -87,23 +133,50 @@ def forward(self, x: torch.Tensor):
Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
"""
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat

att_mat = att_mat.softmax(dim=-1)
if self.window_size > 0:
x, pad = window_partition(x, self.window_size, self.input_size)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()
_, t, _ = x.size()
output = self.input_rearrange(self.qkv(x)) # 3 x B x (s_dim_1 * ... * s_dim_n) x h x C/h
q, k, v = output[0], output[1], output[2]

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
if self.use_flash_attention:
x = xops.memory_efficient_attention(
query=q.contiguous(),
key=k.contiguous(),
value=v.contiguous(),
scale=self.scale,
p=self.dropout_rate,
attn_bias=xops.LowerTriangularMask() if self.causal else None,
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = (
self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
)
# apply causal mask if set
att_mat = (
att_mat.masked_fill(self.causal_mask[:, :, :t, :t] == 0, float("-inf")) if self.causal else att_mat
)

att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)

# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad, self.input_size)

return x
Loading
Loading