Add Speech AutoModels (#13655)

* upload

* correct

* correct

* correct

* finish

* up

* up

* up again
This commit is contained in:
Patrick von Platen 2021-09-21 08:50:33 +02:00 committed by GitHub
parent ea92136597
commit 48fa42e5d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 91 additions and 8 deletions

View File

@ -142,6 +142,20 @@ AutoModelForAudioClassification
:members:
AutoModelForCTC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoModelForCTC
:members:
AutoModelForSpeechSeq2Seq
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoModelForSpeechSeq2Seq
:members:
AutoModelForObjectDetection
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -557,6 +557,7 @@ if is_torch_available():
"AutoModel",
"AutoModelForAudioClassification",
"AutoModelForCausalLM",
"AutoModelForCTC",
"AutoModelForImageClassification",
"AutoModelForMaskedLM",
"AutoModelForMultipleChoice",
@ -566,6 +567,7 @@ if is_torch_available():
"AutoModelForQuestionAnswering",
"AutoModelForSeq2SeqLM",
"AutoModelForSequenceClassification",
"AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification",
"AutoModelWithLMHead",
@ -2320,6 +2322,7 @@ if TYPE_CHECKING:
AutoModel,
AutoModelForAudioClassification,
AutoModelForCausalLM,
AutoModelForCTC,
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
@ -2329,6 +2332,7 @@ if TYPE_CHECKING:
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelWithLMHead,

View File

@ -32,6 +32,7 @@ if is_torch_available():
_import_structure["modeling_auto"] = [
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_CTC_MAPPING",
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
@ -41,6 +42,7 @@ if is_torch_available():
"MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_MAPPING",
@ -48,6 +50,7 @@ if is_torch_available():
"AutoModel",
"AutoModelForAudioClassification",
"AutoModelForCausalLM",
"AutoModelForCTC",
"AutoModelForImageClassification",
"AutoModelForMaskedLM",
"AutoModelForMultipleChoice",
@ -57,6 +60,7 @@ if is_torch_available():
"AutoModelForQuestionAnswering",
"AutoModelForSeq2SeqLM",
"AutoModelForSequenceClassification",
"AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification",
"AutoModelWithLMHead",
@ -124,6 +128,7 @@ if TYPE_CHECKING:
from .modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
@ -133,6 +138,7 @@ if TYPE_CHECKING:
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
@ -140,6 +146,7 @@ if TYPE_CHECKING:
AutoModel,
AutoModelForAudioClassification,
AutoModelForCausalLM,
AutoModelForCTC,
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
@ -149,6 +156,7 @@ if TYPE_CHECKING:
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelWithLMHead,

View File

@ -291,6 +291,13 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
]
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
@ -462,6 +469,14 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
[
# Model for Connectionist temporal classification (CTC) mapping
("wav2vec2", "Wav2Vec2ForCTC"),
("hubert", "HubertForCTC"),
]
)
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
@ -493,6 +508,8 @@ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
class AutoModel(_BaseAutoModelClass):
@ -611,6 +628,22 @@ class AutoModelForAudioClassification(_BaseAutoModelClass):
AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")
class AutoModelForCTC(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CTC_MAPPING
AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
AutoModelForSpeechSeq2Seq = auto_class_update(
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing"
)
class AutoModelWithLMHead(_AutoModelWithLMHead):
@classmethod
def from_config(cls, config):

View File

@ -90,12 +90,14 @@ if is_torch_available():
AutoModel,
AutoModelForAudioClassification,
AutoModelForCausalLM,
AutoModelForCTC,
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForObjectDetection,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
)
@ -121,9 +123,7 @@ SUPPORTED_TASKS = {
"automatic-speech-recognition": {
"impl": AutomaticSpeechRecognitionPipeline,
"tf": (),
# Only load from `config.architectures`, AutoModelForCTC and AutoModelForConditionalGeneration
# do not exist yet.
"pt": () if is_torch_available() else (),
"pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
"default": {"model": {"pt": "facebook/wav2vec2-base-960h"}},
},
"feature-extraction": {

View File

@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Union
import numpy as np
from ..file_utils import is_torch_available
from ..utils import logging
from .base import Pipeline
@ -25,6 +26,9 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
"""
@ -102,6 +106,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
if self.framework == "tf":
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
self.check_model_type(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items())
def __call__(
self,
inputs: Union[np.ndarray, bytes, str],
@ -149,8 +155,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
return processed
def _forward(self, model_inputs):
name = self.model.__class__.__name__
if name.endswith("ForConditionalGeneration") or name.endswith("EncoderDecoderModel"):
model_class = self.model.__class__
if model_class in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
encoder = self.model.get_encoder()
# we need to pass `processed.get("attention_mask")` here since audio encoder
# attention mask length is different from expected text decoder `encoder_attention_mask` length
@ -160,7 +166,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
encoder_outputs=encoder(**model_inputs), attention_mask=model_inputs.get("attention_mask")
)
tokens = tokens.squeeze(0)
elif name.endswith("ForCTC"):
elif model_class in MODEL_FOR_CTC_MAPPING.values():
outputs = self.model(**model_inputs)
tokens = outputs.logits.squeeze(0).argmax(dim=-1)
return tokens

View File

@ -379,6 +379,15 @@ class AutoModelForCausalLM:
requires_backends(cls, ["torch"])
class AutoModelForCTC:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoModelForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@ -460,6 +469,15 @@ class AutoModelForSequenceClassification:
requires_backends(cls, ["torch"])
class AutoModelForSpeechSeq2Seq:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoModelForTableQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

View File

@ -49,10 +49,10 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
@require_torch
def test_torch_small_no_tokenizer_files(self):
# test that model without tokenizer file cannot be loaded
with pytest.raises(ValueError):
with pytest.raises(OSError):
pipeline(
task="automatic-speech-recognition",
model="hf-internal-testing/tiny-random-wav2vec2",
model="patrickvonplaten/tiny-wav2vec2-no-tokenizer",
framework="pt",
)