From 139b62c4aa161969ad1126c4feeec88c9833d4ef Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 1 Aug 2024 19:30:05 +0800 Subject: [PATCH 1/5] Fix outdated link in the docs (#7971) Fixes #7968 ### Description Fix outdated link in the docs ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index b6c8c22f98..85adee7e44 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -37,7 +37,7 @@ Features Getting started --------------- -`MedNIST demo `_ and `MONAI for PyTorch Users `_ are available on Colab. +`MedNIST demo `_ and `MONAI for PyTorch Users `_ are available on Colab. Examples and notebook tutorials are located at `Project-MONAI/tutorials `_. From 1ece8a5c0b187a737f02ec44e7436239e3343e43 Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:19:13 +0800 Subject: [PATCH 2/5] 7982-fix-ci-issue (#7983) Fixes #7982 ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .github/workflows/pythonapp.yml | 1 + requirements-dev.txt | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index cd6b6ccede..fe04f96a80 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -149,6 +149,7 @@ jobs: key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install dependencies run: | + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; python -m pip install --user --upgrade pip setuptools wheel twine # install the latest pytorch for testing # however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated diff --git a/requirements-dev.txt b/requirements-dev.txt index ced783443e..72ba210093 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -21,7 +21,7 @@ black>=22.12 isort>=5.1 ruff pytype>=2020.6.1; platform_system != "Windows" -types-pkg_resources +types-setuptools mypy>=1.5.0 ninja torchvision From ae5a04d685ade10d886db1918d68e292a6096a17 Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Sat, 3 Aug 2024 22:37:36 +0800 Subject: [PATCH 3/5] 7973-add-ngc-prefix (#7974) Fixes #7973 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Yiheng Wang Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/bundle/scripts.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 4967b6cf50..6dd83c1f81 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -217,10 +217,15 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str: def _download_from_ngc( - download_path: Path, filename: str, version: str, remove_prefix: str | None, progress: bool + download_path: Path, + filename: str, + version: str, + prefix: str = "monai_", + remove_prefix: str | None = "monai_", + progress: bool = True, ) -> None: # ensure prefix is contained - filename = _add_ngc_prefix(filename) + filename = _add_ngc_prefix(filename, prefix=prefix) url = _get_ngc_bundle_url(model_name=filename, version=version) filepath = download_path / f"{filename}_v{version}.zip" if remove_prefix: @@ -231,10 +236,16 @@ def _download_from_ngc( def _download_from_ngc_private( - download_path: Path, filename: str, version: str, remove_prefix: str | None, repo: str, headers: dict | None = None + download_path: Path, + filename: str, + version: str, + repo: str, + prefix: str = "monai_", + remove_prefix: str | None = "monai_", + headers: dict | None = None, ) -> None: # ensure prefix is contained - filename = _add_ngc_prefix(filename) + filename = _add_ngc_prefix(filename, prefix=prefix) request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo) if has_requests: headers = {} if headers is None else headers @@ -491,7 +502,7 @@ def download( url: url to download the data. If not `None`, data will be downloaded directly and `source` will not be checked. If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`. - remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles + remove_prefix: This argument is used when `source` is "ngc" or "ngc_private". Currently, all ngc bundles have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to maintain the consistency between these two sources, remove prefix is necessary. Therefore, if specified, downloaded folder name will remove the prefix. From 56ee32e36c5c0c7a5cb10afa4ec5589c81171e6b Mon Sep 17 00:00:00 2001 From: David Carreto Fidalgo Date: Sat, 3 Aug 2024 22:29:45 +0200 Subject: [PATCH 4/5] Fix: Small logic mistake in the `AsDiscrete.__call__` method (#7984) Hi MONAI Team! Thank you very much for this super nice framework, really appreciate it! Just found a small logic mistake in one of the transform classes. To reproduce: ```python import torch from monai.transforms.post.array import AsDiscrete transform = AsDiscrete(argmax=True) prediction = torch.rand(2, 3, 3) transform(prediction, argmax=False) # will still apply argmax ``` ### Description Proposed fix: `argmax` is explicitly checked for `None` in the `__cal__` method. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: David Carreto Fidalgo Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/post/array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index da9b23ce57..2e733c4f6c 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -211,7 +211,8 @@ def __call__( raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor) - if argmax or self.argmax: + argmax = self.argmax if argmax is None else argmax + if argmax: img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True)) to_onehot = self.to_onehot if to_onehot is None else to_onehot From 6c23fd06fc11667beedd0ba730d4104076a8db2d Mon Sep 17 00:00:00 2001 From: Virginia Fernandez <61539159+virginiafdez@users.noreply.github.com> Date: Tue, 6 Aug 2024 17:18:22 +0100 Subject: [PATCH 5/5] Flash attention (#7977) Fixes #7944. ### Description In response to Issue https://github.com/Project-MONAI/MONAI/issues/7944, I added the new functionality scaled_dot_product_attention from PyTorch to re-enable flash attention, present in the original MONAI Generative Models repository. This is allowed for torch >= 2.0 and when argument save_attn = False. Errors are raised otherwise. I ran quick tests and added some checks on test_selfattention and test_crossattention scripts to make sure the outputs are the same as not using flash attention. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Virginia Fernandez Co-authored-by: Virginia Fernandez Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> --- monai/networks/blocks/crossattention.py | 73 +++++++++++++++------ monai/networks/blocks/selfattention.py | 58 ++++++++++++---- monai/networks/blocks/spatialattention.py | 8 ++- monai/networks/blocks/transformerblock.py | 13 +++- monai/networks/nets/diffusion_model_unet.py | 5 ++ tests/test_crossattention.py | 66 +++++++++++++++---- tests/test_selfattention.py | 61 +++++++++++++---- 7 files changed, 223 insertions(+), 61 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index b888ea3942..daa5abdd56 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -17,7 +17,7 @@ import torch.nn as nn from monai.networks.layers.utils import get_rel_pos_embedding_layer -from monai.utils import optional_import +from monai.utils import optional_import, pytorch_after Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -44,6 +44,7 @@ def __init__( rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, attention_dtype: Optional[torch.dtype] = None, + use_flash_attention: bool = False, ) -> None: """ Args: @@ -55,13 +56,16 @@ def __init__( dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. - causal: whether to use causal attention. - sequence_length: if causal is True, it is necessary to specify the sequence length. - rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. - For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. - input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative - positional parameter size. + causal (bool, optional): whether to use causal attention. + sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length. + rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only + "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional + parameter size. attention_dtype: cast attention operations to this dtype. + use_flash_attention: if True, use Pytorch's inbuilt + flash attention for a memory efficient attention mechanism (see + https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ super().__init__() @@ -81,6 +85,20 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") + if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0): + raise ValueError( + "use_flash_attention is only supported for PyTorch versions >= 2.0." + "Upgrade your PyTorch or set the flag to False." + ) + if use_flash_attention and save_attn: + raise ValueError( + "save_attn has been set to True, but use_flash_attention is also set" + "to True. save_attn can only be used if use_flash_attention is False" + ) + + if use_flash_attention and rel_pos_embedding is not None: + raise ValueError("rel_pos_embedding must be None if you are using flash_attention.") + self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.context_input_size = context_input_size if context_input_size else hidden_size @@ -94,6 +112,7 @@ def __init__( self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) + self.dropout_rate = dropout_rate self.scale = self.head_dim**-0.5 self.save_attn = save_attn @@ -101,6 +120,7 @@ def __init__( self.causal = causal self.sequence_length = sequence_length + self.use_flash_attention = use_flash_attention if causal and sequence_length is not None: # causal mask to ensure that attention is only applied to the left in the input sequence @@ -142,26 +162,39 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): q = q.to(self.attention_dtype) k = k.to(self.attention_dtype) - q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) + q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) # k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale - # apply relative positional embedding if defined - att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + if self.use_flash_attention: + x = torch.nn.functional.scaled_dot_product_attention( + query=q.transpose(1, 2), + key=k.transpose(1, 2), + value=v.transpose(1, 2), + scale=self.scale, + dropout_p=self.dropout_rate, + is_causal=self.causal, + ).transpose( + 1, 2 + ) # Back to (b, nh, t, hs) + else: + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + # apply relative positional embedding if defined + if self.rel_positional_embedding is not None: + att_mat = self.rel_positional_embedding(x, att_mat, q) - if self.causal: - att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) - att_mat = att_mat.softmax(dim=-1) + att_mat = att_mat.softmax(dim=-1) - if self.save_attn: - # no gradients and new tensor; - # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html - self.att_mat = att_mat.detach() + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() - att_mat = self.drop_weights(att_mat) - x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 3ab1e1fd10..124c00acc6 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -15,9 +15,10 @@ import torch import torch.nn as nn +import torch.nn.functional as F from monai.networks.layers.utils import get_rel_pos_embedding_layer -from monai.utils import optional_import +from monai.utils import optional_import, pytorch_after Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -42,6 +43,7 @@ def __init__( rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, attention_dtype: Optional[torch.dtype] = None, + use_flash_attention: bool = False, ) -> None: """ Args: @@ -59,6 +61,9 @@ def __init__( input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. attention_dtype: cast attention operations to this dtype. + use_flash_attention: if True, use Pytorch's inbuilt + flash attention for a memory efficient attention mechanism (see + https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -82,6 +87,20 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") + if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0): + raise ValueError( + "use_flash_attention is only supported for PyTorch versions >= 2.0." + "Upgrade your PyTorch or set the flag to False." + ) + if use_flash_attention and save_attn: + raise ValueError( + "save_attn has been set to True, but use_flash_attention is also set" + "to True. save_attn can only be used if use_flash_attention is False." + ) + + if use_flash_attention and rel_pos_embedding is not None: + raise ValueError("rel_pos_embedding must be None if you are using flash_attention.") + self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) @@ -91,12 +110,14 @@ def __init__( self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) + self.dropout_rate = dropout_rate self.scale = self.dim_head**-0.5 self.save_attn = save_attn self.att_mat = torch.Tensor() self.attention_dtype = attention_dtype self.causal = causal self.sequence_length = sequence_length + self.use_flash_attention = use_flash_attention if causal and sequence_length is not None: # causal mask to ensure that attention is only applied to the left in the input sequence @@ -130,23 +151,34 @@ def forward(self, x): q = q.to(self.attention_dtype) k = k.to(self.attention_dtype) - att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + if self.use_flash_attention: + x = F.scaled_dot_product_attention( + query=q.transpose(1, 2), + key=k.transpose(1, 2), + value=v.transpose(1, 2), + scale=self.scale, + dropout_p=self.dropout_rate, + is_causal=self.causal, + ).transpose(1, 2) + else: + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale - # apply relative positional embedding if defined - att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + # apply relative positional embedding if defined + if self.rel_positional_embedding is not None: + att_mat = self.rel_positional_embedding(x, att_mat, q) - if self.causal: - att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf")) + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf")) - att_mat = att_mat.softmax(dim=-1) + att_mat = att_mat.softmax(dim=-1) - if self.save_attn: - # no gradients and new tensor; - # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html - self.att_mat = att_mat.detach() + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() - att_mat = self.drop_weights(att_mat) - x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 75319853d9..1cfafb1585 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -33,6 +33,7 @@ class SpatialAttentionBlock(nn.Module): num_channels: number of input channels. Must be divisible by num_head_channels. num_head_channels: number of channels per head. attention_dtype: cast attention operations to this dtype. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ @@ -44,6 +45,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, attention_dtype: Optional[torch.dtype] = None, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -54,7 +56,11 @@ def __init__( raise ValueError("num_channels must be divisible by num_head_channels") num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 self.attn = SABlock( - hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype + hidden_size=num_channels, + num_heads=num_heads, + qkv_bias=True, + attention_dtype=attention_dtype, + use_flash_attention=use_flash_attention, ) def forward(self, x: torch.Tensor): diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 0aa1697479..28d9c563ac 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -36,6 +36,7 @@ def __init__( causal: bool = False, sequence_length: int | None = None, with_cross_attention: bool = False, + use_flash_attention: bool = False, ) -> None: """ Args: @@ -43,8 +44,10 @@ def __init__( mlp_dim (int): dimension of feedforward layer. num_heads (int): number of attention heads. dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. - qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False. + qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -66,13 +69,19 @@ def __init__( save_attn=save_attn, causal=causal, sequence_length=sequence_length, + use_flash_attention=use_flash_attention, ) self.norm2 = nn.LayerNorm(hidden_size) self.with_cross_attention = with_cross_attention self.norm_cross_attn = nn.LayerNorm(hidden_size) self.cross_attn = CrossAttentionBlock( - hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=False, + use_flash_attention=use_flash_attention, ) def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 8a9ac859a3..a885339d0d 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -66,6 +66,8 @@ class DiffusionUNetTransformerBlock(nn.Module): dropout: dropout probability to use. cross_attention_dim: size of the context vector for cross attention. upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -77,6 +79,7 @@ def __init__( dropout: float = 0.0, cross_attention_dim: int | None = None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.attn1 = SABlock( @@ -86,6 +89,7 @@ def __init__( dim_head=num_head_channels, dropout_rate=dropout, attention_dtype=torch.float if upcast_attention else None, + use_flash_attention=use_flash_attention, ) self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) self.attn2 = CrossAttentionBlock( @@ -96,6 +100,7 @@ def __init__( dim_head=num_head_channels, dropout_rate=dropout, attention_dtype=torch.float if upcast_attention else None, + use_flash_attention=use_flash_attention, ) self.norm1 = nn.LayerNorm(num_channels) self.norm2 = nn.LayerNorm(num_channels) diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 4ab0ab1823..44458147d6 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -22,6 +22,7 @@ from monai.networks.blocks.crossattention import CrossAttentionBlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion einops, has_einops = optional_import("einops") @@ -31,25 +32,29 @@ for num_heads in [4, 6, 8, 12]: for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: for input_size in [(16, 32), (8, 8, 8)]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding, - "input_size": input_size, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_CABLOCK.append(test_case) + for flash_attn in [True, False]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding if not flash_attn else None, + "input_size": input_size, + "use_flash_attention": flash_attn, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_CABLOCK.append(test_case) class TestResBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_CABLOCK) @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) def test_shape(self, input_param, input_shape, expected_shape): + # Without flash attention net = CrossAttentionBlock(**input_param) with eval_mode(net): result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"])) @@ -62,6 +67,25 @@ def test_ill_arg(self): with self.assertRaises(ValueError): CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + @SkipIfBeforePyTorchVersion((2, 0)) + def test_save_attn_with_flash_attention(self): + with self.assertRaises(ValueError): + CrossAttentionBlock( + hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True + ) + + @SkipIfBeforePyTorchVersion((2, 0)) + def test_rel_pos_embedding_with_flash_attention(self): + with self.assertRaises(ValueError): + CrossAttentionBlock( + hidden_size=128, + num_heads=3, + dropout_rate=0.1, + use_flash_attention=True, + save_attn=False, + rel_pos_embedding=RelPosEmbedding.DECOMPOSED, + ) + @skipUnless(has_einops, "Requires einops") def test_attention_dim_not_multiple_of_heads(self): with self.assertRaises(ValueError): @@ -75,6 +99,22 @@ def test_causal_no_sequence_length(self): with self.assertRaises(ValueError): CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_causal_flash_attention(self): + block = CrossAttentionBlock( + hidden_size=128, + num_heads=1, + dropout_rate=0.1, + causal=True, + sequence_length=16, + save_attn=False, + use_flash_attention=True, + ) + input_shape = (1, 16, 128) + # Check it runs correctly + block(torch.randn(input_shape)) + @skipUnless(has_einops, "Requires einops") def test_causal(self): block = CrossAttentionBlock( @@ -119,7 +159,7 @@ def test_access_attn_matrix(self): # no of elements is zero assert no_matrix_acess_blk.att_mat.nelement() == 0 - # be able to acess the attention matrix + # be able to acess the attention matrix. matrix_acess_blk = CrossAttentionBlock( hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True ) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index d069d6aa30..3e98f4c5c4 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -22,6 +22,7 @@ from monai.networks.blocks.selfattention import SABlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion einops, has_einops = optional_import("einops") @@ -31,24 +32,27 @@ for num_heads in [4, 6, 8, 12]: for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: for input_size in [(16, 32), (8, 8, 8)]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding, - "input_size": input_size, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for flash_attn in [True, False]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding if not flash_attn else None, + "input_size": input_size, + "use_flash_attention": flash_attn, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_SABLOCK) @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) def test_shape(self, input_param, input_shape, expected_shape): net = SABlock(**input_param) with eval_mode(net): @@ -62,6 +66,23 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + @SkipIfBeforePyTorchVersion((2, 0)) + def test_rel_pos_embedding_with_flash_attention(self): + with self.assertRaises(ValueError): + SABlock( + hidden_size=128, + num_heads=3, + dropout_rate=0.1, + use_flash_attention=True, + save_attn=False, + rel_pos_embedding=RelPosEmbedding.DECOMPOSED, + ) + + @SkipIfBeforePyTorchVersion((1, 13)) + def test_save_attn_with_flash_attention(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True) + def test_attention_dim_not_multiple_of_heads(self): with self.assertRaises(ValueError): SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1) @@ -74,6 +95,22 @@ def test_causal_no_sequence_length(self): with self.assertRaises(ValueError): SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_causal_flash_attention(self): + block = SABlock( + hidden_size=128, + num_heads=1, + dropout_rate=0.1, + causal=True, + sequence_length=16, + save_attn=False, + use_flash_attention=True, + ) + input_shape = (1, 16, 128) + # Check it runs correctly + block(torch.randn(input_shape)) + @skipUnless(has_einops, "Requires einops") def test_causal(self): block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True)