mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Wav2Vec2] Fix None loss in doc examples (#19218)
* pass sampled_negative_indices parameter to the model to avoid getting a None loss * concerns doc examples for Wav2Vec2ForPreTraining and Wav2Vec2ConformerForPreTraining
This commit is contained in:
parent
1a1893e5d8
commit
49d62b0178
@ -1421,7 +1421,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
|
||||
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
||||
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
|
||||
@ -1432,9 +1432,19 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
|
||||
>>> # compute masked indices
|
||||
>>> batch_size, raw_sequence_length = input_values.shape
|
||||
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
|
||||
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
|
||||
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
|
||||
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
|
||||
>>> mask_time_indices = _compute_mask_indices(
|
||||
... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
|
||||
... )
|
||||
>>> sampled_negative_indices = _sample_negative_indices(
|
||||
... features_shape=(batch_size, sequence_length),
|
||||
... num_negatives=model.config.num_negatives,
|
||||
... mask_time_indices=mask_time_indices,
|
||||
... )
|
||||
>>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
|
||||
>>> sampled_negative_indices = torch.tensor(
|
||||
... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
|
||||
... )
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(input_values, mask_time_indices=mask_time_indices)
|
||||
@ -1448,7 +1458,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
|
||||
>>> # for contrastive loss training model should be put into train mode
|
||||
>>> model = model.train()
|
||||
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
|
||||
>>> loss = model(
|
||||
... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
|
||||
... ).loss
|
||||
```"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
@ -1469,7 +1469,10 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
|
||||
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices
|
||||
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
|
||||
... _compute_mask_indices,
|
||||
... _sample_negative_indices,
|
||||
... )
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
||||
@ -1480,9 +1483,19 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
|
||||
|
||||
>>> # compute masked indices
|
||||
>>> batch_size, raw_sequence_length = input_values.shape
|
||||
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
|
||||
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
|
||||
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
|
||||
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
|
||||
>>> mask_time_indices = _compute_mask_indices(
|
||||
... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
|
||||
... )
|
||||
>>> sampled_negative_indices = _sample_negative_indices(
|
||||
... features_shape=(batch_size, sequence_length),
|
||||
... num_negatives=model.config.num_negatives,
|
||||
... mask_time_indices=mask_time_indices,
|
||||
... )
|
||||
>>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
|
||||
>>> sampled_negative_indices = torch.tensor(
|
||||
... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
|
||||
... )
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(input_values, mask_time_indices=mask_time_indices)
|
||||
@ -1496,7 +1509,9 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
|
||||
|
||||
>>> # for contrastive loss training model should be put into train mode
|
||||
>>> model = model.train()
|
||||
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
|
||||
>>> loss = model(
|
||||
... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
|
||||
... ).loss
|
||||
```"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
Loading…
Reference in New Issue
Block a user