[Flax] Big FlaxBert Refactor (#11364)

* improve flax

* refactor

* typos

* Update src/transformers/modeling_flax_utils.py

* Apply suggestions from code review

* Update src/transformers/modeling_flax_utils.py

* fix typo

* improve error tolerance

* typo

* correct nasty saving bug

* fix from pretrained

* correct tree map

* add note

* correct weight tying
This commit is contained in:
Patrick von Platen 2021-04-23 09:53:09 +02:00 committed by GitHub
parent 3ed5e97ba0
commit 8c9b5fcbaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 306 additions and 197 deletions

View File

@ -12,12 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch - TF 2.0 general utilities."""
""" PyTorch - Flax general utilities."""
import os
from pickle import UnpicklingError
from flax.core.frozen_dict import unfreeze
import numpy as np
import jax.numpy as jnp
import transformers
from flax.serialization import from_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from .utils import logging
@ -37,7 +42,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa
import torch # noqa: F401
except ImportError:
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see "
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see "
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
)
raise
@ -57,7 +62,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# convert pytorch tensor to numpy
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
random_flax_state_dict = flatten_dict(unfreeze(flax_model.params))
random_flax_state_dict = flatten_dict(flax_model.params)
flax_state_dict = {}
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
@ -80,7 +85,12 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
elif add_base_model_prefix and require_base_model_prefix:
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
if pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
# Correctly rename weight parameters
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
pt_tensor = pt_tensor.T
elif pt_tuple_key[-1] == "gamma":
@ -89,12 +99,128 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
if pt_tuple_key in random_flax_state_dict:
if random_flax_state_dict[pt_tuple_key].shape != pt_tensor.shape:
if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
raise ValueError(
"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
)
# add unexpected weight so that warning is thrown
flax_state_dict[pt_tuple_key] = pt_tensor
# also add unexpected weight so that warning is thrown
flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
return unflatten_dict(flax_state_dict)
#####################
# Flax => PyTorch #
#####################
def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path):
"""Load flax checkpoints in a PyTorch model"""
flax_checkpoint_path = os.path.abspath(flax_checkpoint_path)
logger.info(f"Loading Flax weights from {flax_checkpoint_path}")
# import correct flax class
flax_cls = getattr(transformers, "Flax" + model.__class__.__name__)
# load flax weight dict
with open(flax_checkpoint_path, "rb") as state_f:
try:
flax_state_dict = from_bytes(flax_cls, state_f.read())
except UnpicklingError:
raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")
return load_flax_weights_in_pytorch_model(model, flax_state_dict)
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
"""Load flax checkpoints in a PyTorch model"""
try:
import torch # noqa: F401
except ImportError:
logger.error(
"Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see "
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
)
raise
flax_state_dict = flatten_dict(flax_state)
pt_model_dict = pt_model.state_dict()
remove_base_model_prefix = (pt_model.base_model_prefix in flax_state) and (
pt_model.base_model_prefix not in set([k.split(".")[0] for k in pt_model_dict.keys()])
)
add_base_model_prefix = (pt_model.base_model_prefix not in flax_state) and (
pt_model.base_model_prefix in set([k.split(".")[0] for k in pt_model_dict.keys()])
)
# keep track of unexpected & missing keys
unexpected_keys = []
missing_keys = set(pt_model_dict.keys())
for flax_key_tuple, flax_tensor in flax_state_dict.items():
has_base_model_prefix = flax_key_tuple[0] == pt_model.base_model_prefix
require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict
# adapt flax_key to prepare for loading from/to base model only
if remove_base_model_prefix and has_base_model_prefix:
flax_key_tuple = flax_key_tuple[1:]
elif add_base_model_prefix and require_base_model_prefix:
flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple
# rename flax weights to PyTorch format
if flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict:
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_tensor = flax_tensor.T
elif flax_key_tuple[-1] in ["scale", "embedding"]:
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_key = ".".join(flax_key_tuple)
if flax_key in pt_model_dict:
if flax_tensor.shape != pt_model_dict[flax_key].shape:
raise ValueError(
f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected"
f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
)
else:
# add weight to pytorch dict
flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
# remove from missing keys
missing_keys.remove(flax_key)
else:
# weight is not expected by PyTorch model
unexpected_keys.append(flax_key)
pt_model.load_state_dict(pt_model_dict)
# re-transform missing_keys to list
missing_keys = list(missing_keys)
if len(unexpected_keys) > 0:
logger.warning(
"Some weights of the Flax model were not used when "
f"initializing the PyTorch model {pt_model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {pt_model.__class__.__name__} from a Flax model trained on another task "
"or with another architecture (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n"
f"- This IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect "
"to be exactly identical (e.g. initializing a BertForSequenceClassification model from a FlaxBertForSequenceClassification model)."
)
else:
logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model "
f"and are newly initialized: {missing_keys}\n"
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
logger.warning(
f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.\n"
"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {pt_model.__class__.__name__} for predictions without further training."
)
return pt_model

View File

@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -22,7 +22,7 @@ from typing import Dict, Set, Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
@ -46,7 +46,7 @@ logger = logging.get_logger(__name__)
ACT2FN = {
"gelu": nn.gelu,
"gelu": partial(nn.gelu, approximate=False),
"relu": nn.relu,
"silu": nn.swish,
"swish": nn.swish,
@ -129,7 +129,7 @@ class FlaxPreTrainedModel(ABC):
"Some parameters are missing. Make sure that `params` include the following "
f"parameters {self.required_params - param_keys}"
)
self._params = freeze(params)
self._params = params
@classmethod
def from_pretrained(
@ -330,6 +330,10 @@ class FlaxPreTrainedModel(ABC):
state = from_bytes(cls, state_f.read())
except UnpicklingError:
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
# make sure all arrays are stored as jnp.arrays
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261
state = jax.tree_util.tree_map(jnp.array, state)
# if model is base model only use model_prefix key
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
@ -337,6 +341,7 @@ class FlaxPreTrainedModel(ABC):
# flatten dicts
state = flatten_dict(state)
random_state = flatten_dict(unfreeze(model.params))
missing_keys = model.required_params - set(state.keys())
@ -377,6 +382,7 @@ class FlaxPreTrainedModel(ABC):
# set correct parameters
model.params = unflatten_dict(state)
return model
def save_pretrained(self, save_directory: Union[str, os.PathLike]):

View File

@ -30,6 +30,7 @@ from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .file_utils import (
DUMMY_INPUTS,
FLAX_WEIGHTS_NAME,
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
WEIGHTS_NAME,
@ -875,6 +876,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- A path or url to a model folder containing a `flax checkpoint file` in `.msgpack` format (e.g,
``./flax_model/`` containing ``flax_model.msgpack``). In this case, ``from_flax`` should be set
to :obj:`True`.
- :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
arguments ``config`` and ``state_dict``).
model_args (sequence of positional arguments, `optional`):
@ -907,6 +911,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a TensorFlow checkpoint save file (see docstring of
``pretrained_model_name_or_path`` argument).
from_flax (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a Flax checkpoint save file (see docstring of
``pretrained_model_name_or_path`` argument).
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@ -968,11 +975,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
>>> config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
>>> model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
>>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
>>> model = BertModel.from_pretrained('bert-base-uncased', from_flax=True)
"""
config = kwargs.pop("config", None)
state_dict = kwargs.pop("state_dict", None)
cache_dir = kwargs.pop("cache_dir", None)
from_tf = kwargs.pop("from_tf", False)
from_flax = kwargs.pop("from_flax", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
@ -1023,13 +1034,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint in priority if from_tf
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
elif from_flax and os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint in priority if from_flax
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
raise EnvironmentError(
f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index']} found in "
f"directory {pretrained_model_name_or_path} or `from_tf` set to False."
f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index', FLAX_WEIGHTS_NAME]} found in "
f"directory {pretrained_model_name_or_path} or `from_tf` and `from_flax` set to False."
)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
@ -1041,9 +1055,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
)
archive_file = pretrained_model_name_or_path + ".index"
else:
# set correct filename
if from_tf:
filename = TF2_WEIGHTS_NAME
elif from_flax:
filename = FLAX_WEIGHTS_NAME
else:
filename = WEIGHTS_NAME
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
filename=filename,
revision=revision,
mirror=mirror,
)
@ -1090,7 +1112,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
else:
model = cls(config, *model_args, **model_kwargs)
if state_dict is None and not from_tf:
if state_dict is None and not (from_tf or from_flax):
try:
state_dict = torch.load(resolved_archive_file, map_location="cpu")
except Exception:
@ -1120,6 +1142,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
elif from_flax:
try:
from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model
model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file)
except ImportError:
logger.error(
"Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see "
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
)
raise
else:
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []

View File

@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -95,68 +95,6 @@ BERT_INPUTS_DOCSTRING = r"""
"""
class FlaxBertLayerNorm(nn.Module):
"""
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
"""
hidden_size: int
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32
use_bias: bool = True
scale: bool = True
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.weight = self.param("weight", self.scale_init, (self.hidden_size,))
self.bias = self.param("bias", self.scale_init, (self.hidden_size,))
def __call__(self, x):
"""
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that
maintains the mean activation within each example close to 0 and the activation standard deviation close to 1
Args:
x: the inputs
Returns:
Normalized inputs (the same shape as inputs).
"""
mean = jnp.mean(x, axis=-1, keepdims=True)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
var = mean2 - jax.lax.square(mean)
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.weight)
y = (x - mean) * mul
if self.use_bias:
y = y + jnp.asarray(self.bias)
return y
class FlaxBertEmbedding(nn.Module):
"""
Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch
use 'weight'
"""
vocab_size: int
hidden_size: int
initializer_range: float
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range)
self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size))
def __call__(self, input_ids):
return jnp.take(self.embeddings, input_ids, axis=0)
class FlaxBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
@ -164,35 +102,37 @@ class FlaxBertEmbeddings(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.word_embeddings = FlaxBertEmbedding(
self.word_embeddings = nn.Embed(
self.config.vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = FlaxBertEmbedding(
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = FlaxBertEmbedding(
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
batch_size, sequence_length = input_ids.shape
# Embed
inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4")))
position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4")))
token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4")))
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
position_embeds = self.position_embeddings(position_ids.astype("i4"))
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
# Sum all embeddings
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))
# Layer Norm
hidden_states = self.LayerNorm(hidden_states)
@ -281,7 +221,7 @@ class FlaxBertSelfOutput(nn.Module):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
dtype=self.dtype,
)
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
@ -337,7 +277,7 @@ class FlaxBertOutput(nn.Module):
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
@ -372,7 +312,7 @@ class FlaxBertLayerCollection(nn.Module):
]
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
for layer in self.layers:
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
return hidden_states
@ -412,7 +352,7 @@ class FlaxBertPredictionHeadTransform(nn.Module):
def setup(self):
self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
self.activation = ACT2FN[self.config.hidden_act]
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
@ -423,14 +363,22 @@ class FlaxBertPredictionHeadTransform(nn.Module):
class FlaxBertLMPredictionHead(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype)
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype)
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
def __call__(self, hidden_states):
def __call__(self, hidden_states, shared_embedding=None):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
if shared_embedding is not None:
hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
hidden_states = self.decoder(hidden_states)
hidden_states += self.bias
return hidden_states
@ -441,8 +389,8 @@ class FlaxBertOnlyMLMHead(nn.Module):
def setup(self):
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
def __call__(self, hidden_states):
hidden_states = self.predictions(hidden_states)
def __call__(self, hidden_states, shared_embedding=None):
hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding)
return hidden_states
@ -464,8 +412,8 @@ class FlaxBertPreTrainingHeads(nn.Module):
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
self.seq_relationship = nn.Dense(2, dtype=self.dtype)
def __call__(self, hidden_states, pooled_output):
prediction_scores = self.predictions(hidden_states)
def __call__(self, hidden_states, pooled_output, shared_embedding=None):
prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
@ -490,7 +438,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.ones_like(input_ids)
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
attention_mask = jnp.ones_like(input_ids)
params_rng, dropout_rng = jax.random.split(rng)
@ -514,7 +462,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
@ -546,7 +494,6 @@ class FlaxBertModule(nn.Module):
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
)
@ -582,7 +529,15 @@ class FlaxBertForPreTrainingModule(nn.Module):
hidden_states, pooled_output = self.bert(
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
)
prediction_scores, seq_relationship_score = self.cls(hidden_states, pooled_output)
if self.config.tie_word_embeddings:
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
else:
shared_embedding = None
prediction_scores, seq_relationship_score = self.cls(
hidden_states, pooled_output, shared_embedding=shared_embedding
)
return (prediction_scores, seq_relationship_score)
@ -612,8 +567,13 @@ class FlaxBertForMaskedLMModule(nn.Module):
# Model
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
if self.config.tie_word_embeddings:
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
else:
shared_embedding = None
# Compute the prediction scores
logits = self.cls(hidden_states)
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
return (logits,)

View File

@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,9 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Tuple
import numpy as np
from typing import Tuple
import flax.linen as nn
import jax
@ -110,70 +108,6 @@ ROBERTA_INPUTS_DOCSTRING = r"""
"""
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerNorm with Bert->Roberta
class FlaxRobertaLayerNorm(nn.Module):
"""
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
"""
hidden_size: int
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32
use_bias: bool = True
scale: bool = True
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.weight = self.param("weight", self.scale_init, (self.hidden_size,))
self.bias = self.param("bias", self.scale_init, (self.hidden_size,))
def __call__(self, x):
"""
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that
maintains the mean activation within each example close to 0 and the activation standard deviation close to 1
Args:
x: the inputs
Returns:
Normalized inputs (the same shape as inputs).
"""
mean = jnp.mean(x, axis=-1, keepdims=True)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
var = mean2 - jax.lax.square(mean)
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.weight)
y = (x - mean) * mul
if self.use_bias:
y = y + jnp.asarray(self.bias)
return y
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbedding with Bert->Roberta
class FlaxRobertaEmbedding(nn.Module):
"""
Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch
use 'weight'
"""
vocab_size: int
hidden_size: int
initializer_range: float
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range)
self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size))
def __call__(self, input_ids):
return jnp.take(self.embeddings, input_ids, axis=0)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
class FlaxRobertaEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
@ -182,35 +116,37 @@ class FlaxRobertaEmbeddings(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.word_embeddings = FlaxRobertaEmbedding(
self.word_embeddings = nn.Embed(
self.config.vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = FlaxRobertaEmbedding(
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = FlaxRobertaEmbedding(
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
batch_size, sequence_length = input_ids.shape
# Embed
inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4")))
position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4")))
token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4")))
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
position_embeds = self.position_embeddings(position_ids.astype("i4"))
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
# Sum all embeddings
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))
# Layer Norm
hidden_states = self.LayerNorm(hidden_states)
@ -301,7 +237,7 @@ class FlaxRobertaSelfOutput(nn.Module):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
dtype=self.dtype,
)
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
@ -360,7 +296,7 @@ class FlaxRobertaOutput(nn.Module):
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
@ -397,7 +333,7 @@ class FlaxRobertaLayerCollection(nn.Module):
]
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
for layer in self.layers:
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
return hidden_states
@ -515,7 +451,6 @@ class FlaxRobertaModule(nn.Module):
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
)

View File

@ -28,7 +28,10 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
@ -83,29 +86,32 @@ class FlaxModelTesterMixin:
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@is_pt_flax_cross_test
def test_equivalence_flax_pytorch(self):
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
pt_model = pt_model_class(config).eval()
fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
@ -116,7 +122,50 @@ class FlaxModelTesterMixin:
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
fx_model = model_class(config, dtype=jnp.float32)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -134,7 +183,7 @@ class FlaxModelTesterMixin:
outputs_loaded = model_loaded(**prepared_inputs_dict)
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 5e-3)
self.assert_almost_equals(output_loaded, output, 1e-3)
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()