Skip to content

Commit

Permalink
Diffuser change upgraded to 0.26.3 along with MLPERF SD XL support
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN87 committed Jul 13, 2024
1 parent 27a8f27 commit 89235a0
Show file tree
Hide file tree
Showing 4 changed files with 1,194 additions and 361 deletions.
188 changes: 188 additions & 0 deletions optimum/habana/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os
from typing import Optional, Union

import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import Attention
from diffusers.utils import USE_PEFT_BACKEND, logging
from diffusers.utils.import_utils import is_xformers_available
from torch import nn


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None


class Softmax(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, dim = None, invAttnHead= None):
return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead)

class Matmul(nn.Module):
def __init__(self):
super().__init__()

def forward(self, *args, **kwargs):
return torch.matmul(*args, **kwargs)

# ScaledDotProductAttention is based on torch.nn.functional.scaled_dot_product_attention
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super().__init__()
self.bmm1 = Matmul()
self.bmm2 = Matmul()
self.softmax = Softmax()

def forward(self, query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None) -> torch.Tensor:
# Efficient implementation:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
invAttnHead = torch.tensor(scale_factor, dtype=torch.float32).to('hpu')
attn_bias = torch.zeros(L, S, dtype=query.dtype)

if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask

if(S<128):
attn_weight = self.bmm1(key,query.transpose(-2, -1))
attn_weight = self.softmax(attn_weight, dim=-2, invAttnHead=invAttnHead)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return self.bmm2(attn_weight.transpose(-2, -1), value)
else:
attn_weight = self.bmm1(query, key.transpose(-2, -1))
attn_weight = self.softmax(attn_weight, dim=-1, invAttnHead=invAttnHead)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return self.bmm2(attn_weight, value)


# Copied from diffusers.models.attention_processor.AttnProcessor2_0
class AttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""

def __init__(self, attention_module=None):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.attention_module = attention_module

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
# hidden_states = F.scaled_dot_product_attention(
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
# )
if os.environ.get('PATCH_SDPA') is not None:
hidden_states = self.attention_module(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
else:
import habana_frameworks.torch.hpu as ht
from habana_frameworks.torch.hpex.kernels import FusedSDPA
with ht.sdp_kernel(enable_recompute = True):
hidden_states = FusedSDPA.apply(query, key, value, attention_mask, 0.0, False)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states

AttentionProcessor = Union[
AttnProcessor2_0,
]
Loading

0 comments on commit 89235a0

Please sign in to comment.