[FlaxRoberta] Add FlaxRobertaModels & adapt run_mlm_flax.py (#11470)

* add flax roberta

* make style

* correct initialiazation

* modify model to save weights

* fix copied from

* fix copied from

* correct some more code

* add more roberta models

* Apply suggestions from code review

* merge from master

* finish

* finish docs

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen 2021-05-04 19:57:59 +02:00 committed by GitHub
parent 2ce0fb84cc
commit 084a187da3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 695 additions and 47 deletions

View File

@ -166,3 +166,38 @@ FlaxRobertaModel
.. autoclass:: transformers.FlaxRobertaModel
:members: __call__
FlaxRobertaForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxRobertaForMaskedLM
:members: __call__
FlaxRobertaForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxRobertaForSequenceClassification
:members: __call__
FlaxRobertaForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxRobertaForMultipleChoice
:members: __call__
FlaxRobertaForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxRobertaForTokenClassification
:members: __call__
FlaxRobertaForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxRobertaForQuestionAnswering
:members: __call__

View File

@ -45,7 +45,7 @@ from transformers import (
MODEL_FOR_MASKED_LM_MAPPING,
AutoConfig,
AutoTokenizer,
FlaxBertForMaskedLM,
FlaxAutoModelForMaskedLM,
HfArgumentParser,
PreTrainedTokenizerBase,
TensorType,
@ -105,6 +105,12 @@ class ModelArguments:
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
},
)
@dataclass
@ -162,6 +168,10 @@ class DataTrainingArguments:
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
},
)
line_by_line: bool = field(
default=False,
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
@ -537,27 +547,76 @@ if __name__ == "__main__":
column_names = datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
padding = "max_length" if data_args.pad_to_max_length else False
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
def tokenize_function(examples):
# Remove empty lines
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
return tokenizer(
examples,
return_special_tokens_mask=True,
padding=padding,
truncation=True,
max_length=data_args.max_seq_length,
if data_args.line_by_line:
# When using line_by_line, we just tokenize each nonempty line.
padding = "max_length" if data_args.pad_to_max_length else False
def tokenize_function(examples):
# Remove empty lines
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
return tokenizer(
examples,
return_special_tokens_mask=True,
padding=padding,
truncation=True,
max_length=max_seq_length,
)
tokenized_datasets = datasets.map(
tokenize_function,
input_columns=[text_column_name],
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
tokenized_datasets = datasets.map(
tokenize_function,
input_columns=[text_column_name],
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
else:
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
# efficient when it receives the `special_tokens_mask`.
def tokenize_function(examples):
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
total_length = (total_length // max_seq_length) * max_seq_length
# Split by chunks of max_len.
result = {
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
for k, t in concatenated_examples.items()
}
return result
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
# might be slower to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
tokenized_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)
# Enable tensorboard only on the master node
if has_tensorboard and jax.host_id() == 0:
@ -571,13 +630,7 @@ if __name__ == "__main__":
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
model = FlaxBertForMaskedLM.from_pretrained(
"bert-base-cased",
dtype=jnp.float32,
input_shape=(training_args.train_batch_size, config.max_position_embeddings),
seed=training_args.seed,
dropout_rate=0.1,
)
model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
# Setup optimizer
optimizer = Adam(
@ -602,8 +655,8 @@ if __name__ == "__main__":
# Store some constant
nb_epochs = int(training_args.num_train_epochs)
batch_size = int(training_args.train_batch_size)
eval_batch_size = int(training_args.eval_batch_size)
batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
for epoch in epochs:
@ -657,3 +710,8 @@ if __name__ == "__main__":
if has_tensorboard and jax.host_id() == 0:
for name, value in eval_summary.items():
summary_writer.scalar(name, value, epoch)
# save last checkpoint
if jax.host_id() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], optimizer.target))
model.save_pretrained(training_args.output_dir, params=params)

View File

@ -1403,7 +1403,17 @@ if is_flax_available():
"FlaxBertPreTrainedModel",
]
)
_import_structure["models.roberta"].append("FlaxRobertaModel")
_import_structure["models.roberta"].extend(
[
"FlaxRobertaForMaskedLM",
"FlaxRobertaForMultipleChoice",
"FlaxRobertaForQuestionAnswering",
"FlaxRobertaForSequenceClassification",
"FlaxRobertaForTokenClassification",
"FlaxRobertaModel",
"FlaxRobertaPreTrainedModel",
]
)
else:
from .utils import dummy_flax_objects
@ -2575,7 +2585,15 @@ if TYPE_CHECKING:
FlaxBertModel,
FlaxBertPreTrainedModel,
)
from .models.roberta import FlaxRobertaModel
from .models.roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
FlaxRobertaPreTrainedModel,
)
else:
# Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them.

View File

@ -1608,9 +1608,9 @@ def is_tensor(x):
if is_flax_available():
import jaxlib.xla_extension as jax_xla
from jax.interpreters.partial_eval import DynamicJaxprTracer
from jax.core import Tracer
if isinstance(x, (jax_xla.DeviceArray, DynamicJaxprTracer)):
if isinstance(x, (jax_xla.DeviceArray, Tracer)):
return True
return isinstance(x, np.ndarray)

View File

@ -388,7 +388,7 @@ class FlaxPreTrainedModel(PushToHubMixin):
return model
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub=False, **kwargs):
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method
@ -416,7 +416,8 @@ class FlaxPreTrainedModel(PushToHubMixin):
# save model
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
with open(output_model_file, "wb") as f:
model_bytes = to_bytes(self.params)
params = params if params is not None else self.params
model_bytes = to_bytes(params)
f.write(model_bytes)
logger.info(f"Model weights saved in {output_model_file}")

View File

@ -28,7 +28,14 @@ from ..bert.modeling_flax_bert import (
FlaxBertForTokenClassification,
FlaxBertModel,
)
from ..roberta.modeling_flax_roberta import FlaxRobertaModel
from ..roberta.modeling_flax_roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
)
from .auto_factory import auto_class_factory
from .configuration_auto import BertConfig, RobertaConfig
@ -47,6 +54,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[
# Model for pre-training mapping
(RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForPreTraining),
]
)
@ -54,6 +62,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
[
# Model for Masked LM mapping
(RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForMaskedLM),
]
)
@ -61,6 +70,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Sequence Classification mapping
(RobertaConfig, FlaxRobertaForSequenceClassification),
(BertConfig, FlaxBertForSequenceClassification),
]
)
@ -68,6 +78,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
[
# Model for Question Answering mapping
(RobertaConfig, FlaxRobertaForQuestionAnswering),
(BertConfig, FlaxBertForQuestionAnswering),
]
)
@ -75,6 +86,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Token Classification mapping
(RobertaConfig, FlaxRobertaForTokenClassification),
(BertConfig, FlaxBertForTokenClassification),
]
)
@ -82,6 +94,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
[
# Model for Multiple Choice mapping
(RobertaConfig, FlaxRobertaForMultipleChoice),
(BertConfig, FlaxBertForMultipleChoice),
]
)

View File

@ -61,7 +61,15 @@ if is_tf_available():
]
if is_flax_available():
_import_structure["modeling_flax_roberta"] = ["FlaxRobertaModel"]
_import_structure["modeling_flax_roberta"] = [
"FlaxRobertaForMaskedLM",
"FlaxRobertaForMultipleChoice",
"FlaxRobertaForQuestionAnswering",
"FlaxRobertaForSequenceClassification",
"FlaxRobertaForTokenClassification",
"FlaxRobertaModel",
"FlaxRobertaPreTrainedModel",
]
if TYPE_CHECKING:
@ -97,7 +105,15 @@ if TYPE_CHECKING:
)
if is_flax_available():
from .modeling_flax_roberta import FlaxRobertaModel
from .modeling_tf_roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
FlaxRobertaPreTrainedModel,
)
else:
import importlib

View File

@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from typing import Callable, Optional, Tuple
import numpy as np
import flax.linen as nn
import jax
@ -23,8 +25,16 @@ from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPooling,
FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput,
FlaxQuestionAnsweringModelOutput,
FlaxSequenceClassifierOutput,
FlaxTokenClassifierOutput,
)
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring
from ...utils import logging
from .configuration_roberta import RobertaConfig
@ -49,7 +59,14 @@ def create_position_ids_from_input_ids(input_ids, padding_idx):
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = (input_ids != padding_idx).astype("i4")
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
if mask.ndim > 2:
mask = mask.reshape((-1, mask.shape[-1]))
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
incremental_indices = incremental_indices.reshape(input_ids.shape)
else:
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
return incremental_indices.astype("i4") + padding_idx
@ -436,6 +453,67 @@ class FlaxRobertaPooler(nn.Module):
return nn.tanh(cls_hidden_state)
class FlaxRobertaLMHead(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.decoder = nn.Dense(
self.config.vocab_size,
dtype=self.dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
def __call__(self, hidden_states, shared_embedding=None):
hidden_states = self.dense(hidden_states)
hidden_states = ACT2FN["gelu"](hidden_states)
hidden_states = self.layer_norm(hidden_states)
if shared_embedding is not None:
hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
hidden_states = self.decoder(hidden_states)
hidden_states += self.bias
return hidden_states
class FlaxRobertaClassificationHead(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.out_proj = nn.Dense(
self.config.num_labels,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
def __call__(self, hidden_states, deterministic=True):
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.dense(hidden_states)
hidden_states = nn.tanh(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -585,3 +663,347 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
append_call_sample_docstring(
FlaxRobertaModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC
)
class FlaxRobertaForMaskedLMModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.roberta(
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
else:
shared_embedding = None
# Compute the prediction scores
logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxMaskedLMOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
class FlaxRobertaForMaskedLM(FlaxRobertaPreTrainedModel):
module_class = FlaxRobertaForMaskedLMModule
append_call_sample_docstring(
FlaxRobertaForMaskedLM,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxBaseModelOutputWithPooling,
_CONFIG_FOR_DOC,
mask="<mask>",
)
class FlaxRobertaForSequenceClassificationModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.roberta(
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output, deterministic=deterministic)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxSequenceClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Roberta Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
ROBERTA_START_DOCSTRING,
)
class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel):
module_class = FlaxRobertaForSequenceClassificationModule
append_call_sample_docstring(
FlaxRobertaForSequenceClassification,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxSequenceClassifierOutput,
_CONFIG_FOR_DOC,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->Roberta, with self.bert->self.roberta
class FlaxRobertaForMultipleChoiceModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
num_choices = input_ids.shape[1]
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
# Model
outputs = self.roberta(
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
logits = self.classifier(pooled_output)
reshaped_logits = logits.reshape(-1, num_choices)
if not return_dict:
return (reshaped_logits,) + outputs[2:]
return FlaxMultipleChoiceModelOutput(
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
ROBERTA_START_DOCSTRING,
)
class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel):
module_class = FlaxRobertaForMultipleChoiceModule
overwrite_call_docstring(
FlaxRobertaForMultipleChoice, ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
append_call_sample_docstring(
FlaxRobertaForMultipleChoice,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxMultipleChoiceModelOutput,
_CONFIG_FOR_DOC,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->Roberta, with self.bert->self.roberta
class FlaxRobertaForTokenClassificationModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.roberta(
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
logits = self.classifier(hidden_states)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxTokenClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
ROBERTA_START_DOCSTRING,
)
class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel):
module_class = FlaxRobertaForTokenClassificationModule
append_call_sample_docstring(
FlaxRobertaForTokenClassification,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxTokenClassifierOutput,
_CONFIG_FOR_DOC,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->Roberta, with self.bert->self.roberta
class FlaxRobertaForQuestionAnsweringModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.roberta(
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if not return_dict:
return (start_logits, end_logits) + outputs[1:]
return FlaxQuestionAnsweringModelOutput(
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
ROBERTA_START_DOCSTRING,
)
class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel):
module_class = FlaxRobertaForQuestionAnsweringModule
append_call_sample_docstring(
FlaxRobertaForQuestionAnswering,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC,
)

View File

@ -180,6 +180,51 @@ class FlaxBertPreTrainedModel:
requires_backends(self, ["flax"])
class FlaxRobertaForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRobertaForMultipleChoice:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRobertaForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRobertaForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRobertaForTokenClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRobertaModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@ -187,3 +232,12 @@ class FlaxRobertaModel:
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRobertaPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])

View File

@ -150,7 +150,7 @@ class FlaxModelTesterMixin:
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
@ -161,7 +161,7 @@ class FlaxModelTesterMixin:
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
@ -191,7 +191,7 @@ class FlaxModelTesterMixin:
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
@ -204,7 +204,7 @@ class FlaxModelTesterMixin:
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -219,6 +219,7 @@ class FlaxModelTesterMixin:
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**prepared_inputs_dict).to_tuple()
# verify that normal save_pretrained works as expected
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_loaded = model_class.from_pretrained(tmpdirname)
@ -227,6 +228,16 @@ class FlaxModelTesterMixin:
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 1e-3)
# verify that save_pretrained for distributed training
# with `params=params` works as expected
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, params=model.params)
model_loaded = model_class.from_pretrained(tmpdirname)
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 1e-3)
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -23,7 +23,14 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
if is_flax_available():
from transformers.models.roberta.modeling_flax_roberta import FlaxRobertaModel
from transformers.models.roberta.modeling_flax_roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
)
class FlaxRobertaModelTester(unittest.TestCase):
@ -48,6 +55,7 @@ class FlaxRobertaModelTester(unittest.TestCase):
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_choices=4,
):
self.parent = parent
self.batch_size = batch_size
@ -68,6 +76,7 @@ class FlaxRobertaModelTester(unittest.TestCase):
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_choices = num_choices
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -107,7 +116,18 @@ class FlaxRobertaModelTester(unittest.TestCase):
@require_flax
class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxRobertaModel,) if is_flax_available() else ()
all_model_classes = (
(
FlaxRobertaModel,
FlaxRobertaForMaskedLM,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
)
if is_flax_available()
else ()
)
def setUp(self):
self.model_tester = FlaxRobertaModelTester(self)