[Flax] Align jax flax device name (#12987)

* [Flax] Align device name in docs

* make style

* fix import error
This commit is contained in:
Patrick von Platen 2021-08-04 16:00:09 +02:00 committed by GitHub
parent 07df5578d9
commit da9754a3a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 365 additions and 387 deletions

View File

@ -224,7 +224,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
`What are input IDs? <../glossary.html#input-ids>`__
Returns:
text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings
obtained by applying the projection layer to the pooled output of text model.
"""
if position_ids is None:
@ -273,7 +273,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
:meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
Returns:
image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings
obtained by applying the projection layer to the pooled output of vision model.
"""

View File

@ -19,7 +19,6 @@ from abc import ABC
import jax
import jax.lax as lax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from .file_utils import add_start_docstrings
from .utils.logging import get_logger
@ -30,7 +29,7 @@ logger = get_logger(__name__)
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
@ -38,14 +37,14 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
details.
`What are input IDs? <../glossary.html#input-ids>`__
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`):
scores (:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search
kwargs:
Additional logits processor specific kwargs.
Return:
:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
"""
@ -54,7 +53,7 @@ class FlaxLogitsProcessor(ABC):
"""Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
"""Flax method for processing logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
@ -65,7 +64,7 @@ class FlaxLogitsWarper(ABC):
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
"""Flax method for warping logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
@ -81,9 +80,7 @@ class FlaxLogitsProcessorList(list):
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int, **kwargs
) -> jax_xla.DeviceArray:
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int, **kwargs) -> jnp.ndarray:
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 3:
@ -111,9 +108,7 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
self.temperature = temperature
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
scores = scores / self.temperature
return scores
@ -141,9 +136,7 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
mask_scores = jnp.full_like(scores, self.filter_value)
@ -183,9 +176,7 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
batch_size, vocab_size = scores.shape
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
@ -212,9 +203,7 @@ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
def __init__(self, bos_token_id: int):
self.bos_token_id = bos_token_id
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
new_scores = jnp.full(scores.shape, -float("inf"))
apply_penalty = 1 - jnp.bool_(cur_len - 1)
@ -242,9 +231,7 @@ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
self.max_length = max_length
self.eos_token_id = eos_token_id
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
new_scores = jnp.full(scores.shape, -float("inf"))
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
@ -277,9 +264,7 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
self.min_length = min_length
self.eos_token_id = eos_token_id
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
# create boolean flag to decide if min length penalty should be applied
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)

View File

@ -23,7 +23,6 @@ import numpy as np
import flax
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from jax import lax
from .file_utils import ModelOutput
@ -49,11 +48,11 @@ class FlaxGreedySearchOutput(ModelOutput):
Args:
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`):
The generated sequences.
"""
sequences: jax_xla.DeviceArray = None
sequences: jnp.ndarray = None
@flax.struct.dataclass
@ -63,11 +62,11 @@ class FlaxSampleOutput(ModelOutput):
Args:
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`):
The generated sequences.
"""
sequences: jax_xla.DeviceArray = None
sequences: jnp.ndarray = None
@flax.struct.dataclass
@ -77,44 +76,44 @@ class FlaxBeamSearchOutput(ModelOutput):
Args:
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`):
The generated sequences.
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size,)`):
scores (:obj:`jnp.ndarray` of shape :obj:`(batch_size,)`):
The scores (log probabilites) of the generated sequences.
"""
sequences: jax_xla.DeviceArray = None
scores: jax_xla.DeviceArray = None
sequences: jnp.ndarray = None
scores: jnp.ndarray = None
@flax.struct.dataclass
class GreedyState:
cur_len: jax_xla.DeviceArray
sequences: jax_xla.DeviceArray
running_token: jax_xla.DeviceArray
is_sent_finished: jax_xla.DeviceArray
model_kwargs: Dict[str, jax_xla.DeviceArray]
cur_len: jnp.ndarray
sequences: jnp.ndarray
running_token: jnp.ndarray
is_sent_finished: jnp.ndarray
model_kwargs: Dict[str, jnp.ndarray]
@flax.struct.dataclass
class SampleState:
cur_len: jax_xla.DeviceArray
sequences: jax_xla.DeviceArray
running_token: jax_xla.DeviceArray
is_sent_finished: jax_xla.DeviceArray
prng_key: jax_xla.DeviceArray
model_kwargs: Dict[str, jax_xla.DeviceArray]
cur_len: jnp.ndarray
sequences: jnp.ndarray
running_token: jnp.ndarray
is_sent_finished: jnp.ndarray
prng_key: jnp.ndarray
model_kwargs: Dict[str, jnp.ndarray]
@flax.struct.dataclass
class BeamSearchState:
cur_len: jax_xla.DeviceArray
running_sequences: jax_xla.DeviceArray
running_scores: jax_xla.DeviceArray
sequences: jax_xla.DeviceArray
scores: jax_xla.DeviceArray
is_sent_finished: jax_xla.DeviceArray
model_kwargs: Dict[str, jax_xla.DeviceArray]
cur_len: jnp.ndarray
running_sequences: jnp.ndarray
running_scores: jnp.ndarray
sequences: jnp.ndarray
scores: jnp.ndarray
is_sent_finished: jnp.ndarray
model_kwargs: Dict[str, jnp.ndarray]
class FlaxGenerationMixin:
@ -156,14 +155,14 @@ class FlaxGenerationMixin:
def generate(
self,
input_ids: jax_xla.DeviceArray,
input_ids: jnp.ndarray,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
do_sample: Optional[bool] = None,
prng_key: Optional[jax_xla.DeviceArray] = None,
prng_key: Optional[jnp.ndarray] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
@ -175,7 +174,7 @@ class FlaxGenerationMixin:
length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
params: Optional[Dict[str, jnp.ndarray]] = None,
**model_kwargs,
):
r"""
@ -191,7 +190,7 @@ class FlaxGenerationMixin:
Parameters:
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
@ -217,7 +216,7 @@ class FlaxGenerationMixin:
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
a considerably slower runtime.
params (:obj:`Dict[str, jax_xla.DeviceArray]`, `optional`):
params (:obj:`Dict[str, jnp.ndarray]`, `optional`):
Optionally the model parameters can be passed. Can be useful for parallelized generation.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
@ -395,8 +394,8 @@ class FlaxGenerationMixin:
eos_token_id: Optional[int] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
params: Optional[Dict[str, jnp.ndarray]] = None,
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
):
# init values
max_length = max_length if max_length is not None else self.config.max_length
@ -479,12 +478,12 @@ class FlaxGenerationMixin:
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
prng_key: Optional[jax_xla.DeviceArray] = None,
prng_key: Optional[jnp.ndarray] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
logits_warper: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
params: Optional[Dict[str, jnp.ndarray]] = None,
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
):
# init values
max_length = max_length if max_length is not None else self.config.max_length
@ -580,8 +579,8 @@ class FlaxGenerationMixin:
early_stopping: Optional[bool] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
params: Optional[Dict[str, jnp.ndarray]] = None,
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
):
"""
This beam search function is heavily inspired by Flax's official example:

View File

@ -14,7 +14,7 @@
from typing import Dict, Optional, Tuple
import flax
import jaxlib.xla_extension as jax_xla
import jax.numpy as jnp
from .file_utils import ModelOutput
@ -25,24 +25,24 @@ class FlaxBaseModelOutput(ModelOutput):
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
last_hidden_state: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
@ -51,28 +51,28 @@ class FlaxBaseModelOutputWithPast(ModelOutput):
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (:obj:`Dict[str, jax_xla.DeviceArray]`):
past_key_values (:obj:`Dict[str, jnp.ndarray]`):
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: jax_xla.DeviceArray = None
past_key_values: Optional[Dict[str, jax_xla.DeviceArray]] = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
last_hidden_state: jnp.ndarray = None
past_key_values: Optional[Dict[str, jnp.ndarray]] = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
@ -81,29 +81,29 @@ class FlaxBaseModelOutputWithPooling(ModelOutput):
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`):
pooler_output (:obj:`jnp.ndarray` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
prediction (classification) objective during pretraining.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: jax_xla.DeviceArray = None
pooler_output: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
last_hidden_state: jnp.ndarray = None
pooler_output: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
@ -112,44 +112,44 @@ class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
1, hidden_size)` is output.
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
``config.is_encoder_decoder=True`` 2 additional tensors of shape :obj:`(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
``config.is_encoder_decoder=True`` in the cross-attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
last_hidden_state: jax_xla.DeviceArray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
last_hidden_state: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
@ -159,58 +159,58 @@ class FlaxSeq2SeqModelOutput(ModelOutput):
decoding.
Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
1, hidden_size)` is output.
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
last_hidden_state: jax_xla.DeviceArray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
last_hidden_state: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
@ -219,39 +219,39 @@ class FlaxCausalLMOutputWithCrossAttentions(ModelOutput):
Base class for causal language model (or autoregressive) outputs.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Cross attentions weights after the attention softmax, used to compute the weighted average in the
cross-attention heads.
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`jax_xla.DeviceArray` tuples of length :obj:`config.n_layers`, with each tuple containing the
cached key, value states of the self-attention and the cross-attention layers if model is used in
encoder-decoder setting. Only relevant if ``config.is_decoder = True``.
past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`jnp.ndarray` tuples of length :obj:`config.n_layers`, with each tuple containing the cached
key, value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
setting. Only relevant if ``config.is_decoder = True``.
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
"""
logits: jax_xla.DeviceArray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
logits: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
@ -260,24 +260,24 @@ class FlaxMaskedLMOutput(ModelOutput):
Base class for masked language models outputs.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
FlaxCausalLMOutput = FlaxMaskedLMOutput
@ -289,55 +289,55 @@ class FlaxSeq2SeqLMOutput(ModelOutput):
Base class for sequence-to-sequence language models outputs.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
logits: jax_xla.DeviceArray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
logits: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
@ -346,25 +346,25 @@ class FlaxNextSentencePredictorOutput(ModelOutput):
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
@ -373,24 +373,24 @@ class FlaxSequenceClassifierOutput(ModelOutput):
Base class for outputs of sentence classification models.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`):
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
@ -399,55 +399,55 @@ class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput):
Base class for outputs of sequence-to-sequence sentence classification models.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`):
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
logits: jax_xla.DeviceArray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
logits: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
@ -456,26 +456,26 @@ class FlaxMultipleChoiceModelOutput(ModelOutput):
Base class for outputs of multiple choice models.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, num_choices)`):
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
@ -484,24 +484,24 @@ class FlaxTokenClassifierOutput(ModelOutput):
Base class for outputs of token classification models.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
@ -510,27 +510,27 @@ class FlaxQuestionAnsweringModelOutput(ModelOutput):
Base class for outputs of question answering models.
Args:
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
start_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
end_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
start_logits: jax_xla.DeviceArray = None
end_logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
start_logits: jnp.ndarray = None
end_logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
@ -539,55 +539,55 @@ class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
Base class for outputs of sequence-to-sequence question answering models.
Args:
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
start_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
end_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Span-end scores (before SoftMax).
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
start_logits: jax_xla.DeviceArray = None
end_logits: jax_xla.DeviceArray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
start_logits: jnp.ndarray = None
end_logits: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None

View File

@ -21,7 +21,6 @@ import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict
from flax.linen.attention import dot_product_attention_weights
from jax import lax
@ -61,28 +60,28 @@ class FlaxBertForPreTrainingOutput(ModelOutput):
Output type of :class:`~transformers.BertForPreTraining`.
Args:
prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
prediction_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):
seq_relationship_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
prediction_logits: jax_xla.DeviceArray = None
seq_relationship_logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
prediction_logits: jnp.ndarray = None
seq_relationship_logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
BERT_START_DOCSTRING = r"""

View File

@ -21,7 +21,6 @@ import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict
from flax.linen.attention import dot_product_attention_weights
from jax import lax
@ -59,28 +58,28 @@ class FlaxBigBirdForPreTrainingOutput(ModelOutput):
Output type of :class:`~transformers.BigBirdForPreTraining`.
Args:
prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
prediction_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):
seq_relationship_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
prediction_logits: jax_xla.DeviceArray = None
seq_relationship_logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
prediction_logits: jnp.ndarray = None
seq_relationship_logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
@ -89,30 +88,30 @@ class FlaxBigBirdForQuestionAnsweringModelOutput(ModelOutput):
Base class for outputs of question answering models.
Args:
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
start_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
end_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Span-end scores (before SoftMax).
pooled_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`):
pooled_output (:obj:`jnp.ndarray` of shape :obj:`(batch_size, hidden_size)`):
pooled_output returned by FlaxBigBirdModel.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
start_logits: jax_xla.DeviceArray = None
end_logits: jax_xla.DeviceArray = None
pooled_output: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
start_logits: jnp.ndarray = None
end_logits: jnp.ndarray = None
pooled_output: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
BIG_BIRD_START_DOCSTRING = r"""

View File

@ -19,7 +19,6 @@ import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
@ -156,16 +155,16 @@ CLIP_INPUTS_DOCSTRING = r"""
class FlaxCLIPOutput(ModelOutput):
"""
Args:
logits_per_image:(:obj:`jax_xla.DeviceArray` of shape :obj:`(image_batch_size, text_batch_size)`):
logits_per_image:(:obj:`jnp.ndarray` of shape :obj:`(image_batch_size, text_batch_size)`):
The scaled dot product scores between :obj:`image_embeds` and :obj:`text_embeds`. This represents the
image-text similarity scores.
logits_per_text:(:obj:`jax_xla.DeviceArray` of shape :obj:`(text_batch_size, image_batch_size)`):
logits_per_text:(:obj:`jnp.ndarray` of shape :obj:`(text_batch_size, image_batch_size)`):
The scaled dot product scores between :obj:`text_embeds` and :obj:`image_embeds`. This represents the
text-image similarity scores.
text_embeds(:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`):
text_embeds(:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of
:class:`~transformers.FlaxCLIPTextModel`.
image_embeds(:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`):
image_embeds(:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of
:class:`~transformers.FlaxCLIPVisionModel`.
text_model_output(:obj:`FlaxBaseModelOutputWithPooling`):
@ -174,10 +173,10 @@ class FlaxCLIPOutput(ModelOutput):
The output of the :class:`~transformers.FlaxCLIPVisionModel`.
"""
logits_per_image: jax_xla.DeviceArray = None
logits_per_text: jax_xla.DeviceArray = None
text_embeds: jax_xla.DeviceArray = None
image_embeds: jax_xla.DeviceArray = None
logits_per_image: jnp.ndarray = None
logits_per_text: jnp.ndarray = None
text_embeds: jnp.ndarray = None
image_embeds: jnp.ndarray = None
text_model_output: FlaxBaseModelOutputWithPooling = None
vision_model_output: FlaxBaseModelOutputWithPooling = None
@ -801,8 +800,8 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
`What are input IDs? <../glossary.html#input-ids>`__
Returns:
text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
obtained by applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPTextModel`.
text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings obtained by
applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPTextModel`.
Examples::
@ -855,9 +854,8 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
:meth:`transformers.CLIPFeatureExtractor.__call__` for details.
Returns:
image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
obtained by applying the projection layer to the pooled output of
:class:`~transformers.FlaxCLIPVisionModel`
image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings obtained
by applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPVisionModel`
Examples::

View File

@ -21,7 +21,6 @@ import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict
from flax.linen.attention import dot_product_attention_weights
from jax import lax
@ -60,24 +59,24 @@ class FlaxElectraForPreTrainingOutput(ModelOutput):
Output type of :class:`~transformers.ElectraForPreTraining`.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
ELECTRA_START_DOCSTRING = r"""

View File

@ -44,7 +44,6 @@ if is_flax_available():
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict
from transformers import (
@ -127,7 +126,7 @@ class FlaxModelTesterMixin:
if "ForMultipleChoice" in model_class.__name__:
inputs_dict = {
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
if isinstance(v, (jnp.ndarray, np.ndarray))
else v
for k, v in inputs_dict.items()
}