Fix doc examples: unexpected keyword argument (#14689)

* Fix doc examples: unexpected keyword argument

* Don't delete token_type_ids from inputs

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2021-12-10 17:44:08 +01:00 committed by GitHub
parent 5b00400198
commit ae82ee6a48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 3 deletions

View File

@ -1262,7 +1262,7 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2, device=model.device)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
>>> with torch.no_grad():
... outputs = model(input_values, mask_time_indices=mask_time_indices)

View File

@ -1260,7 +1260,7 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2, device=model.device)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
>>> with torch.no_grad():
... outputs = model(input_values, mask_time_indices=mask_time_indices)

View File

@ -1372,7 +1372,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2, device=model.device)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
>>> with torch.no_grad():
... outputs = model(input_values, mask_time_indices=mask_time_indices)