Skip to content

Commit

Permalink
add altclip & ast model (#944)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng authored Mar 20, 2024
1 parent 2d0f1bf commit c56d826
Show file tree
Hide file tree
Showing 25 changed files with 4,581 additions and 5 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ The table below represents the current support in the library for each of those
| Model | Pynative support | Graph Support |
|-------------------------------|---------------------|---------------|
| ALBERT |||
| ALIGN |||
| AltCLIP |||
| Audio Spectrogram Transformer |||
| Autoformer | ✅ (Inference only)||
| BaiChuan |||
| Bark |||
Expand Down
9 changes: 9 additions & 0 deletions mindnlp/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
from . import (
albert,
align,
altclip,
audio_spectrogram_transformer,
auto,
autoformer,
bark,
bart,
bert,
Expand Down Expand Up @@ -83,7 +86,10 @@

from .albert import *
from .align import *
from .altclip import *
from .audio_spectrogram_transformer import *
from .auto import *
from .autoformer import *
from .bark import *
from .bart import *
from .bert import *
Expand Down Expand Up @@ -148,7 +154,10 @@
__all__ = []
__all__.extend(albert.__all__)
__all__.extend(align.__all__)
__all__.extend(altclip.__all__)
__all__.extend(audio_spectrogram_transformer.__all__)
__all__.extend(auto.__all__)
__all__.extend(autoformer.__all__)
__all__.extend(bart.__all__)
__all__.extend(bark.__all__)
__all__.extend(bert.__all__)
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/transformers/models/align/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 Huawei Technologies Co., Ltd
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
24 changes: 24 additions & 0 deletions mindnlp/transformers/models/altclip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Align model."""
from . import configuration_altclip, modeling_altclip, processing_altclip
from .configuration_altclip import *
from .modeling_altclip import *
from .processing_altclip import *

__all__ = []
__all__.extend(configuration_altclip.__all__)
__all__.extend(modeling_altclip.__all__)
__all__.extend(processing_altclip.__all__)
409 changes: 409 additions & 0 deletions mindnlp/transformers/models/altclip/configuration_altclip.py

Large diffs are not rendered by default.

1,610 changes: 1,610 additions & 0 deletions mindnlp/transformers/models/altclip/modeling_altclip.py

Large diffs are not rendered by default.

134 changes: 134 additions & 0 deletions mindnlp/transformers/models/altclip/processing_altclip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# coding=utf-8
# Copyright 2022 WenXiang ZhongzhiCheng LedellWu LiuGuang BoWenZhang The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Image/Text processor class for AltCLIP
"""
import warnings

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding


class AltCLIPProcessor(ProcessorMixin):
r"""
Constructs a AltCLIP processor which wraps a CLIP image processor and a XLM-Roberta tokenizer into a single
processor.
[`AltCLIPProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`XLMRobertaTokenizerFast`]. See
the [`~AltCLIPProcessor.__call__`] and [`~AltCLIPProcessor.decode`] for more information.
Args:
image_processor ([`CLIPImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`XLMRobertaTokenizerFast`], *optional*):
The tokenizer is a required input.
"""

attributes = ["image_processor", "tokenizer"]
image_processor_class = "CLIPImageProcessor"
tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast")

def __init__(self, image_processor=None, tokenizer=None, **kwargs):
feature_extractor = None
if "feature_extractor" in kwargs:
warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")

image_processor = image_processor if image_processor is not None else feature_extractor
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")

super().__init__(image_processor, tokenizer)

def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to XLMRobertaTokenizerFast's [`~XLMRobertaTokenizerFast.__call__`] if `text` is not
`None` to encode the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""

if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")

if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)

if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)

if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
if text is not None:
return encoding
return BatchEncoding(data={**image_features}, tensor_type=return_tensors)

def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to XLMRobertaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`].
Please refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)

def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to XLMRobertaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)

@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))

__all__ = ["AltCLIPProcessor"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Audio Spectrgram Transformer model."""
from . import configuration_audio_spectrogram_transformer, feature_extraction_audio_spectrogram_transformer, modeling_audio_spectrogram_transformer
from .configuration_audio_spectrogram_transformer import *
from .feature_extraction_audio_spectrogram_transformer import *
from .modeling_audio_spectrogram_transformer import *

__all__ = []
__all__.extend(configuration_audio_spectrogram_transformer.__all__)
__all__.extend(feature_extraction_audio_spectrogram_transformer.__all__)
__all__.extend(modeling_audio_spectrogram_transformer.__all__)
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# coding=utf-8
# Copyright 2022 Google AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Audio Spectogram Transformer (AST) model configuration"""


from mindnlp.utils import logging
from ...configuration_utils import PretrainedConfig


logger = logging.get_logger(__name__)

AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"MIT/ast-finetuned-audioset-10-10-0.4593": (
"https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593/resolve/main/config.json"
),
}


class ASTConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ASTModel`]. It is used to instantiate an AST
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the AST
[MIT/ast-finetuned-audioset-10-10-0.4593](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593)
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys and values.
frequency_stride (`int`, *optional*, defaults to 10):
Frequency stride to use when patchifying the spectrograms.
time_stride (`int`, *optional*, defaults to 10):
Temporal stride to use when patchifying the spectrograms.
max_length (`int`, *optional*, defaults to 1024):
Temporal dimension of the spectrograms.
num_mel_bins (`int`, *optional*, defaults to 128):
Frequency dimension of the spectrograms (number of Mel-frequency bins).
Example:
```python
>>> from transformers import ASTConfig, ASTModel
>>> # Initializing a AST MIT/ast-finetuned-audioset-10-10-0.4593 style configuration
>>> configuration = ASTConfig()
>>> # Initializing a model (with random weights) from the MIT/ast-finetuned-audioset-10-10-0.4593 style configuration
>>> model = ASTModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "audio-spectrogram-transformer"

def __init__(
self,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
layer_norm_eps=1e-12,
patch_size=16,
qkv_bias=True,
frequency_stride=10,
time_stride=10,
max_length=1024,
num_mel_bins=128,
**kwargs,
):
super().__init__(**kwargs)

self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.patch_size = patch_size
self.qkv_bias = qkv_bias
self.frequency_stride = frequency_stride
self.time_stride = time_stride
self.max_length = max_length
self.num_mel_bins = num_mel_bins

__all__ = [
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ASTConfig",
]
Loading

0 comments on commit c56d826

Please sign in to comment.