[ASR Pipe] Update init to set model type and subsequently call parent init method (#28486)

* add image processor arg

* super

* rm args
This commit is contained in:
Sanchit Gandhi 2024-01-18 16:11:49 +00:00 committed by GitHub
parent c662c78c71
commit 0eaa5ea38e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,11 +17,10 @@ from typing import TYPE_CHECKING, Dict, Optional, Union
import numpy as np
import requests
from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import is_torch_available, is_torchaudio_available, logging
from .audio_utils import ffmpeg_read
from .base import ArgumentHandler, ChunkPipeline, infer_framework_load_model
from .base import ChunkPipeline
if TYPE_CHECKING:
@ -35,7 +34,7 @@ logger = logging.get_logger(__name__)
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
from ..models.auto.modeling_auto import MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
def rescale_stride(stride, ratio):
@ -155,11 +154,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
feature_extractor ([`SequenceFeatureExtractor`]):
The feature extractor that will be used by the pipeline to encode waveform for the model.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
[`PreTrainedTokenizer`].
feature_extractor ([`SequenceFeatureExtractor`]):
The feature extractor that will be used by the pipeline to encode waveform for the model.
decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*):
[PyCTCDecode's
BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180)
can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information.
chunk_length_s (`float`, *optional*, defaults to 0):
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default).
@ -190,10 +193,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
device (Union[`int`, `torch.device`], *optional*):
Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the
model on the associated CUDA device id.
decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*):
[PyCTCDecode's
BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180)
can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information.
torch_dtype (Union[`int`, `torch.dtype`], *optional*):
The data-type (dtype) of the computation. Setting this to `None` will use float32 precision. Set to
`torch.float16` or `torch.bfloat16` to use half-precision in the respective dtypes.
"""
@ -203,77 +205,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
feature_extractor: Union["SequenceFeatureExtractor", str] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
task: str = "",
args_parser: ArgumentHandler = None,
device: Union[int, "torch.device"] = None,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
binary_output: bool = False,
**kwargs,
):
if framework is None:
framework, model = infer_framework_load_model(model, config=model.config)
self.task = task
self.model = model
self.tokenizer = tokenizer
self.feature_extractor = feature_extractor
self.modelcard = modelcard
self.framework = framework
# `accelerate` device map
hf_device_map = getattr(self.model, "hf_device_map", None)
if hf_device_map is not None and device is not None:
raise ValueError(
"The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please "
"discard the `device` argument when creating your pipeline object."
)
if self.framework == "tf":
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
# We shouldn't call `model.to()` for models loaded with accelerate
if device is not None and not (isinstance(device, int) and device < 0):
self.model.to(device)
if device is None:
if hf_device_map is not None:
# Take the first device used by `accelerate`.
device = next(iter(hf_device_map.values()))
else:
device = -1
if is_torch_available() and self.framework == "pt":
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
else:
self.device = torch.device(f"cuda:{device}")
else:
self.device = device if device is not None else -1
self.torch_dtype = torch_dtype
self.binary_output = binary_output
# Update config and generation_config with task specific parameters
task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params:
self.model.config.update(task_specific_params.get(task))
if self.model.can_generate():
self.model.generation_config.update(**task_specific_params.get(task))
self.call_count = 0
self._batch_size = kwargs.pop("batch_size", None)
self._num_workers = kwargs.pop("num_workers", None)
# set the model type so we can check we have the right pre- and post-processing parameters
if self.model.config.model_type == "whisper":
if model.config.model_type == "whisper":
self.type = "seq2seq_whisper"
elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
elif model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
self.type = "seq2seq"
elif (
feature_extractor._processor_class
@ -285,11 +224,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
else:
self.type = "ctc"
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy()
mapping.update(MODEL_FOR_CTC_MAPPING_NAMES)
self.check_model_type(mapping)
super().__init__(model, tokenizer, feature_extractor, device=device, torch_dtype=torch_dtype, **kwargs)
def __call__(
self,