mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove boiler plate code (#11340)
* remove boiler plate code * adapt roberta * correct docs * finish refactor
This commit is contained in:
parent
ac588594e2
commit
50595a3336
@ -15,9 +15,9 @@
|
||||
Utilities for working with the local dataset cache. Parts of this file is adapted from the AllenNLP library at
|
||||
https://github.com/allenai/allennlp.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import fnmatch
|
||||
import functools
|
||||
import importlib.util
|
||||
import io
|
||||
import json
|
||||
@ -27,6 +27,7 @@ import shutil
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
import types
|
||||
from collections import OrderedDict, UserDict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import fields
|
||||
@ -1674,3 +1675,12 @@ class _BaseLazyModule(ModuleType):
|
||||
|
||||
def _get_module(self, module_name: str) -> ModuleType:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def copy_func(f):
|
||||
""" Returns a copy of a function f."""
|
||||
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
|
||||
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
|
||||
g = functools.update_wrapper(g, f)
|
||||
g.__kwdefaults__ = f.__kwdefaults__
|
||||
return g
|
||||
|
@ -28,7 +28,16 @@ from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url
|
||||
from .file_utils import (
|
||||
FLAX_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
add_start_docstrings_to_model_forward,
|
||||
cached_path,
|
||||
copy_func,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
)
|
||||
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
|
||||
from .utils import logging
|
||||
|
||||
@ -85,13 +94,13 @@ class FlaxPreTrainedModel(ABC):
|
||||
self.dtype = dtype
|
||||
|
||||
# randomely initialized parameters
|
||||
random_params = self.init(self.key, input_shape)
|
||||
random_params = self.init_weights(self.key, input_shape)
|
||||
|
||||
# save required_params as set
|
||||
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
||||
self.params = random_params
|
||||
|
||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
|
||||
raise NotImplementedError(f"init method has to be implemented for {self}")
|
||||
|
||||
@property
|
||||
@ -394,3 +403,12 @@ class FlaxPreTrainedModel(ABC):
|
||||
with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f:
|
||||
model_bytes = to_bytes(self.params)
|
||||
f.write(model_bytes)
|
||||
|
||||
|
||||
def overwrite_call_docstring(model_class, docstring):
|
||||
# copy __call__ function to be sure docstring is changed only for this function
|
||||
model_class.__call__ = copy_func(model_class.__call__)
|
||||
# delete existing docstring
|
||||
model_class.__call__.__doc__ = None
|
||||
# set correct docstring
|
||||
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
|
||||
|
@ -14,10 +14,10 @@
|
||||
# limitations under the License.
|
||||
"""Factory function to build auto-model classes."""
|
||||
|
||||
import functools
|
||||
import types
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import copy_func
|
||||
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
|
||||
|
||||
|
||||
@ -385,15 +385,6 @@ class _BaseAutoModelClass:
|
||||
)
|
||||
|
||||
|
||||
def copy_func(f):
|
||||
""" Returns a copy of a function f."""
|
||||
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
|
||||
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
|
||||
g = functools.update_wrapper(g, f)
|
||||
g.__kwdefaults__ = f.__kwdefaults__
|
||||
return g
|
||||
|
||||
|
||||
def insert_head_doc(docstring, head_doc=""):
|
||||
if len(head_doc) > 0:
|
||||
return docstring.replace(
|
||||
|
@ -26,7 +26,7 @@ from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, overwrite_call_docstring
|
||||
from ...utils import logging
|
||||
from .configuration_bert import BertConfig
|
||||
|
||||
@ -91,6 +91,7 @@ BERT_INPUTS_DOCSTRING = r"""
|
||||
config.max_position_embeddings - 1]``.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@ -477,49 +478,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
|
||||
config_class = BertConfig
|
||||
base_model_prefix = "bert"
|
||||
module_class: nn.Module = None
|
||||
|
||||
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
return input_ids, attention_mask, token_type_ids, position_ids
|
||||
|
||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
||||
)
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertModel(FlaxBertPreTrainedModel):
|
||||
"""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertModule(config=config, dtype=dtype, **kwargs)
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
@ -531,9 +509,15 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
# init input tensors if not passed
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
@ -576,49 +560,11 @@ class FlaxBertModule(nn.Module):
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
|
||||
sentence prediction (classification)` head.
|
||||
""",
|
||||
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForPreTrainingModule(config, **kwargs)
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
class FlaxBertModel(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertModule
|
||||
|
||||
|
||||
class FlaxBertForPreTrainingModule(nn.Module):
|
||||
@ -641,44 +587,15 @@ class FlaxBertForPreTrainingModule(nn.Module):
|
||||
return (prediction_scores, seq_relationship_score)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForMaskedLMModule(config, **kwargs)
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
|
||||
sentence prediction (classification)` head.
|
||||
""",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForPreTrainingModule
|
||||
|
||||
|
||||
class FlaxBertForMaskedLMModule(nn.Module):
|
||||
@ -701,46 +618,9 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
||||
return (logits,)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForNextSentencePredictionModule(config, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForMaskedLMModule
|
||||
|
||||
|
||||
class FlaxBertForNextSentencePredictionModule(nn.Module):
|
||||
@ -764,48 +644,11 @@ class FlaxBertForNextSentencePredictionModule(nn.Module):
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
||||
output) e.g. for GLUE tasks.
|
||||
""",
|
||||
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForSequenceClassificationModule(config, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForNextSentencePredictionModule
|
||||
|
||||
|
||||
class FlaxBertForSequenceClassificationModule(nn.Module):
|
||||
@ -836,47 +679,13 @@ class FlaxBertForSequenceClassificationModule(nn.Module):
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
||||
softmax) e.g. for RocStories/SWAG tasks.
|
||||
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
||||
output) e.g. for GLUE tasks.
|
||||
""",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForMultipleChoiceModule(config, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForSequenceClassificationModule
|
||||
|
||||
|
||||
class FlaxBertForMultipleChoiceModule(nn.Module):
|
||||
@ -912,47 +721,19 @@ class FlaxBertForMultipleChoiceModule(nn.Module):
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
||||
Named-Entity-Recognition (NER) tasks.
|
||||
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
||||
softmax) e.g. for RocStories/SWAG tasks.
|
||||
""",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForTokenClassificationModule(config, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForMultipleChoiceModule
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
# adapt docstring slightly for FlaxBertForMultipleChoice
|
||||
overwrite_call_docstring(
|
||||
FlaxBertForMultipleChoice, BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForTokenClassificationModule(nn.Module):
|
||||
@ -978,47 +759,13 @@ class FlaxBertForTokenClassificationModule(nn.Module):
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
||||
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
||||
Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForQuestionAnsweringModule(config, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForTokenClassificationModule
|
||||
|
||||
|
||||
class FlaxBertForQuestionAnsweringModule(nn.Module):
|
||||
@ -1041,3 +788,14 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
return (start_logits, end_logits)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
||||
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
""",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForQuestionAnsweringModule
|
||||
|
@ -441,40 +441,7 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
||||
)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
|
||||
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
return input_ids, attention_mask, token_type_ids, position_ids
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
||||
"""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
||||
Kaiser and Illia Polosukhin.
|
||||
"""
|
||||
module_class: nn.Module = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -484,23 +451,41 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs
|
||||
):
|
||||
module = FlaxRobertaModule(config, dtype=dtype, **kwargs)
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
token_type_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
# init input tensors if not passed
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
@ -541,3 +526,11 @@ class FlaxRobertaModule(nn.Module):
|
||||
|
||||
pooled = self.pooler(hidden_states)
|
||||
return hidden_states, pooled
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
||||
module_class = FlaxRobertaModule
|
||||
|
Loading…
Reference in New Issue
Block a user