Skip to content

Commit

Permalink
Feature IP Adapter Xformers Attention Processor (#9881)
Browse files Browse the repository at this point in the history
* Feature IP Adapter Xformers Attention Processor: this fix error loading incorrect attention processor when setting Xformers attn after load ip adapter scale, issues: #8863 #8872
  • Loading branch information
elismasilva authored Nov 9, 2024
1 parent 8d6dc2b commit dac623b
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 11 deletions.
14 changes: 8 additions & 6 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,14 @@


if is_transformers_available():
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
)
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
)

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -284,7 +282,9 @@ def set_ip_adapter_scale(self, scale):
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)

for attn_name, attn_processor in unet.attn_processors.items():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
if isinstance(
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to "
Expand Down Expand Up @@ -342,7 +342,9 @@ def unload_ip_adapter(self):
)
attn_procs[name] = (
attn_processor_class
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
if isinstance(
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
)
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)
13 changes: 9 additions & 4 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
from ..models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
)

if low_cpu_mem_usage:
Expand Down Expand Up @@ -804,11 +805,15 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = self.attn_processors[name].__class__
attn_procs[name] = attn_processor_class()

else:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
if "XFormers" in str(self.attn_processors[name].__class__):
attn_processor_class = IPAdapterXFormersAttnProcessor
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else IPAdapterAttnProcessor
)
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
Expand Down
262 changes: 261 additions & 1 deletion src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,10 @@ def set_use_memory_efficient_attention_xformers(
XFormersAttnAddedKVProcessor,
),
)

is_ip_adapter = hasattr(self, "processor") and isinstance(
self.processor,
(IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
)
if use_memory_efficient_attention_xformers:
if is_added_kv_processor and is_custom_diffusion:
raise NotImplementedError(
Expand Down Expand Up @@ -368,6 +371,19 @@ def set_use_memory_efficient_attention_xformers(
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
elif is_ip_adapter:
processor = IPAdapterXFormersAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
num_tokens=self.processor.num_tokens,
scale=self.processor.scale,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_ip"):
processor.to(
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
Expand All @@ -386,6 +402,18 @@ def set_use_memory_efficient_attention_xformers(
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
elif is_ip_adapter:
processor = IPAdapterAttnProcessor2_0(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
num_tokens=self.processor.num_tokens,
scale=self.processor.scale,
)
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_ip"):
processor.to(
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
)
else:
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
Expand Down Expand Up @@ -4542,6 +4570,238 @@ def __call__(
return hidden_states


class IPAdapterXFormersAttnProcessor(torch.nn.Module):
r"""
Attention processor for IP-Adapter using xFormers.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
scale (`float` or `List[float]`, defaults to 1.0):
the weight scale of image prompt.
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""

def __init__(
self,
hidden_size,
cross_attention_dim=None,
num_tokens=(4,),
scale=1.0,
attention_op: Optional[Callable] = None,
):
super().__init__()

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.attention_op = attention_op

if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
self.num_tokens = num_tokens

if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale

self.to_k_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
self.to_v_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
)

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.FloatTensor] = None,
):
residual = hidden_states

# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()

hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)

if ip_hidden_states:
if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
raise ValueError(
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
f"({len(ip_hidden_states)})"
)
else:
for index, (mask, scale, ip_state) in enumerate(
zip(ip_adapter_masks, self.scale, ip_hidden_states)
):
if mask is None:
continue
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape "
"[1, num_images_for_ip_adapter, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if mask.shape[1] != ip_state.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of ip images ({ip_state.shape[1]}) at index {index}"
)
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of scales ({len(scale)}) at index {index}"
)
else:
ip_adapter_masks = [None] * len(self.scale)

# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
skip = False
if isinstance(scale, list):
if all(s == 0 for s in scale):
skip = True
elif scale == 0:
skip = True
if not skip:
if mask is not None:
mask = mask.to(torch.float16)
if not isinstance(scale, list):
scale = [scale] * mask.shape[1]

current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])

ip_key = attn.head_to_batch_dim(ip_key).contiguous()
ip_value = attn.head_to_batch_dim(ip_value).contiguous()

_current_ip_hidden_states = xformers.ops.memory_efficient_attention(
query, ip_key, ip_value, op=self.attention_op
)
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)

mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = attn.head_to_batch_dim(ip_key).contiguous()
ip_value = attn.head_to_batch_dim(ip_value).contiguous()

current_ip_hidden_states = xformers.ops.memory_efficient_attention(
query, ip_key, ip_value, op=self.attention_op
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)

hidden_states = hidden_states + scale * current_ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class PAGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
Expand Down

0 comments on commit dac623b

Please sign in to comment.