mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
[Flax Whisper] Update decode docstring (#23908)
This commit is contained in:
parent
fabe17a726
commit
9603ef890a
@ -1017,16 +1017,17 @@ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
```python
|
```python
|
||||||
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
|
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
|
||||||
>>> from datasets import load_dataset
|
>>> from datasets import load_dataset
|
||||||
|
>>> import jax.numpy as jnp
|
||||||
|
|
||||||
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
|
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
|
||||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
|
>>> input_features = processor(ds[0]["audio"]["array"], return_tensors="np").input_features
|
||||||
>>> input_features = inputs.input_features
|
|
||||||
>>> encoder_outputs = model.encode(input_features=input_features)
|
>>> encoder_outputs = model.encode(input_features=input_features)
|
||||||
>>> decoder_start_token_id = model.config.decoder_start_token_id
|
>>> decoder_start_token_id = model.config.decoder_start_token_id
|
||||||
|
|
||||||
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
>>> decoder_input_ids = jnp.ones((input_features.shape[0], 1), dtype="i4") * decoder_start_token_id
|
||||||
|
|
||||||
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
|
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
|
||||||
>>> last_decoder_hidden_states = outputs.last_hidden_state
|
>>> last_decoder_hidden_states = outputs.last_hidden_state
|
||||||
|
Loading…
Reference in New Issue
Block a user