mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[Doctests] Fix ignore bug and add more doc tests (#15911)
* finish speech doc tests * finish * boom * Update src/transformers/models/speech_to_text/modeling_speech_to_text.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
b693cbf99c
commit
6cbfa7bf4c
@ -67,10 +67,12 @@ IGNORE_RESULT = doctest.register_optionflag('IGNORE_RESULT')
|
|||||||
|
|
||||||
OutputChecker = doctest.OutputChecker
|
OutputChecker = doctest.OutputChecker
|
||||||
|
|
||||||
|
|
||||||
class CustomOutputChecker(OutputChecker):
|
class CustomOutputChecker(OutputChecker):
|
||||||
def check_output(self, want, got, optionflags):
|
def check_output(self, want, got, optionflags):
|
||||||
if IGNORE_RESULT and optionflags:
|
if IGNORE_RESULT & optionflags:
|
||||||
return True
|
return True
|
||||||
return OutputChecker.check_output(self, want, got, optionflags)
|
return OutputChecker.check_output(self, want, got, optionflags)
|
||||||
|
|
||||||
|
|
||||||
doctest.OutputChecker = CustomOutputChecker
|
doctest.OutputChecker = CustomOutputChecker
|
||||||
|
@ -55,21 +55,21 @@ _EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
|
|||||||
|
|
||||||
# CTC docstring
|
# CTC docstring
|
||||||
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
||||||
_CTC_EXPECTED_LOSS = 53.48
|
_CTC_EXPECTED_LOSS = 66.95
|
||||||
|
|
||||||
# Audio class docstring
|
# Audio class docstring
|
||||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
_SEQ_CLASS_CHECKPOINT = "superb/data2vec-audio-base-superb-ks"
|
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-data2vec-seq-class"
|
||||||
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
|
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
|
||||||
_SEQ_CLASS_EXPECTED_LOSS = 6.54
|
_SEQ_CLASS_EXPECTED_LOSS = 0.69
|
||||||
|
|
||||||
# Frame class docstring
|
# Frame class docstring
|
||||||
_FRAME_CLASS_CHECKPOINT = "anton-l/data2vec-audio-base-superb-sd"
|
_FRAME_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-data2vec-audio-frame"
|
||||||
_FRAME_EXPECTED_OUTPUT = [0, 0]
|
_FRAME_EXPECTED_OUTPUT = [1, 1]
|
||||||
|
|
||||||
# Speaker Verification docstring
|
# Speaker Verification docstring
|
||||||
_XVECTOR_CHECKPOINT = "anton-l/data2vec-audio-base-superb-sv"
|
_XVECTOR_CHECKPOINT = "hf-internal-testing/tiny-random-data2vec-xvector"
|
||||||
_XVECTOR_EXPECTED_OUTPUT = 0.98
|
_XVECTOR_EXPECTED_OUTPUT = 1.0
|
||||||
|
|
||||||
|
|
||||||
DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
|
@ -465,22 +465,28 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import SpeechEncoderDecoderModel, Speech2Text2Processor
|
>>> from transformers import SpeechEncoderDecoderModel, Wav2Vec2Processor
|
||||||
>>> from datasets import load_dataset
|
>>> from datasets import load_dataset
|
||||||
>>> import torch
|
>>> import torch
|
||||||
|
|
||||||
>>> processor = Speech2Text2Processor.from_pretrained("facebook/s2t-wav2vec2-large-en-de")
|
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15")
|
||||||
>>> model = SpeechEncoderDecoderModel.from_pretrained("facebook/s2t-wav2vec2-large-en-de")
|
>>> model = SpeechEncoderDecoderModel.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15")
|
||||||
|
|
||||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
|
||||||
>>> input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values
|
>>> input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values
|
||||||
>>> decoder_input_ids = torch.tensor([[model.config.decoder.decoder_start_token_id]])
|
>>> # Inference: Translate English speech to German
|
||||||
>>> outputs = model(input_values=input_values, decoder_input_ids=decoder_input_ids)
|
|
||||||
|
|
||||||
>>> # inference (generation)
|
|
||||||
>>> generated = model.generate(input_values)
|
>>> generated = model.generate(input_values)
|
||||||
>>> translation = processor.batch_decode(generated)
|
>>> decoded = processor.batch_decode(generated, skip_special_tokens=True)[0]
|
||||||
|
>>> decoded
|
||||||
|
'Mr. Quilter ist der Apostel der Mittelschicht und wir freuen uns, sein Evangelium willkommen heißen zu können.'
|
||||||
|
|
||||||
|
>>> # Training: Train model on English transcription
|
||||||
|
>>> with processor.as_target_processor():
|
||||||
|
... labels = processor(ds[0]["text"], return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
>>> loss = model(input_values, labels=labels).loss
|
||||||
|
>>> loss.backward()
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
@ -24,12 +24,7 @@ from torch import nn
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import (
|
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||||
add_code_sample_docstrings,
|
|
||||||
add_start_docstrings,
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
@ -44,8 +39,6 @@ from .configuration_speech_to_text import Speech2TextConfig
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "Speech2TextConfig"
|
_CONFIG_FOR_DOC = "Speech2TextConfig"
|
||||||
_TOKENIZER_FOR_DOC = "Speech2TextTokenizer"
|
|
||||||
_CHECKPOINT_FOR_DOC = "facebook/s2t-small-librispeech-asr"
|
|
||||||
|
|
||||||
|
|
||||||
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
@ -780,7 +773,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel):
|
|||||||
attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask)
|
attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask)
|
||||||
padding_mask = attention_mask.ne(1).long()
|
padding_mask = attention_mask.ne(1).long()
|
||||||
else:
|
else:
|
||||||
padding_mask = torch.zeros_like(inputs_embeds, dtype=torch.long)
|
padding_mask = torch.zeros(inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device)
|
||||||
|
|
||||||
embed_pos = self.embed_positions(padding_mask)
|
embed_pos = self.embed_positions(padding_mask)
|
||||||
|
|
||||||
@ -1144,12 +1137,7 @@ class Speech2TextModel(Speech2TextPreTrainedModel):
|
|||||||
return self.decoder
|
return self.decoder
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
processor_class=_TOKENIZER_FOR_DOC,
|
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
|
||||||
output_type=Seq2SeqModelOutput,
|
|
||||||
config_class=_CONFIG_FOR_DOC,
|
|
||||||
)
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_features=None,
|
input_features=None,
|
||||||
@ -1167,6 +1155,28 @@ class Speech2TextModel(Speech2TextPreTrainedModel):
|
|||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import Speech2TextModel, Speech2TextFeatureExtractor
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
|
||||||
|
>>> model = Speech2TextModel.from_pretrained("facebook/s2t-small-librispeech-asr")
|
||||||
|
>>> feature_extractor = Speech2TextFeatureExtractor.from_pretrained("facebook/s2t-small-librispeech-asr")
|
||||||
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
>>> input_features = feature_extractor(
|
||||||
|
... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt"
|
||||||
|
>>> ).input_features
|
||||||
|
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
|
||||||
|
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
|
||||||
|
>>> list(last_hidden_state.shape)
|
||||||
|
[1, 2, 256]
|
||||||
|
```"""
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@ -1305,27 +1315,22 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
|
|||||||
>>> import torch
|
>>> import torch
|
||||||
>>> from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
|
>>> from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
|
||||||
>>> from datasets import load_dataset
|
>>> from datasets import load_dataset
|
||||||
>>> import soundfile as sf
|
|
||||||
|
|
||||||
>>> model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
|
>>> model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
|
||||||
>>> processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")
|
>>> processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")
|
||||||
|
|
||||||
|
|
||||||
>>> def map_to_array(batch):
|
|
||||||
... speech, _ = sf.read(batch["file"])
|
|
||||||
... batch["speech"] = speech
|
|
||||||
... return batch
|
|
||||||
|
|
||||||
|
|
||||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
>>> ds = ds.map(map_to_array)
|
|
||||||
|
|
||||||
>>> input_features = processor(
|
>>> input_features = processor(
|
||||||
... ds["speech"][0], sampling_rate=16000, return_tensors="pt"
|
... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt"
|
||||||
>>> ).input_features # Batch size 1
|
>>> ).input_features
|
||||||
|
|
||||||
>>> generated_ids = model.generate(inputs=input_features)
|
>>> generated_ids = model.generate(inputs=input_features)
|
||||||
|
|
||||||
>>> transcription = processor.batch_decode(generated_ids)
|
>>> transcription = processor.batch_decode(generated_ids)[0]
|
||||||
|
>>> transcription
|
||||||
|
'mister quilter is the apostle of the middle classes and we are glad to welcome his gospel'
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
@ -35,13 +35,12 @@ from .configuration_speech_to_text_2 import Speech2Text2Config
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "Speech2Text2Config"
|
_CONFIG_FOR_DOC = "Speech2Text2Config"
|
||||||
_TOKENIZER_FOR_DOC = "Speech2Text2Tokenizer"
|
_CHECKPOINT_FOR_DOC = "facebook/s2t-wav2vec2-large-en-de"
|
||||||
_CHECKPOINT_FOR_DOC = "facebook/s2t-small-librispeech-asr"
|
|
||||||
|
|
||||||
|
|
||||||
SPEECH_TO_TEXT_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
SPEECH_TO_TEXT_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"facebook/s2t-small-librispeech-asr",
|
"facebook/s2t-wav2vec2-large-en-de",
|
||||||
# See all Speech2Text2 models at https://huggingface.co/models?filter=speech_to_text
|
# See all Speech2Text2 models at https://huggingface.co/models?filter=speech2text2
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -865,13 +864,34 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
|
|||||||
... Wav2Vec2Model,
|
... Wav2Vec2Model,
|
||||||
... Speech2Text2Config,
|
... Speech2Text2Config,
|
||||||
... Wav2Vec2Config,
|
... Wav2Vec2Config,
|
||||||
|
... Wav2Vec2FeatureExtractor,
|
||||||
|
... Speech2Text2Tokenizer,
|
||||||
... )
|
... )
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
|
||||||
|
>>> feature_extractor = Wav2Vec2FeatureExtractor()
|
||||||
|
>>> tokenizer = Speech2Text2Tokenizer.from_pretrained(_CHECKPOINT_FOR_DOC)
|
||||||
|
|
||||||
>>> encoder = Wav2Vec2Model(Wav2Vec2Config())
|
>>> encoder = Wav2Vec2Model(Wav2Vec2Config())
|
||||||
>>> decoder = Speech2Text2ForCausalLM(Speech2Text2Config())
|
>>> decoder = Speech2Text2ForCausalLM(Speech2Text2Config())
|
||||||
# init speech2text model
|
# init random speech2text model
|
||||||
|
|
||||||
>>> model = SpeechEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
>>> model = SpeechEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||||
|
>>> model.config.pad_token_id = tokenizer.pad_token_id
|
||||||
|
>>> model.config.decoder_start_token_id = tokenizer.bos_token_id
|
||||||
|
# pre-process inputs and labels
|
||||||
|
|
||||||
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
>>> input_values = feature_extractor(
|
||||||
|
... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt"
|
||||||
|
>>> ).input_values # Batch size 1
|
||||||
|
>>> decoder_input_ids = tokenizer(ds[0]["text"], return_tensors="pt").input_ids
|
||||||
|
# compute loss
|
||||||
|
|
||||||
|
>>> loss = model(inputs=input_values, labels=decoder_input_ids).loss
|
||||||
|
# backprop loss
|
||||||
|
|
||||||
|
>>> loss.backward()
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
@ -1478,17 +1478,8 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
|
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||||
>>> model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
|
>>> model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||||
|
|
||||||
|
|
||||||
>>> def map_to_array(batch):
|
|
||||||
... speech, _ = sf.read(batch["file"])
|
|
||||||
... batch["speech"] = speech
|
|
||||||
... return batch
|
|
||||||
|
|
||||||
|
|
||||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
>>> ds = ds.map(map_to_array)
|
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
|
||||||
|
|
||||||
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
|
||||||
|
|
||||||
>>> # compute masked indices
|
>>> # compute masked indices
|
||||||
>>> batch_size, raw_sequence_length = input_values.shape
|
>>> batch_size, raw_sequence_length = input_values.shape
|
||||||
|
@ -566,17 +566,15 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
>>> word_offsets = [
|
>>> word_offsets = [
|
||||||
... {
|
... {
|
||||||
... "word": d["word"],
|
... "word": d["word"],
|
||||||
... "start_time": d["start_offset"] * time_offset,
|
... "start_time": round(d["start_offset"] * time_offset, 2),
|
||||||
... "end_time": d["end_offset"] * time_offset,
|
... "end_time": round(d["end_offset"] * time_offset, 2),
|
||||||
... }
|
... }
|
||||||
... for d in outputs.word_offsets
|
... for d in outputs.word_offsets
|
||||||
... ]
|
... ]
|
||||||
>>> # compare word offsets with audio `common_voice_en_100038.mp3` online on the dataset viewer:
|
>>> # compare word offsets with audio `common_voice_en_100038.mp3` online on the dataset viewer:
|
||||||
>>> # https://huggingface.co/datasets/common_voice/viewer/en/train
|
>>> # https://huggingface.co/datasets/common_voice/viewer/en/train
|
||||||
>>> word_offset
|
>>> word_offsets[:3]
|
||||||
>>> # [{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES',
|
[{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES', 'start_time': 1.64, 'end_time': 1.9}, {'word': 'MILISANDRA', 'start_time': 2.26, 'end_time': 2.9}]
|
||||||
>>> # 'start_time': 1.64, 'end_time': 1.90}, {'word': 'MILISANDRA',
|
|
||||||
>>> # 'start_time': 2.26, 'end_time': 2.9}, {'word': 'LOOK', 'start_time': 3.0, 'end_time': 3.16}, ...
|
|
||||||
```"""
|
```"""
|
||||||
# Convert inputs to python lists
|
# Convert inputs to python lists
|
||||||
token_ids = to_py_obj(token_ids)
|
token_ids = to_py_obj(token_ids)
|
||||||
|
@ -401,7 +401,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
>>> # Let's see how to retrieve time steps for a model
|
>>> # Let's see how to retrieve time steps for a model
|
||||||
>>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
|
>>> from transformers import AutoTokenizer, AutoProcessor, AutoModelForCTC
|
||||||
>>> from datasets import load_dataset
|
>>> from datasets import load_dataset
|
||||||
>>> import datasets
|
>>> import datasets
|
||||||
>>> import torch
|
>>> import torch
|
||||||
@ -417,29 +417,27 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
>>> sample = next(dataset_iter)
|
>>> sample = next(dataset_iter)
|
||||||
|
|
||||||
>>> # forward sample through model to get greedily predicted transcription ids
|
>>> # forward sample through model to get greedily predicted transcription ids
|
||||||
>>> input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
|
>>> input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
|
||||||
>>> with torch.no_grad():
|
>>> with torch.no_grad():
|
||||||
... logits = model(input_values).logits[0].cpu().numpy()
|
... logits = model(input_values).logits[0].cpu().numpy()
|
||||||
|
|
||||||
>>> # retrieve word stamps (analogous commands for `output_char_offsets`)
|
>>> # retrieve word stamps (analogous commands for `output_char_offsets`)
|
||||||
>>> outputs = tokenizer.decode(logits, output_word_offsets=True)
|
>>> outputs = processor.decode(logits, output_word_offsets=True)
|
||||||
>>> # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate
|
>>> # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate
|
||||||
>>> time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate
|
>>> time_offset = model.config.inputs_to_logits_ratio / processor.feature_extractor.sampling_rate
|
||||||
|
|
||||||
>>> word_offsets = [
|
>>> word_offsets = [
|
||||||
... {
|
... {
|
||||||
... "word": d["word"],
|
... "word": d["word"],
|
||||||
... "start_time": d["start_offset"] * time_offset,
|
... "start_time": round(d["start_offset"] * time_offset, 2),
|
||||||
... "end_time": d["end_offset"] * time_offset,
|
... "end_time": round(d["end_offset"] * time_offset, 2),
|
||||||
... }
|
... }
|
||||||
... for d in outputs.word_offsets
|
... for d in outputs.word_offsets
|
||||||
... ]
|
... ]
|
||||||
>>> # compare word offsets with audio `common_voice_en_100038.mp3` online on the dataset viewer:
|
>>> # compare word offsets with audio `common_voice_en_100038.mp3` online on the dataset viewer:
|
||||||
>>> # https://huggingface.co/datasets/common_voice/viewer/en/train
|
>>> # https://huggingface.co/datasets/common_voice/viewer/en/train
|
||||||
>>> word_offset
|
>>> word_offsets[:4]
|
||||||
>>> # [{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES',
|
[{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES', 'start_time': 1.64, 'end_time': 1.88}, {'word': 'A', 'start_time': 2.12, 'end_time': 2.14}, {'word': 'MILE', 'start_time': 2.26, 'end_time': 2.46}]
|
||||||
>>> # 'start_time': 1.64, 'end_time': 1.88}, {'word': 'A',
|
|
||||||
>>> # 'start_time': 2.12, 'end_time': 2.14}, {'word': 'MILE', 'start_time': 2.26, 'end_time': 2.46}, ...
|
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
from pyctcdecode.constants import (
|
from pyctcdecode.constants import (
|
||||||
|
@ -185,6 +185,17 @@ class Speech2TextModelTester:
|
|||||||
|
|
||||||
return input_lengths
|
return input_lengths
|
||||||
|
|
||||||
|
def create_and_check_model_forward(self, config, inputs_dict):
|
||||||
|
model = Speech2TextModel(config=config).to(torch_device).eval()
|
||||||
|
|
||||||
|
input_features = inputs_dict["input_features"]
|
||||||
|
decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
|
||||||
|
|
||||||
|
self.parent.assertTrue(last_hidden_state.shape, (13, 7, 16))
|
||||||
|
|
||||||
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||||
model = Speech2TextModel(config=config).get_decoder().to(torch_device).eval()
|
model = Speech2TextModel(config=config).get_decoder().to(torch_device).eval()
|
||||||
input_ids = inputs_dict["decoder_input_ids"]
|
input_ids = inputs_dict["decoder_input_ids"]
|
||||||
@ -284,6 +295,10 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
|||||||
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||||||
self.assertEqual(info["missing_keys"], [])
|
self.assertEqual(info["missing_keys"], [])
|
||||||
|
|
||||||
|
def test_model_forward(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||||
|
|
||||||
def test_decoder_model_past_with_large_inputs(self):
|
def test_decoder_model_past_with_large_inputs(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
@ -1,9 +1,15 @@
|
|||||||
src/transformers/models/wav2vec2/modeling_wav2vec2.py
|
src/transformers/models/wav2vec2/modeling_wav2vec2.py
|
||||||
|
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
|
||||||
|
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
|
||||||
src/transformers/models/hubert/modeling_hubert.py
|
src/transformers/models/hubert/modeling_hubert.py
|
||||||
src/transformers/models/wavlm/modeling_wavlm.py
|
src/transformers/models/wavlm/modeling_wavlm.py
|
||||||
src/transformers/models/unispeech/modeling_unispeech.py
|
src/transformers/models/unispeech/modeling_unispeech.py
|
||||||
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
|
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
|
||||||
src/transformers/models/sew/modeling_sew.py
|
src/transformers/models/sew/modeling_sew.py
|
||||||
src/transformers/models/sew_d/modeling_sew_d.py
|
src/transformers/models/sew_d/modeling_sew_d.py
|
||||||
|
src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
|
||||||
|
src/transformers/models/speech_to_text/modeling_speech_to_text.py
|
||||||
|
src/transformers/models/speech_encoder_decoder/modeling_speech_enocder_decoder.py
|
||||||
|
src/transformers/models/data2vec/modeling_data2vec_audio.py
|
||||||
docs/source/quicktour.mdx
|
docs/source/quicktour.mdx
|
||||||
docs/source/task_summary.mdx
|
docs/source/task_summary.mdx
|
||||||
|
Loading…
Reference in New Issue
Block a user