fix mono conversion and support torch tensor as input
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled

This commit is contained in:
Quentin Lhoest 2025-07-03 15:01:03 +02:00
parent ad030a85cb
commit c9d463e3b0
2 changed files with 30 additions and 12 deletions

View File

@ -174,15 +174,22 @@ class AudioClassificationPipeline(Pipeline):
if isinstance(inputs, bytes):
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
if is_torch_available():
import torch
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu().numpy()
if is_torchcodec_available():
import torch
import torchcodec
if isinstance(inputs, torchcodec.decoders.AudioDecoder):
_audio_samples = inputs.get_all_samples()
_array = _audio_samples.data
_data = _audio_samples.data
# to mono
_array = np.mean(_array, axis=tuple(range(_array.ndim - 1))) if _array.ndim > 1 else _array
inputs = {"array": _array, "sampling_rate": _audio_samples.sample_rate}
_data = torch.mean(_data, 0) if _data.ndim > 1 else _data
inputs = {"array": _data, "sampling_rate": _audio_samples.sample_rate}
if isinstance(inputs, dict):
inputs = inputs.copy() # So we don't mutate the original dictionary outside the pipeline
@ -191,7 +198,7 @@ class AudioClassificationPipeline(Pipeline):
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
raise ValueError(
"When passing a dictionary to AudioClassificationPipeline, the dict needs to contain a "
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
'"raw" key containing the numpy array or torch tensor representing the audio and a "sampling_rate" key, '
"containing the sampling_rate associated with that array"
)
@ -214,11 +221,13 @@ class AudioClassificationPipeline(Pipeline):
)
inputs = F.resample(
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
torch.from_numpy(inputs) if isinstance(inputs, np.ndarray) else inputs,
in_sampling_rate,
self.feature_extractor.sampling_rate,
).numpy()
if not isinstance(inputs, np.ndarray):
raise TypeError("We expect a numpy ndarray as input")
raise TypeError("We expect a numpy ndarray or torch tensor as input")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")

View File

@ -360,15 +360,21 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
stride = None
extra = {}
if is_torch_available():
import torch
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu().numpy()
if is_torchcodec_available():
import torchcodec
if isinstance(inputs, torchcodec.decoders.AudioDecoder):
_audio_samples = inputs.get_all_samples()
_array = _audio_samples.data
_data = _audio_samples.data
# to mono
_array = np.mean(_array, axis=tuple(range(_array.ndim - 1))) if _array.ndim > 1 else _array
inputs = {"array": _array, "sampling_rate": _audio_samples.sample_rate}
_data = torch.mean(_data, 0) if _data.ndim > 1 else _data
inputs = {"array": _data, "sampling_rate": _audio_samples.sample_rate}
if isinstance(inputs, dict):
stride = inputs.pop("stride", None)
@ -377,7 +383,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
raise ValueError(
"When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a "
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
'"raw" key containing the numpy array or torch tensor representing the audio and a "sampling_rate" key, '
"containing the sampling_rate associated with that array"
)
@ -399,7 +405,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
)
inputs = F.resample(
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
torch.from_numpy(inputs) if isinstance(inputs, np.ndarray) else inputs,
in_sampling_rate,
in_sampling_rate,
self.feature_extractor.sampling_rate,
).numpy()
ratio = self.feature_extractor.sampling_rate / in_sampling_rate
else:
@ -414,7 +423,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# of the original length in the stride so we can cut properly.
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
if not isinstance(inputs, np.ndarray):
raise TypeError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
raise TypeError(f"We expect a numpy ndarray or torch tensor as input, got `{type(inputs)}`")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")