mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
[Flax] Add docstrings & model outputs (#11498)
* add attentions & hidden states * add model outputs + docs * finish docs * finish tests * finish impl * del @ * finish * finish * correct test * apply sylvains suggestions * Update src/transformers/models/bert/modeling_flax_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * simplify more Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
3f6add8bab
commit
f748bd4242
@ -794,6 +794,17 @@ PT_CAUSAL_LM_SAMPLE = r"""
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
|
||||
PT_SAMPLE_DOCSTRINGS = {
|
||||
"SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
|
||||
"QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
|
||||
"TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
|
||||
"MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
|
||||
"MaskedLM": PT_MASKED_LM_SAMPLE,
|
||||
"LMHead": PT_CAUSAL_LM_SAMPLE,
|
||||
"BaseModel": PT_BASE_MODEL_SAMPLE,
|
||||
}
|
||||
|
||||
|
||||
TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
|
||||
Example::
|
||||
|
||||
@ -915,30 +926,148 @@ TF_CAUSAL_LM_SAMPLE = r"""
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
|
||||
TF_SAMPLE_DOCSTRINGS = {
|
||||
"SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE,
|
||||
"QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE,
|
||||
"TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE,
|
||||
"MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE,
|
||||
"MaskedLM": TF_MASKED_LM_SAMPLE,
|
||||
"LMHead": TF_CAUSAL_LM_SAMPLE,
|
||||
"BaseModel": TF_BASE_MODEL_SAMPLE,
|
||||
}
|
||||
|
||||
|
||||
FLAX_TOKEN_CLASSIFICATION_SAMPLE = r"""
|
||||
Example::
|
||||
|
||||
>>> from transformers import {tokenizer_class}, {model_class}
|
||||
|
||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
|
||||
FLAX_QUESTION_ANSWERING_SAMPLE = r"""
|
||||
Example::
|
||||
|
||||
>>> from transformers import {tokenizer_class}, {model_class}
|
||||
|
||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
|
||||
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||
>>> inputs = tokenizer(question, text, return_tensors='jax')
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> start_scores = outputs.start_logits
|
||||
>>> end_scores = outputs.end_logits
|
||||
"""
|
||||
|
||||
FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
|
||||
Example::
|
||||
|
||||
>>> from transformers import {tokenizer_class}, {model_class}
|
||||
|
||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
|
||||
|
||||
>>> outputs = model(**inputs, labels=labels)
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
|
||||
FLAX_MASKED_LM_SAMPLE = r"""
|
||||
Example::
|
||||
|
||||
>>> from transformers import {tokenizer_class}, {model_class}
|
||||
|
||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
|
||||
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors='jax')
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
|
||||
FLAX_BASE_MODEL_SAMPLE = r"""
|
||||
Example::
|
||||
|
||||
>>> from transformers import {tokenizer_class}, {model_class}
|
||||
|
||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
|
||||
FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
|
||||
Example::
|
||||
|
||||
>>> from transformers import {tokenizer_class}, {model_class}
|
||||
|
||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
|
||||
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
>>> choice0 = "It is eaten with a fork and a knife."
|
||||
>>> choice1 = "It is eaten while held in the hand."
|
||||
|
||||
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='jax', padding=True)
|
||||
>>> outputs = model(**{{k: v[None, :] for k,v in encoding.items()}})
|
||||
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
|
||||
FLAX_SAMPLE_DOCSTRINGS = {
|
||||
"SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
|
||||
"QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
|
||||
"TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE,
|
||||
"MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
|
||||
"MaskedLM": FLAX_MASKED_LM_SAMPLE,
|
||||
"BaseModel": FLAX_BASE_MODEL_SAMPLE,
|
||||
}
|
||||
|
||||
|
||||
def add_code_sample_docstrings(
|
||||
*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None
|
||||
*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None, model_cls=None
|
||||
):
|
||||
def docstring_decorator(fn):
|
||||
model_class = fn.__qualname__.split(".")[0]
|
||||
is_tf_class = model_class[:2] == "TF"
|
||||
# model_class defaults to function's class if not specified otherwise
|
||||
model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
|
||||
|
||||
if model_class[:2] == "TF":
|
||||
sample_docstrings = TF_SAMPLE_DOCSTRINGS
|
||||
elif model_class[:4] == "Flax":
|
||||
sample_docstrings = FLAX_SAMPLE_DOCSTRINGS
|
||||
else:
|
||||
sample_docstrings = PT_SAMPLE_DOCSTRINGS
|
||||
|
||||
doc_kwargs = dict(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
|
||||
|
||||
if "SequenceClassification" in model_class:
|
||||
code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE
|
||||
code_sample = sample_docstrings["SequenceClassification"]
|
||||
elif "QuestionAnswering" in model_class:
|
||||
code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE
|
||||
code_sample = sample_docstrings["QuestionAnswering"]
|
||||
elif "TokenClassification" in model_class:
|
||||
code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE
|
||||
code_sample = sample_docstrings["TokenClassification"]
|
||||
elif "MultipleChoice" in model_class:
|
||||
code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE
|
||||
code_sample = sample_docstrings["MultipleChoice"]
|
||||
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
|
||||
doc_kwargs["mask"] = "[MASK]" if mask is None else mask
|
||||
code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
|
||||
code_sample = sample_docstrings["MaskedLM"]
|
||||
elif "LMHead" in model_class or "CausalLM" in model_class:
|
||||
code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
|
||||
code_sample = sample_docstrings["LMHead"]
|
||||
elif "Model" in model_class or "Encoder" in model_class:
|
||||
code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
|
||||
code_sample = sample_docstrings["BaseModel"]
|
||||
else:
|
||||
raise ValueError(f"Docstring can't be built for model {model_class}")
|
||||
|
||||
@ -1462,7 +1591,10 @@ def tf_required(func):
|
||||
|
||||
|
||||
def is_tensor(x):
|
||||
"""Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`."""
|
||||
"""
|
||||
Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or
|
||||
:obj:`np.ndarray`.
|
||||
"""
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
@ -1473,6 +1605,14 @@ def is_tensor(x):
|
||||
|
||||
if isinstance(x, tf.Tensor):
|
||||
return True
|
||||
|
||||
if is_flax_available():
|
||||
import jaxlib.xla_extension as jax_xla
|
||||
from jax.interpreters.partial_eval import DynamicJaxprTracer
|
||||
|
||||
if isinstance(x, (jax_xla.DeviceArray, DynamicJaxprTracer)):
|
||||
return True
|
||||
|
||||
return isinstance(x, np.ndarray)
|
||||
|
||||
|
||||
|
239
src/transformers/modeling_flax_outputs.py
Normal file
239
src/transformers/modeling_flax_outputs.py
Normal file
@ -0,0 +1,239 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import jaxlib.xla_extension as jax_xla
|
||||
|
||||
from .file_utils import ModelOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxBaseModelOutput(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs, with potential hidden states and attentions.
|
||||
|
||||
Args:
|
||||
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
last_hidden_state: jax_xla.DeviceArray = None
|
||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxBaseModelOutputWithPooling(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs that also contains a pooling of the last hidden states.
|
||||
|
||||
Args:
|
||||
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`):
|
||||
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
|
||||
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
|
||||
prediction (classification) objective during pretraining.
|
||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
last_hidden_state: jax_xla.DeviceArray = None
|
||||
pooler_output: jax_xla.DeviceArray = None
|
||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxMaskedLMOutput(ModelOutput):
|
||||
"""
|
||||
Base class for masked language models outputs.
|
||||
|
||||
Args:
|
||||
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
logits: jax_xla.DeviceArray = None
|
||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxNextSentencePredictorOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||
|
||||
Args:
|
||||
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):
|
||||
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
||||
before SoftMax).
|
||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
logits: jax_xla.DeviceArray = None
|
||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxSequenceClassifierOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of sentence classification models.
|
||||
|
||||
Args:
|
||||
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
logits: jax_xla.DeviceArray = None
|
||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxMultipleChoiceModelOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of multiple choice models.
|
||||
|
||||
Args:
|
||||
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, num_choices)`):
|
||||
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
|
||||
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
logits: jax_xla.DeviceArray = None
|
||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxTokenClassifierOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of token classification models.
|
||||
|
||||
Args:
|
||||
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
logits: jax_xla.DeviceArray = None
|
||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxQuestionAnsweringModelOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of question answering models.
|
||||
|
||||
Args:
|
||||
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Span-start scores (before SoftMax).
|
||||
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Span-end scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
start_logits: jax_xla.DeviceArray = None
|
||||
end_logits: jax_xla.DeviceArray = None
|
||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
@ -32,12 +32,14 @@ from .file_utils import (
|
||||
FLAX_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
PushToHubMixin,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
cached_path,
|
||||
copy_func,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
|
||||
from .utils import logging
|
||||
@ -432,3 +434,22 @@ def overwrite_call_docstring(model_class, docstring):
|
||||
model_class.__call__.__doc__ = None
|
||||
# set correct docstring
|
||||
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
|
||||
|
||||
|
||||
def append_call_sample_docstring(model_class, tokenizer_class, checkpoint, output_type, config_class, mask=None):
|
||||
model_class.__call__ = copy_func(model_class.__call__)
|
||||
model_class.__call__ = add_code_sample_docstrings(
|
||||
tokenizer_class=tokenizer_class,
|
||||
checkpoint=checkpoint,
|
||||
output_type=output_type,
|
||||
config_class=config_class,
|
||||
model_cls=model_class.__name__,
|
||||
)(model_class.__call__)
|
||||
|
||||
|
||||
def append_replace_return_docstrings(model_class, output_type, config_class):
|
||||
model_class.__call__ = copy_func(model_class.__call__)
|
||||
model_class.__call__ = replace_return_docstrings(
|
||||
output_type=output_type,
|
||||
config_class=config_class,
|
||||
)(model_class.__call__)
|
||||
|
@ -13,30 +13,79 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxlib.xla_extension as jax_xla
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.linen import dot_product_attention
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, overwrite_call_docstring
|
||||
from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutput,
|
||||
FlaxBaseModelOutputWithPooling,
|
||||
FlaxMaskedLMOutput,
|
||||
FlaxMultipleChoiceModelOutput,
|
||||
FlaxNextSentencePredictorOutput,
|
||||
FlaxQuestionAnsweringModelOutput,
|
||||
FlaxSequenceClassifierOutput,
|
||||
FlaxTokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_flax_utils import (
|
||||
ACT2FN,
|
||||
FlaxPreTrainedModel,
|
||||
append_call_sample_docstring,
|
||||
append_replace_return_docstrings,
|
||||
overwrite_call_docstring,
|
||||
)
|
||||
from ...utils import logging
|
||||
from .configuration_bert import BertConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "bert-base-uncased"
|
||||
_CONFIG_FOR_DOC = "BertConfig"
|
||||
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxBertForPreTrainingOutput(ModelOutput):
|
||||
"""
|
||||
Output type of :class:`~transformers.BertForPreTraining`.
|
||||
|
||||
Args:
|
||||
prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):
|
||||
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
||||
before SoftMax).
|
||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
prediction_logits: jax_xla.DeviceArray = None
|
||||
seq_relationship_logits: jax_xla.DeviceArray = None
|
||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
BERT_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
||||
@ -166,7 +215,7 @@ class FlaxBertSelfAttention(nn.Module):
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
|
||||
query_states = self.query(hidden_states).reshape(
|
||||
@ -208,7 +257,12 @@ class FlaxBertSelfAttention(nn.Module):
|
||||
precision=None,
|
||||
)
|
||||
|
||||
return attn_output.reshape(attn_output.shape[:2] + (-1,))
|
||||
outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),)
|
||||
|
||||
# TODO: at the moment it's not possible to retrieve attn_weights from
|
||||
# dot_product_attention, but should be in the future -> add functionality then
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxBertSelfOutput(nn.Module):
|
||||
@ -239,13 +293,22 @@ class FlaxBertAttention(nn.Module):
|
||||
self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype)
|
||||
self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
|
||||
attn_outputs = self.self(
|
||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
||||
)
|
||||
attn_output = attn_outputs[0]
|
||||
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += attn_outputs[1]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxBertIntermediate(nn.Module):
|
||||
@ -295,11 +358,20 @@ class FlaxBertLayer(nn.Module):
|
||||
self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
|
||||
self.output = FlaxBertOutput(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False):
|
||||
attention_outputs = self.attention(
|
||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
|
||||
hidden_states = self.intermediate(attention_output)
|
||||
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attention_outputs[1],)
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxBertLayerCollection(nn.Module):
|
||||
@ -311,10 +383,40 @@ class FlaxBertLayerCollection(nn.Module):
|
||||
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||
return hidden_states
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertEncoder(nn.Module):
|
||||
@ -324,8 +426,23 @@ class FlaxBertEncoder(nn.Module):
|
||||
def setup(self):
|
||||
self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
return self.layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertPooler(nn.Module):
|
||||
@ -456,7 +573,21 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
if output_attentions:
|
||||
raise NotImplementedError(
|
||||
"Currently attention scores cannot be returned. Please set `output_attentions` to False for now."
|
||||
)
|
||||
|
||||
# init input tensors if not passed
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
@ -479,6 +610,9 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
@ -493,17 +627,43 @@ class FlaxBertModule(nn.Module):
|
||||
self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
|
||||
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
hidden_states = self.embeddings(
|
||||
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
||||
)
|
||||
hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
|
||||
outputs = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
|
||||
|
||||
if not self.add_pooling_layer:
|
||||
return hidden_states
|
||||
if not return_dict:
|
||||
# if pooled is None, don't return it
|
||||
if pooled is None:
|
||||
return (hidden_states,) + outputs[1:]
|
||||
return (hidden_states, pooled) + outputs[1:]
|
||||
|
||||
pooled = self.pooler(hidden_states)
|
||||
return hidden_states, pooled
|
||||
return FlaxBaseModelOutputWithPooling(
|
||||
last_hidden_state=hidden_states,
|
||||
pooler_output=pooled,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -514,6 +674,11 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertModule
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxBertModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForPreTrainingModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
@ -523,11 +688,27 @@ class FlaxBertForPreTrainingModule(nn.Module):
|
||||
self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
|
||||
# Model
|
||||
hidden_states, pooled_output = self.bert(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
@ -535,11 +716,22 @@ class FlaxBertForPreTrainingModule(nn.Module):
|
||||
else:
|
||||
shared_embedding = None
|
||||
|
||||
hidden_states = outputs[0]
|
||||
pooled_output = outputs[1]
|
||||
|
||||
prediction_scores, seq_relationship_score = self.cls(
|
||||
hidden_states, pooled_output, shared_embedding=shared_embedding
|
||||
)
|
||||
|
||||
return (prediction_scores, seq_relationship_score)
|
||||
if not return_dict:
|
||||
return (prediction_scores, seq_relationship_score) + outputs[2:]
|
||||
|
||||
return FlaxBertForPreTrainingOutput(
|
||||
prediction_logits=prediction_scores,
|
||||
seq_relationship_logits=seq_relationship_score,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -553,6 +745,32 @@ class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForPreTrainingModule
|
||||
|
||||
|
||||
FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import BertTokenizer, FlaxBertForPreTraining
|
||||
|
||||
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
>>> model = FlaxBertForPreTraining.from_pretrained('bert-base-uncased')
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> prediction_logits = outputs.prediction_logits
|
||||
>>> seq_relationship_logits = outputs.seq_relationship_logits
|
||||
"""
|
||||
|
||||
overwrite_call_docstring(
|
||||
FlaxBertForPreTraining,
|
||||
BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_PRETRAINING_DOCSTRING,
|
||||
)
|
||||
append_replace_return_docstrings(
|
||||
FlaxBertForPreTraining, output_type=FlaxBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForMaskedLMModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
@ -562,11 +780,29 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
||||
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# Model
|
||||
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.tie_word_embeddings:
|
||||
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
||||
else:
|
||||
@ -575,7 +811,14 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
||||
# Compute the prediction scores
|
||||
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
|
||||
|
||||
return (logits,)
|
||||
if not return_dict:
|
||||
return (logits,) + outputs[1:]
|
||||
|
||||
return FlaxMaskedLMOutput(
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||
@ -583,6 +826,11 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForMaskedLMModule
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxBertForMaskedLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForNextSentencePredictionModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
@ -592,15 +840,41 @@ class FlaxBertForNextSentencePredictionModule(nn.Module):
|
||||
self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
# Model
|
||||
_, pooled_output = self.bert(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
seq_relationship_scores = self.cls(pooled_output)
|
||||
return (seq_relationship_scores,)
|
||||
|
||||
if not return_dict:
|
||||
return (seq_relationship_scores,) + outputs[2:]
|
||||
|
||||
return FlaxNextSentencePredictorOutput(
|
||||
logits=seq_relationship_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -611,6 +885,35 @@ class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForNextSentencePredictionModule
|
||||
|
||||
|
||||
FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING = """
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import BertTokenizer, FlaxBertForNextSentencePrediction
|
||||
|
||||
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
>>> model = FlaxBertForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
||||
|
||||
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='jax')
|
||||
|
||||
>>> outputs = model(**encoding)
|
||||
>>> logits = outputs.logits
|
||||
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
||||
"""
|
||||
|
||||
|
||||
overwrite_call_docstring(
|
||||
FlaxBertForNextSentencePrediction,
|
||||
BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING,
|
||||
)
|
||||
append_replace_return_docstrings(
|
||||
FlaxBertForNextSentencePrediction, output_type=FlaxNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForSequenceClassificationModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
@ -624,17 +927,40 @@ class FlaxBertForSequenceClassificationModule(nn.Module):
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# Model
|
||||
_, pooled_output = self.bert(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
return (logits,)
|
||||
if not return_dict:
|
||||
return (logits,) + outputs[2:]
|
||||
|
||||
return FlaxSequenceClassifierOutput(
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -648,6 +974,15 @@ class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForSequenceClassificationModule
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxBertForSequenceClassification,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxSequenceClassifierOutput,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForMultipleChoiceModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
@ -658,7 +993,15 @@ class FlaxBertForMultipleChoiceModule(nn.Module):
|
||||
self.classifier = nn.Dense(1, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
num_choices = input_ids.shape[1]
|
||||
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
|
||||
@ -667,16 +1010,31 @@ class FlaxBertForMultipleChoiceModule(nn.Module):
|
||||
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
|
||||
|
||||
# Model
|
||||
_, pooled_output = self.bert(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
reshaped_logits = logits.reshape(-1, num_choices)
|
||||
|
||||
return (reshaped_logits,)
|
||||
if not return_dict:
|
||||
return (reshaped_logits,) + outputs[2:]
|
||||
|
||||
return FlaxMultipleChoiceModelOutput(
|
||||
logits=reshaped_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -690,10 +1048,12 @@ class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForMultipleChoiceModule
|
||||
|
||||
|
||||
# adapt docstring slightly for FlaxBertForMultipleChoice
|
||||
overwrite_call_docstring(
|
||||
FlaxBertForMultipleChoice, BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
||||
)
|
||||
append_call_sample_docstring(
|
||||
FlaxBertForMultipleChoice, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMultipleChoiceModelOutput, _CONFIG_FOR_DOC
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForTokenClassificationModule(nn.Module):
|
||||
@ -706,15 +1066,40 @@ class FlaxBertForTokenClassificationModule(nn.Module):
|
||||
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# Model
|
||||
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
logits = self.classifier(hidden_states)
|
||||
|
||||
return (logits,)
|
||||
if not return_dict:
|
||||
return (logits,) + outputs[1:]
|
||||
|
||||
return FlaxTokenClassifierOutput(
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -728,6 +1113,11 @@ class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForTokenClassificationModule
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxBertForTokenClassification, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxTokenClassifierOutput, _CONFIG_FOR_DOC
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForQuestionAnsweringModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
@ -737,17 +1127,44 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
|
||||
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# Model
|
||||
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(hidden_states)
|
||||
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
return (start_logits, end_logits)
|
||||
if not return_dict:
|
||||
return (start_logits, end_logits) + outputs[1:]
|
||||
|
||||
return FlaxQuestionAnsweringModelOutput(
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -759,3 +1176,12 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
|
||||
)
|
||||
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForQuestionAnsweringModule
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxBertForQuestionAnswering,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxQuestionAnsweringModelOutput,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
@ -23,13 +23,15 @@ from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
||||
from ...utils import logging
|
||||
from .configuration_roberta import RobertaConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "roberta-base"
|
||||
_CONFIG_FOR_DOC = "RobertaConfig"
|
||||
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
||||
|
||||
@ -181,7 +183,7 @@ class FlaxRobertaSelfAttention(nn.Module):
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
|
||||
query_states = self.query(hidden_states).reshape(
|
||||
@ -223,7 +225,12 @@ class FlaxRobertaSelfAttention(nn.Module):
|
||||
precision=None,
|
||||
)
|
||||
|
||||
return attn_output.reshape(attn_output.shape[:2] + (-1,))
|
||||
outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),)
|
||||
|
||||
# TODO: at the moment it's not possible to retrieve attn_weights from
|
||||
# dot_product_attention, but should be in the future -> add functionality then
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
|
||||
@ -256,13 +263,22 @@ class FlaxRobertaAttention(nn.Module):
|
||||
self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype)
|
||||
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
|
||||
attn_outputs = self.self(
|
||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
||||
)
|
||||
attn_output = attn_outputs[0]
|
||||
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += attn_outputs[1]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
||||
@ -315,11 +331,20 @@ class FlaxRobertaLayer(nn.Module):
|
||||
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
|
||||
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False):
|
||||
attention_outputs = self.attention(
|
||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
|
||||
hidden_states = self.intermediate(attention_output)
|
||||
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attention_outputs[1],)
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
|
||||
@ -332,10 +357,40 @@ class FlaxRobertaLayerCollection(nn.Module):
|
||||
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||
return hidden_states
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
|
||||
@ -346,8 +401,23 @@ class FlaxRobertaEncoder(nn.Module):
|
||||
def setup(self):
|
||||
self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
return self.layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
|
||||
@ -412,7 +482,21 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
if output_attentions:
|
||||
raise NotImplementedError(
|
||||
"Currently attention scores cannot be returned." "Please set `output_attentions` to False for now."
|
||||
)
|
||||
|
||||
# init input tensors if not passed
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
@ -435,6 +519,9 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
@ -450,17 +537,43 @@ class FlaxRobertaModule(nn.Module):
|
||||
self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype)
|
||||
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
hidden_states = self.embeddings(
|
||||
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
||||
)
|
||||
hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
|
||||
outputs = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
|
||||
|
||||
if not self.add_pooling_layer:
|
||||
return hidden_states
|
||||
if not return_dict:
|
||||
# if pooled is None, don't return it
|
||||
if pooled is None:
|
||||
return (hidden_states,) + outputs[1:]
|
||||
return (hidden_states, pooled) + outputs[1:]
|
||||
|
||||
pooled = self.pooler(hidden_states)
|
||||
return hidden_states, pooled
|
||||
return FlaxBaseModelOutputWithPooling(
|
||||
last_hidden_state=hidden_states,
|
||||
pooler_output=pooled,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -469,3 +582,8 @@ class FlaxRobertaModule(nn.Module):
|
||||
)
|
||||
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
||||
module_class = FlaxRobertaModule
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxRobertaModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC
|
||||
)
|
||||
|
@ -998,7 +998,6 @@ class ModelTesterMixin:
|
||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def set_nan_tensor_to_zero(t):
|
||||
|
@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import random
|
||||
import tempfile
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -28,6 +30,7 @@ if is_flax_available():
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxlib.xla_extension as jax_xla
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
@ -77,6 +80,7 @@ class FlaxModelTesterMixin:
|
||||
inputs_dict = {
|
||||
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
|
||||
for k, v in inputs_dict.items()
|
||||
if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
|
||||
}
|
||||
|
||||
return inputs_dict
|
||||
@ -85,6 +89,41 @@ class FlaxModelTesterMixin:
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def set_nan_tensor_to_zero(t):
|
||||
t[t != t] = 0
|
||||
return t
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assert_almost_equals(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
|
||||
)
|
||||
|
||||
recursive_check(tuple_output, dict_output)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -108,7 +147,7 @@ class FlaxModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||
@ -117,7 +156,7 @@ class FlaxModelTesterMixin:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
@ -149,7 +188,7 @@ class FlaxModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||
@ -171,17 +210,20 @@ class FlaxModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ != "FlaxBertModel":
|
||||
continue
|
||||
|
||||
with self.subTest(model_class.__name__):
|
||||
model = model_class(config)
|
||||
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
outputs = model(**prepared_inputs_dict)
|
||||
outputs = model(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_loaded = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_loaded = model_loaded(**prepared_inputs_dict)
|
||||
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 1e-3)
|
||||
|
||||
@ -195,19 +237,47 @@ class FlaxModelTesterMixin:
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
|
||||
return model(input_ids, attention_mask, token_type_ids)
|
||||
return model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
).to_tuple()
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict)
|
||||
|
||||
with self.subTest("JIT Disabled"):
|
||||
with jax.disable_jit():
|
||||
outputs = model_jitted(**prepared_inputs_dict)
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict)
|
||||
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted_return_dict(input_ids, attention_mask=None, token_type_ids=None):
|
||||
return model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
|
||||
# jitted function cannot return OrderedDict
|
||||
with self.assertRaises(TypeError):
|
||||
model_jitted_return_dict(**prepared_inputs_dict)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.__call__)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["input_ids", "attention_mask"]
|
||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||
|
||||
def test_naming_convention(self):
|
||||
for model_class in self.all_model_classes:
|
||||
model_class_name = model_class.__name__
|
||||
@ -218,3 +288,30 @@ class FlaxModelTesterMixin:
|
||||
module_cls = getattr(bert_modeling_flax_module, module_class_name)
|
||||
|
||||
self.assertIsNotNone(module_cls)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||
seq_length = self.model_tester.seq_length
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
Loading…
Reference in New Issue
Block a user