mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Flax] Correct flax docs (#12782)
* fix_torch_device_generate_test * remove @ * fix flax docs * correct more docs in flax * another correction * fix flax docs * Apply suggestions from code review
This commit is contained in:
parent
a317e6c3be
commit
fbf468b057
@ -299,3 +299,93 @@ TFSeq2SeqQuestionAnsweringModelOutput
|
||||
|
||||
.. autoclass:: transformers.modeling_tf_outputs.TFSeq2SeqQuestionAnsweringModelOutput
|
||||
:members:
|
||||
|
||||
|
||||
FlaxBaseModelOutput
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutput
|
||||
|
||||
|
||||
FlaxBaseModelOutputWithPast
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPast
|
||||
|
||||
|
||||
FlaxBaseModelOutputWithPooling
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPooling
|
||||
|
||||
|
||||
FlaxBaseModelOutputWithPastAndCrossAttentions
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions
|
||||
|
||||
|
||||
FlaxSeq2SeqModelOutput
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqModelOutput
|
||||
|
||||
|
||||
FlaxCausalLMOutputWithCrossAttentions
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions
|
||||
|
||||
|
||||
FlaxMaskedLMOutput
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxMaskedLMOutput
|
||||
|
||||
|
||||
FlaxSeq2SeqLMOutput
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqLMOutput
|
||||
|
||||
|
||||
FlaxNextSentencePredictorOutput
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxNextSentencePredictorOutput
|
||||
|
||||
|
||||
FlaxSequenceClassifierOutput
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxSequenceClassifierOutput
|
||||
|
||||
|
||||
FlaxSeq2SeqSequenceClassifierOutput
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqSequenceClassifierOutput
|
||||
|
||||
|
||||
FlaxMultipleChoiceModelOutput
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxMultipleChoiceModelOutput
|
||||
|
||||
|
||||
FlaxTokenClassifierOutput
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxTokenClassifierOutput
|
||||
|
||||
|
||||
FlaxQuestionAnsweringModelOutput
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxQuestionAnsweringModelOutput
|
||||
|
||||
|
||||
FlaxSeq2SeqQuestionAnsweringModelOutput
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqQuestionAnsweringModelOutput
|
||||
|
@ -76,6 +76,9 @@ Bert specific outputs
|
||||
.. autoclass:: transformers.models.bert.modeling_tf_bert.TFBertForPreTrainingOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.bert.modeling_flax_bert.FlaxBertForPreTrainingOutput
|
||||
:members:
|
||||
|
||||
|
||||
BertModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -67,6 +67,22 @@ Wav2Vec2Processor
|
||||
:members: __call__, pad, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor
|
||||
|
||||
|
||||
Wav2Vec2 specific outputs
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.wav2vec2.modeling_flax_wav2vec2.FlaxWav2Vec2BaseModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.wav2vec2.modeling_flax_wav2vec2.FlaxWav2Vec2ForPreTrainingOutput
|
||||
:members:
|
||||
|
||||
|
||||
Wav2Vec2Model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -1063,7 +1063,7 @@ FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
|
||||
|
||||
>>> outputs = model(**inputs, labels=labels)
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
|
||||
@ -1122,9 +1122,10 @@ FLAX_CAUSAL_LM_SAMPLE = r"""
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
|
||||
>>> outputs = model(**inputs, labels=inputs["input_ids"])
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> logits = outputs.logits
|
||||
>>> # retrieve logts for next token
|
||||
>>> next_token_logits = outputs.logits[:, -1]
|
||||
"""
|
||||
|
||||
FLAX_SAMPLE_DOCSTRINGS = {
|
||||
|
@ -30,7 +30,7 @@ from flax.linen.attention import dot_product_attention_weights
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, replace_return_docstrings
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutput,
|
||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -1167,6 +1167,7 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
|
||||
|
||||
return outputs
|
||||
|
||||
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: jnp.ndarray,
|
||||
@ -1520,7 +1521,7 @@ FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING = """
|
||||
>>> input_ids = tokenizer([TXT], return_tensors='jax')['input_ids']
|
||||
>>> logits = model(input_ids).logits
|
||||
|
||||
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
|
||||
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item()
|
||||
>>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
|
||||
>>> values, predictions = jax.lax.top_k(probs)
|
||||
|
||||
|
@ -941,7 +941,7 @@ FLAX_CLIP_TEXT_MODEL_DOCSTRING = """
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooled_output # pooled (EOS token) states
|
||||
>>> pooler_output = outputs.pooler_output # pooled (EOS token) states
|
||||
"""
|
||||
|
||||
overwrite_call_docstring(FlaxCLIPTextModel, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_DOCSTRING)
|
||||
@ -997,7 +997,7 @@ FLAX_CLIP_VISION_MODEL_DOCSTRING = """
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooled_output # pooled CLS states
|
||||
>>> pooler_output = outputs.pooler_output # pooled CLS states
|
||||
"""
|
||||
|
||||
overwrite_call_docstring(FlaxCLIPVisionModel, CLIP_VISION_INPUTS_DOCSTRING + FLAX_CLIP_VISION_MODEL_DOCSTRING)
|
||||
|
@ -30,7 +30,7 @@ from flax.linen.attention import dot_product_attention_weights
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, replace_return_docstrings
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutput,
|
||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -45,7 +45,7 @@ from .configuration_marian import MarianConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de'"
|
||||
_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de"
|
||||
_CONFIG_FOR_DOC = "MarianConfig"
|
||||
_TOKENIZER_FOR_DOC = "MarianTokenizer"
|
||||
|
||||
@ -1125,6 +1125,7 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
|
||||
|
||||
return outputs
|
||||
|
||||
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING)
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: jnp.ndarray,
|
||||
|
@ -30,7 +30,7 @@ from flax.linen.attention import dot_product_attention_weights
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, replace_return_docstrings
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutput,
|
||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -1192,6 +1192,7 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
|
||||
|
||||
return outputs
|
||||
|
||||
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: jnp.ndarray,
|
||||
@ -1517,36 +1518,37 @@ class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel):
|
||||
return model_kwargs
|
||||
|
||||
|
||||
FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING = """
|
||||
FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING = r"""
|
||||
Returns:
|
||||
|
||||
Summarization example::
|
||||
|
||||
>>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration
|
||||
>>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration, MBartConfig
|
||||
|
||||
>>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25')
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25')
|
||||
|
||||
>>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
||||
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='jax')
|
||||
>>> ARTICLE_TO_SUMMARIZE = "Meine Freunde sind cool, aber sie essen zu viel Kuchen."
|
||||
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np')
|
||||
|
||||
>>> # Generate Summary
|
||||
>>> summary_ids = model.generate(inputs['input_ids'], decoder_start_token_id=tokenizer.lang_code_to_id[tgt_lang]).sequences
|
||||
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
|
||||
>>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True).sequences
|
||||
>>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
|
||||
|
||||
Mask filling example::
|
||||
|
||||
>>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25')
|
||||
>>> TXT = "My friends are <mask> but they eat too many carbs."
|
||||
>>> # de_DE is the language symbol id <LID> for German
|
||||
>>> TXT = "</s> Meine Freunde sind <mask> nett aber sie essen zu viel Kuchen. </s> de_DE"
|
||||
|
||||
>>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25')
|
||||
>>> input_ids = tokenizer([TXT], return_tensors='jax')['input_ids']
|
||||
>>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors='np')['input_ids']
|
||||
>>> logits = model(input_ids).logits
|
||||
|
||||
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
|
||||
>>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
|
||||
>>> values, predictions = jax.lax.top_k(probs)
|
||||
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item()
|
||||
>>> probs = logits[0, masked_index].softmax(dim=0)
|
||||
>>> values, predictions = probs.topk(5)
|
||||
|
||||
>>> tokenizer.decode(predictions).split()
|
||||
"""
|
||||
|
@ -36,13 +36,20 @@ from ...modeling_flax_outputs import (
|
||||
FlaxSeq2SeqLMOutput,
|
||||
FlaxSeq2SeqModelOutput,
|
||||
)
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
||||
from ...modeling_flax_utils import (
|
||||
ACT2FN,
|
||||
FlaxPreTrainedModel,
|
||||
append_call_sample_docstring,
|
||||
append_replace_return_docstrings,
|
||||
overwrite_call_docstring,
|
||||
)
|
||||
from ...utils import logging
|
||||
from .configuration_t5 import T5Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "t5-small"
|
||||
_CONFIG_FOR_DOC = "T5Config"
|
||||
_TOKENIZER_FOR_DOC = "T5Tokenizer"
|
||||
|
||||
@ -844,6 +851,69 @@ T5_DECODE_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
T5_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
|
||||
should be able to pad the inputs on both the right and the left.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
detail.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
|
||||
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
|
||||
<./t5.html#training>`__.
|
||||
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are decoder input IDs? <../glossary.html#decoder-input-ids>`__
|
||||
|
||||
T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If
|
||||
:obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
|
||||
:obj:`past_key_values`).
|
||||
|
||||
To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training
|
||||
<./t5.html#training>`__.
|
||||
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||
also be used by default.
|
||||
encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
|
||||
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
|
||||
sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of
|
||||
the decoder.
|
||||
past_key_values (:obj:`tuple(tuple(jnp.ndarray))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
|
||||
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
@ -884,6 +954,7 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
||||
decoder_attention_mask,
|
||||
)["params"]
|
||||
|
||||
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: jnp.ndarray,
|
||||
@ -1155,71 +1226,6 @@ T5_START_DOCSTRING = r"""
|
||||
model weights.
|
||||
"""
|
||||
|
||||
T5_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
|
||||
should be able to pad the inputs on both the right and the left.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
detail.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
|
||||
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
|
||||
<./t5.html#training>`__.
|
||||
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are decoder input IDs? <../glossary.html#decoder-input-ids>`__
|
||||
|
||||
T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If
|
||||
:obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
|
||||
:obj:`past_key_values`).
|
||||
|
||||
To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training
|
||||
<./t5.html#training>`__.
|
||||
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||
also be used by default.
|
||||
encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
|
||||
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
|
||||
sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of
|
||||
the decoder.
|
||||
past_key_values (:obj:`tuple(tuple(jnp.ndarray))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
|
||||
use_cache (:obj:`bool`, `optional`):
|
||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||
decoding (see :obj:`past_key_values`).
|
||||
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
|
||||
@ -1252,8 +1258,6 @@ class FlaxT5Module(nn.Module):
|
||||
decoder_config.num_layers = self.config.num_decoder_layers
|
||||
self.decoder = FlaxT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=FlaxSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def __call__(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1266,22 +1270,6 @@ class FlaxT5Module(nn.Module):
|
||||
return_dict=None,
|
||||
deterministic: bool = True,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import T5Tokenizer, FlaxT5Model
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
>>> model = FlaxT5Model.from_pretrained('t5-small')
|
||||
|
||||
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="np").input_ids # Batch size 1
|
||||
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids # Batch size 1
|
||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
@ -1325,6 +1313,32 @@ class FlaxT5Model(FlaxT5PreTrainedModel):
|
||||
module_class = FlaxT5Module
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxT5Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC
|
||||
)
|
||||
|
||||
FLAX_T5_MODEL_DOCSTRING = """
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import T5Tokenizer, FlaxT5Model
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
>>> model = FlaxT5Model.from_pretrained('t5-small')
|
||||
|
||||
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="np").input_ids
|
||||
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
|
||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
|
||||
|
||||
overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTRING)
|
||||
append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
|
||||
|
||||
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
|
||||
class FlaxT5ForConditionalGenerationModule(nn.Module):
|
||||
config: T5Config
|
||||
@ -1364,8 +1378,6 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def __call__(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1378,24 +1390,6 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
|
||||
return_dict=None,
|
||||
deterministic: bool = True,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
||||
|
||||
>>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='np').input_ids
|
||||
>>> decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='np').input_ids
|
||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
>>> logits = outputs.logits
|
||||
|
||||
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="np").input_ids
|
||||
>>> outputs = model.generate(input_ids)
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Encode
|
||||
@ -1479,7 +1473,7 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
|
||||
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
|
||||
>>> text = "My friends are cool but they eat too many carbs."
|
||||
>>> text = "summarize: My friends are cool but they eat too many carbs."
|
||||
>>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
|
||||
>>> encoder_outputs = model.encode(**inputs)
|
||||
|
||||
@ -1614,3 +1608,30 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
|
||||
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||
return model_kwargs
|
||||
|
||||
|
||||
FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
|
||||
|
||||
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
|
||||
>>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs."
|
||||
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=512, return_tensors='jax')
|
||||
|
||||
>>> # Generate Summary
|
||||
>>> summary_ids = model.generate(inputs['input_ids']).sequences
|
||||
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
|
||||
"""
|
||||
|
||||
|
||||
overwrite_call_docstring(
|
||||
FlaxT5ForConditionalGeneration, T5_INPUTS_DOCSTRING + FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING
|
||||
)
|
||||
append_replace_return_docstrings(
|
||||
FlaxT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
|
@ -581,7 +581,7 @@ FLAX_VISION_CLASSIF_DOCSTRING = """
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import FlaxViTFeatureExtractor, ViTForImageClassification
|
||||
>>> from transformers import ViTFeatureExtractor, FlaxViTForImageClassification
|
||||
>>> from PIL import Image
|
||||
>>> import jax
|
||||
>>> import requests
|
||||
@ -595,9 +595,10 @@ FLAX_VISION_CLASSIF_DOCSTRING = """
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="jax")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs.logits
|
||||
|
||||
>>> # model predicts one of the 1000 ImageNet classes
|
||||
>>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
|
||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
|
||||
"""
|
||||
|
||||
overwrite_call_docstring(FlaxViTForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING)
|
||||
|
@ -29,7 +29,12 @@ from jax import lax
|
||||
|
||||
from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
||||
from ...modeling_flax_utils import (
|
||||
ACT2FN,
|
||||
FlaxPreTrainedModel,
|
||||
append_replace_return_docstrings,
|
||||
overwrite_call_docstring,
|
||||
)
|
||||
from ...utils import logging
|
||||
from .configuration_wav2vec2 import Wav2Vec2Config
|
||||
|
||||
@ -853,31 +858,6 @@ class FlaxWav2Vec2Module(nn.Module):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
"""
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2Model
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
>>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
>>> speech, _ = sf.read(batch["file"])
|
||||
>>> batch["speech"] = speech
|
||||
>>> return batch
|
||||
|
||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = processor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
|
||||
>>> hidden_states = model(input_values).last_hidden_state
|
||||
|
||||
"""
|
||||
extract_features = self.feature_extractor(input_values)
|
||||
|
||||
# make sure that no loss is computed on padded inputs
|
||||
@ -947,6 +927,39 @@ class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
|
||||
module_class = FlaxWav2Vec2Module
|
||||
|
||||
|
||||
FLAX_WAV2VEC2_MODEL_DOCSTRING = """
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2Model
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-lv60")
|
||||
>>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60")
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
>>> speech, _ = sf.read(batch["file"])
|
||||
>>> batch["speech"] = speech
|
||||
>>> return batch
|
||||
|
||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="np").input_values # Batch size 1
|
||||
>>> hidden_states = model(input_values).last_hidden_state
|
||||
"""
|
||||
|
||||
overwrite_call_docstring(
|
||||
FlaxWav2Vec2Model,
|
||||
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_MODEL_DOCSTRING,
|
||||
)
|
||||
append_replace_return_docstrings(
|
||||
FlaxWav2Vec2Model, output_type=FlaxWav2Vec2BaseModelOutput, config_class=Wav2Vec2Config
|
||||
)
|
||||
|
||||
|
||||
class FlaxWav2Vec2ForCTCModule(nn.Module):
|
||||
config: Wav2Vec2Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
@ -970,36 +983,6 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2ForCTC
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
>>> model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
>>> speech, _ = sf.read(batch["file"])
|
||||
>>> batch["speech"] = speech
|
||||
>>> return batch
|
||||
|
||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = processor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
|
||||
>>> logits = model(input_values).logits
|
||||
>>> predicted_ids = jnp.argmax(logits, axis=-1)
|
||||
|
||||
>>> transcription = processor.decode(predicted_ids[0])
|
||||
>>> # should give: "A MAN SAID TO THE UNIVERSE SIR I EXIST"
|
||||
|
||||
"""
|
||||
|
||||
outputs = self.wav2vec2(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
@ -1044,6 +1027,46 @@ class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
|
||||
module_class = FlaxWav2Vec2ForCTCModule
|
||||
|
||||
|
||||
FLAX_WAV2VEC2_FOR_CTC_DOCSTRING = """
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2ForCTC
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60")
|
||||
>>> model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60")
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
>>> speech, _ = sf.read(batch["file"])
|
||||
>>> batch["speech"] = speech
|
||||
>>> return batch
|
||||
|
||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="np").input_values # Batch size 1
|
||||
>>> logits = model(input_values).logits
|
||||
>>> predicted_ids = jnp.argmax(logits, axis=-1)
|
||||
|
||||
>>> transcription = processor.decode(predicted_ids[0])
|
||||
>>> # should give: "A MAN SAID TO THE UNIVERSE SIR I EXIST"
|
||||
"""
|
||||
|
||||
overwrite_call_docstring(
|
||||
FlaxWav2Vec2ForCTC,
|
||||
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_CTC_DOCSTRING,
|
||||
)
|
||||
append_replace_return_docstrings(FlaxWav2Vec2ForCTC, output_type=FlaxCausalLMOutput, config_class=Wav2Vec2Config)
|
||||
|
||||
|
||||
class FlaxWav2Vec2ForCTCModule(nn.Module):
|
||||
config: Wav2Vec2Config
|
||||
|
||||
|
||||
class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
|
||||
config: Wav2Vec2Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
@ -1080,43 +1103,6 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
|
||||
|
||||
Example::
|
||||
|
||||
>>> import optax
|
||||
>>> import numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from transformers import Wav2Vec2FeatureExtractor, FlaxWav2Vec2ForPreTraining
|
||||
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
>>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
... speech, _ = sf.read(batch["file"])
|
||||
... batch["speech"] = speech
|
||||
... return batch
|
||||
|
||||
|
||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
|
||||
|
||||
>>> # compute masked indices
|
||||
>>> batch_size, raw_sequence_length = input_values.shape
|
||||
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
|
||||
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
|
||||
|
||||
>>> outputs = model(input_values, mask_time_indices=mask_time_indices)
|
||||
|
||||
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
|
||||
>>> cosine_sim = optax.cosine_similarity(
|
||||
... outputs.projected_states, outputs.projected_quantized_states, axis=-1
|
||||
... )
|
||||
|
||||
>>> # show that cosine similarity is much higher than random
|
||||
>>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5
|
||||
|
||||
"""
|
||||
|
||||
@ -1222,3 +1208,60 @@ class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel):
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
|
||||
FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING = """
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> import optax
|
||||
>>> import numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from transformers import Wav2Vec2FeatureExtractor, FlaxWav2Vec2ForPreTraining
|
||||
>>> from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-lv60")
|
||||
>>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60")
|
||||
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
... speech, _ = sf.read(batch["file"])
|
||||
... batch["speech"] = speech
|
||||
... return batch
|
||||
|
||||
|
||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
|
||||
|
||||
>>> # compute masked indices
|
||||
>>> batch_size, raw_sequence_length = input_values.shape
|
||||
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
|
||||
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
|
||||
|
||||
>>> outputs = model(input_values, mask_time_indices=mask_time_indices)
|
||||
|
||||
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
|
||||
>>> cosine_sim = optax.cosine_similarity(
|
||||
... outputs.projected_states, outputs.projected_quantized_states
|
||||
... )
|
||||
|
||||
>>> # show that cosine similarity is much higher than random
|
||||
>>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5
|
||||
"""
|
||||
|
||||
overwrite_call_docstring(
|
||||
FlaxWav2Vec2ForPreTraining,
|
||||
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING,
|
||||
)
|
||||
append_replace_return_docstrings(
|
||||
FlaxWav2Vec2ForPreTraining, output_type=FlaxWav2Vec2ForPreTrainingOutput, config_class=Wav2Vec2Config
|
||||
)
|
||||
|
||||
|
||||
class FlaxWav2Vec2ForCTCModule(nn.Module):
|
||||
config: Wav2Vec2Config
|
||||
|
@ -1183,7 +1183,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
return logits
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@replace_return_docstrings(output_type=Wav2Vec2ForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
@ -1338,7 +1338,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@replace_return_docstrings(output_type=Wav2Vec2BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
@ -1420,7 +1420,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
|
Loading…
Reference in New Issue
Block a user