diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 7feb7790dc3..86eb7b8dc5a 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1421,7 +1421,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): ```python >>> import torch >>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining - >>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices + >>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices >>> from datasets import load_dataset >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base") @@ -1432,9 +1432,19 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): >>> # compute masked indices >>> batch_size, raw_sequence_length = input_values.shape - >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) - >>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2) - >>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long) + >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item() + >>> mask_time_indices = _compute_mask_indices( + ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2 + ... ) + >>> sampled_negative_indices = _sample_negative_indices( + ... features_shape=(batch_size, sequence_length), + ... num_negatives=model.config.num_negatives, + ... mask_time_indices=mask_time_indices, + ... ) + >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long) + >>> sampled_negative_indices = torch.tensor( + ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long + ... ) >>> with torch.no_grad(): ... outputs = model(input_values, mask_time_indices=mask_time_indices) @@ -1448,7 +1458,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): >>> # for contrastive loss training model should be put into train mode >>> model = model.train() - >>> loss = model(input_values, mask_time_indices=mask_time_indices).loss + >>> loss = model( + ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices + ... ).loss ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 8723c6338d2..05502c9878d 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -1469,7 +1469,10 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): ```python >>> import torch >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining - >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices + >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( + ... _compute_mask_indices, + ... _sample_negative_indices, + ... ) >>> from datasets import load_dataset >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") @@ -1480,9 +1483,19 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): >>> # compute masked indices >>> batch_size, raw_sequence_length = input_values.shape - >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) - >>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2) - >>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long) + >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item() + >>> mask_time_indices = _compute_mask_indices( + ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2 + ... ) + >>> sampled_negative_indices = _sample_negative_indices( + ... features_shape=(batch_size, sequence_length), + ... num_negatives=model.config.num_negatives, + ... mask_time_indices=mask_time_indices, + ... ) + >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long) + >>> sampled_negative_indices = torch.tensor( + ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long + ... ) >>> with torch.no_grad(): ... outputs = model(input_values, mask_time_indices=mask_time_indices) @@ -1496,7 +1509,9 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): >>> # for contrastive loss training model should be put into train mode >>> model = model.train() - >>> loss = model(input_values, mask_time_indices=mask_time_indices).loss + >>> loss = model( + ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices + ... ).loss ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict