mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[FlaxRoberta] Add FlaxRobertaModels & adapt run_mlm_flax.py (#11470)
* add flax roberta * make style * correct initialiazation * modify model to save weights * fix copied from * fix copied from * correct some more code * add more roberta models * Apply suggestions from code review * merge from master * finish * finish docs Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
parent
2ce0fb84cc
commit
084a187da3
@ -166,3 +166,38 @@ FlaxRobertaModel
|
||||
|
||||
.. autoclass:: transformers.FlaxRobertaModel
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxRobertaForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxRobertaForMaskedLM
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxRobertaForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxRobertaForSequenceClassification
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxRobertaForMultipleChoice
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxRobertaForMultipleChoice
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxRobertaForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxRobertaForTokenClassification
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxRobertaForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxRobertaForQuestionAnswering
|
||||
:members: __call__
|
||||
|
@ -45,7 +45,7 @@ from transformers import (
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
FlaxBertForMaskedLM,
|
||||
FlaxAutoModelForMaskedLM,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerBase,
|
||||
TensorType,
|
||||
@ -105,6 +105,12 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
dtype: Optional[str] = field(
|
||||
default="float32",
|
||||
metadata={
|
||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -162,6 +168,10 @@ class DataTrainingArguments:
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||
},
|
||||
)
|
||||
line_by_line: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
@ -537,27 +547,76 @@ if __name__ == "__main__":
|
||||
column_names = datasets["validation"].column_names
|
||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
def tokenize_function(examples):
|
||||
# Remove empty lines
|
||||
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
|
||||
return tokenizer(
|
||||
examples,
|
||||
return_special_tokens_mask=True,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
max_length=data_args.max_seq_length,
|
||||
if data_args.line_by_line:
|
||||
# When using line_by_line, we just tokenize each nonempty line.
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
def tokenize_function(examples):
|
||||
# Remove empty lines
|
||||
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
|
||||
return tokenizer(
|
||||
examples,
|
||||
return_special_tokens_mask=True,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
max_length=max_seq_length,
|
||||
)
|
||||
|
||||
tokenized_datasets = datasets.map(
|
||||
tokenize_function,
|
||||
input_columns=[text_column_name],
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
tokenized_datasets = datasets.map(
|
||||
tokenize_function,
|
||||
input_columns=[text_column_name],
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
else:
|
||||
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
||||
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
|
||||
# efficient when it receives the `special_tokens_mask`.
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
|
||||
|
||||
tokenized_datasets = datasets.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
||||
# max_seq_length.
|
||||
def group_texts(examples):
|
||||
# Concatenate all texts.
|
||||
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
||||
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
||||
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
||||
# customize this part to your needs.
|
||||
total_length = (total_length // max_seq_length) * max_seq_length
|
||||
# Split by chunks of max_len.
|
||||
result = {
|
||||
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
return result
|
||||
|
||||
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
|
||||
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
|
||||
# might be slower to preprocess.
|
||||
#
|
||||
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||
|
||||
tokenized_datasets = tokenized_datasets.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Enable tensorboard only on the master node
|
||||
if has_tensorboard and jax.host_id() == 0:
|
||||
@ -571,13 +630,7 @@ if __name__ == "__main__":
|
||||
rng = jax.random.PRNGKey(training_args.seed)
|
||||
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
|
||||
model = FlaxBertForMaskedLM.from_pretrained(
|
||||
"bert-base-cased",
|
||||
dtype=jnp.float32,
|
||||
input_shape=(training_args.train_batch_size, config.max_position_embeddings),
|
||||
seed=training_args.seed,
|
||||
dropout_rate=0.1,
|
||||
)
|
||||
model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
||||
|
||||
# Setup optimizer
|
||||
optimizer = Adam(
|
||||
@ -602,8 +655,8 @@ if __name__ == "__main__":
|
||||
|
||||
# Store some constant
|
||||
nb_epochs = int(training_args.num_train_epochs)
|
||||
batch_size = int(training_args.train_batch_size)
|
||||
eval_batch_size = int(training_args.eval_batch_size)
|
||||
batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||
|
||||
epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
@ -657,3 +710,8 @@ if __name__ == "__main__":
|
||||
if has_tensorboard and jax.host_id() == 0:
|
||||
for name, value in eval_summary.items():
|
||||
summary_writer.scalar(name, value, epoch)
|
||||
|
||||
# save last checkpoint
|
||||
if jax.host_id() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], optimizer.target))
|
||||
model.save_pretrained(training_args.output_dir, params=params)
|
||||
|
@ -1403,7 +1403,17 @@ if is_flax_available():
|
||||
"FlaxBertPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.roberta"].append("FlaxRobertaModel")
|
||||
_import_structure["models.roberta"].extend(
|
||||
[
|
||||
"FlaxRobertaForMaskedLM",
|
||||
"FlaxRobertaForMultipleChoice",
|
||||
"FlaxRobertaForQuestionAnswering",
|
||||
"FlaxRobertaForSequenceClassification",
|
||||
"FlaxRobertaForTokenClassification",
|
||||
"FlaxRobertaModel",
|
||||
"FlaxRobertaPreTrainedModel",
|
||||
]
|
||||
)
|
||||
else:
|
||||
from .utils import dummy_flax_objects
|
||||
|
||||
@ -2575,7 +2585,15 @@ if TYPE_CHECKING:
|
||||
FlaxBertModel,
|
||||
FlaxBertPreTrainedModel,
|
||||
)
|
||||
from .models.roberta import FlaxRobertaModel
|
||||
from .models.roberta import (
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
FlaxRobertaForQuestionAnswering,
|
||||
FlaxRobertaForSequenceClassification,
|
||||
FlaxRobertaForTokenClassification,
|
||||
FlaxRobertaModel,
|
||||
FlaxRobertaPreTrainedModel,
|
||||
)
|
||||
else:
|
||||
# Import the same objects as dummies to get them in the namespace.
|
||||
# They will raise an import error if the user tries to instantiate / use them.
|
||||
|
@ -1608,9 +1608,9 @@ def is_tensor(x):
|
||||
|
||||
if is_flax_available():
|
||||
import jaxlib.xla_extension as jax_xla
|
||||
from jax.interpreters.partial_eval import DynamicJaxprTracer
|
||||
from jax.core import Tracer
|
||||
|
||||
if isinstance(x, (jax_xla.DeviceArray, DynamicJaxprTracer)):
|
||||
if isinstance(x, (jax_xla.DeviceArray, Tracer)):
|
||||
return True
|
||||
|
||||
return isinstance(x, np.ndarray)
|
||||
|
@ -388,7 +388,7 @@ class FlaxPreTrainedModel(PushToHubMixin):
|
||||
|
||||
return model
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub=False, **kwargs):
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method
|
||||
@ -416,7 +416,8 @@ class FlaxPreTrainedModel(PushToHubMixin):
|
||||
# save model
|
||||
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
|
||||
with open(output_model_file, "wb") as f:
|
||||
model_bytes = to_bytes(self.params)
|
||||
params = params if params is not None else self.params
|
||||
model_bytes = to_bytes(params)
|
||||
f.write(model_bytes)
|
||||
|
||||
logger.info(f"Model weights saved in {output_model_file}")
|
||||
|
@ -28,7 +28,14 @@ from ..bert.modeling_flax_bert import (
|
||||
FlaxBertForTokenClassification,
|
||||
FlaxBertModel,
|
||||
)
|
||||
from ..roberta.modeling_flax_roberta import FlaxRobertaModel
|
||||
from ..roberta.modeling_flax_roberta import (
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
FlaxRobertaForQuestionAnswering,
|
||||
FlaxRobertaForSequenceClassification,
|
||||
FlaxRobertaForTokenClassification,
|
||||
FlaxRobertaModel,
|
||||
)
|
||||
from .auto_factory import auto_class_factory
|
||||
from .configuration_auto import BertConfig, RobertaConfig
|
||||
|
||||
@ -47,6 +54,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for pre-training mapping
|
||||
(RobertaConfig, FlaxRobertaForMaskedLM),
|
||||
(BertConfig, FlaxBertForPreTraining),
|
||||
]
|
||||
)
|
||||
@ -54,6 +62,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Masked LM mapping
|
||||
(RobertaConfig, FlaxRobertaForMaskedLM),
|
||||
(BertConfig, FlaxBertForMaskedLM),
|
||||
]
|
||||
)
|
||||
@ -61,6 +70,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
(RobertaConfig, FlaxRobertaForSequenceClassification),
|
||||
(BertConfig, FlaxBertForSequenceClassification),
|
||||
]
|
||||
)
|
||||
@ -68,6 +78,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Question Answering mapping
|
||||
(RobertaConfig, FlaxRobertaForQuestionAnswering),
|
||||
(BertConfig, FlaxBertForQuestionAnswering),
|
||||
]
|
||||
)
|
||||
@ -75,6 +86,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Token Classification mapping
|
||||
(RobertaConfig, FlaxRobertaForTokenClassification),
|
||||
(BertConfig, FlaxBertForTokenClassification),
|
||||
]
|
||||
)
|
||||
@ -82,6 +94,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Multiple Choice mapping
|
||||
(RobertaConfig, FlaxRobertaForMultipleChoice),
|
||||
(BertConfig, FlaxBertForMultipleChoice),
|
||||
]
|
||||
)
|
||||
|
@ -61,7 +61,15 @@ if is_tf_available():
|
||||
]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_roberta"] = ["FlaxRobertaModel"]
|
||||
_import_structure["modeling_flax_roberta"] = [
|
||||
"FlaxRobertaForMaskedLM",
|
||||
"FlaxRobertaForMultipleChoice",
|
||||
"FlaxRobertaForQuestionAnswering",
|
||||
"FlaxRobertaForSequenceClassification",
|
||||
"FlaxRobertaForTokenClassification",
|
||||
"FlaxRobertaModel",
|
||||
"FlaxRobertaPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -97,7 +105,15 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_roberta import FlaxRobertaModel
|
||||
from .modeling_tf_roberta import (
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
FlaxRobertaForQuestionAnswering,
|
||||
FlaxRobertaForSequenceClassification,
|
||||
FlaxRobertaForTokenClassification,
|
||||
FlaxRobertaModel,
|
||||
FlaxRobertaPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import importlib
|
||||
|
@ -12,7 +12,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
@ -23,8 +25,16 @@ from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutput,
|
||||
FlaxBaseModelOutputWithPooling,
|
||||
FlaxMaskedLMOutput,
|
||||
FlaxMultipleChoiceModelOutput,
|
||||
FlaxQuestionAnsweringModelOutput,
|
||||
FlaxSequenceClassifierOutput,
|
||||
FlaxTokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring
|
||||
from ...utils import logging
|
||||
from .configuration_roberta import RobertaConfig
|
||||
|
||||
@ -49,7 +59,14 @@ def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = (input_ids != padding_idx).astype("i4")
|
||||
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
|
||||
|
||||
if mask.ndim > 2:
|
||||
mask = mask.reshape((-1, mask.shape[-1]))
|
||||
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
|
||||
incremental_indices = incremental_indices.reshape(input_ids.shape)
|
||||
else:
|
||||
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
|
||||
|
||||
return incremental_indices.astype("i4") + padding_idx
|
||||
|
||||
|
||||
@ -436,6 +453,67 @@ class FlaxRobertaPooler(nn.Module):
|
||||
return nn.tanh(cls_hidden_state)
|
||||
|
||||
|
||||
class FlaxRobertaLMHead(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.decoder = nn.Dense(
|
||||
self.config.vocab_size,
|
||||
dtype=self.dtype,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
|
||||
|
||||
def __call__(self, hidden_states, shared_embedding=None):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = ACT2FN["gelu"](hidden_states)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
if shared_embedding is not None:
|
||||
hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
||||
else:
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
|
||||
hidden_states += self.bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxRobertaClassificationHead(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.out_proj = nn.Dense(
|
||||
self.config.num_labels,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = nn.tanh(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
hidden_states = self.out_proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
@ -585,3 +663,347 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
||||
append_call_sample_docstring(
|
||||
FlaxRobertaModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC
|
||||
)
|
||||
|
||||
|
||||
class FlaxRobertaForMaskedLMModule(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# Model
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.tie_word_embeddings:
|
||||
shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
||||
else:
|
||||
shared_embedding = None
|
||||
|
||||
# Compute the prediction scores
|
||||
logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
|
||||
|
||||
if not return_dict:
|
||||
return (logits,) + outputs[1:]
|
||||
|
||||
return FlaxMaskedLMOutput(
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
||||
class FlaxRobertaForMaskedLM(FlaxRobertaPreTrainedModel):
|
||||
module_class = FlaxRobertaForMaskedLMModule
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxRobertaForMaskedLM,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxBaseModelOutputWithPooling,
|
||||
_CONFIG_FOR_DOC,
|
||||
mask="<mask>",
|
||||
)
|
||||
|
||||
|
||||
class FlaxRobertaForSequenceClassificationModule(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# Model
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output, deterministic=deterministic)
|
||||
|
||||
if not return_dict:
|
||||
return (logits,) + outputs[1:]
|
||||
|
||||
return FlaxSequenceClassifierOutput(
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Roberta Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
||||
pooled output) e.g. for GLUE tasks.
|
||||
""",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel):
|
||||
module_class = FlaxRobertaForSequenceClassificationModule
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxRobertaForSequenceClassification,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxSequenceClassifierOutput,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->Roberta, with self.bert->self.roberta
|
||||
class FlaxRobertaForMultipleChoiceModule(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(1, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
num_choices = input_ids.shape[1]
|
||||
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
|
||||
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
|
||||
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
|
||||
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
|
||||
|
||||
# Model
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
reshaped_logits = logits.reshape(-1, num_choices)
|
||||
|
||||
if not return_dict:
|
||||
return (reshaped_logits,) + outputs[2:]
|
||||
|
||||
return FlaxMultipleChoiceModelOutput(
|
||||
logits=reshaped_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
||||
softmax) e.g. for RocStories/SWAG tasks.
|
||||
""",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel):
|
||||
module_class = FlaxRobertaForMultipleChoiceModule
|
||||
|
||||
|
||||
overwrite_call_docstring(
|
||||
FlaxRobertaForMultipleChoice, ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
||||
)
|
||||
append_call_sample_docstring(
|
||||
FlaxRobertaForMultipleChoice,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxMultipleChoiceModelOutput,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->Roberta, with self.bert->self.roberta
|
||||
class FlaxRobertaForTokenClassificationModule(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# Model
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
logits = self.classifier(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (logits,) + outputs[1:]
|
||||
|
||||
return FlaxTokenClassifierOutput(
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
||||
Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel):
|
||||
module_class = FlaxRobertaForTokenClassificationModule
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxRobertaForTokenClassification,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxTokenClassifierOutput,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->Roberta, with self.bert->self.roberta
|
||||
class FlaxRobertaForQuestionAnsweringModule(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# Model
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(hidden_states)
|
||||
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
if not return_dict:
|
||||
return (start_logits, end_logits) + outputs[1:]
|
||||
|
||||
return FlaxQuestionAnsweringModelOutput(
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
||||
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
""",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel):
|
||||
module_class = FlaxRobertaForQuestionAnsweringModule
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxRobertaForQuestionAnswering,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxQuestionAnsweringModelOutput,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
@ -180,6 +180,51 @@ class FlaxBertPreTrainedModel:
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxRobertaForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxRobertaForMultipleChoice:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxRobertaForQuestionAnswering:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxRobertaForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxRobertaForTokenClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxRobertaModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
@ -187,3 +232,12 @@ class FlaxRobertaModel:
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxRobertaPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
@ -150,7 +150,7 @@ class FlaxModelTesterMixin:
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
@ -161,7 +161,7 @@ class FlaxModelTesterMixin:
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3)
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
@ -191,7 +191,7 @@ class FlaxModelTesterMixin:
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
@ -204,7 +204,7 @@ class FlaxModelTesterMixin:
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -219,6 +219,7 @@ class FlaxModelTesterMixin:
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
outputs = model(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
# verify that normal save_pretrained works as expected
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_loaded = model_class.from_pretrained(tmpdirname)
|
||||
@ -227,6 +228,16 @@ class FlaxModelTesterMixin:
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 1e-3)
|
||||
|
||||
# verify that save_pretrained for distributed training
|
||||
# with `params=params` works as expected
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, params=model.params)
|
||||
model_loaded = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 1e-3)
|
||||
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -23,7 +23,14 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers.models.roberta.modeling_flax_roberta import FlaxRobertaModel
|
||||
from transformers.models.roberta.modeling_flax_roberta import (
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
FlaxRobertaForQuestionAnswering,
|
||||
FlaxRobertaForSequenceClassification,
|
||||
FlaxRobertaForTokenClassification,
|
||||
FlaxRobertaModel,
|
||||
)
|
||||
|
||||
|
||||
class FlaxRobertaModelTester(unittest.TestCase):
|
||||
@ -48,6 +55,7 @@ class FlaxRobertaModelTester(unittest.TestCase):
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_choices=4,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -68,6 +76,7 @@ class FlaxRobertaModelTester(unittest.TestCase):
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_choices = num_choices
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@ -107,7 +116,18 @@ class FlaxRobertaModelTester(unittest.TestCase):
|
||||
@require_flax
|
||||
class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxRobertaModel,) if is_flax_available() else ()
|
||||
all_model_classes = (
|
||||
(
|
||||
FlaxRobertaModel,
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForSequenceClassification,
|
||||
FlaxRobertaForTokenClassification,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
FlaxRobertaForQuestionAnswering,
|
||||
)
|
||||
if is_flax_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxRobertaModelTester(self)
|
||||
|
Loading…
Reference in New Issue
Block a user