Add TF<>PT and Flax<>PT everywhere (#14047)

* up

* up

* up

* up

* up

* up

* up

* add clip

* fix clip PyTorch

* fix clip PyTorch

* up

* up

* up

* up

* up

* up

* up
This commit is contained in:
Patrick von Platen 2021-10-25 23:55:08 +02:00 committed by GitHub
parent 8560b55b5e
commit 0c3174c758
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 589 additions and 45 deletions

View File

@ -314,35 +314,13 @@ class FlaxAlbertLayer(nn.Module):
return outputs
class FlaxAlbertLayers(nn.Module):
config: AlbertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
layer_index: Optional[str] = None
def setup(self):
self.albert_layers = FlaxAlbertLayer(self.config, name=self.layer_index, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
):
outputs = self.albert_layers(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
return outputs
class FlaxAlbertLayerCollection(nn.Module):
config: AlbertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = [
FlaxAlbertLayers(self.config, name="albert_layers", layer_index=str(i), dtype=self.dtype)
for i in range(self.config.inner_group_num)
FlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num)
]
def __call__(
@ -385,7 +363,7 @@ class FlaxAlbertLayerCollections(nn.Module):
layer_index: Optional[str] = None
def setup(self):
self.albert_layer_groups = FlaxAlbertLayerCollection(self.config, name=self.layer_index, dtype=self.dtype)
self.albert_layers = FlaxAlbertLayerCollection(self.config, dtype=self.dtype)
def __call__(
self,
@ -395,7 +373,7 @@ class FlaxAlbertLayerCollections(nn.Module):
output_attentions: bool = False,
output_hidden_states: bool = False,
):
outputs = self.albert_layer_groups(
outputs = self.albert_layers(
hidden_states,
attention_mask,
deterministic=deterministic,
@ -405,19 +383,13 @@ class FlaxAlbertLayerCollections(nn.Module):
return outputs
class FlaxAlbertEncoder(nn.Module):
class FlaxAlbertLayerGroups(nn.Module):
config: AlbertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
self.embedding_hidden_mapping_in = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
dtype=self.dtype,
)
self.layers = [
FlaxAlbertLayerCollections(self.config, name="albert_layer_groups", layer_index=str(i), dtype=self.dtype)
FlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_groups)
]
@ -430,7 +402,6 @@ class FlaxAlbertEncoder(nn.Module):
output_hidden_states: bool = False,
return_dict: bool = True,
):
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
all_attentions = () if output_attentions else None
all_hidden_states = (hidden_states,) if output_hidden_states else None
@ -459,6 +430,37 @@ class FlaxAlbertEncoder(nn.Module):
)
class FlaxAlbertEncoder(nn.Module):
config: AlbertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.embedding_hidden_mapping_in = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
dtype=self.dtype,
)
self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
return self.albert_layer_groups(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
class FlaxAlbertOnlyMLMHead(nn.Module):
config: AlbertConfig
dtype: jnp.dtype = jnp.float32

View File

@ -1222,7 +1222,10 @@ class TFHubertMainLayer(tf.keras.layers.Layer):
if inputs["attention_mask"] is not None:
# compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1))
attention_mask = tf.sequence_mask(output_lengths, dtype=hidden_states.dtype)
attention_mask = tf.sequence_mask(
output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype
)
hidden_states = self.feature_projection(hidden_states, training=inputs["training"])

View File

@ -59,7 +59,7 @@ class BigBirdModelTester:
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu_fast",
hidden_act="gelu_new",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=256,

View File

@ -23,9 +23,17 @@ import unittest
import numpy as np
import requests
import transformers
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import (
is_flax_available,
is_pt_flax_cross_test,
require_torch,
require_vision,
slow,
torch_device,
)
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, random_attention_mask
@ -45,6 +53,14 @@ if is_vision_available():
from transformers import CLIPProcessor
if is_flax_available():
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
class CLIPVisionModelTester:
def __init__(
self,
@ -330,6 +346,13 @@ class CLIPTextModelTester:
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
if input_mask is not None:
batch_size, seq_length = input_mask.shape
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
input_mask[batch_idx, :start_index] = 1
input_mask[batch_idx, start_index:] = 0
config = self.get_config()
return config, input_ids, input_mask
@ -558,6 +581,125 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
self.assertTrue(models_equal)
# overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
return
fx_model_class = getattr(transformers, fx_model_class_name)
# load Flax class
fx_model = fx_model_class(config, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
# overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# load corresponding PyTorch class
pt_model = model_class(config).eval()
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
# no flax model exists for this class
return
fx_model_class = getattr(transformers, fx_model_class_name)
# load Flax class
fx_model = fx_model_class(config, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
@slow
def test_model_from_pretrained(self):
for model_name in CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:

View File

@ -25,10 +25,13 @@ import unittest
import warnings
from typing import Dict, List, Tuple
import numpy as np
import transformers
from huggingface_hub import HfApi, Repository
from requests.exceptions import HTTPError
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
ENDPOINT_STAGING,
@ -36,6 +39,8 @@ from transformers.testing_utils import (
USER,
CaptureLogger,
TestCasePlus,
is_pt_flax_cross_test,
is_pt_tf_cross_test,
is_staging_test,
require_torch,
require_torch_multi_gpu,
@ -45,7 +50,6 @@ from transformers.testing_utils import (
if is_torch_available():
import numpy as np
import torch
from torch import nn
@ -70,6 +74,13 @@ if is_torch_available():
T5ForConditionalGeneration,
)
if is_flax_available():
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
@ -1417,6 +1428,241 @@ class ModelTesterMixin:
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
)
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import numpy as np
import tensorflow as tf
import transformers
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
if not hasattr(transformers, tf_model_class_name):
# transformers does not have TF version yet
return
tf_model_class = getattr(transformers, tf_model_class_name)
config.output_hidden_states = True
tf_model = tf_model_class(config)
pt_model = model_class(config)
# make sure only tf inputs are forward that actually exist in function args
tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys())
# remove all head masks
tf_input_keys.discard("head_mask")
tf_input_keys.discard("cross_attn_head_mask")
tf_input_keys.discard("decoder_head_mask")
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: v for k, v in pt_inputs.items() if k in tf_input_keys}
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
tf_inputs_dict = {}
for key, tensor in pt_inputs.items():
# skip key that does not exist in tf
if type(tensor) == bool:
tf_inputs_dict[key] = tensor
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad():
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict, training=False)
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy()
tf_nans = np.copy(np.isnan(tf_hidden_states))
pt_nans = np.copy(np.isnan(pt_hidden_states))
pt_hidden_states[tf_nans] = 0
tf_hidden_states[tf_nans] = 0
pt_hidden_states[pt_nans] = 0
tf_hidden_states[pt_nans] = 0
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
self.assertLessEqual(max_diff, 4e-2)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
tf_inputs_dict = {}
for key, tensor in pt_inputs.items():
# skip key that does not exist in tf
if type(tensor) == bool:
tensor = np.array(tensor, dtype=bool)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad():
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict)
tfo = tfo[0].numpy()
pto = pto[0].numpy()
tf_nans = np.copy(np.isnan(tfo))
pt_nans = np.copy(np.isnan(pto))
pto[tf_nans] = 0
tfo[tf_nans] = 0
pto[pt_nans] = 0
tfo[pt_nans] = 0
max_diff = np.amax(np.abs(tfo - pto))
self.assertLessEqual(max_diff, 4e-2)
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
return
fx_model_class = getattr(transformers, fx_model_class_name)
# load Flax class
fx_model = fx_model_class(config, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# load corresponding PyTorch class
pt_model = model_class(config).eval()
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
# no flax model exists for this class
return
fx_model_class = getattr(transformers, fx_model_class_name)
# load Flax class
fx_model = fx_model_class(config, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -72,7 +72,7 @@ class LongformerModelTester:
# because its local attention only attends to `self.attention_window + 1` locations
# (assuming no token with global attention, otherwise the last dimension of attentions
# is x + self.attention_window + 1, where x is the number of tokens with global attention)
self.key_length = self.attention_window + 1
self.key_length = self.attention_window + 2
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests
@ -243,6 +243,8 @@ class LongformerModelTester:
choice_labels,
) = config_and_inputs
global_attention_mask = torch.zeros_like(input_ids)
global_attention_mask[:, -1] = 1
inputs_dict = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,

View File

@ -15,13 +15,16 @@
import copy
import os
import tempfile
import unittest
import numpy as np
from transformers import LxmertConfig, is_torch_available
import transformers
from transformers import LxmertConfig, is_tf_available, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import is_pt_tf_cross_test, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
@ -40,6 +43,10 @@ if is_torch_available():
from transformers.models.lxmert.modeling_lxmert import LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_tf_available():
import tensorflow as tf
class LxmertModelTester:
"""You can also import this e.g from .test_modeling_bart import BartModelTester"""
@ -496,7 +503,7 @@ class LxmertModelTester:
result_pretrain_more.question_answering_score.shape, (self.batch_size, num_large_labels)
)
def prepare_config_and_inputs_for_common(self):
def prepare_config_and_inputs_for_common(self, return_obj_labels=False):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
@ -520,6 +527,9 @@ class LxmertModelTester:
"attention_mask": input_mask,
}
if return_obj_labels:
inputs_dict["obj_labels"] = obj_labels
return config, inputs_dict
@ -732,6 +742,128 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(hidden_states_vision.grad)
self.assertIsNotNone(attentions_vision.grad)
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
return_obj_labels="PreTraining" in model_class.__name__
)
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
if not hasattr(transformers, tf_model_class_name):
# transformers does not have TF version yet
return
tf_model_class = getattr(transformers, tf_model_class_name)
config.output_hidden_states = True
config.task_obj_predict = False
pt_model = model_class(config)
tf_model = tf_model_class(config)
# Check we can load pt model in tf and vice-versa with model => model functions
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
def recursive_numpy_convert(iterable):
return_dict = {}
for key, value in iterable.items():
if type(value) == bool:
return_dict[key] = value
if isinstance(value, dict):
return_dict[key] = recursive_numpy_convert(value)
else:
if isinstance(value, (list, tuple)):
return_dict[key] = (
tf.convert_to_tensor(iter_value.numpy(), dtype=tf.int32) for iter_value in value
)
else:
return_dict[key] = tf.convert_to_tensor(value.numpy(), dtype=tf.int32)
return return_dict
tf_inputs_dict = recursive_numpy_convert(pt_inputs)
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
# Delete obj labels as we want to compute the hidden states and not the loss
if "obj_labels" in inputs_dict:
del inputs_dict["obj_labels"]
def torch_type(key):
if key in ("visual_feats", "visual_pos"):
return torch.float32
else:
return torch.long
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
tf_inputs_dict = recursive_numpy_convert(pt_inputs)
with torch.no_grad():
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict, training=False)
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy()
tf_nans = np.copy(np.isnan(tf_hidden_states))
pt_nans = np.copy(np.isnan(pt_hidden_states))
pt_hidden_states[tf_nans] = 0
tf_hidden_states[tf_nans] = 0
pt_hidden_states[pt_nans] = 0
tf_hidden_states[pt_nans] = 0
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
# Debug info (remove when fixed)
if max_diff >= 2e-2:
print("===")
print(model_class)
print(config)
print(inputs_dict)
print(pt_inputs)
self.assertLessEqual(max_diff, 6e-2)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
for key, value in pt_inputs.items():
if key in ("visual_feats", "visual_pos"):
pt_inputs[key] = value.to(torch.float32)
else:
pt_inputs[key] = value.to(torch.long)
with torch.no_grad():
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict)
tfo = tfo[0].numpy()
pto = pto[0].numpy()
tf_nans = np.copy(np.isnan(tfo))
pt_nans = np.copy(np.isnan(pto))
pto[tf_nans] = 0
tfo[tf_nans] = 0
pto[pt_nans] = 0
tfo[pt_nans] = 0
max_diff = np.amax(np.abs(tfo - pto))
self.assertLessEqual(max_diff, 6e-2)
@require_torch
class LxmertModelIntegrationTest(unittest.TestCase):

View File

@ -431,7 +431,6 @@ class TFModelTesterMixin:
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import torch
import transformers

View File

@ -22,7 +22,14 @@ import pytest
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import Wav2Vec2Config, is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
from transformers.testing_utils import (
is_pt_flax_cross_test,
require_datasets,
require_soundfile,
require_torch,
slow,
torch_device,
)
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, _config_zero_init
@ -131,6 +138,7 @@ class Wav2Vec2ModelTester:
hidden_dropout_prob=self.hidden_dropout_prob,
intermediate_size=self.intermediate_size,
layer_norm_eps=self.layer_norm_eps,
do_stable_layer_norm=self.do_stable_layer_norm,
hidden_act=self.hidden_act,
initializer_range=self.initializer_range,
vocab_size=self.vocab_size,
@ -357,6 +365,16 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
def test_model_common_attributes(self):
pass
@is_pt_flax_cross_test
# non-robust architecture does not exist in Flax
def test_equivalence_flax_to_pt(self):
pass
@is_pt_flax_cross_test
# non-robust architecture does not exist in Flax
def test_equivalence_pt_to_flax(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True