Skip to content

Commit

Permalink
Merge branch 'Project-MONAI:dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
surajpaib authored Aug 6, 2024
2 parents 123c778 + 6c23fd0 commit c936162
Show file tree
Hide file tree
Showing 12 changed files with 244 additions and 69 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ jobs:
key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }}
- name: Install dependencies
run: |
find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
python -m pip install --user --upgrade pip setuptools wheel twine
# install the latest pytorch for testing
# however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Features
Getting started
---------------

`MedNIST demo <https://colab.research.google.com/drive/1wy8XUSnNWlhDNazFdvGBHLfdkGvOHBKe>`_ and `MONAI for PyTorch Users <https://colab.research.google.com/drive/1boqy7ENpKrqaJoxFlbHIBnIODAs1Ih1T>`_ are available on Colab.
`MedNIST demo <https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/2d_classification/mednist_tutorial.ipynb>`_ and `MONAI for PyTorch Users <https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/modules/developer_guide.ipynb>`_ are available on Colab.

Examples and notebook tutorials are located at `Project-MONAI/tutorials <https://github.com/Project-MONAI/tutorials>`_.

Expand Down
21 changes: 16 additions & 5 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,15 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str:


def _download_from_ngc(
download_path: Path, filename: str, version: str, remove_prefix: str | None, progress: bool
download_path: Path,
filename: str,
version: str,
prefix: str = "monai_",
remove_prefix: str | None = "monai_",
progress: bool = True,
) -> None:
# ensure prefix is contained
filename = _add_ngc_prefix(filename)
filename = _add_ngc_prefix(filename, prefix=prefix)
url = _get_ngc_bundle_url(model_name=filename, version=version)
filepath = download_path / f"{filename}_v{version}.zip"
if remove_prefix:
Expand All @@ -231,10 +236,16 @@ def _download_from_ngc(


def _download_from_ngc_private(
download_path: Path, filename: str, version: str, remove_prefix: str | None, repo: str, headers: dict | None = None
download_path: Path,
filename: str,
version: str,
repo: str,
prefix: str = "monai_",
remove_prefix: str | None = "monai_",
headers: dict | None = None,
) -> None:
# ensure prefix is contained
filename = _add_ngc_prefix(filename)
filename = _add_ngc_prefix(filename, prefix=prefix)
request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
if has_requests:
headers = {} if headers is None else headers
Expand Down Expand Up @@ -491,7 +502,7 @@ def download(
url: url to download the data. If not `None`, data will be downloaded directly
and `source` will not be checked.
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
remove_prefix: This argument is used when `source` is "ngc" or "ngc_private". Currently, all ngc bundles
have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
maintain the consistency between these two sources, remove prefix is necessary.
Therefore, if specified, downloaded folder name will remove the prefix.
Expand Down
73 changes: 53 additions & 20 deletions monai/networks/blocks/crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.nn as nn

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import
from monai.utils import optional_import, pytorch_after

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")

Expand All @@ -44,6 +44,7 @@ def __init__(
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
"""
Args:
Expand All @@ -55,13 +56,16 @@ def __init__(
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
causal: whether to use causal attention.
sequence_length: if causal is True, it is necessary to specify the sequence length.
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
causal (bool, optional): whether to use causal attention.
sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only
"decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional
parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use Pytorch's inbuilt
flash attention for a memory efficient attention mechanism (see
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""

super().__init__()
Expand All @@ -81,6 +85,20 @@ def __init__(
if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
raise ValueError(
"use_flash_attention is only supported for PyTorch versions >= 2.0."
"Upgrade your PyTorch or set the flag to False."
)
if use_flash_attention and save_attn:
raise ValueError(
"save_attn has been set to True, but use_flash_attention is also set"
"to True. save_attn can only be used if use_flash_attention is False"
)

if use_flash_attention and rel_pos_embedding is not None:
raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.context_input_size = context_input_size if context_input_size else hidden_size
Expand All @@ -94,13 +112,15 @@ def __init__(
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.dropout_rate = dropout_rate

self.scale = self.head_dim**-0.5
self.save_attn = save_attn
self.attention_dtype = attention_dtype

self.causal = causal
self.sequence_length = sequence_length
self.use_flash_attention = use_flash_attention

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
Expand Down Expand Up @@ -142,26 +162,39 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
q = q.to(self.attention_dtype)
k = k.to(self.attention_dtype)

q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) #
k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
if self.use_flash_attention:
x = torch.nn.functional.scaled_dot_product_attention(
query=q.transpose(1, 2),
key=k.transpose(1, 2),
value=v.transpose(1, 2),
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
).transpose(
1, 2
) # Back to (b, nh, t, hs)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
# apply relative positional embedding if defined
if self.rel_positional_embedding is not None:
att_mat = self.rel_positional_embedding(x, att_mat, q)

if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
Expand Down
58 changes: 45 additions & 13 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import
from monai.utils import optional_import, pytorch_after

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")

Expand All @@ -42,6 +43,7 @@ def __init__(
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
"""
Args:
Expand All @@ -59,6 +61,9 @@ def __init__(
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use Pytorch's inbuilt
flash attention for a memory efficient attention mechanism (see
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""

Expand All @@ -82,6 +87,20 @@ def __init__(
if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
raise ValueError(
"use_flash_attention is only supported for PyTorch versions >= 2.0."
"Upgrade your PyTorch or set the flag to False."
)
if use_flash_attention and save_attn:
raise ValueError(
"save_attn has been set to True, but use_flash_attention is also set"
"to True. save_attn can only be used if use_flash_attention is False."
)

if use_flash_attention and rel_pos_embedding is not None:
raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
Expand All @@ -91,12 +110,14 @@ def __init__(
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.dropout_rate = dropout_rate
self.scale = self.dim_head**-0.5
self.save_attn = save_attn
self.att_mat = torch.Tensor()
self.attention_dtype = attention_dtype
self.causal = causal
self.sequence_length = sequence_length
self.use_flash_attention = use_flash_attention

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
Expand Down Expand Up @@ -130,23 +151,34 @@ def forward(self, x):
q = q.to(self.attention_dtype)
k = k.to(self.attention_dtype)

att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
if self.use_flash_attention:
x = F.scaled_dot_product_attention(
query=q.transpose(1, 2),
key=k.transpose(1, 2),
value=v.transpose(1, 2),
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
).transpose(1, 2)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
# apply relative positional embedding if defined
if self.rel_positional_embedding is not None:
att_mat = self.rel_positional_embedding(x, att_mat, q)

if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))
if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
Expand Down
8 changes: 7 additions & 1 deletion monai/networks/blocks/spatialattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SpatialAttentionBlock(nn.Module):
num_channels: number of input channels. Must be divisible by num_head_channels.
num_head_channels: number of channels per head.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

Expand All @@ -44,6 +45,7 @@ def __init__(
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
super().__init__()

Expand All @@ -54,7 +56,11 @@ def __init__(
raise ValueError("num_channels must be divisible by num_head_channels")
num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
self.attn = SABlock(
hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype
hidden_size=num_channels,
num_heads=num_heads,
qkv_bias=True,
attention_dtype=attention_dtype,
use_flash_attention=use_flash_attention,
)

def forward(self, x: torch.Tensor):
Expand Down
13 changes: 11 additions & 2 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,18 @@ def __init__(
causal: bool = False,
sequence_length: int | None = None,
with_cross_attention: bool = False,
use_flash_attention: bool = False,
) -> None:
"""
Args:
hidden_size (int): dimension of hidden layer.
mlp_dim (int): dimension of feedforward layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""

Expand All @@ -66,13 +69,19 @@ def __init__(
save_attn=save_attn,
causal=causal,
sequence_length=sequence_length,
use_flash_attention=use_flash_attention,
)
self.norm2 = nn.LayerNorm(hidden_size)
self.with_cross_attention = with_cross_attention

self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)

def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
Expand Down
Loading

0 comments on commit c936162

Please sign in to comment.