[ASR Pipe] Update init to set model type and subsequently call parent init method (#28486)

* add image processor arg

* super

* rm args
This commit is contained in:
Sanchit Gandhi 2024-01-18 16:11:49 +00:00 committed by GitHub
parent c662c78c71
commit 0eaa5ea38e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,11 +17,10 @@ from typing import TYPE_CHECKING, Dict, Optional, Union
import numpy as np import numpy as np
import requests import requests
from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer
from ..utils import is_torch_available, is_torchaudio_available, logging from ..utils import is_torch_available, is_torchaudio_available, logging
from .audio_utils import ffmpeg_read from .audio_utils import ffmpeg_read
from .base import ArgumentHandler, ChunkPipeline, infer_framework_load_model from .base import ChunkPipeline
if TYPE_CHECKING: if TYPE_CHECKING:
@ -35,7 +34,7 @@ logger = logging.get_logger(__name__)
if is_torch_available(): if is_torch_available():
import torch import torch
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES from ..models.auto.modeling_auto import MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
def rescale_stride(stride, ratio): def rescale_stride(stride, ratio):
@ -155,11 +154,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow. [`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
feature_extractor ([`SequenceFeatureExtractor`]):
The feature extractor that will be used by the pipeline to encode waveform for the model.
tokenizer ([`PreTrainedTokenizer`]): tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
[`PreTrainedTokenizer`]. [`PreTrainedTokenizer`].
feature_extractor ([`SequenceFeatureExtractor`]): decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*):
The feature extractor that will be used by the pipeline to encode waveform for the model. [PyCTCDecode's
BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180)
can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information.
chunk_length_s (`float`, *optional*, defaults to 0): chunk_length_s (`float`, *optional*, defaults to 0):
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default). The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default).
@ -190,10 +193,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
device (Union[`int`, `torch.device`], *optional*): device (Union[`int`, `torch.device`], *optional*):
Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the
model on the associated CUDA device id. model on the associated CUDA device id.
decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*): torch_dtype (Union[`int`, `torch.dtype`], *optional*):
[PyCTCDecode's The data-type (dtype) of the computation. Setting this to `None` will use float32 precision. Set to
BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180) `torch.float16` or `torch.bfloat16` to use half-precision in the respective dtypes.
can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information.
""" """
@ -203,77 +205,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
feature_extractor: Union["SequenceFeatureExtractor", str] = None, feature_extractor: Union["SequenceFeatureExtractor", str] = None,
tokenizer: Optional[PreTrainedTokenizer] = None, tokenizer: Optional[PreTrainedTokenizer] = None,
decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None, decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
task: str = "",
args_parser: ArgumentHandler = None,
device: Union[int, "torch.device"] = None, device: Union[int, "torch.device"] = None,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None, torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
binary_output: bool = False,
**kwargs, **kwargs,
): ):
if framework is None:
framework, model = infer_framework_load_model(model, config=model.config)
self.task = task
self.model = model
self.tokenizer = tokenizer
self.feature_extractor = feature_extractor
self.modelcard = modelcard
self.framework = framework
# `accelerate` device map
hf_device_map = getattr(self.model, "hf_device_map", None)
if hf_device_map is not None and device is not None:
raise ValueError(
"The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please "
"discard the `device` argument when creating your pipeline object."
)
if self.framework == "tf":
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
# We shouldn't call `model.to()` for models loaded with accelerate
if device is not None and not (isinstance(device, int) and device < 0):
self.model.to(device)
if device is None:
if hf_device_map is not None:
# Take the first device used by `accelerate`.
device = next(iter(hf_device_map.values()))
else:
device = -1
if is_torch_available() and self.framework == "pt":
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
else:
self.device = torch.device(f"cuda:{device}")
else:
self.device = device if device is not None else -1
self.torch_dtype = torch_dtype
self.binary_output = binary_output
# Update config and generation_config with task specific parameters
task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params:
self.model.config.update(task_specific_params.get(task))
if self.model.can_generate():
self.model.generation_config.update(**task_specific_params.get(task))
self.call_count = 0
self._batch_size = kwargs.pop("batch_size", None)
self._num_workers = kwargs.pop("num_workers", None)
# set the model type so we can check we have the right pre- and post-processing parameters # set the model type so we can check we have the right pre- and post-processing parameters
if self.model.config.model_type == "whisper": if model.config.model_type == "whisper":
self.type = "seq2seq_whisper" self.type = "seq2seq_whisper"
elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values(): elif model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
self.type = "seq2seq" self.type = "seq2seq"
elif ( elif (
feature_extractor._processor_class feature_extractor._processor_class
@ -285,11 +224,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
else: else:
self.type = "ctc" self.type = "ctc"
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) super().__init__(model, tokenizer, feature_extractor, device=device, torch_dtype=torch_dtype, **kwargs)
mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy()
mapping.update(MODEL_FOR_CTC_MAPPING_NAMES)
self.check_model_type(mapping)
def __call__( def __call__(
self, self,