mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[ASR Pipe] Update init to set model type and subsequently call parent init method (#28486)
* add image processor arg * super * rm args
This commit is contained in:
parent
c662c78c71
commit
0eaa5ea38e
@ -17,11 +17,10 @@ from typing import TYPE_CHECKING, Dict, Optional, Union
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from ..modelcard import ModelCard
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
from ..utils import is_torch_available, is_torchaudio_available, logging
|
||||
from .audio_utils import ffmpeg_read
|
||||
from .base import ArgumentHandler, ChunkPipeline, infer_framework_load_model
|
||||
from .base import ChunkPipeline
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -35,7 +34,7 @@ logger = logging.get_logger(__name__)
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
||||
|
||||
|
||||
def rescale_stride(stride, ratio):
|
||||
@ -155,11 +154,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
|
||||
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
|
||||
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
|
||||
feature_extractor ([`SequenceFeatureExtractor`]):
|
||||
The feature extractor that will be used by the pipeline to encode waveform for the model.
|
||||
tokenizer ([`PreTrainedTokenizer`]):
|
||||
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
|
||||
[`PreTrainedTokenizer`].
|
||||
feature_extractor ([`SequenceFeatureExtractor`]):
|
||||
The feature extractor that will be used by the pipeline to encode waveform for the model.
|
||||
decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*):
|
||||
[PyCTCDecode's
|
||||
BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180)
|
||||
can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information.
|
||||
chunk_length_s (`float`, *optional*, defaults to 0):
|
||||
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default).
|
||||
|
||||
@ -190,10 +193,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
device (Union[`int`, `torch.device`], *optional*):
|
||||
Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the
|
||||
model on the associated CUDA device id.
|
||||
decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*):
|
||||
[PyCTCDecode's
|
||||
BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180)
|
||||
can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information.
|
||||
torch_dtype (Union[`int`, `torch.dtype`], *optional*):
|
||||
The data-type (dtype) of the computation. Setting this to `None` will use float32 precision. Set to
|
||||
`torch.float16` or `torch.bfloat16` to use half-precision in the respective dtypes.
|
||||
|
||||
"""
|
||||
|
||||
@ -203,77 +205,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
feature_extractor: Union["SequenceFeatureExtractor", str] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizer] = None,
|
||||
decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None,
|
||||
modelcard: Optional[ModelCard] = None,
|
||||
framework: Optional[str] = None,
|
||||
task: str = "",
|
||||
args_parser: ArgumentHandler = None,
|
||||
device: Union[int, "torch.device"] = None,
|
||||
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
|
||||
binary_output: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if framework is None:
|
||||
framework, model = infer_framework_load_model(model, config=model.config)
|
||||
|
||||
self.task = task
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.feature_extractor = feature_extractor
|
||||
self.modelcard = modelcard
|
||||
self.framework = framework
|
||||
|
||||
# `accelerate` device map
|
||||
hf_device_map = getattr(self.model, "hf_device_map", None)
|
||||
|
||||
if hf_device_map is not None and device is not None:
|
||||
raise ValueError(
|
||||
"The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please "
|
||||
"discard the `device` argument when creating your pipeline object."
|
||||
)
|
||||
|
||||
if self.framework == "tf":
|
||||
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
|
||||
|
||||
# We shouldn't call `model.to()` for models loaded with accelerate
|
||||
if device is not None and not (isinstance(device, int) and device < 0):
|
||||
self.model.to(device)
|
||||
|
||||
if device is None:
|
||||
if hf_device_map is not None:
|
||||
# Take the first device used by `accelerate`.
|
||||
device = next(iter(hf_device_map.values()))
|
||||
else:
|
||||
device = -1
|
||||
|
||||
if is_torch_available() and self.framework == "pt":
|
||||
if isinstance(device, torch.device):
|
||||
self.device = device
|
||||
elif isinstance(device, str):
|
||||
self.device = torch.device(device)
|
||||
elif device < 0:
|
||||
self.device = torch.device("cpu")
|
||||
else:
|
||||
self.device = torch.device(f"cuda:{device}")
|
||||
else:
|
||||
self.device = device if device is not None else -1
|
||||
self.torch_dtype = torch_dtype
|
||||
self.binary_output = binary_output
|
||||
|
||||
# Update config and generation_config with task specific parameters
|
||||
task_specific_params = self.model.config.task_specific_params
|
||||
if task_specific_params is not None and task in task_specific_params:
|
||||
self.model.config.update(task_specific_params.get(task))
|
||||
if self.model.can_generate():
|
||||
self.model.generation_config.update(**task_specific_params.get(task))
|
||||
|
||||
self.call_count = 0
|
||||
self._batch_size = kwargs.pop("batch_size", None)
|
||||
self._num_workers = kwargs.pop("num_workers", None)
|
||||
|
||||
# set the model type so we can check we have the right pre- and post-processing parameters
|
||||
if self.model.config.model_type == "whisper":
|
||||
if model.config.model_type == "whisper":
|
||||
self.type = "seq2seq_whisper"
|
||||
elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
|
||||
elif model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
|
||||
self.type = "seq2seq"
|
||||
elif (
|
||||
feature_extractor._processor_class
|
||||
@ -285,11 +224,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
else:
|
||||
self.type = "ctc"
|
||||
|
||||
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
|
||||
|
||||
mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy()
|
||||
mapping.update(MODEL_FOR_CTC_MAPPING_NAMES)
|
||||
self.check_model_type(mapping)
|
||||
super().__init__(model, tokenizer, feature_extractor, device=device, torch_dtype=torch_dtype, **kwargs)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user