mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add Speech AutoModels (#13655)
* upload * correct * correct * correct * finish * up * up * up again
This commit is contained in:
parent
ea92136597
commit
48fa42e5d5
@ -142,6 +142,20 @@ AutoModelForAudioClassification
|
||||
:members:
|
||||
|
||||
|
||||
AutoModelForCTC
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.AutoModelForCTC
|
||||
:members:
|
||||
|
||||
|
||||
AutoModelForSpeechSeq2Seq
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.AutoModelForSpeechSeq2Seq
|
||||
:members:
|
||||
|
||||
|
||||
AutoModelForObjectDetection
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -557,6 +557,7 @@ if is_torch_available():
|
||||
"AutoModel",
|
||||
"AutoModelForAudioClassification",
|
||||
"AutoModelForCausalLM",
|
||||
"AutoModelForCTC",
|
||||
"AutoModelForImageClassification",
|
||||
"AutoModelForMaskedLM",
|
||||
"AutoModelForMultipleChoice",
|
||||
@ -566,6 +567,7 @@ if is_torch_available():
|
||||
"AutoModelForQuestionAnswering",
|
||||
"AutoModelForSeq2SeqLM",
|
||||
"AutoModelForSequenceClassification",
|
||||
"AutoModelForSpeechSeq2Seq",
|
||||
"AutoModelForTableQuestionAnswering",
|
||||
"AutoModelForTokenClassification",
|
||||
"AutoModelWithLMHead",
|
||||
@ -2320,6 +2322,7 @@ if TYPE_CHECKING:
|
||||
AutoModel,
|
||||
AutoModelForAudioClassification,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForCTC,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForMultipleChoice,
|
||||
@ -2329,6 +2332,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoModelForTableQuestionAnswering,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelWithLMHead,
|
||||
|
@ -32,6 +32,7 @@ if is_torch_available():
|
||||
_import_structure["modeling_auto"] = [
|
||||
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"MODEL_FOR_CTC_MAPPING",
|
||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
@ -41,6 +42,7 @@ if is_torch_available():
|
||||
"MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
|
||||
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"MODEL_MAPPING",
|
||||
@ -48,6 +50,7 @@ if is_torch_available():
|
||||
"AutoModel",
|
||||
"AutoModelForAudioClassification",
|
||||
"AutoModelForCausalLM",
|
||||
"AutoModelForCTC",
|
||||
"AutoModelForImageClassification",
|
||||
"AutoModelForMaskedLM",
|
||||
"AutoModelForMultipleChoice",
|
||||
@ -57,6 +60,7 @@ if is_torch_available():
|
||||
"AutoModelForQuestionAnswering",
|
||||
"AutoModelForSeq2SeqLM",
|
||||
"AutoModelForSequenceClassification",
|
||||
"AutoModelForSpeechSeq2Seq",
|
||||
"AutoModelForTableQuestionAnswering",
|
||||
"AutoModelForTokenClassification",
|
||||
"AutoModelWithLMHead",
|
||||
@ -124,6 +128,7 @@ if TYPE_CHECKING:
|
||||
from .modeling_auto import (
|
||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_CTC_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
@ -133,6 +138,7 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
@ -140,6 +146,7 @@ if TYPE_CHECKING:
|
||||
AutoModel,
|
||||
AutoModelForAudioClassification,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForCTC,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForMultipleChoice,
|
||||
@ -149,6 +156,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoModelForTableQuestionAnswering,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelWithLMHead,
|
||||
|
@ -291,6 +291,13 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
|
||||
("speech_to_text", "Speech2TextForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
@ -462,6 +469,14 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Connectionist temporal classification (CTC) mapping
|
||||
("wav2vec2", "Wav2Vec2ForCTC"),
|
||||
("hubert", "HubertForCTC"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
||||
@ -493,6 +508,8 @@ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
|
||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
|
||||
|
||||
|
||||
class AutoModel(_BaseAutoModelClass):
|
||||
@ -611,6 +628,22 @@ class AutoModelForAudioClassification(_BaseAutoModelClass):
|
||||
AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")
|
||||
|
||||
|
||||
class AutoModelForCTC(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_CTC_MAPPING
|
||||
|
||||
|
||||
AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")
|
||||
|
||||
|
||||
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
||||
|
||||
|
||||
AutoModelForSpeechSeq2Seq = auto_class_update(
|
||||
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing"
|
||||
)
|
||||
|
||||
|
||||
class AutoModelWithLMHead(_AutoModelWithLMHead):
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
|
@ -90,12 +90,14 @@ if is_torch_available():
|
||||
AutoModel,
|
||||
AutoModelForAudioClassification,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForCTC,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForObjectDetection,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoModelForTableQuestionAnswering,
|
||||
AutoModelForTokenClassification,
|
||||
)
|
||||
@ -121,9 +123,7 @@ SUPPORTED_TASKS = {
|
||||
"automatic-speech-recognition": {
|
||||
"impl": AutomaticSpeechRecognitionPipeline,
|
||||
"tf": (),
|
||||
# Only load from `config.architectures`, AutoModelForCTC and AutoModelForConditionalGeneration
|
||||
# do not exist yet.
|
||||
"pt": () if is_torch_available() else (),
|
||||
"pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "facebook/wav2vec2-base-960h"}},
|
||||
},
|
||||
"feature-extraction": {
|
||||
|
@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..file_utils import is_torch_available
|
||||
from ..utils import logging
|
||||
from .base import Pipeline
|
||||
|
||||
@ -25,6 +26,9 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_torch_available():
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
||||
|
||||
|
||||
def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
|
||||
"""
|
||||
@ -102,6 +106,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
if self.framework == "tf":
|
||||
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
|
||||
|
||||
self.check_model_type(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items())
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[np.ndarray, bytes, str],
|
||||
@ -149,8 +155,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
return processed
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
name = self.model.__class__.__name__
|
||||
if name.endswith("ForConditionalGeneration") or name.endswith("EncoderDecoderModel"):
|
||||
model_class = self.model.__class__
|
||||
if model_class in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
||||
encoder = self.model.get_encoder()
|
||||
# we need to pass `processed.get("attention_mask")` here since audio encoder
|
||||
# attention mask length is different from expected text decoder `encoder_attention_mask` length
|
||||
@ -160,7 +166,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
encoder_outputs=encoder(**model_inputs), attention_mask=model_inputs.get("attention_mask")
|
||||
)
|
||||
tokens = tokens.squeeze(0)
|
||||
elif name.endswith("ForCTC"):
|
||||
elif model_class in MODEL_FOR_CTC_MAPPING.values():
|
||||
outputs = self.model(**model_inputs)
|
||||
tokens = outputs.logits.squeeze(0).argmax(dim=-1)
|
||||
return tokens
|
||||
|
@ -379,6 +379,15 @@ class AutoModelForCausalLM:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForCTC:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForImageClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
@ -460,6 +469,15 @@ class AutoModelForSequenceClassification:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForSpeechSeq2Seq:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForTableQuestionAnswering:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
@ -49,10 +49,10 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_torch_small_no_tokenizer_files(self):
|
||||
# test that model without tokenizer file cannot be loaded
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(OSError):
|
||||
pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="hf-internal-testing/tiny-random-wav2vec2",
|
||||
model="patrickvonplaten/tiny-wav2vec2-no-tokenizer",
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user