mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix mono conversion and support torch tensor as input
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
This commit is contained in:
parent
ad030a85cb
commit
c9d463e3b0
@ -174,15 +174,22 @@ class AudioClassificationPipeline(Pipeline):
|
||||
if isinstance(inputs, bytes):
|
||||
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = inputs.cpu().numpy()
|
||||
|
||||
if is_torchcodec_available():
|
||||
import torch
|
||||
import torchcodec
|
||||
|
||||
if isinstance(inputs, torchcodec.decoders.AudioDecoder):
|
||||
_audio_samples = inputs.get_all_samples()
|
||||
_array = _audio_samples.data
|
||||
_data = _audio_samples.data
|
||||
# to mono
|
||||
_array = np.mean(_array, axis=tuple(range(_array.ndim - 1))) if _array.ndim > 1 else _array
|
||||
inputs = {"array": _array, "sampling_rate": _audio_samples.sample_rate}
|
||||
_data = torch.mean(_data, 0) if _data.ndim > 1 else _data
|
||||
inputs = {"array": _data, "sampling_rate": _audio_samples.sample_rate}
|
||||
|
||||
if isinstance(inputs, dict):
|
||||
inputs = inputs.copy() # So we don't mutate the original dictionary outside the pipeline
|
||||
@ -191,7 +198,7 @@ class AudioClassificationPipeline(Pipeline):
|
||||
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
|
||||
raise ValueError(
|
||||
"When passing a dictionary to AudioClassificationPipeline, the dict needs to contain a "
|
||||
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
|
||||
'"raw" key containing the numpy array or torch tensor representing the audio and a "sampling_rate" key, '
|
||||
"containing the sampling_rate associated with that array"
|
||||
)
|
||||
|
||||
@ -214,11 +221,13 @@ class AudioClassificationPipeline(Pipeline):
|
||||
)
|
||||
|
||||
inputs = F.resample(
|
||||
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
|
||||
torch.from_numpy(inputs) if isinstance(inputs, np.ndarray) else inputs,
|
||||
in_sampling_rate,
|
||||
self.feature_extractor.sampling_rate,
|
||||
).numpy()
|
||||
|
||||
if not isinstance(inputs, np.ndarray):
|
||||
raise TypeError("We expect a numpy ndarray as input")
|
||||
raise TypeError("We expect a numpy ndarray or torch tensor as input")
|
||||
if len(inputs.shape) != 1:
|
||||
raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")
|
||||
|
||||
|
@ -360,15 +360,21 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
stride = None
|
||||
extra = {}
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = inputs.cpu().numpy()
|
||||
|
||||
if is_torchcodec_available():
|
||||
import torchcodec
|
||||
|
||||
if isinstance(inputs, torchcodec.decoders.AudioDecoder):
|
||||
_audio_samples = inputs.get_all_samples()
|
||||
_array = _audio_samples.data
|
||||
_data = _audio_samples.data
|
||||
# to mono
|
||||
_array = np.mean(_array, axis=tuple(range(_array.ndim - 1))) if _array.ndim > 1 else _array
|
||||
inputs = {"array": _array, "sampling_rate": _audio_samples.sample_rate}
|
||||
_data = torch.mean(_data, 0) if _data.ndim > 1 else _data
|
||||
inputs = {"array": _data, "sampling_rate": _audio_samples.sample_rate}
|
||||
|
||||
if isinstance(inputs, dict):
|
||||
stride = inputs.pop("stride", None)
|
||||
@ -377,7 +383,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
|
||||
raise ValueError(
|
||||
"When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a "
|
||||
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
|
||||
'"raw" key containing the numpy array or torch tensor representing the audio and a "sampling_rate" key, '
|
||||
"containing the sampling_rate associated with that array"
|
||||
)
|
||||
|
||||
@ -399,7 +405,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
)
|
||||
|
||||
inputs = F.resample(
|
||||
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
|
||||
torch.from_numpy(inputs) if isinstance(inputs, np.ndarray) else inputs,
|
||||
in_sampling_rate,
|
||||
in_sampling_rate,
|
||||
self.feature_extractor.sampling_rate,
|
||||
).numpy()
|
||||
ratio = self.feature_extractor.sampling_rate / in_sampling_rate
|
||||
else:
|
||||
@ -414,7 +423,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
# of the original length in the stride so we can cut properly.
|
||||
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
|
||||
if not isinstance(inputs, np.ndarray):
|
||||
raise TypeError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
|
||||
raise TypeError(f"We expect a numpy ndarray or torch tensor as input, got `{type(inputs)}`")
|
||||
if len(inputs.shape) != 1:
|
||||
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user