[Flax] Add docstrings & model outputs (#11498)

* add attentions & hidden states

* add model outputs + docs

* finish docs

* finish tests

* finish impl

* del @

* finish

* finish

* correct test

* apply sylvains suggestions

* Update src/transformers/models/bert/modeling_flax_bert.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* simplify more

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen 2021-04-29 12:04:51 +02:00 committed by GitHub
parent 3f6add8bab
commit f748bd4242
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1130 additions and 90 deletions

View File

@ -794,6 +794,17 @@ PT_CAUSAL_LM_SAMPLE = r"""
>>> logits = outputs.logits
"""
PT_SAMPLE_DOCSTRINGS = {
"SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
"TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
"MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
"MaskedLM": PT_MASKED_LM_SAMPLE,
"LMHead": PT_CAUSAL_LM_SAMPLE,
"BaseModel": PT_BASE_MODEL_SAMPLE,
}
TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example::
@ -915,30 +926,148 @@ TF_CAUSAL_LM_SAMPLE = r"""
>>> logits = outputs.logits
"""
TF_SAMPLE_DOCSTRINGS = {
"SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE,
"TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE,
"MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE,
"MaskedLM": TF_MASKED_LM_SAMPLE,
"LMHead": TF_CAUSAL_LM_SAMPLE,
"BaseModel": TF_BASE_MODEL_SAMPLE,
}
FLAX_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
>>> outputs = model(**inputs)
>>> logits = outputs.logits
"""
FLAX_QUESTION_ANSWERING_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> inputs = tokenizer(question, text, return_tensors='jax')
>>> outputs = model(**inputs)
>>> start_scores = outputs.start_logits
>>> end_scores = outputs.end_logits
"""
FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
>>> outputs = model(**inputs, labels=labels)
>>> logits = outputs.logits
"""
FLAX_MASKED_LM_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors='jax')
>>> outputs = model(**inputs)
>>> logits = outputs.logits
"""
FLAX_BASE_MODEL_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> choice0 = "It is eaten with a fork and a knife."
>>> choice1 = "It is eaten while held in the hand."
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='jax', padding=True)
>>> outputs = model(**{{k: v[None, :] for k,v in encoding.items()}})
>>> logits = outputs.logits
"""
FLAX_SAMPLE_DOCSTRINGS = {
"SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
"TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE,
"MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
"MaskedLM": FLAX_MASKED_LM_SAMPLE,
"BaseModel": FLAX_BASE_MODEL_SAMPLE,
}
def add_code_sample_docstrings(
*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None
*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None, model_cls=None
):
def docstring_decorator(fn):
model_class = fn.__qualname__.split(".")[0]
is_tf_class = model_class[:2] == "TF"
# model_class defaults to function's class if not specified otherwise
model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
if model_class[:2] == "TF":
sample_docstrings = TF_SAMPLE_DOCSTRINGS
elif model_class[:4] == "Flax":
sample_docstrings = FLAX_SAMPLE_DOCSTRINGS
else:
sample_docstrings = PT_SAMPLE_DOCSTRINGS
doc_kwargs = dict(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
if "SequenceClassification" in model_class:
code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE
code_sample = sample_docstrings["SequenceClassification"]
elif "QuestionAnswering" in model_class:
code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE
code_sample = sample_docstrings["QuestionAnswering"]
elif "TokenClassification" in model_class:
code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE
code_sample = sample_docstrings["TokenClassification"]
elif "MultipleChoice" in model_class:
code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE
code_sample = sample_docstrings["MultipleChoice"]
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
doc_kwargs["mask"] = "[MASK]" if mask is None else mask
code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
code_sample = sample_docstrings["MaskedLM"]
elif "LMHead" in model_class or "CausalLM" in model_class:
code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
code_sample = sample_docstrings["LMHead"]
elif "Model" in model_class or "Encoder" in model_class:
code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
code_sample = sample_docstrings["BaseModel"]
else:
raise ValueError(f"Docstring can't be built for model {model_class}")
@ -1462,7 +1591,10 @@ def tf_required(func):
def is_tensor(x):
"""Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`."""
"""
Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or
:obj:`np.ndarray`.
"""
if is_torch_available():
import torch
@ -1473,6 +1605,14 @@ def is_tensor(x):
if isinstance(x, tf.Tensor):
return True
if is_flax_available():
import jaxlib.xla_extension as jax_xla
from jax.interpreters.partial_eval import DynamicJaxprTracer
if isinstance(x, (jax_xla.DeviceArray, DynamicJaxprTracer)):
return True
return isinstance(x, np.ndarray)

View File

@ -0,0 +1,239 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple
import jaxlib.xla_extension as jax_xla
from .file_utils import ModelOutput
@dataclass
class FlaxBaseModelOutput(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxBaseModelOutputWithPooling(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
prediction (classification) objective during pretraining.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: jax_xla.DeviceArray = None
pooler_output: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxMaskedLMOutput(ModelOutput):
"""
Base class for masked language models outputs.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxNextSentencePredictorOutput(ModelOutput):
"""
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxTokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxQuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of question answering models.
Args:
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
start_logits: jax_xla.DeviceArray = None
end_logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None

View File

@ -32,12 +32,14 @@ from .file_utils import (
FLAX_WEIGHTS_NAME,
WEIGHTS_NAME,
PushToHubMixin,
add_code_sample_docstrings,
add_start_docstrings_to_model_forward,
cached_path,
copy_func,
hf_bucket_url,
is_offline_mode,
is_remote_url,
replace_return_docstrings,
)
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import logging
@ -432,3 +434,22 @@ def overwrite_call_docstring(model_class, docstring):
model_class.__call__.__doc__ = None
# set correct docstring
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
def append_call_sample_docstring(model_class, tokenizer_class, checkpoint, output_type, config_class, mask=None):
model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__ = add_code_sample_docstrings(
tokenizer_class=tokenizer_class,
checkpoint=checkpoint,
output_type=output_type,
config_class=config_class,
model_cls=model_class.__name__,
)(model_class.__call__)
def append_replace_return_docstrings(model_class, output_type, config_class):
model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__ = replace_return_docstrings(
output_type=output_type,
config_class=config_class,
)(model_class.__call__)

View File

@ -13,30 +13,79 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Tuple
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention
from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, overwrite_call_docstring
from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPooling,
FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput,
FlaxNextSentencePredictorOutput,
FlaxQuestionAnsweringModelOutput,
FlaxSequenceClassifierOutput,
FlaxTokenClassifierOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import logging
from .configuration_bert import BertConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "bert-base-uncased"
_CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer"
@dataclass
class FlaxBertForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.BertForPreTraining`.
Args:
prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
prediction_logits: jax_xla.DeviceArray = None
seq_relationship_logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
BERT_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
@ -166,7 +215,7 @@ class FlaxBertSelfAttention(nn.Module):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
def __call__(self, hidden_states, attention_mask, deterministic=True):
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
head_dim = self.config.hidden_size // self.config.num_attention_heads
query_states = self.query(hidden_states).reshape(
@ -208,7 +257,12 @@ class FlaxBertSelfAttention(nn.Module):
precision=None,
)
return attn_output.reshape(attn_output.shape[:2] + (-1,))
outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),)
# TODO: at the moment it's not possible to retrieve attn_weights from
# dot_product_attention, but should be in the future -> add functionality then
return outputs
class FlaxBertSelfOutput(nn.Module):
@ -239,13 +293,22 @@ class FlaxBertAttention(nn.Module):
self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype)
self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic=True):
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
attn_outputs = self.self(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
attn_output = attn_outputs[0]
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
return hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += attn_outputs[1]
return outputs
class FlaxBertIntermediate(nn.Module):
@ -295,11 +358,20 @@ class FlaxBertLayer(nn.Module):
self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
self.output = FlaxBertOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False):
attention_outputs = self.attention(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
attention_output = attention_outputs[0]
hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
return hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attention_outputs[1],)
return outputs
class FlaxBertLayerCollection(nn.Module):
@ -311,10 +383,40 @@ class FlaxBertLayerCollection(nn.Module):
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
return hidden_states
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class FlaxBertEncoder(nn.Module):
@ -324,8 +426,23 @@ class FlaxBertEncoder(nn.Module):
def setup(self):
self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return self.layer(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class FlaxBertPooler(nn.Module):
@ -456,7 +573,21 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if output_attentions:
raise NotImplementedError(
"Currently attention scores cannot be returned. Please set `output_attentions` to False for now."
)
# init input tensors if not passed
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
@ -479,6 +610,9 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
@ -493,17 +627,43 @@ class FlaxBertModule(nn.Module):
self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
)
hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
outputs = self.encoder(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
if not self.add_pooling_layer:
return hidden_states
if not return_dict:
# if pooled is None, don't return it
if pooled is None:
return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:]
pooled = self.pooler(hidden_states)
return hidden_states, pooled
return FlaxBaseModelOutputWithPooling(
last_hidden_state=hidden_states,
pooler_output=pooled,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -514,6 +674,11 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
module_class = FlaxBertModule
append_call_sample_docstring(
FlaxBertModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC
)
class FlaxBertForPreTrainingModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
@ -523,11 +688,27 @@ class FlaxBertForPreTrainingModule(nn.Module):
self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
hidden_states, pooled_output = self.bert(
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
outputs = self.bert(
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.tie_word_embeddings:
@ -535,11 +716,22 @@ class FlaxBertForPreTrainingModule(nn.Module):
else:
shared_embedding = None
hidden_states = outputs[0]
pooled_output = outputs[1]
prediction_scores, seq_relationship_score = self.cls(
hidden_states, pooled_output, shared_embedding=shared_embedding
)
return (prediction_scores, seq_relationship_score)
if not return_dict:
return (prediction_scores, seq_relationship_score) + outputs[2:]
return FlaxBertForPreTrainingOutput(
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -553,6 +745,32 @@ class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
module_class = FlaxBertForPreTrainingModule
FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """
Returns:
Example::
>>> from transformers import BertTokenizer, FlaxBertForPreTraining
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = FlaxBertForPreTraining.from_pretrained('bert-base-uncased')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.prediction_logits
>>> seq_relationship_logits = outputs.seq_relationship_logits
"""
overwrite_call_docstring(
FlaxBertForPreTraining,
BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_PRETRAINING_DOCSTRING,
)
append_replace_return_docstrings(
FlaxBertForPreTraining, output_type=FlaxBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
)
class FlaxBertForMaskedLMModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
@ -562,11 +780,29 @@ class FlaxBertForMaskedLMModule(nn.Module):
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
outputs = self.bert(
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
else:
@ -575,7 +811,14 @@ class FlaxBertForMaskedLMModule(nn.Module):
# Compute the prediction scores
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
return (logits,)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxMaskedLMOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
@ -583,6 +826,11 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
module_class = FlaxBertForMaskedLMModule
append_call_sample_docstring(
FlaxBertForMaskedLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC
)
class FlaxBertForNextSentencePredictionModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
@ -592,15 +840,41 @@ class FlaxBertForNextSentencePredictionModule(nn.Module):
self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return_dict = return_dict if return_dict is not None else self.config.return_dict
# Model
_, pooled_output = self.bert(
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
outputs = self.bert(
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
seq_relationship_scores = self.cls(pooled_output)
return (seq_relationship_scores,)
if not return_dict:
return (seq_relationship_scores,) + outputs[2:]
return FlaxNextSentencePredictorOutput(
logits=seq_relationship_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -611,6 +885,35 @@ class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
module_class = FlaxBertForNextSentencePredictionModule
FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING = """
Returns:
Example::
>>> from transformers import BertTokenizer, FlaxBertForNextSentencePrediction
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = FlaxBertForNextSentencePrediction.from_pretrained('bert-base-uncased')
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='jax')
>>> outputs = model(**encoding)
>>> logits = outputs.logits
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
"""
overwrite_call_docstring(
FlaxBertForNextSentencePrediction,
BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING,
)
append_replace_return_docstrings(
FlaxBertForNextSentencePrediction, output_type=FlaxNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
)
class FlaxBertForSequenceClassificationModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
@ -624,17 +927,40 @@ class FlaxBertForSequenceClassificationModule(nn.Module):
)
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
_, pooled_output = self.bert(
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
outputs = self.bert(
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
logits = self.classifier(pooled_output)
return (logits,)
if not return_dict:
return (logits,) + outputs[2:]
return FlaxSequenceClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -648,6 +974,15 @@ class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
module_class = FlaxBertForSequenceClassificationModule
append_call_sample_docstring(
FlaxBertForSequenceClassification,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxSequenceClassifierOutput,
_CONFIG_FOR_DOC,
)
class FlaxBertForMultipleChoiceModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
@ -658,7 +993,15 @@ class FlaxBertForMultipleChoiceModule(nn.Module):
self.classifier = nn.Dense(1, dtype=self.dtype)
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
num_choices = input_ids.shape[1]
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
@ -667,16 +1010,31 @@ class FlaxBertForMultipleChoiceModule(nn.Module):
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
# Model
_, pooled_output = self.bert(
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
outputs = self.bert(
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
logits = self.classifier(pooled_output)
reshaped_logits = logits.reshape(-1, num_choices)
return (reshaped_logits,)
if not return_dict:
return (reshaped_logits,) + outputs[2:]
return FlaxMultipleChoiceModelOutput(
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -690,10 +1048,12 @@ class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
module_class = FlaxBertForMultipleChoiceModule
# adapt docstring slightly for FlaxBertForMultipleChoice
overwrite_call_docstring(
FlaxBertForMultipleChoice, BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
append_call_sample_docstring(
FlaxBertForMultipleChoice, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMultipleChoiceModelOutput, _CONFIG_FOR_DOC
)
class FlaxBertForTokenClassificationModule(nn.Module):
@ -706,15 +1066,40 @@ class FlaxBertForTokenClassificationModule(nn.Module):
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
outputs = self.bert(
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
logits = self.classifier(hidden_states)
return (logits,)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxTokenClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -728,6 +1113,11 @@ class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
module_class = FlaxBertForTokenClassificationModule
append_call_sample_docstring(
FlaxBertForTokenClassification, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxTokenClassifierOutput, _CONFIG_FOR_DOC
)
class FlaxBertForQuestionAnsweringModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
@ -737,17 +1127,44 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
outputs = self.bert(
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
return (start_logits, end_logits)
if not return_dict:
return (start_logits, end_logits) + outputs[1:]
return FlaxQuestionAnsweringModelOutput(
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -759,3 +1176,12 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
)
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
module_class = FlaxBertForQuestionAnsweringModule
append_call_sample_docstring(
FlaxBertForQuestionAnswering,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC,
)

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from typing import Optional, Tuple
import flax.linen as nn
import jax
@ -23,13 +23,15 @@ from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from ...utils import logging
from .configuration_roberta import RobertaConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "roberta-base"
_CONFIG_FOR_DOC = "RobertaConfig"
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
@ -181,7 +183,7 @@ class FlaxRobertaSelfAttention(nn.Module):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
def __call__(self, hidden_states, attention_mask, deterministic=True):
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
head_dim = self.config.hidden_size // self.config.num_attention_heads
query_states = self.query(hidden_states).reshape(
@ -223,7 +225,12 @@ class FlaxRobertaSelfAttention(nn.Module):
precision=None,
)
return attn_output.reshape(attn_output.shape[:2] + (-1,))
outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),)
# TODO: at the moment it's not possible to retrieve attn_weights from
# dot_product_attention, but should be in the future -> add functionality then
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
@ -256,13 +263,22 @@ class FlaxRobertaAttention(nn.Module):
self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype)
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic=True):
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
attn_outputs = self.self(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
attn_output = attn_outputs[0]
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
return hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += attn_outputs[1]
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
@ -315,11 +331,20 @@ class FlaxRobertaLayer(nn.Module):
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False):
attention_outputs = self.attention(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
attention_output = attention_outputs[0]
hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
return hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attention_outputs[1],)
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
@ -332,10 +357,40 @@ class FlaxRobertaLayerCollection(nn.Module):
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
return hidden_states
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
@ -346,8 +401,23 @@ class FlaxRobertaEncoder(nn.Module):
def setup(self):
self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return self.layer(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
@ -412,7 +482,21 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if output_attentions:
raise NotImplementedError(
"Currently attention scores cannot be returned." "Please set `output_attentions` to False for now."
)
# init input tensors if not passed
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
@ -435,6 +519,9 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
@ -450,17 +537,43 @@ class FlaxRobertaModule(nn.Module):
self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype)
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
)
hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
outputs = self.encoder(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
if not self.add_pooling_layer:
return hidden_states
if not return_dict:
# if pooled is None, don't return it
if pooled is None:
return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:]
pooled = self.pooler(hidden_states)
return hidden_states, pooled
return FlaxBaseModelOutputWithPooling(
last_hidden_state=hidden_states,
pooler_output=pooled,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -469,3 +582,8 @@ class FlaxRobertaModule(nn.Module):
)
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
module_class = FlaxRobertaModule
append_call_sample_docstring(
FlaxRobertaModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC
)

View File

@ -998,7 +998,6 @@ class ModelTesterMixin:
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):

View File

@ -13,8 +13,10 @@
# limitations under the License.
import copy
import inspect
import random
import tempfile
from typing import List, Tuple
import numpy as np
@ -28,6 +30,7 @@ if is_flax_available():
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
@ -77,6 +80,7 @@ class FlaxModelTesterMixin:
inputs_dict = {
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
for k, v in inputs_dict.items()
if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
}
return inputs_dict
@ -85,6 +89,41 @@ class FlaxModelTesterMixin:
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assert_almost_equals(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -108,7 +147,7 @@ class FlaxModelTesterMixin:
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict)
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
@ -117,7 +156,7 @@ class FlaxModelTesterMixin:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
@ -149,7 +188,7 @@ class FlaxModelTesterMixin:
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict)
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
@ -171,17 +210,20 @@ class FlaxModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if model_class.__name__ != "FlaxBertModel":
continue
with self.subTest(model_class.__name__):
model = model_class(config)
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**prepared_inputs_dict)
outputs = model(**prepared_inputs_dict).to_tuple()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_loaded = model_class.from_pretrained(tmpdirname)
outputs_loaded = model_loaded(**prepared_inputs_dict)
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 1e-3)
@ -195,19 +237,47 @@ class FlaxModelTesterMixin:
@jax.jit
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
return model(input_ids, attention_mask, token_type_ids)
return model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
).to_tuple()
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict)
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict)
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@jax.jit
def model_jitted_return_dict(input_ids, attention_mask=None, token_type_ids=None):
return model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
# jitted function cannot return OrderedDict
with self.assertRaises(TypeError):
model_jitted_return_dict(**prepared_inputs_dict)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.__call__)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_ids", "attention_mask"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_naming_convention(self):
for model_class in self.all_model_classes:
model_class_name = model_class.__name__
@ -218,3 +288,30 @@ class FlaxModelTesterMixin:
module_cls = getattr(bert_modeling_flax_module, module_class_name)
self.assertIsNotNone(module_cls)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
seq_length = self.model_tester.seq_length
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)