mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Flax] Big FlaxBert Refactor (#11364)
* improve flax * refactor * typos * Update src/transformers/modeling_flax_utils.py * Apply suggestions from code review * Update src/transformers/modeling_flax_utils.py * fix typo * improve error tolerance * typo * correct nasty saving bug * fix from pretrained * correct tree map * add note * correct weight tying
This commit is contained in:
parent
3ed5e97ba0
commit
8c9b5fcbaf
@ -12,12 +12,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch - TF 2.0 general utilities."""
|
||||
""" PyTorch - Flax general utilities."""
|
||||
|
||||
|
||||
import os
|
||||
from pickle import UnpicklingError
|
||||
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
import numpy as np
|
||||
|
||||
import jax.numpy as jnp
|
||||
import transformers
|
||||
from flax.serialization import from_bytes
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
|
||||
from .utils import logging
|
||||
@ -37,7 +42,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa
|
||||
import torch # noqa: F401
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see "
|
||||
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see "
|
||||
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
|
||||
)
|
||||
raise
|
||||
@ -57,7 +62,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||
# convert pytorch tensor to numpy
|
||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||
|
||||
random_flax_state_dict = flatten_dict(unfreeze(flax_model.params))
|
||||
random_flax_state_dict = flatten_dict(flax_model.params)
|
||||
flax_state_dict = {}
|
||||
|
||||
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
|
||||
@ -80,7 +85,12 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||
elif add_base_model_prefix and require_base_model_prefix:
|
||||
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
|
||||
|
||||
if pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
|
||||
# Correctly rename weight parameters
|
||||
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
||||
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
||||
elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
||||
pt_tensor = pt_tensor.T
|
||||
elif pt_tuple_key[-1] == "gamma":
|
||||
@ -89,12 +99,128 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
||||
|
||||
if pt_tuple_key in random_flax_state_dict:
|
||||
if random_flax_state_dict[pt_tuple_key].shape != pt_tensor.shape:
|
||||
if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
|
||||
raise ValueError(
|
||||
"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
|
||||
)
|
||||
|
||||
# add unexpected weight so that warning is thrown
|
||||
flax_state_dict[pt_tuple_key] = pt_tensor
|
||||
# also add unexpected weight so that warning is thrown
|
||||
flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
|
||||
|
||||
return unflatten_dict(flax_state_dict)
|
||||
|
||||
|
||||
#####################
|
||||
# Flax => PyTorch #
|
||||
#####################
|
||||
|
||||
|
||||
def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path):
|
||||
"""Load flax checkpoints in a PyTorch model"""
|
||||
flax_checkpoint_path = os.path.abspath(flax_checkpoint_path)
|
||||
logger.info(f"Loading Flax weights from {flax_checkpoint_path}")
|
||||
|
||||
# import correct flax class
|
||||
flax_cls = getattr(transformers, "Flax" + model.__class__.__name__)
|
||||
|
||||
# load flax weight dict
|
||||
with open(flax_checkpoint_path, "rb") as state_f:
|
||||
try:
|
||||
flax_state_dict = from_bytes(flax_cls, state_f.read())
|
||||
except UnpicklingError:
|
||||
raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")
|
||||
|
||||
return load_flax_weights_in_pytorch_model(model, flax_state_dict)
|
||||
|
||||
|
||||
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
||||
"""Load flax checkpoints in a PyTorch model"""
|
||||
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see "
|
||||
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
|
||||
)
|
||||
raise
|
||||
|
||||
flax_state_dict = flatten_dict(flax_state)
|
||||
pt_model_dict = pt_model.state_dict()
|
||||
|
||||
remove_base_model_prefix = (pt_model.base_model_prefix in flax_state) and (
|
||||
pt_model.base_model_prefix not in set([k.split(".")[0] for k in pt_model_dict.keys()])
|
||||
)
|
||||
add_base_model_prefix = (pt_model.base_model_prefix not in flax_state) and (
|
||||
pt_model.base_model_prefix in set([k.split(".")[0] for k in pt_model_dict.keys()])
|
||||
)
|
||||
|
||||
# keep track of unexpected & missing keys
|
||||
unexpected_keys = []
|
||||
missing_keys = set(pt_model_dict.keys())
|
||||
|
||||
for flax_key_tuple, flax_tensor in flax_state_dict.items():
|
||||
has_base_model_prefix = flax_key_tuple[0] == pt_model.base_model_prefix
|
||||
require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict
|
||||
|
||||
# adapt flax_key to prepare for loading from/to base model only
|
||||
if remove_base_model_prefix and has_base_model_prefix:
|
||||
flax_key_tuple = flax_key_tuple[1:]
|
||||
elif add_base_model_prefix and require_base_model_prefix:
|
||||
flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple
|
||||
|
||||
# rename flax weights to PyTorch format
|
||||
if flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict:
|
||||
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
|
||||
flax_tensor = flax_tensor.T
|
||||
elif flax_key_tuple[-1] in ["scale", "embedding"]:
|
||||
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
|
||||
|
||||
flax_key = ".".join(flax_key_tuple)
|
||||
|
||||
if flax_key in pt_model_dict:
|
||||
if flax_tensor.shape != pt_model_dict[flax_key].shape:
|
||||
raise ValueError(
|
||||
f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected"
|
||||
f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
||||
)
|
||||
else:
|
||||
# add weight to pytorch dict
|
||||
flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
|
||||
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
|
||||
# remove from missing keys
|
||||
missing_keys.remove(flax_key)
|
||||
else:
|
||||
# weight is not expected by PyTorch model
|
||||
unexpected_keys.append(flax_key)
|
||||
|
||||
pt_model.load_state_dict(pt_model_dict)
|
||||
|
||||
# re-transform missing_keys to list
|
||||
missing_keys = list(missing_keys)
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
"Some weights of the Flax model were not used when "
|
||||
f"initializing the PyTorch model {pt_model.__class__.__name__}: {unexpected_keys}\n"
|
||||
f"- This IS expected if you are initializing {pt_model.__class__.__name__} from a Flax model trained on another task "
|
||||
"or with another architecture (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n"
|
||||
f"- This IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect "
|
||||
"to be exactly identical (e.g. initializing a BertForSequenceClassification model from a FlaxBertForSequenceClassification model)."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n")
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model "
|
||||
f"and are newly initialized: {missing_keys}\n"
|
||||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.\n"
|
||||
"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||
f"you can already use {pt_model.__class__.__name__} for predictions without further training."
|
||||
)
|
||||
|
||||
return pt_model
|
||||
|
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -22,7 +22,7 @@ from typing import Dict, Set, Tuple, Union
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.serialization import from_bytes, to_bytes
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax.random import PRNGKey
|
||||
@ -46,7 +46,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
ACT2FN = {
|
||||
"gelu": nn.gelu,
|
||||
"gelu": partial(nn.gelu, approximate=False),
|
||||
"relu": nn.relu,
|
||||
"silu": nn.swish,
|
||||
"swish": nn.swish,
|
||||
@ -129,7 +129,7 @@ class FlaxPreTrainedModel(ABC):
|
||||
"Some parameters are missing. Make sure that `params` include the following "
|
||||
f"parameters {self.required_params - param_keys}"
|
||||
)
|
||||
self._params = freeze(params)
|
||||
self._params = params
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
@ -330,6 +330,10 @@ class FlaxPreTrainedModel(ABC):
|
||||
state = from_bytes(cls, state_f.read())
|
||||
except UnpicklingError:
|
||||
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
|
||||
# make sure all arrays are stored as jnp.arrays
|
||||
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
||||
# https://github.com/google/flax/issues/1261
|
||||
state = jax.tree_util.tree_map(jnp.array, state)
|
||||
|
||||
# if model is base model only use model_prefix key
|
||||
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
|
||||
@ -337,6 +341,7 @@ class FlaxPreTrainedModel(ABC):
|
||||
|
||||
# flatten dicts
|
||||
state = flatten_dict(state)
|
||||
|
||||
random_state = flatten_dict(unfreeze(model.params))
|
||||
|
||||
missing_keys = model.required_params - set(state.keys())
|
||||
@ -377,6 +382,7 @@ class FlaxPreTrainedModel(ABC):
|
||||
|
||||
# set correct parameters
|
||||
model.params = unflatten_dict(state)
|
||||
|
||||
return model
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
|
@ -30,6 +30,7 @@ from .activations import get_activation
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import (
|
||||
DUMMY_INPUTS,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
TF_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
@ -875,6 +876,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
|
||||
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
|
||||
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
- A path or url to a model folder containing a `flax checkpoint file` in `.msgpack` format (e.g,
|
||||
``./flax_model/`` containing ``flax_model.msgpack``). In this case, ``from_flax`` should be set
|
||||
to :obj:`True`.
|
||||
- :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
|
||||
arguments ``config`` and ``state_dict``).
|
||||
model_args (sequence of positional arguments, `optional`):
|
||||
@ -907,6 +911,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Load the model weights from a TensorFlow checkpoint save file (see docstring of
|
||||
``pretrained_model_name_or_path`` argument).
|
||||
from_flax (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Load the model weights from a Flax checkpoint save file (see docstring of
|
||||
``pretrained_model_name_or_path`` argument).
|
||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
@ -968,11 +975,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
|
||||
>>> config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
|
||||
>>> model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
>>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
|
||||
>>> model = BertModel.from_pretrained('bert-base-uncased', from_flax=True)
|
||||
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
state_dict = kwargs.pop("state_dict", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
from_tf = kwargs.pop("from_tf", False)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@ -1023,13 +1034,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||
# Load from a TF 2.0 checkpoint in priority if from_tf
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||
elif from_flax and os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
||||
# Load from a Flax checkpoint in priority if from_flax
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index']} found in "
|
||||
f"directory {pretrained_model_name_or_path} or `from_tf` set to False."
|
||||
f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index', FLAX_WEIGHTS_NAME]} found in "
|
||||
f"directory {pretrained_model_name_or_path} or `from_tf` and `from_flax` set to False."
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
@ -1041,9 +1055,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
)
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
else:
|
||||
# set correct filename
|
||||
if from_tf:
|
||||
filename = TF2_WEIGHTS_NAME
|
||||
elif from_flax:
|
||||
filename = FLAX_WEIGHTS_NAME
|
||||
else:
|
||||
filename = WEIGHTS_NAME
|
||||
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
|
||||
filename=filename,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
)
|
||||
@ -1090,7 +1112,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
else:
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if state_dict is None and not from_tf:
|
||||
if state_dict is None and not (from_tf or from_flax):
|
||||
try:
|
||||
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
||||
except Exception:
|
||||
@ -1120,6 +1142,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise
|
||||
elif from_flax:
|
||||
try:
|
||||
from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model
|
||||
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see "
|
||||
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
|
||||
)
|
||||
raise
|
||||
else:
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
old_keys = []
|
||||
|
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -95,68 +95,6 @@ BERT_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
class FlaxBertLayerNorm(nn.Module):
|
||||
"""
|
||||
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
|
||||
"""
|
||||
|
||||
hidden_size: int
|
||||
epsilon: float = 1e-6
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
use_bias: bool = True
|
||||
scale: bool = True
|
||||
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||
|
||||
def setup(self):
|
||||
self.weight = self.param("weight", self.scale_init, (self.hidden_size,))
|
||||
self.bias = self.param("bias", self.scale_init, (self.hidden_size,))
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
|
||||
a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that
|
||||
maintains the mean activation within each example close to 0 and the activation standard deviation close to 1
|
||||
|
||||
Args:
|
||||
x: the inputs
|
||||
|
||||
Returns:
|
||||
Normalized inputs (the same shape as inputs).
|
||||
"""
|
||||
mean = jnp.mean(x, axis=-1, keepdims=True)
|
||||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
||||
var = mean2 - jax.lax.square(mean)
|
||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||
|
||||
if self.scale:
|
||||
mul = mul * jnp.asarray(self.weight)
|
||||
y = (x - mean) * mul
|
||||
|
||||
if self.use_bias:
|
||||
y = y + jnp.asarray(self.bias)
|
||||
return y
|
||||
|
||||
|
||||
class FlaxBertEmbedding(nn.Module):
|
||||
"""
|
||||
Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch
|
||||
use 'weight'
|
||||
"""
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
initializer_range: float
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range)
|
||||
self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size))
|
||||
|
||||
def __call__(self, input_ids):
|
||||
return jnp.take(self.embeddings, input_ids, axis=0)
|
||||
|
||||
|
||||
class FlaxBertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
@ -164,35 +102,37 @@ class FlaxBertEmbeddings(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.word_embeddings = FlaxBertEmbedding(
|
||||
self.word_embeddings = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.position_embeddings = FlaxBertEmbedding(
|
||||
self.position_embeddings = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.token_type_embeddings = FlaxBertEmbedding(
|
||||
self.token_type_embeddings = nn.Embed(
|
||||
self.config.type_vocab_size,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
# Embed
|
||||
inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4")))
|
||||
position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4")))
|
||||
token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4")))
|
||||
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
|
||||
position_embeds = self.position_embeddings(position_ids.astype("i4"))
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
|
||||
|
||||
# Sum all embeddings
|
||||
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
|
||||
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
|
||||
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))
|
||||
|
||||
# Layer Norm
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
@ -281,7 +221,7 @@ class FlaxBertSelfOutput(nn.Module):
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
||||
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
||||
@ -337,7 +277,7 @@ class FlaxBertOutput(nn.Module):
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
@ -372,7 +312,7 @@ class FlaxBertLayerCollection(nn.Module):
|
||||
]
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
for layer in self.layers:
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
@ -412,7 +352,7 @@ class FlaxBertPredictionHeadTransform(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
@ -423,14 +363,22 @@ class FlaxBertPredictionHeadTransform(nn.Module):
|
||||
class FlaxBertLMPredictionHead(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||
|
||||
def setup(self):
|
||||
self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype)
|
||||
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype)
|
||||
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
|
||||
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
def __call__(self, hidden_states, shared_embedding=None):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
|
||||
if shared_embedding is not None:
|
||||
hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
||||
else:
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
|
||||
hidden_states += self.bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -441,8 +389,8 @@ class FlaxBertOnlyMLMHead(nn.Module):
|
||||
def setup(self):
|
||||
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = self.predictions(hidden_states)
|
||||
def __call__(self, hidden_states, shared_embedding=None):
|
||||
hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -464,8 +412,8 @@ class FlaxBertPreTrainingHeads(nn.Module):
|
||||
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
|
||||
self.seq_relationship = nn.Dense(2, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, pooled_output):
|
||||
prediction_scores = self.predictions(hidden_states)
|
||||
def __call__(self, hidden_states, pooled_output, shared_embedding=None):
|
||||
prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
|
||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
@ -490,7 +438,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
@ -514,7 +462,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
@ -546,7 +494,6 @@ class FlaxBertModule(nn.Module):
|
||||
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
||||
)
|
||||
@ -582,7 +529,15 @@ class FlaxBertForPreTrainingModule(nn.Module):
|
||||
hidden_states, pooled_output = self.bert(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||
)
|
||||
prediction_scores, seq_relationship_score = self.cls(hidden_states, pooled_output)
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
||||
else:
|
||||
shared_embedding = None
|
||||
|
||||
prediction_scores, seq_relationship_score = self.cls(
|
||||
hidden_states, pooled_output, shared_embedding=shared_embedding
|
||||
)
|
||||
|
||||
return (prediction_scores, seq_relationship_score)
|
||||
|
||||
@ -612,8 +567,13 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
||||
# Model
|
||||
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
||||
else:
|
||||
shared_embedding = None
|
||||
|
||||
# Compute the prediction scores
|
||||
logits = self.cls(hidden_states)
|
||||
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
|
||||
|
||||
return (logits,)
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,9 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
@ -110,70 +108,6 @@ ROBERTA_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerNorm with Bert->Roberta
|
||||
class FlaxRobertaLayerNorm(nn.Module):
|
||||
"""
|
||||
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
|
||||
"""
|
||||
|
||||
hidden_size: int
|
||||
epsilon: float = 1e-6
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
use_bias: bool = True
|
||||
scale: bool = True
|
||||
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||
|
||||
def setup(self):
|
||||
self.weight = self.param("weight", self.scale_init, (self.hidden_size,))
|
||||
self.bias = self.param("bias", self.scale_init, (self.hidden_size,))
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
|
||||
a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that
|
||||
maintains the mean activation within each example close to 0 and the activation standard deviation close to 1
|
||||
|
||||
Args:
|
||||
x: the inputs
|
||||
|
||||
Returns:
|
||||
Normalized inputs (the same shape as inputs).
|
||||
"""
|
||||
mean = jnp.mean(x, axis=-1, keepdims=True)
|
||||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
||||
var = mean2 - jax.lax.square(mean)
|
||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||
|
||||
if self.scale:
|
||||
mul = mul * jnp.asarray(self.weight)
|
||||
y = (x - mean) * mul
|
||||
|
||||
if self.use_bias:
|
||||
y = y + jnp.asarray(self.bias)
|
||||
return y
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbedding with Bert->Roberta
|
||||
class FlaxRobertaEmbedding(nn.Module):
|
||||
"""
|
||||
Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch
|
||||
use 'weight'
|
||||
"""
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
initializer_range: float
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range)
|
||||
self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size))
|
||||
|
||||
def __call__(self, input_ids):
|
||||
return jnp.take(self.embeddings, input_ids, axis=0)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
|
||||
class FlaxRobertaEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
@ -182,35 +116,37 @@ class FlaxRobertaEmbeddings(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.word_embeddings = FlaxRobertaEmbedding(
|
||||
self.word_embeddings = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.position_embeddings = FlaxRobertaEmbedding(
|
||||
self.position_embeddings = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.token_type_embeddings = FlaxRobertaEmbedding(
|
||||
self.token_type_embeddings = nn.Embed(
|
||||
self.config.type_vocab_size,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
# Embed
|
||||
inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4")))
|
||||
position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4")))
|
||||
token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4")))
|
||||
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
|
||||
position_embeds = self.position_embeddings(position_ids.astype("i4"))
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
|
||||
|
||||
# Sum all embeddings
|
||||
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
|
||||
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
|
||||
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))
|
||||
|
||||
# Layer Norm
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
@ -301,7 +237,7 @@ class FlaxRobertaSelfOutput(nn.Module):
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
||||
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
||||
@ -360,7 +296,7 @@ class FlaxRobertaOutput(nn.Module):
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
@ -397,7 +333,7 @@ class FlaxRobertaLayerCollection(nn.Module):
|
||||
]
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
for layer in self.layers:
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
@ -515,7 +451,6 @@ class FlaxRobertaModule(nn.Module):
|
||||
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
||||
)
|
||||
|
@ -28,7 +28,10 @@ if is_flax_available():
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
@ -83,29 +86,32 @@ class FlaxModelTesterMixin:
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_pytorch(self):
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
pt_model = pt_model_class(config).eval()
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
@ -116,7 +122,50 @@ class FlaxModelTesterMixin:
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -134,7 +183,7 @@ class FlaxModelTesterMixin:
|
||||
|
||||
outputs_loaded = model_loaded(**prepared_inputs_dict)
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 5e-3)
|
||||
self.assert_almost_equals(output_loaded, output, 1e-3)
|
||||
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
Loading…
Reference in New Issue
Block a user