Remove boiler plate code (#11340)

* remove boiler plate code

* adapt roberta

* correct docs

* finish refactor
This commit is contained in:
Patrick von Platen 2021-04-21 18:34:38 +02:00 committed by GitHub
parent ac588594e2
commit 50595a3336
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 133 additions and 363 deletions

View File

@ -15,9 +15,9 @@
Utilities for working with the local dataset cache. Parts of this file is adapted from the AllenNLP library at
https://github.com/allenai/allennlp.
"""
import copy
import fnmatch
import functools
import importlib.util
import io
import json
@ -27,6 +27,7 @@ import shutil
import sys
import tarfile
import tempfile
import types
from collections import OrderedDict, UserDict
from contextlib import contextmanager
from dataclasses import fields
@ -1674,3 +1675,12 @@ class _BaseLazyModule(ModuleType):
def _get_module(self, module_name: str) -> ModuleType:
raise NotImplementedError
def copy_func(f):
""" Returns a copy of a function f."""
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = f.__kwdefaults__
return g

View File

@ -28,7 +28,16 @@ from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig
from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url
from .file_utils import (
FLAX_WEIGHTS_NAME,
WEIGHTS_NAME,
add_start_docstrings_to_model_forward,
cached_path,
copy_func,
hf_bucket_url,
is_offline_mode,
is_remote_url,
)
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import logging
@ -85,13 +94,13 @@ class FlaxPreTrainedModel(ABC):
self.dtype = dtype
# randomely initialized parameters
random_params = self.init(self.key, input_shape)
random_params = self.init_weights(self.key, input_shape)
# save required_params as set
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
self.params = random_params
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
raise NotImplementedError(f"init method has to be implemented for {self}")
@property
@ -394,3 +403,12 @@ class FlaxPreTrainedModel(ABC):
with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f:
model_bytes = to_bytes(self.params)
f.write(model_bytes)
def overwrite_call_docstring(model_class, docstring):
# copy __call__ function to be sure docstring is changed only for this function
model_class.__call__ = copy_func(model_class.__call__)
# delete existing docstring
model_class.__call__.__doc__ = None
# set correct docstring
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)

View File

@ -14,10 +14,10 @@
# limitations under the License.
"""Factory function to build auto-model classes."""
import functools
import types
from ...configuration_utils import PretrainedConfig
from ...file_utils import copy_func
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
@ -385,15 +385,6 @@ class _BaseAutoModelClass:
)
def copy_func(f):
""" Returns a copy of a function f."""
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = f.__kwdefaults__
return g
def insert_head_doc(docstring, head_doc=""):
if len(head_doc) > 0:
return docstring.replace(

View File

@ -26,7 +26,7 @@ from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, overwrite_call_docstring
from ...utils import logging
from .configuration_bert import BertConfig
@ -91,6 +91,7 @@ BERT_INPUTS_DOCSTRING = r"""
config.max_position_embeddings - 1]``.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
@ -477,49 +478,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
config_class = BertConfig
base_model_prefix = "bert"
module_class: nn.Module = None
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
if position_ids is None:
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
return input_ids, attention_mask, token_type_ids, position_ids
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.ones_like(input_ids)
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
attention_mask = jnp.ones_like(input_ids)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class FlaxBertModel(FlaxBertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
"""
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertModule(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
@ -531,9 +509,15 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# init input tensors if not passed
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {}
@ -576,49 +560,11 @@ class FlaxBertModule(nn.Module):
@add_start_docstrings(
"""
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
sentence prediction (classification)` head.
""",
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForPreTrainingModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertModel(FlaxBertPreTrainedModel):
module_class = FlaxBertModule
class FlaxBertForPreTrainingModule(nn.Module):
@ -641,44 +587,15 @@ class FlaxBertForPreTrainingModule(nn.Module):
return (prediction_scores, seq_relationship_score)
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForMaskedLMModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
@add_start_docstrings(
"""
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
sentence prediction (classification)` head.
""",
BERT_START_DOCSTRING,
)
class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
module_class = FlaxBertForPreTrainingModule
class FlaxBertForMaskedLMModule(nn.Module):
@ -701,46 +618,9 @@ class FlaxBertForMaskedLMModule(nn.Module):
return (logits,)
@add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING,
)
class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForNextSentencePredictionModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
module_class = FlaxBertForMaskedLMModule
class FlaxBertForNextSentencePredictionModule(nn.Module):
@ -764,48 +644,11 @@ class FlaxBertForNextSentencePredictionModule(nn.Module):
@add_start_docstrings(
"""
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
""",
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING,
)
class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForSequenceClassificationModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
module_class = FlaxBertForNextSentencePredictionModule
class FlaxBertForSequenceClassificationModule(nn.Module):
@ -836,47 +679,13 @@ class FlaxBertForSequenceClassificationModule(nn.Module):
@add_start_docstrings(
"""
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
""",
BERT_START_DOCSTRING,
)
class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForMultipleChoiceModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
module_class = FlaxBertForSequenceClassificationModule
class FlaxBertForMultipleChoiceModule(nn.Module):
@ -912,47 +721,19 @@ class FlaxBertForMultipleChoiceModule(nn.Module):
@add_start_docstrings(
"""
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
BERT_START_DOCSTRING,
)
class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForTokenClassificationModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
module_class = FlaxBertForMultipleChoiceModule
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
# adapt docstring slightly for FlaxBertForMultipleChoice
overwrite_call_docstring(
FlaxBertForMultipleChoice, BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
class FlaxBertForTokenClassificationModule(nn.Module):
@ -978,47 +759,13 @@ class FlaxBertForTokenClassificationModule(nn.Module):
@add_start_docstrings(
"""
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
BERT_START_DOCSTRING,
)
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForQuestionAnsweringModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
module_class = FlaxBertForTokenClassificationModule
class FlaxBertForQuestionAnsweringModule(nn.Module):
@ -1041,3 +788,14 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
end_logits = end_logits.squeeze(-1)
return (start_logits, end_logits)
@add_start_docstrings(
"""
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
BERT_START_DOCSTRING,
)
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
module_class = FlaxBertForQuestionAnsweringModule

View File

@ -441,40 +441,7 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
config_class = RobertaConfig
base_model_prefix = "roberta"
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
return input_ids, attention_mask, token_type_ids, position_ids
@add_start_docstrings(
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
ROBERTA_START_DOCSTRING,
)
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
Kaiser and Illia Polosukhin.
"""
module_class: nn.Module = None
def __init__(
self,
@ -484,23 +451,41 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
dtype: jnp.dtype = jnp.float32,
**kwargs
):
module = FlaxRobertaModule(config, dtype=dtype, **kwargs)
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.ones_like(input_ids)
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
attention_mask = jnp.ones_like(input_ids)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
token_type_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# init input tensors if not passed
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {}
@ -541,3 +526,11 @@ class FlaxRobertaModule(nn.Module):
pooled = self.pooler(hidden_states)
return hidden_states, pooled
@add_start_docstrings(
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
ROBERTA_START_DOCSTRING,
)
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
module_class = FlaxRobertaModule