Skip to content

Commit

Permalink
support automatic_speech_recognition pipeline (#934)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng authored Mar 18, 2024
1 parent c746607 commit 0620b7f
Show file tree
Hide file tree
Showing 19 changed files with 2,283 additions and 8 deletions.
3 changes: 2 additions & 1 deletion .github/pylint.conf
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ disable=raw-checker-failed,
too-few-public-methods,
no-member,
protected-access,
abstract-method
abstract-method,
cyclic-import

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ The table below represents the current support in the library for each of those
| Phi2 |||
| Pop2piano |||
| Qwen2 |||
| Reformer |||
| RegNet | Todo ||
| RoBERTa |||
| RWKV |||
Expand Down
30 changes: 30 additions & 0 deletions llm/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import mindspore
from mindnlp.transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset


model_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, ms_dtype=mindspore.float16, low_cpu_mem_usage=True, use_safetensors=True
)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
return_timestamps=True,
ms_dtype=mindspore.float16,
)

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]

result = pipe(sample)
print(result["text"])
1 change: 1 addition & 0 deletions mindnlp/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
PROCESSOR_NAME = "processor_config.json"

DEFAULT_ROOT = os.path.join(os.getcwd(), ".mindnlp")
# for modelscope models
Expand Down
6 changes: 3 additions & 3 deletions mindnlp/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def _maybe_initialize_input_ids_for_generation(
encoder_outputs = model_kwargs.get("encoder_outputs")
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.size()[:-1]
shape = encoder_outputs.last_hidden_state.shape[:-1]
return ops.ones(shape, dtype=mindspore.int64) * -100

if bos_token_id is None:
Expand Down Expand Up @@ -609,7 +609,7 @@ def _maybe_initialize_input_ids_for_generation(
encoder_outputs = model_kwargs.get("encoder_outputs")
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.size()[:-1]
shape = encoder_outputs.last_hidden_state.shape[:-1]
return ops.ones(shape, dtype=mindspore.int64) * -100

if bos_token_id is None:
Expand Down Expand Up @@ -651,7 +651,7 @@ def _prepare_input_ids_for_generation(
) -> mindspore.Tensor:
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.size()[:-1]
shape = encoder_outputs.last_hidden_state.shape[:-1]
return ops.ones(shape, dtype=mindspore.float32) * -100

if bos_token_id is None:
Expand Down
7 changes: 7 additions & 0 deletions mindnlp/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Auto class."""

from mindnlp.utils import (
Expand All @@ -28,6 +29,9 @@
)

from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from .image_processing_auto import IMAGE_PROCESSOR_MAPPING, AutoImageProcessor
from .processing_auto import PROCESSOR_MAPPING, AutoProcessor


from .modeling_auto import (
Expand Down Expand Up @@ -106,6 +110,9 @@
'AutoConfig',
'TOKENIZER_MAPPING',
'AutoTokenizer',
"FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor",
"IMAGE_PROCESSOR_MAPPING", "AutoImageProcessor",
"PROCESSOR_MAPPING", "AutoProcessor",
'MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING',
'MODEL_FOR_AUDIO_XVECTOR_MAPPING',
'MODEL_FOR_BACKBONE_MAPPING',
Expand Down
1 change: 1 addition & 0 deletions mindnlp/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# pylint: disable=C0116
# pylint: disable=C2801
"""Factory function to build auto-model classes."""
Expand Down
2 changes: 2 additions & 0 deletions mindnlp/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# pylint: disable=C0116
# pylint: disable=C0103
""" Auto Config class."""
Expand Down Expand Up @@ -67,6 +68,7 @@
("starcoder2", "Starcoder2Config"),
('t5', 'T5Config'),
('wav2vec2', 'Wav2Vec2Config'),
("whisper", "WhisperConfig"),
('xlm-roberta', 'XLMRobertaConfig'),
]
)
Expand Down
Loading

0 comments on commit 0620b7f

Please sign in to comment.