Add torchcodec in docstrings/tests for datasets 4.0 (#39156)

* fix dataset run_object_detection

* bump version

* keep same dataset actually

* torchcodec in docstrings and testing utils

* torchcodec in dockerfiles and requirements

* remove duplicate

* add torchocodec to all the remaining docker files

* fix tests

* support torchcodec in audio classification and ASR

* [commit to revert] build ci-dev images

* [commit to revert] trigger circleci

* [commit to revert] build ci-dev images

* fix

* fix modeling_hubert

* backward compatible run_object_detection

* revert ci trigger commits

* fix mono conversion and support torch tensor as input

* revert map_to_array docs + fix it

* revert mono

* nit in docstring

* style

* fix modular

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Quentin Lhoest 2025-07-08 17:06:12 +02:00 committed by GitHub
parent 1255480fd2
commit 1ecd52e50a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
78 changed files with 448 additions and 350 deletions

View File

@ -2,10 +2,10 @@ FROM python:3.9-slim
ENV PYTHONDONTWRITEBYTECODE=1
ARG REF=main
USER root
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git g++ cmake pkg-config openssh-client git
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git g++ cmake pkg-config openssh-client git ffmpeg
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
RUN uv pip install --no-cache-dir 'torch' 'torchaudio' 'torchvision' --index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir 'torch' 'torchaudio' 'torchvision' 'torchcodec' --index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-deps timm accelerate --extra-index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing]" seqeval albumentations jiwer
RUN uv pip uninstall transformers

View File

@ -2,10 +2,10 @@ FROM python:3.9-slim
ENV PYTHONDONTWRITEBYTECODE=1
ARG REF=main
USER root
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git pkg-config openssh-client git
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git pkg-config openssh-client git ffmpeg
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
RUN uv pip install --no-cache-dir 'torch' 'torchaudio' 'torchvision' --index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir 'torch' 'torchaudio' 'torchvision' 'torchcodec' --index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-deps timm accelerate --extra-index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing]"
RUN uv pip uninstall transformers

View File

@ -2,10 +2,10 @@ FROM python:3.9-slim
ENV PYTHONDONTWRITEBYTECODE=1
ARG REF=main
USER root
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git g++ cmake pkg-config openssh-client git git-lfs
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git g++ cmake pkg-config openssh-client git git-lfs ffmpeg
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
RUN uv pip install --no-cache-dir 'torch' 'torchaudio' 'torchvision' --index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir 'torch' 'torchaudio' 'torchvision' 'torchcodec' --index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-deps timm accelerate --extra-index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing,tiktoken,num2words,video]"
RUN uv pip uninstall transformers

View File

@ -26,7 +26,7 @@ RUN git clone https://github.com/huggingface/transformers && cd transformers &&
# 1. Put several commands in a single `RUN` to avoid image/layer exporting issue. Could be revised in the future.
# 2. Regarding `torch` part, We might need to specify proper versions for `torchvision` and `torchaudio`.
# Currently, let's not bother to specify their versions explicitly (so installed with their latest release versions).
RUN python3 -m pip install --no-cache-dir -e ./transformers[dev,onnxruntime] && [ ${#PYTORCH} -gt 0 -a "$PYTORCH" != "pre" ] && VERSION='torch=='$PYTORCH'.*' || VERSION='torch'; echo "export VERSION='$VERSION'" >> ~/.profile && echo torch=$VERSION && [ "$PYTORCH" != "pre" ] && python3 -m pip install --no-cache-dir -U $VERSION torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA || python3 -m pip install --no-cache-dir -U --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/$CUDA && python3 -m pip uninstall -y tensorflow tensorflow_text tensorflow_probability
RUN python3 -m pip install --no-cache-dir -e ./transformers[dev,onnxruntime] && [ ${#PYTORCH} -gt 0 -a "$PYTORCH" != "pre" ] && VERSION='torch=='$PYTORCH'.*' || VERSION='torch'; echo "export VERSION='$VERSION'" >> ~/.profile && echo torch=$VERSION && [ "$PYTORCH" != "pre" ] && python3 -m pip install --no-cache-dir -U $VERSION torchvision torchaudio torchcodec --extra-index-url https://download.pytorch.org/whl/$CUDA || python3 -m pip install --no-cache-dir -U --pre torch torchvision torchaudio torchcodec --extra-index-url https://download.pytorch.org/whl/nightly/$CUDA && python3 -m pip uninstall -y tensorflow tensorflow_text tensorflow_probability
RUN python3 -m pip uninstall -y flax jax

View File

@ -21,7 +21,7 @@ RUN python3 -m pip install --no-cache-dir './transformers[deepspeed-testing]' 'p
# Install latest release PyTorch
# (PyTorch must be installed before pre-compiling any DeepSpeed c++/cuda ops.)
# (https://www.deepspeed.ai/tutorials/advanced-install/#pre-install-deepspeed-ops)
RUN python3 -m pip uninstall -y torch torchvision torchaudio && python3 -m pip install --no-cache-dir -U torch==$PYTORCH torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA
RUN python3 -m pip uninstall -y torch torchvision torchaudio && python3 -m pip install --no-cache-dir -U torch==$PYTORCH torchvision torchaudio torchcodec --extra-index-url https://download.pytorch.org/whl/$CUDA
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate

View File

@ -19,7 +19,7 @@ RUN python3 -m pip uninstall -y torch torchvision torchaudio
# Install **nightly** release PyTorch (flag `--pre`)
# (PyTorch must be installed before pre-compiling any DeepSpeed c++/cuda ops.)
# (https://www.deepspeed.ai/tutorials/advanced-install/#pre-install-deepspeed-ops)
RUN python3 -m pip install --no-cache-dir -U --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/$CUDA
RUN python3 -m pip install --no-cache-dir -U --pre torch torchvision torchaudio torchcodec --extra-index-url https://download.pytorch.org/whl/nightly/$CUDA
# `datasets` requires pandas, pandas has some modules compiled with numpy=1.x causing errors
RUN python3 -m pip install --no-cache-dir './transformers[deepspeed-testing]' 'pandas<2' 'numpy<2'

View File

@ -26,7 +26,7 @@ RUN [ ${#PYTORCH} -gt 0 ] && VERSION='torch=='$PYTORCH'.*' || VERSION='torch';
RUN echo torch=$VERSION
# `torchvision` and `torchaudio` should be installed along with `torch`, especially for nightly build.
# Currently, let's just use their latest releases (when `torch` is installed with a release version)
RUN python3 -m pip install --no-cache-dir -U $VERSION torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA
RUN python3 -m pip install --no-cache-dir -U $VERSION torchvision torchaudio torchcodec --extra-index-url https://download.pytorch.org/whl/$CUDA
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate

View File

@ -61,19 +61,16 @@ predicted token ids.
- Step-by-step Speech Translation
```python
>>> import torch
>>> from transformers import Speech2Text2Processor, SpeechEncoderDecoderModel
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> model = SpeechEncoderDecoderModel.from_pretrained("facebook/s2t-wav2vec2-large-en-de")
>>> processor = Speech2Text2Processor.from_pretrained("facebook/s2t-wav2vec2-large-en-de")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> def map_to_array(example):
... example["speech"] = example["audio"]["array"]
... return example
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

View File

@ -172,9 +172,9 @@ Otherwise, [`~Wav2Vec2ProcessorWithLM.batch_decode`] performance will be slower
>>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
>>> def map_to_array(batch):
... batch["speech"] = batch["audio"]["array"]
... return batch
>>> def map_to_array(example):
... example["speech"] = example["audio"]["array"]
... return example
>>> # prepare speech data for batch inference

View File

@ -22,6 +22,7 @@ protobuf
torch
torchvision
torchaudio
torchcodec
jiwer
librosa
evaluate >= 0.2.0

View File

@ -1,5 +1,5 @@
albumentations >= 1.4.16
timm
datasets
datasets>=4.0
torchmetrics
pycocotools

View File

@ -399,7 +399,10 @@ def main():
dataset["validation"] = split["test"]
# Get dataset categories and prepare mappings for label_name <-> label_id
categories = dataset["train"].features["objects"].feature["category"].names
if isinstance(dataset["train"].features["objects"], dict):
categories = dataset["train"].features["objects"]["category"].feature.names
else: # (for old versions of `datasets` that used Sequence({...}) of the objects)
categories = dataset["train"].features["objects"].feature["category"].names
id2label = dict(enumerate(categories))
label2id = {v: k for k, v in id2label.items()}

View File

@ -460,7 +460,10 @@ def main():
dataset["validation"] = split["test"]
# Get dataset categories and prepare mappings for label_name <-> label_id
categories = dataset["train"].features["objects"].feature["category"].names
if isinstance(dataset["train"].features["objects"], dict):
categories = dataset["train"].features["objects"]["category"].feature.names
else: # (for old versions of `datasets` that used Sequence({...}) of the objects)
categories = dataset["train"].features["objects"].feature["category"].names
id2label = dict(enumerate(categories))
label2id = {v: k for k, v in id2label.items()}

View File

@ -435,10 +435,12 @@ class ASTModel(ASTPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, max_length, num_mel_bins)`):
Float values mel features extracted from the raw audio waveform. Raw audio waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~ASTFeatureExtractor.__call__`]
loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a
`torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or the soundfile library
(`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the
mel features, padding and conversion into a tensor of type `torch.FloatTensor`.
See [`~ASTFeatureExtractor.__call__`]
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -525,10 +527,11 @@ class ASTForAudioClassification(ASTPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, max_length, num_mel_bins)`):
Float values mel features extracted from the raw audio waveform. Raw audio waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~ASTFeatureExtractor.__call__`]
loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via
the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the
mel features, padding and conversion into a tensor of type `torch.FloatTensor`.
See [`~ASTFeatureExtractor.__call__`]
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the audio classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -1653,14 +1653,15 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel, GenerationMixin):
>>> text = "This is an example text."
>>> ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050))
>>> _, audio, sr = ds.sort("id").select(range(1))[:1]["audio"][0].values()
>>> audio = ds.sort("id")["audio"][0]
>>> audio_sample, sr = audio["array"], audio["sampling_rate"]
>>> # Define processor and model
>>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev")
>>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev")
>>> # Generate processor output and model output
>>> processor_output = processor(raw_speech=audio, sampling_rate=sr, text=text, return_tensors="pt")
>>> processor_output = processor(raw_speech=audio_sample, sampling_rate=sr, text=text, return_tensors="pt")
>>> speech_embeds = model.get_speech_features(
... input_ids=processor_output["input_ids"], input_features=processor_output["input_features"]
... )
@ -1732,14 +1733,15 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel, GenerationMixin):
>>> ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050))
>>> _, audio, sr = ds.sort("id").select(range(1))[:1]["audio"][0].values()
>>> audio = ds.sort("id")["audio"][0]
>>> audio_sample, sr = audio["array"], audio["sampling_rate"]
>>> # Define processor and model
>>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev")
>>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev")
>>> # processor outputs and model outputs
>>> processor_output = processor(raw_speech=audio, sampling_rate=sr, text=text, return_tensors="pt")
>>> processor_output = processor(raw_speech=audio_sample, sampling_rate=sr, text=text, return_tensors="pt")
>>> outputs = model(
... input_ids=processor_output["input_ids"],
... input_features=processor_output["input_features"],

View File

@ -1022,9 +1022,10 @@ class Data2VecAudioForSequenceClassification(Data2VecAudioPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -1136,9 +1137,10 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -1318,9 +1320,10 @@ class Data2VecAudioForXVector(Data2VecAudioPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -954,16 +954,14 @@ class HubertModel(HubertPreTrainedModel):
```python
>>> from transformers import AutoProcessor, HubertModel
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
>>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> def map_to_array(example):
... example["speech"] = example["audio"]["array"]
... return example
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
@ -1230,9 +1228,10 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`HubertProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`HubertProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -1459,16 +1459,14 @@ class TFHubertModel(TFHubertPreTrainedModel):
```python
>>> from transformers import AutoProcessor, TFHubertModel
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
>>> model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> def map_to_array(example):
... example["speech"] = example["audio"]["array"]
... return example
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
@ -1571,16 +1569,14 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
>>> import tensorflow as tf
>>> from transformers import AutoProcessor, TFHubertForCTC
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
>>> model = TFHubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> def map_to_array(example):
... example["speech"] = example["audio"]["array"]
... return example
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

View File

@ -239,16 +239,14 @@ class HubertModel(Wav2Vec2Model, HubertPreTrainedModel):
```python
>>> from transformers import AutoProcessor, HubertModel
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
>>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> def map_to_array(example):
... example["speech"] = example["audio"]["array"]
... return example
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

View File

@ -540,8 +540,9 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
Float values of the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec libary (`pip install torchcodec`) or
the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -922,8 +923,9 @@ class MoonshineModel(MoonshinePreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
Float values of the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec libary (`pip install torchcodec`) or
the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
Example:
@ -1039,8 +1041,9 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
Float values of the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec libary (`pip install torchcodec`) or
the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.

View File

@ -575,8 +575,9 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
Float values of the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec libary (`pip install torchcodec`) or
the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -751,8 +752,9 @@ class MoonshineModel(WhisperModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
Float values of the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec libary (`pip install torchcodec`) or
the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
Example:
@ -852,8 +854,9 @@ class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixi
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
Float values of the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec libary (`pip install torchcodec`) or
the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.

View File

@ -810,8 +810,9 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
r"""
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec libary (`pip install torchcodec`) or
the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
@ -1830,10 +1831,11 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, feature_sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via
the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the
mel features, padding and conversion into a tensor of type `torch.FloatTensor`.
See [`~WhisperFeatureExtractor.__call__`]
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size), *optional*):
The tensors corresponding to the input videos. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`NewTaskModelProcessor`] uses

View File

@ -1795,8 +1795,9 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
r"""
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec libary (`pip install torchcodec`) or
the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
@ -2276,10 +2277,11 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, feature_sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via
the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the
mel features, padding and conversion into a tensor of type `torch.FloatTensor`.
See [`~WhisperFeatureExtractor.__call__`]
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size), *optional*):
The tensors corresponding to the input videos. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`NewTaskModelProcessor`] uses

View File

@ -362,8 +362,9 @@ class Qwen2AudioEncoder(Qwen2AudioPreTrainedModel):
Args:
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec libary (`pip install torchcodec`) or
the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
attention_mask (`torch.Tensor`)`, *optional*):
@ -742,10 +743,11 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, feature_sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via
the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the
mel features, padding and conversion into a tensor of type `torch.FloatTensor`.
See [`~WhisperFeatureExtractor.__call__`]
feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:

View File

@ -1046,9 +1046,10 @@ class SEWForSequenceClassification(SEWPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`SEWProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`SEWProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -1597,9 +1597,10 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`SEWDProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`SEWDProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -86,8 +86,9 @@ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Args:
inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
Float values of input raw speech waveform or speech features. Values can be obtained by loading a `.flac`
or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile
library (`pip install soundfile`). To prepare the array into `inputs`, either the [`Wav2Vec2Processor`] or
or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.*
via the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `inputs`, either the [`Wav2Vec2Processor`] or
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
`torch.FloatTensor`.
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
@ -128,10 +129,10 @@ SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
Args:
inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
or *.wav* audio file into an array of type *list[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
*torch.FloatTensor*.
or *.wav* audio file into an array of type *list[float]* or a *numpy.ndarray*, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or [`Speech2TextProcessor`] should be used
for padding and conversion into a tensor of type *torch.FloatTensor*.
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

View File

@ -339,8 +339,9 @@ class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin):
r"""
inputs (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
Float values of input raw speech waveform or speech features. Values can be obtained by loading a `.flac`
or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile
library (`pip install soundfile`). To prepare the array into `inputs`, either the [`Wav2Vec2Processor`] or
or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.*
via the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `inputs`, either the [`Wav2Vec2Processor`] or
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
`torch.FloatTensor`.
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@ -369,15 +370,17 @@ class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin):
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
into an array of type *list[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding
and conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details.
into an array of type *list[float]* or a *numpy.ndarray*, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding and conversion
into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details.
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`, *optional*):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.*
via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`Speech2TextFeatureExtractor`] should be used for extracting the fbank features, padding and conversion
into a tensor of type `torch.FloatTensor`. See [`~Speech2TextFeatureExtractor.__call__`]
by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.*
via the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`Speech2TextFeatureExtractor`] should be used for extracting
the fbank features, padding and conversion into a tensor of type `torch.FloatTensor`.
See [`~Speech2TextFeatureExtractor.__call__`]
Examples:

View File

@ -619,8 +619,9 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel):
Args:
input_features (`torch.LongTensor` of shape `(batch_size, sequence_length, feature_size)`):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec libary (`pip install torchcodec`) or
the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the fbank features,
padding and conversion into a tensor of type `torch.FloatTensor`. See
[`~Speech2TextFeatureExtractor.__call__`]
@ -1096,10 +1097,12 @@ class Speech2TextModel(Speech2TextPreTrainedModel):
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.*
via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the fbank features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~Speech2TextFeatureExtractor.__call__`]
by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a
`torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or the soundfile library
(`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting
the fbank features, padding and conversion into a tensor of type `torch.FloatTensor`.
See [`~Speech2TextFeatureExtractor.__call__`]
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
@ -1258,10 +1261,12 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel, Generation
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.*
via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the fbank features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~Speech2TextFeatureExtractor.__call__`]
by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a
`torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or the soundfile library
(`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting
the fbank features, padding and conversion into a tensor of type `torch.FloatTensor`.
See [`~Speech2TextFeatureExtractor.__call__`]
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.

View File

@ -691,10 +691,12 @@ SPEECH_TO_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_features (`tf.Tensor` of shape `(batch_size, sequence_length, feature_size)`):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.*
via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the fbank features, padding and conversion into a
tensor of floats. See [`~Speech2TextFeatureExtractor.__call__`]
by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray or a
`torch.Tensor``, *e.g.* via the torchcodec library (`pip install torchcodec`) or the soundfile library
(`pip install soundfile`).
To prepare the arrayinto `input_features`, the [`AutoFeatureExtractor`] should be used for extracting
the fbank features, padding and conversion into a tensor of floats.
See [`~Speech2TextFeatureExtractor.__call__`]
attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
@ -847,8 +849,9 @@ class TFSpeech2TextEncoder(keras.layers.Layer):
Args:
input_features (`tf.Tensor` of shape `(batch_size, sequence_length, feature_size)`):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec libary (`pip install torchcodec`) or
the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the fbank features,
padding and conversion into a tensor of floats. See [`~Speech2TextFeatureExtractor.__call__`]
attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1469,7 +1472,6 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
>>> import tensorflow as tf
>>> from transformers import Speech2TextProcessor, TFSpeech2TextForConditionalGeneration
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> model = TFSpeech2TextForConditionalGeneration.from_pretrained(
... "facebook/s2t-small-librispeech-asr", from_pt=True
@ -1477,10 +1479,9 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
>>> processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> def map_to_array(example):
... example["speech"] = example["audio"]["array"]
... return example
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

View File

@ -2156,8 +2156,9 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (*pip install
soundfile*). To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
@ -2841,9 +2842,10 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (*pip install
soundfile*). To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding and conversion into
a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
decoder_input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`):
Float values of input mel spectrogram.
@ -2966,10 +2968,11 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform.
Values can be obtained by loading a *.flac* or *.wav* audio file into an array of type `list[float]` or
a `numpy.ndarray`, *e.g.* via the soundfile library (*pip install soundfile*). To prepare the array
into `input_values`, the [`SpeechT5Processor`] should be used for padding and conversion into a tensor
of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
Values can be obtained by loading a *.flac* or *.wav* audio file into an array of type `list[float]`,
a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`)
or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
Tensor containing the speaker embeddings.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):

View File

@ -1455,9 +1455,10 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`UniSpeechProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`UniSpeechProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -1450,9 +1450,10 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`UniSpeechSatProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`UniSpeechSatProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -1564,9 +1565,10 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`UniSpeechSatProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`UniSpeechSatProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -1746,9 +1748,10 @@ class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`UniSpeechSatProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`UniSpeechSatProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -255,9 +255,10 @@ WAV2VEC2_INPUTS_DOCSTRING = r"""
Args:
input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `jnp.ndarray`. See [`Wav2Vec2Processor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `jnp.ndarray`. See [`Wav2Vec2Processor.__call__`] for details.
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1]`:
@ -1064,16 +1065,14 @@ FLAX_WAV2VEC2_MODEL_DOCSTRING = """
```python
>>> from transformers import AutoProcessor, FlaxWav2Vec2Model
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-lv60")
>>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> def map_to_array(example):
... example["speech"] = example["audio"]["array"]
... return example
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
@ -1183,16 +1182,14 @@ FLAX_WAV2VEC2_FOR_CTC_DOCSTRING = """
>>> import jax.numpy as jnp
>>> from transformers import AutoProcessor, FlaxWav2Vec2ForCTC
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-960h-lv60")
>>> model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> def map_to_array(example):
... example["speech"] = example["audio"]["array"]
... return example
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
@ -1384,16 +1381,14 @@ FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING = """
>>> from transformers import AutoFeatureExtractor, FlaxWav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-large-lv60")
>>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> def map_to_array(example):
... example["speech"] = example["audio"]["array"]
... return example
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

View File

@ -1530,16 +1530,14 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
```python
>>> from transformers import AutoProcessor, TFWav2Vec2Model
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> def map_to_array(example):
... example["speech"] = example["audio"]["array"]
... return example
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
@ -1642,16 +1640,15 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
>>> import tensorflow as tf
>>> from transformers import AutoProcessor, TFWav2Vec2ForCTC
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> from torchcodec.decoders import AudioDecoder
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> def map_to_array(example):
... example["speech"] = example["audio"]["array"]
... return example
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

View File

@ -1981,9 +1981,10 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -2095,9 +2096,10 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -2277,9 +2279,10 @@ class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -997,9 +997,10 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
masked extracted features in *config.proj_codevector_dim* space.
@ -1094,9 +1095,10 @@ class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel):
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
@ -1205,9 +1207,10 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -1300,9 +1303,10 @@ class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel):
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -1463,9 +1467,10 @@ class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel):
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -708,9 +708,10 @@ class Wav2Vec2BertModel(Wav2Vec2Model, Wav2Vec2BertPreTrainedModel):
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
masked extracted features in *config.proj_codevector_dim* space.
@ -779,9 +780,10 @@ class Wav2Vec2BertForCTC(Wav2Vec2ConformerForCTC):
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
@ -875,9 +877,10 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2ForSequenceClassification):
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -950,9 +953,10 @@ class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2ConformerForAudioFrameClas
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -1016,9 +1020,10 @@ class Wav2Vec2BertForXVector(Wav2Vec2ConformerForXVector):
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -1579,9 +1579,10 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -1681,9 +1682,10 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -1851,9 +1853,10 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2ConformerProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -1328,9 +1328,10 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -1442,9 +1443,10 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -1624,9 +1626,10 @@ class WavLMForXVector(WavLMPreTrainedModel):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -429,10 +429,11 @@ class WhisperGenerationMixin(GenerationMixin):
Parameters:
input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`,
*e.g.* via the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel
features, padding and conversion into a tensor of type `torch.FloatTensor`.
See [`~WhisperFeatureExtractor.__call__`] for details.
generation_config ([`~generation.GenerationConfig`], *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
@ -1598,7 +1599,7 @@ class WhisperGenerationMixin(GenerationMixin):
Parameters:
input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via
loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.

View File

@ -101,10 +101,12 @@ WHISPER_INPUTS_DOCSTRING = r"""
Args:
input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`WhisperFeatureExtractor`] should be used for extracting the features, padding and conversion into a
tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`]
loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a
`torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or the soundfile library
(`pip install soundfile`).
To prepare the array into `input_features`, the [`WhisperFeatureExtractor`] should be used for extracting
the features, padding and conversion into a tensor of type `numpy.ndarray`.
See [`~WhisperFeatureExtractor.__call__`]
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
is not used. By default the silence in the input log mel spectrogram are ignored.
@ -138,10 +140,11 @@ WHISPER_ENCODE_INPUTS_DOCSTRING = r"""
Args:
input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`].
loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via
the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`WhisperFeatureExtractor`] should be used for extracting
the mel features, padding and conversion into a tensor of type `numpy.ndarray`.
See [`~WhisperFeatureExtractor.__call__`].
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
is not used. By default the silence in the input log mel spectrogram are ignored.

View File

@ -600,10 +600,12 @@ WHISPER_INPUTS_DOCSTRING = r"""
Args:
input_features (`tf.Tensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.*
via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the fbank features, padding and conversion into a
tensor of type `tf.Tensor`. See [`~WhisperFeatureExtractor.__call__`]
by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a
`torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or the soundfile library
(`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the
fbank features, padding and conversion into a tensor of type `tf.Tensor`.
See [`~WhisperFeatureExtractor.__call__`]
decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
@ -728,8 +730,9 @@ class TFWhisperEncoder(keras.layers.Layer):
Args:
input_features (`tf.Tensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec libary (`pip install torchcodec`) or
the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the fbank features,
padding and conversion into a tensor of type `tf.Tensor`. See [`~WhisperFeatureExtractor.__call__`]
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):

View File

@ -650,8 +650,9 @@ class WhisperEncoder(WhisperPreTrainedModel):
Args:
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec libary (`pip install torchcodec`) or
the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
attention_mask (`torch.Tensor`)`, *optional*):
@ -1096,10 +1097,11 @@ class WhisperModel(WhisperPreTrainedModel):
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via
the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the
mel features, padding and conversion into a tensor of type `torch.FloatTensor`.
See [`~WhisperFeatureExtractor.__call__`]
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
@ -1266,10 +1268,11 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via
the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the
mel features, padding and conversion into a tensor of type `torch.FloatTensor`.
See [`~WhisperFeatureExtractor.__call__`]
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
@ -1600,10 +1603,11 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via
the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the
mel features, padding and conversion into a tensor of type `torch.FloatTensor`.
See [`~WhisperFeatureExtractor.__call__`]
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -17,7 +17,7 @@ from typing import Any, Union
import numpy as np
import requests
from ..utils import add_end_docstrings, is_torch_available, is_torchaudio_available, logging
from ..utils import add_end_docstrings, is_torch_available, is_torchaudio_available, is_torchcodec_available, logging
from .base import Pipeline, build_pipeline_init_args
@ -174,6 +174,21 @@ class AudioClassificationPipeline(Pipeline):
if isinstance(inputs, bytes):
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
if is_torch_available():
import torch
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu().numpy()
if is_torchcodec_available():
import torch
import torchcodec
if isinstance(inputs, torchcodec.decoders.AudioDecoder):
_audio_samples = inputs.get_all_samples()
_array = _audio_samples.data
inputs = {"array": _array, "sampling_rate": _audio_samples.sample_rate}
if isinstance(inputs, dict):
inputs = inputs.copy() # So we don't mutate the original dictionary outside the pipeline
# Accepting `"array"` which is the key defined in `datasets` for
@ -181,7 +196,7 @@ class AudioClassificationPipeline(Pipeline):
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
raise ValueError(
"When passing a dictionary to AudioClassificationPipeline, the dict needs to contain a "
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
'"raw" key containing the numpy array or torch tensor representing the audio and a "sampling_rate" key, '
"containing the sampling_rate associated with that array"
)
@ -204,11 +219,13 @@ class AudioClassificationPipeline(Pipeline):
)
inputs = F.resample(
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
torch.from_numpy(inputs) if isinstance(inputs, np.ndarray) else inputs,
in_sampling_rate,
self.feature_extractor.sampling_rate,
).numpy()
if not isinstance(inputs, np.ndarray):
raise TypeError("We expect a numpy ndarray as input")
raise TypeError("We expect a numpy ndarray or torch tensor as input")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")

View File

@ -19,7 +19,7 @@ import requests
from ..generation import GenerationConfig
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import is_torch_available, is_torchaudio_available, logging
from ..utils import is_torch_available, is_torchaudio_available, is_torchcodec_available, logging
from .audio_utils import ffmpeg_read
from .base import ChunkPipeline
@ -364,6 +364,21 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
stride = None
extra = {}
if is_torch_available():
import torch
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu().numpy()
if is_torchcodec_available():
import torchcodec
if isinstance(inputs, torchcodec.decoders.AudioDecoder):
_audio_samples = inputs.get_all_samples()
_array = _audio_samples.data
inputs = {"array": _array, "sampling_rate": _audio_samples.sample_rate}
if isinstance(inputs, dict):
stride = inputs.pop("stride", None)
# Accepting `"array"` which is the key defined in `datasets` for
@ -371,7 +386,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
raise ValueError(
"When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a "
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
'"raw" key containing the numpy array or torch tensor representing the audio and a "sampling_rate" key, '
"containing the sampling_rate associated with that array"
)
@ -393,7 +408,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
)
inputs = F.resample(
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
torch.from_numpy(inputs) if isinstance(inputs, np.ndarray) else inputs,
in_sampling_rate,
in_sampling_rate,
self.feature_extractor.sampling_rate,
).numpy()
ratio = self.feature_extractor.sampling_rate / in_sampling_rate
else:
@ -408,7 +426,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# of the original length in the stride so we can cut properly.
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
if not isinstance(inputs, np.ndarray):
raise TypeError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
raise TypeError(f"We expect a numpy ndarray or torch tensor as input, got `{type(inputs)}`")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")

View File

@ -130,7 +130,6 @@ from .utils import (
is_scipy_available,
is_sentencepiece_available,
is_seqio_available,
is_soundfile_available,
is_spacy_available,
is_speech_available,
is_spqr_available,
@ -656,7 +655,7 @@ def require_torchcodec(test_case):
These tests are skipped when Torchcodec isn't installed.
"""
return unittest.skipUnless(is_torchcodec_available(), "test requires Torchvision")(test_case)
return unittest.skipUnless(is_torchcodec_available(), "test requires Torchcodec")(test_case)
def require_torch_or_tf(test_case):
@ -1268,16 +1267,6 @@ def require_clearml(test_case):
return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case)
def require_soundfile(test_case):
"""
Decorator marking a test that requires soundfile
These tests are skipped when soundfile isn't installed.
"""
return unittest.skipUnless(is_soundfile_available(), "test requires soundfile")(test_case)
def require_deepspeed(test_case):
"""
Decorator marking a test that requires deepspeed

View File

@ -248,9 +248,10 @@ class ModelArgs:
input_values = {
"description": """
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`{processor_class}.__call__`] for details.
into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
(`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
into a tensor of type `torch.FloatTensor`. See [`{processor_class}.__call__`] for details.
""",
"shape": "of shape `(batch_size, sequence_length)`",
}

View File

@ -154,7 +154,7 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]

View File

@ -165,7 +165,7 @@ class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]

View File

@ -215,7 +215,7 @@ class ClvpFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=22050))
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples], [x["sampling_rate"] for x in speech_samples]

View File

@ -373,10 +373,12 @@ class ClvpModelForConditionalGenerationTester:
ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050))
_, audio, sr = ds.sort("id").select(range(1))[:1]["audio"][0].values()
audio = ds.sort("id")[0]["audio"]
audio_sample = audio["array"]
sr = audio["sampling_rate"]
feature_extractor = ClvpFeatureExtractor()
input_features = feature_extractor(raw_speech=audio, sampling_rate=sr, return_tensors="pt")[
input_features = feature_extractor(raw_speech=audio_sample, sampling_rate=sr, return_tensors="pt")[
"input_features"
].to(torch_device)
@ -562,7 +564,8 @@ class ClvpIntegrationTest(unittest.TestCase):
self.text = "This is an example text."
ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050))
_, self.speech_samples, self.sr = ds.sort("id").select(range(1))[:1]["audio"][0].values()
audio = ds.sort("id")["audio"][0]
self.speech_samples, self.sr = audio["array"], audio["sampling_rate"]
self.model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev").to(torch_device)
self.model.eval()

View File

@ -143,7 +143,7 @@ class DacFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
audio_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
audio_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in audio_samples]

View File

@ -21,7 +21,7 @@ from datasets import load_dataset
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import Data2VecAudioConfig, is_torch_available
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
from transformers.testing_utils import require_torch, require_torchcodec, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init
@ -656,7 +656,7 @@ class Data2VecAudioUtilsTest(unittest.TestCase):
@require_torch
@require_soundfile
@require_torchcodec
@slow
class Data2VecAudioModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):

View File

@ -145,7 +145,7 @@ class DiaFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
audio_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
audio_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in audio_samples]

View File

@ -665,8 +665,12 @@ class DiaForConditionalGenerationIntegrationTest(unittest.TestCase):
@require_torch_accelerator
def test_dia_model_integration_generate_audio_context(self):
text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"]
audio_sample_1 = torchaudio.load(self.audio_prompt_1_path, channels_first=True)[0].squeeze().numpy()
audio_sample_2 = torchaudio.load(self.audio_prompt_2_path, channels_first=True)[0].squeeze().numpy()
audio_sample_1 = (
torchaudio.load(self.audio_prompt_1_path, channels_first=True, backend="soundfile")[0].squeeze().numpy()
)
audio_sample_2 = (
torchaudio.load(self.audio_prompt_2_path, channels_first=True, backend="soundfile")[0].squeeze().numpy()
)
audio = [audio_sample_1, audio_sample_2]
processor = DiaProcessor.from_pretrained(self.model_checkpoint)

View File

@ -139,7 +139,7 @@ class EnCodecFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
audio_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
audio_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in audio_samples]

View File

@ -340,7 +340,7 @@ class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]

View File

@ -22,7 +22,7 @@ import unittest
import pytest
from transformers import HubertConfig, is_torch_available
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
from transformers.testing_utils import require_torch, require_torchcodec, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
@ -750,7 +750,7 @@ class HubertUtilsTest(unittest.TestCase):
@require_torch
@require_soundfile
@require_torchcodec
@slow
class HubertModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):

View File

@ -713,7 +713,7 @@ class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCa
def _load_datasamples(self, num_samples):
self._load_dataset()
ds = self._dataset
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
@slow

View File

@ -443,7 +443,7 @@ class MoonshineModelIntegrationTests(unittest.TestCase):
def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]

View File

@ -207,7 +207,7 @@ class Phi4MultimodalFeatureExtractionTest(SequenceFeatureExtractionTestMixin, un
def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import requests
@ -33,13 +32,13 @@ from transformers import (
from transformers.testing_utils import (
Expectations,
cleanup,
require_soundfile,
require_torch,
require_torch_large_accelerator,
require_torchcodec,
slow,
torch_device,
)
from transformers.utils import is_soundfile_available
from transformers.utils import is_torchcodec_available
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -54,8 +53,8 @@ if is_vision_available():
from PIL import Image
if is_soundfile_available():
import soundfile
if is_torchcodec_available():
import torchcodec
class Phi4MultimodalModelTester:
@ -296,11 +295,9 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase):
self.assistant_token = "<|assistant|>"
self.end_token = "<|end|>"
self.image = Image.open(requests.get(self.image_url, stream=True).raw)
with tempfile.NamedTemporaryFile(mode="w+b", suffix=".wav") as tmp:
tmp.write(requests.get(self.audio_url, stream=True).raw.data)
tmp.flush()
tmp.seek(0)
self.audio, self.sampling_rate = soundfile.read(tmp.name)
audio_bytes = requests.get(self.audio_url, stream=True).raw.data
samples = torchcodec.decoders.AudioDecoder(audio_bytes).get_all_samples()
self.audio, self.sampling_rate = samples.data, samples.sample_rate
cleanup(torch_device, gc_collect=True)
@ -378,7 +375,7 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase):
self.assertEqual(response, EXPECTED_RESPONSE)
@require_soundfile
@require_torchcodec
def test_audio_text_generation(self):
model = AutoModelForCausalLM.from_pretrained(
self.checkpoint_path, revision=self.revision, torch_dtype=torch.float16, device_map=torch_device

View File

@ -19,7 +19,7 @@ import unittest
import pytest
from transformers import SEWConfig, is_torch_available
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
from transformers.testing_utils import require_torch, require_torchcodec, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
@ -453,7 +453,7 @@ class SEWUtilsTest(unittest.TestCase):
@require_torch
@require_soundfile
@require_torchcodec
@slow
class SEWModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):

View File

@ -19,7 +19,7 @@ import unittest
import pytest
from transformers import SEWDConfig, is_torch_available
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
from transformers.testing_utils import require_torch, require_torchcodec, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
@ -464,7 +464,7 @@ class SEWDUtilsTest(unittest.TestCase):
@require_torch
@require_soundfile
@require_torchcodec
@slow
class SEWDModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):

View File

@ -294,7 +294,7 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]

View File

@ -381,7 +381,7 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]

View File

@ -764,7 +764,7 @@ class SpeechT5ForSpeechToTextIntegrationTests(unittest.TestCase):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
@ -1792,7 +1792,7 @@ class SpeechT5ForSpeechToSpeechIntegrationTests(unittest.TestCase):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]

View File

@ -21,7 +21,7 @@ import pytest
from datasets import load_dataset
from transformers import UniSpeechConfig, is_torch_available
from transformers.testing_utils import is_flaky, require_soundfile, require_torch, slow, torch_device
from transformers.testing_utils import is_flaky, require_torch, require_torchcodec, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
@ -553,7 +553,7 @@ class UniSpeechRobustModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
@require_torch
@require_soundfile
@require_torchcodec
@slow
class UniSpeechModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):

View File

@ -21,7 +21,7 @@ import pytest
from datasets import load_dataset
from transformers import UniSpeechSatConfig, is_torch_available
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
from transformers.testing_utils import require_torch, require_torchcodec, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
@ -807,7 +807,7 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
@require_soundfile
@require_torchcodec
@slow
class UniSpeechSatModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):

View File

@ -330,7 +330,7 @@ class UnivNetFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=self.feat_extract_tester.sampling_rate))
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples], [x["sampling_rate"] for x in speech_samples]

View File

@ -216,7 +216,7 @@ class UnivNetModelIntegrationTests(unittest.TestCase):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=sampling_rate))
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples], [x["sampling_rate"] for x in speech_samples]

View File

@ -34,10 +34,10 @@ from transformers.testing_utils import (
is_torchaudio_available,
require_flash_attn,
require_pyctcdecode,
require_soundfile,
require_torch,
require_torch_gpu,
require_torchaudio,
require_torchcodec,
run_test_in_subprocess,
slow,
torch_device,
@ -1444,7 +1444,7 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
@require_torch
@require_soundfile
@require_torchcodec
@slow
class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
def tearDown(self):

View File

@ -254,7 +254,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]

View File

@ -1460,7 +1460,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
def _load_datasamples(self, num_samples):
self._load_dataset()
ds = self._dataset
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
@slow

View File

@ -1190,7 +1190,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
num_beams=1,
)
transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
transcription_non_ass = pipe(sample, generate_kwargs={"assistant_model": assistant_model})["text"]
transcription_ass = pipe(sample)["text"]
self.assertEqual(transcription_ass, transcription_non_ass)

View File

@ -278,7 +278,7 @@ class AudioUtilsFunctionTester(unittest.TestCase):
if self._dataset is None:
self._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
speech_samples = self._dataset.sort("id").select(range(num_samples))[:num_samples]["audio"]
speech_samples = self._dataset.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
def test_spectrogram_impulse(self):

View File

@ -72,3 +72,14 @@ try:
print("Number of TF GPUs available:", len(tf.config.list_physical_devices("GPU")))
except ImportError:
print("TensorFlow version:", None)
try:
import torchcodec
versions = torchcodec._core.get_ffmpeg_library_versions()
print("FFmpeg version:", versions["ffmpeg_version"])
except ImportError:
print("FFmpeg version:", None)
except (AttributeError, KeyError):
print("Failed to get FFmpeg version")