Skip to content

Commit

Permalink
add static shape support in image process, replace unfold with conv2d…
Browse files Browse the repository at this point in the history
… to speedup finetune

Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi committed Sep 14, 2024
1 parent 998574e commit 960b730
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 14 deletions.
3 changes: 1 addition & 2 deletions examples/image-to-text/run_image2text_lora_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def main():
do_image_splitting=model_args.do_image_splitting,
padding_side="right",
)

setattr(processor.image_processor, "pad_to_longest_edge", True)
config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
Expand Down Expand Up @@ -503,7 +503,6 @@ def main():
],
}
]
processor.tokenizer.padding_side = "left"
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
text=[text.strip()],
Expand Down
4 changes: 4 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .models import (
DeciLMConfig,
DeciLMForCausalLM,
Gaudi2Idefics2ImageProcessor,
GaudiBloomForCausalLM,
GaudiBloomMLP,
GaudiCLIPAttention,
Expand Down Expand Up @@ -58,6 +59,7 @@
GaudiGPTNeoXLayer,
GaudiIdefics2ForConditionalGeneration,
GaudiIdefics2Model,
GaudiIdefics2VisionEmbeddings,
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
GaudiLlamaDynamicNTKScalingRotaryEmbedding,
Expand Down Expand Up @@ -387,6 +389,8 @@ def adapt_transformers_to_gaudi():
GaudiIdefics2ForConditionalGeneration
)
transformers.models.idefics2.modeling_idefics2.Idefics2Model = GaudiIdefics2Model
transformers.models.idefics2.image_processing_idefics2.Idefics2ImageProcessor = Gaudi2Idefics2ImageProcessor
transformers.models.idefics2.modeling_idefics2.Idefics2VisionEmbeddings = GaudiIdefics2VisionEmbeddings

# Optimization for Clip on Gaudi
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings = GaudiCLIPVisionEmbeddings
Expand Down
7 changes: 6 additions & 1 deletion optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@
GaudiGPTJForCausalLM,
GaudiGPTJModel,
)
from .idefics2 import GaudiIdefics2ForConditionalGeneration, GaudiIdefics2Model
from .idefics2 import (
Gaudi2Idefics2ImageProcessor,
GaudiIdefics2ForConditionalGeneration,
GaudiIdefics2Model,
GaudiIdefics2VisionEmbeddings,
)
from .llama import (
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
Expand Down
3 changes: 2 additions & 1 deletion optimum/habana/transformers/models/idefics2/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .modeling_idefics2 import GaudiIdefics2ForConditionalGeneration, GaudiIdefics2Model
from .image_processing_idefics2 import Gaudi2Idefics2ImageProcessor
from .modeling_idefics2 import GaudiIdefics2ForConditionalGeneration, GaudiIdefics2Model, GaudiIdefics2VisionEmbeddings
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Iterable, List, Optional, Union

import numpy as np
from transformers.image_processing_utils import BatchFeature
from transformers.image_utils import ChannelDimension, infer_channel_dimension_format
from transformers.models.idefics2.image_processing_idefics2 import (
Idefics2ImageProcessor,
get_max_height_width,
make_pixel_mask,
)
from transformers.utils import TensorType


class Gaudi2Idefics2ImageProcessor(Idefics2ImageProcessor):
def pad(
self,
images: List[np.ndarray],
constant_values: Union[float, Iterable[float]] = 0,
return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature:
"""
Inherits from Idefics2ImageProcessor::pad https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/image_processing_idefics2.py#L314
The only differences are:
- pad size use longest_edge, so the image size will not change, aims to accelerate finetune speed
"""

if getattr(self, "pad_to_longest_edge", False):
pad_size = (self.size["longest_edge"], self.size["longest_edge"])
else:
pad_size = get_max_height_width(images, input_data_format=input_data_format)

batch_size = len(images)
max_num_images = max(len(images_) for images_ in images)
input_data_format = (
infer_channel_dimension_format(images[0][0]) if input_data_format is None else input_data_format
)
data_format = input_data_format if data_format is None else data_format

def empty_image(size, input_data_format):
if input_data_format == ChannelDimension.FIRST:
return np.zeros((3, *size), dtype=np.uint8)
elif input_data_format == ChannelDimension.LAST:
return np.zeros((*size, 3), dtype=np.uint8)
raise ValueError("Invalid channel dimension format.")

padded_images_list = [
[empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size)
]
padded_masks = [[np.zeros(pad_size) for _ in range(max_num_images)] for _ in range(batch_size)]

for batch_idx in range(batch_size):
for sample_idx, image in enumerate(images[batch_idx]):
padded_images_list[batch_idx][sample_idx] = self._pad_image(
image,
pad_size,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
padded_masks[batch_idx][sample_idx] = make_pixel_mask(
image, output_size=pad_size, input_data_format=input_data_format
)

padded_masks = padded_masks if return_pixel_mask else None
return padded_images_list, padded_masks
62 changes: 55 additions & 7 deletions optimum/habana/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,52 @@
Idefics2CausalLMOutputWithPast,
Idefics2ForConditionalGeneration,
Idefics2Model,
Idefics2VisionEmbeddings,
)
from transformers.utils import logging


logger = logging.get_logger(__name__)


class GaudiIdefics2VisionEmbeddings(Idefics2VisionEmbeddings):
def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
"""
Inherits from Idefics2VisionEmbeddings::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L159
The only differences are:
- add int() in nb_patches_h. nb_patches_w to avoid overflow in torch.arange. sometimes return shape is nb_patches_h/nb_patch_w + 1
- delete to("cpu") of p_attn_mask
"""

batch_size, _, max_im_h, max_im_w = pixel_values.shape

patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)

max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w),
fill_value=0,
device=self.position_embedding.weight.device,
)

for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = int(p_attn_mask[:, 0].sum())
nb_patches_w = int(p_attn_mask[0].sum())

fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)

bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings


class GaudiIdefics2Model(Idefics2Model):
def forward(
self,
Expand All @@ -52,7 +91,7 @@ def forward(
Inherits from Idefics2Model::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1303
The only differences are:
- ignoring new Cache path for HPU
- unfold is not supported in HPU, fallback to cpu
- unfold is not supported in HPU, replace with conv2d
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -120,9 +159,13 @@ def forward(
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()

patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.cpu().unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.cpu().unfold(dimension=2, size=patch_size, step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
conv_kernel = torch.ones(
[1, 1, patch_size, patch_size], dtype=pixel_values.dtype, device=pixel_values.device
)
patches_subgrid = torch.nn.functional.conv2d(
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), conv_kernel, stride=patch_size
).squeeze(1)
patch_attention_mask = torch.eq(patches_subgrid, (patch_size * patch_size))

# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
Expand Down Expand Up @@ -345,6 +388,7 @@ def prepare_inputs_for_generation(
The only differences are:
- add new args token_idx
- add None "Cache" past_key_values support
- move vision_model to prepare_input_for_generation
"""
past_length = 0
token_idx = kwargs.get("token_idx", None)
Expand Down Expand Up @@ -430,9 +474,13 @@ def prepare_inputs_for_generation(
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()

patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.cpu().unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.cpu().unfold(dimension=2, size=patch_size, step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
conv_kernel = torch.ones(
[1, 1, patch_size, patch_size], dtype=pixel_values.dtype, device=pixel_values.device
)
patches_subgrid = torch.nn.functional.conv2d(
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), conv_kernel, stride=patch_size
).squeeze(1)
patch_attention_mask = torch.eq(patches_subgrid, (patch_size * patch_size))

# Get sequence from the vision encoder
image_hidden_states = self.model.vision_model(
Expand Down
6 changes: 3 additions & 3 deletions tests/baselines/idefics2_8b.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
"multi_card": {
"learning_rate": 5e-5,
"train_batch_size": 2,
"train_runtime": 420,
"train_samples_per_second": 6.728,
"eval_accuracy": 0.54,
"train_runtime": 240,
"train_samples_per_second": 14,
"eval_accuracy": 0.6,
"extra_arguments": [
"--bf16",
"--gradient_accumulation_steps 8",
Expand Down

0 comments on commit 960b730

Please sign in to comment.