[Flax] Correct flax docs (#12782)

* fix_torch_device_generate_test

* remove @

* fix flax docs

* correct more docs in flax

* another correction

* fix flax docs

* Apply suggestions from code review
This commit is contained in:
Patrick von Platen 2021-08-04 16:31:23 +02:00 committed by GitHub
parent a317e6c3be
commit fbf468b057
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 403 additions and 224 deletions

View File

@ -299,3 +299,93 @@ TFSeq2SeqQuestionAnsweringModelOutput
.. autoclass:: transformers.modeling_tf_outputs.TFSeq2SeqQuestionAnsweringModelOutput
:members:
FlaxBaseModelOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutput
FlaxBaseModelOutputWithPast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPast
FlaxBaseModelOutputWithPooling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPooling
FlaxBaseModelOutputWithPastAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions
FlaxSeq2SeqModelOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqModelOutput
FlaxCausalLMOutputWithCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions
FlaxMaskedLMOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxMaskedLMOutput
FlaxSeq2SeqLMOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqLMOutput
FlaxNextSentencePredictorOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxNextSentencePredictorOutput
FlaxSequenceClassifierOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxSequenceClassifierOutput
FlaxSeq2SeqSequenceClassifierOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqSequenceClassifierOutput
FlaxMultipleChoiceModelOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxMultipleChoiceModelOutput
FlaxTokenClassifierOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxTokenClassifierOutput
FlaxQuestionAnsweringModelOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxQuestionAnsweringModelOutput
FlaxSeq2SeqQuestionAnsweringModelOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqQuestionAnsweringModelOutput

View File

@ -76,6 +76,9 @@ Bert specific outputs
.. autoclass:: transformers.models.bert.modeling_tf_bert.TFBertForPreTrainingOutput
:members:
.. autoclass:: transformers.models.bert.modeling_flax_bert.FlaxBertForPreTrainingOutput
:members:
BertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -67,6 +67,22 @@ Wav2Vec2Processor
:members: __call__, pad, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor
Wav2Vec2 specific outputs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput
:members:
.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput
:members:
.. autoclass:: transformers.models.wav2vec2.modeling_flax_wav2vec2.FlaxWav2Vec2BaseModelOutput
:members:
.. autoclass:: transformers.models.wav2vec2.modeling_flax_wav2vec2.FlaxWav2Vec2ForPreTrainingOutput
:members:
Wav2Vec2Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1063,7 +1063,7 @@ FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
>>> outputs = model(**inputs, labels=labels)
>>> outputs = model(**inputs)
>>> logits = outputs.logits
"""
@ -1122,9 +1122,10 @@ FLAX_CAUSAL_LM_SAMPLE = r"""
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
>>> outputs = model(**inputs, labels=inputs["input_ids"])
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # retrieve logts for next token
>>> next_token_logits = outputs.logits[:, -1]
"""
FLAX_SAMPLE_DOCSTRINGS = {

View File

@ -30,7 +30,7 @@ from flax.linen.attention import dot_product_attention_weights
from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, replace_return_docstrings
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
@ -1167,6 +1167,7 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
return outputs
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray,
@ -1520,7 +1521,7 @@ FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING = """
>>> input_ids = tokenizer([TXT], return_tensors='jax')['input_ids']
>>> logits = model(input_ids).logits
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item()
>>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
>>> values, predictions = jax.lax.top_k(probs)

View File

@ -941,7 +941,7 @@ FLAX_CLIP_TEXT_MODEL_DOCSTRING = """
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooled_output # pooled (EOS token) states
>>> pooler_output = outputs.pooler_output # pooled (EOS token) states
"""
overwrite_call_docstring(FlaxCLIPTextModel, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_DOCSTRING)
@ -997,7 +997,7 @@ FLAX_CLIP_VISION_MODEL_DOCSTRING = """
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooled_output # pooled CLS states
>>> pooler_output = outputs.pooler_output # pooled CLS states
"""
overwrite_call_docstring(FlaxCLIPVisionModel, CLIP_VISION_INPUTS_DOCSTRING + FLAX_CLIP_VISION_MODEL_DOCSTRING)

View File

@ -30,7 +30,7 @@ from flax.linen.attention import dot_product_attention_weights
from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, replace_return_docstrings
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
@ -45,7 +45,7 @@ from .configuration_marian import MarianConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de'"
_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de"
_CONFIG_FOR_DOC = "MarianConfig"
_TOKENIZER_FOR_DOC = "MarianTokenizer"
@ -1125,6 +1125,7 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
return outputs
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray,

View File

@ -30,7 +30,7 @@ from flax.linen.attention import dot_product_attention_weights
from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, replace_return_docstrings
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
@ -1192,6 +1192,7 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
return outputs
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray,
@ -1517,36 +1518,37 @@ class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel):
return model_kwargs
FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING = """
FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING = r"""
Returns:
Summarization example::
>>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration
>>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration, MBartConfig
>>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25')
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25')
>>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='jax')
>>> ARTICLE_TO_SUMMARIZE = "Meine Freunde sind cool, aber sie essen zu viel Kuchen."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np')
>>> # Generate Summary
>>> summary_ids = model.generate(inputs['input_ids'], decoder_start_token_id=tokenizer.lang_code_to_id[tgt_lang]).sequences
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
>>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True).sequences
>>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
Mask filling example::
>>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25')
>>> TXT = "My friends are <mask> but they eat too many carbs."
>>> # de_DE is the language symbol id <LID> for German
>>> TXT = "</s> Meine Freunde sind <mask> nett aber sie essen zu viel Kuchen. </s> de_DE"
>>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25')
>>> input_ids = tokenizer([TXT], return_tensors='jax')['input_ids']
>>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors='np')['input_ids']
>>> logits = model(input_ids).logits
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
>>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
>>> values, predictions = jax.lax.top_k(probs)
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item()
>>> probs = logits[0, masked_index].softmax(dim=0)
>>> values, predictions = probs.topk(5)
>>> tokenizer.decode(predictions).split()
"""

View File

@ -36,13 +36,20 @@ from ...modeling_flax_outputs import (
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
)
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import logging
from .configuration_t5 import T5Config
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "t5-small"
_CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"
@ -844,6 +851,69 @@ T5_DECODE_INPUTS_DOCSTRING = r"""
"""
T5_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
should be able to pad the inputs on both the right and the left.
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
detail.
`What are input IDs? <../glossary.html#input-ids>`__
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
<./t5.html#training>`__.
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are decoder input IDs? <../glossary.html#decoder-input-ids>`__
T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If
:obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
:obj:`past_key_values`).
To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training
<./t5.html#training>`__.
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of
the decoder.
past_key_values (:obj:`tuple(tuple(jnp.ndarray))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -884,6 +954,7 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
decoder_attention_mask,
)["params"]
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray,
@ -1155,71 +1226,6 @@ T5_START_DOCSTRING = r"""
model weights.
"""
T5_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
should be able to pad the inputs on both the right and the left.
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
detail.
`What are input IDs? <../glossary.html#input-ids>`__
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
<./t5.html#training>`__.
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are decoder input IDs? <../glossary.html#decoder-input-ids>`__
T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If
:obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
:obj:`past_key_values`).
To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training
<./t5.html#training>`__.
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of
the decoder.
past_key_values (:obj:`tuple(tuple(jnp.ndarray))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
@add_start_docstrings(
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
@ -1252,8 +1258,6 @@ class FlaxT5Module(nn.Module):
decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype)
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def __call__(
self,
input_ids=None,
@ -1266,22 +1270,6 @@ class FlaxT5Module(nn.Module):
return_dict=None,
deterministic: bool = True,
):
r"""
Returns:
Example::
>>> from transformers import T5Tokenizer, FlaxT5Model
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = FlaxT5Model.from_pretrained('t5-small')
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="np").input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Encode if needed (training, first prediction pass)
@ -1325,6 +1313,32 @@ class FlaxT5Model(FlaxT5PreTrainedModel):
module_class = FlaxT5Module
append_call_sample_docstring(
FlaxT5Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC
)
FLAX_T5_MODEL_DOCSTRING = """
Returns:
Example::
>>> from transformers import T5Tokenizer, FlaxT5Model
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = FlaxT5Model.from_pretrained('t5-small')
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="np").input_ids
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
"""
overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTRING)
append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
class FlaxT5ForConditionalGenerationModule(nn.Module):
config: T5Config
@ -1364,8 +1378,6 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
dtype=self.dtype,
)
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def __call__(
self,
input_ids=None,
@ -1378,24 +1390,6 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
return_dict=None,
deterministic: bool = True,
):
r"""
Returns:
Examples::
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
>>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='np').input_ids
>>> decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='np').input_ids
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> logits = outputs.logits
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="np").input_ids
>>> outputs = model.generate(input_ids)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Encode
@ -1479,7 +1473,7 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> text = "My friends are cool but they eat too many carbs."
>>> text = "summarize: My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
>>> encoder_outputs = model.encode(**inputs)
@ -1614,3 +1608,30 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
return model_kwargs
FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """
Returns:
Example::
>>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=512, return_tensors='jax')
>>> # Generate Summary
>>> summary_ids = model.generate(inputs['input_ids']).sequences
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
"""
overwrite_call_docstring(
FlaxT5ForConditionalGeneration, T5_INPUTS_DOCSTRING + FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING
)
append_replace_return_docstrings(
FlaxT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)

View File

@ -581,7 +581,7 @@ FLAX_VISION_CLASSIF_DOCSTRING = """
Example::
>>> from transformers import FlaxViTFeatureExtractor, ViTForImageClassification
>>> from transformers import ViTFeatureExtractor, FlaxViTForImageClassification
>>> from PIL import Image
>>> import jax
>>> import requests
@ -595,9 +595,10 @@ FLAX_VISION_CLASSIF_DOCSTRING = """
>>> inputs = feature_extractor(images=image, return_tensors="jax")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
>>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
"""
overwrite_call_docstring(FlaxViTForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING)

View File

@ -29,7 +29,12 @@ from jax import lax
from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import logging
from .configuration_wav2vec2 import Wav2Vec2Config
@ -853,31 +858,6 @@ class FlaxWav2Vec2Module(nn.Module):
output_hidden_states=None,
return_dict=None,
):
"""
Returns:
Example::
>>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2Model
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
>>> def map_to_array(batch):
>>> speech, _ = sf.read(batch["file"])
>>> batch["speech"] = speech
>>> return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = processor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
>>> hidden_states = model(input_values).last_hidden_state
"""
extract_features = self.feature_extractor(input_values)
# make sure that no loss is computed on padded inputs
@ -947,6 +927,39 @@ class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2Module
FLAX_WAV2VEC2_MODEL_DOCSTRING = """
Returns:
Example::
>>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2Model
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-lv60")
>>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60")
>>> def map_to_array(batch):
>>> speech, _ = sf.read(batch["file"])
>>> batch["speech"] = speech
>>> return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="np").input_values # Batch size 1
>>> hidden_states = model(input_values).last_hidden_state
"""
overwrite_call_docstring(
FlaxWav2Vec2Model,
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_MODEL_DOCSTRING,
)
append_replace_return_docstrings(
FlaxWav2Vec2Model, output_type=FlaxWav2Vec2BaseModelOutput, config_class=Wav2Vec2Config
)
class FlaxWav2Vec2ForCTCModule(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
@ -970,36 +983,6 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
Example::
>>> import jax.numpy as jnp
>>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2ForCTC
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
>>> def map_to_array(batch):
>>> speech, _ = sf.read(batch["file"])
>>> batch["speech"] = speech
>>> return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = processor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
>>> logits = model(input_values).logits
>>> predicted_ids = jnp.argmax(logits, axis=-1)
>>> transcription = processor.decode(predicted_ids[0])
>>> # should give: "A MAN SAID TO THE UNIVERSE SIR I EXIST"
"""
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
@ -1044,6 +1027,46 @@ class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2ForCTCModule
FLAX_WAV2VEC2_FOR_CTC_DOCSTRING = """
Returns:
Example::
>>> import jax.numpy as jnp
>>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2ForCTC
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60")
>>> model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60")
>>> def map_to_array(batch):
>>> speech, _ = sf.read(batch["file"])
>>> batch["speech"] = speech
>>> return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="np").input_values # Batch size 1
>>> logits = model(input_values).logits
>>> predicted_ids = jnp.argmax(logits, axis=-1)
>>> transcription = processor.decode(predicted_ids[0])
>>> # should give: "A MAN SAID TO THE UNIVERSE SIR I EXIST"
"""
overwrite_call_docstring(
FlaxWav2Vec2ForCTC,
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_CTC_DOCSTRING,
)
append_replace_return_docstrings(FlaxWav2Vec2ForCTC, output_type=FlaxCausalLMOutput, config_class=Wav2Vec2Config)
class FlaxWav2Vec2ForCTCModule(nn.Module):
config: Wav2Vec2Config
class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
@ -1080,43 +1103,6 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
Example::
>>> import optax
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from transformers import Wav2Vec2FeatureExtractor, FlaxWav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
>>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
>>> outputs = model(input_values, mask_time_indices=mask_time_indices)
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
>>> cosine_sim = optax.cosine_similarity(
... outputs.projected_states, outputs.projected_quantized_states, axis=-1
... )
>>> # show that cosine similarity is much higher than random
>>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5
"""
@ -1222,3 +1208,60 @@ class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel):
return_dict,
rngs=rngs,
)
FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING = """
Returns:
Example::
>>> import optax
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from transformers import Wav2Vec2FeatureExtractor, FlaxWav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-lv60")
>>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
>>> outputs = model(input_values, mask_time_indices=mask_time_indices)
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
>>> cosine_sim = optax.cosine_similarity(
... outputs.projected_states, outputs.projected_quantized_states
... )
>>> # show that cosine similarity is much higher than random
>>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5
"""
overwrite_call_docstring(
FlaxWav2Vec2ForPreTraining,
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING,
)
append_replace_return_docstrings(
FlaxWav2Vec2ForPreTraining, output_type=FlaxWav2Vec2ForPreTrainingOutput, config_class=Wav2Vec2Config
)
class FlaxWav2Vec2ForCTCModule(nn.Module):
config: Wav2Vec2Config

View File

@ -1183,7 +1183,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
return logits
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=Wav2Vec2ForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,
@ -1338,7 +1338,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
self.init_weights()
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=Wav2Vec2BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,
@ -1420,7 +1420,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
self.wav2vec2.feature_extractor._freeze_parameters()
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,