[Wav2Vec2] Fix None loss in doc examples (#19218)

* pass sampled_negative_indices parameter to the model to avoid getting a None loss
* concerns doc examples for Wav2Vec2ForPreTraining and Wav2Vec2ConformerForPreTraining
This commit is contained in:
rbsteinm 2022-09-29 19:23:14 +02:00 committed by GitHub
parent 1a1893e5d8
commit 49d62b0178
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 10 deletions

View File

@ -1421,7 +1421,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
>>> from datasets import load_dataset
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
@ -1432,9 +1432,19 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
>>> mask_time_indices = _compute_mask_indices(
... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
... )
>>> sampled_negative_indices = _sample_negative_indices(
... features_shape=(batch_size, sequence_length),
... num_negatives=model.config.num_negatives,
... mask_time_indices=mask_time_indices,
... )
>>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
>>> sampled_negative_indices = torch.tensor(
... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
... )
>>> with torch.no_grad():
... outputs = model(input_values, mask_time_indices=mask_time_indices)
@ -1448,7 +1458,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
>>> # for contrastive loss training model should be put into train mode
>>> model = model.train()
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
>>> loss = model(
... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
... ).loss
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

View File

@ -1469,7 +1469,10 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
... _compute_mask_indices,
... _sample_negative_indices,
... )
>>> from datasets import load_dataset
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
@ -1480,9 +1483,19 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
>>> mask_time_indices = _compute_mask_indices(
... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
... )
>>> sampled_negative_indices = _sample_negative_indices(
... features_shape=(batch_size, sequence_length),
... num_negatives=model.config.num_negatives,
... mask_time_indices=mask_time_indices,
... )
>>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
>>> sampled_negative_indices = torch.tensor(
... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
... )
>>> with torch.no_grad():
... outputs = model(input_values, mask_time_indices=mask_time_indices)
@ -1496,7 +1509,9 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
>>> # for contrastive loss training model should be put into train mode
>>> model = model.train()
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
>>> loss = model(
... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
... ).loss
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict