[Docs] Fix Speech Encoder Decoder doc sample (#18346)

* [Docs] Fix Speech Encoder Decoder doc sample

* improve pre-processing comment

* make style
This commit is contained in:
Sanchit Gandhi 2022-07-29 09:11:28 +01:00 committed by GitHub
parent da503ea02f
commit a4ee463d95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -85,25 +85,26 @@ As you can see, only 2 inputs are required for the model in order to compute a l
speech inputs) and `labels` (which are the `input_ids` of the encoded target sequence).
```python
>>> from transformers import Wav2Vec2Processor, SpeechEncoderDecoderModel
>>> from transformers import AutoTokenizer, AutoFeatureExtractor, SpeechEncoderDecoderModel
>>> from datasets import load_dataset
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
>>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
>>> model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
... "facebook/wav2vec2-base-960h", "bert-base-uncased"
... )
>>> encoder_id = "facebook/wav2vec2-base-960h" # acoustic model encoder
>>> decoder_id = "bert-base-uncased" # text decoder
>>> model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
>>> model.config.pad_token_id = processor.tokenizer.pad_token_id
>>> feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
>>> tokenizer = AutoTokenizer.from_pretrained(decoder_id)
>>> # Combine pre-trained encoder and pre-trained decoder to form a Seq2Seq model
>>> model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id)
>>> # load a speech input
>>> model.config.decoder_start_token_id = tokenizer.cls_token_id
>>> model.config.pad_token_id = tokenizer.pad_token_id
>>> # load an audio input and pre-process (normalise mean/std to 0/1)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values
>>> # load its corresponding transcription
>>> with processor.as_target_processor():
... labels = processor(ds[0]["text"], return_tensors="pt").input_ids
>>> # load its corresponding transcription and tokenize to generate labels
>>> labels = tokenizer(ds[0]["text"], return_tensors="pt").input_ids
>>> # the forward function automatically creates the correct decoder_input_ids
>>> loss = model(input_values, labels=labels).loss