mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix TFLEDModel (#15356)
* fix tf led * fix * fix * Add test_pt_tf_model_equivalence_extra for TFLED * add a (temporary) test Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
87918d3221
commit
5a70987301
@ -227,10 +227,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
query_vectors, key_vectors, self.one_sided_attn_window_size
|
||||
)
|
||||
|
||||
# values to pad for attention probs
|
||||
remove_from_windowed_attention_mask = attention_mask != 0
|
||||
# cast to fp32/fp16 then replace 1's with -inf
|
||||
float_mask = tf.cast(remove_from_windowed_attention_mask, dtype=query_vectors.dtype) * LARGE_NEGATIVE
|
||||
|
||||
# diagonal mask with zeros everywhere and -inf inplace of padding
|
||||
diagonal_mask = self._sliding_chunks_query_key_matmul(
|
||||
tf.ones(shape_list(attention_mask)),
|
||||
attention_mask,
|
||||
float_mask,
|
||||
self.one_sided_attn_window_size,
|
||||
)
|
||||
|
||||
@ -1726,7 +1731,9 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
||||
|
||||
# merge `global_attention_mask` and `attention_mask`
|
||||
if inputs["global_attention_mask"] is not None:
|
||||
inputs["attention_mask"] = inputs["global_attention_mask"] + 1
|
||||
inputs["attention_mask"] = inputs["attention_mask"] * tf.cast(
|
||||
(inputs["global_attention_mask"] + 1), dtype=inputs["attention_mask"].dtype
|
||||
)
|
||||
|
||||
(
|
||||
padding_len,
|
||||
|
@ -50,6 +50,8 @@ _CHECKPOINT_FOR_DOC = "allenai/longformer-base-4096"
|
||||
_CONFIG_FOR_DOC = "LongformerConfig"
|
||||
_TOKENIZER_FOR_DOC = "LongformerTokenizer"
|
||||
|
||||
LARGE_NEGATIVE = -1e8
|
||||
|
||||
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"allenai/longformer-base-4096",
|
||||
"allenai/longformer-large-4096",
|
||||
@ -755,10 +757,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
query_vectors, key_vectors, self.one_sided_attn_window_size
|
||||
)
|
||||
|
||||
# values to pad for attention probs
|
||||
remove_from_windowed_attention_mask = attention_mask != 0
|
||||
# cast to fp32/fp16 then replace 1's with -inf
|
||||
float_mask = tf.cast(remove_from_windowed_attention_mask, dtype=query_vectors.dtype) * LARGE_NEGATIVE
|
||||
|
||||
# diagonal mask with zeros everywhere and -inf inplace of padding
|
||||
diagonal_mask = self._sliding_chunks_query_key_matmul(
|
||||
tf.ones(shape_list(attention_mask)),
|
||||
attention_mask,
|
||||
float_mask,
|
||||
self.one_sided_attn_window_size,
|
||||
)
|
||||
|
||||
|
@ -17,13 +17,14 @@
|
||||
import unittest
|
||||
|
||||
from transformers import LEDConfig, is_tf_available
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, slow
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFLEDForConditionalGeneration, TFLEDModel
|
||||
@ -362,6 +363,128 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
check_encoder_attentions_output(outputs)
|
||||
|
||||
# TODO: Remove this once a more thorough pt/tf equivalence could be implemented in `test_modeling_tf_common.py`.
|
||||
# (Currently, such a test will fail some other model tests: it requires some time to fix them.)
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence_extra(self):
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
|
||||
def prepare_pt_inputs_from_tf_inputs(tf_inputs_dict):
|
||||
|
||||
pt_inputs_dict = {}
|
||||
for name, key in tf_inputs_dict.items():
|
||||
if type(key) == bool:
|
||||
pt_inputs_dict[name] = key
|
||||
elif name == "input_values":
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
elif name == "pixel_values":
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
else:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||
|
||||
return pt_inputs_dict
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
config.output_hidden_states = True
|
||||
|
||||
tf_model = model_class(config)
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
tf_inputs_dict_maybe_with_labels = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||
|
||||
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
|
||||
pt_model.eval()
|
||||
|
||||
pt_inputs_dict = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict)
|
||||
pt_inputs_dict_maybe_with_labels = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict_maybe_with_labels)
|
||||
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs_dict)
|
||||
tfo = tf_model(tf_inputs_dict, training=False)
|
||||
|
||||
tf_hidden_states = tfo[0].numpy()
|
||||
pt_hidden_states = pto[0].numpy()
|
||||
|
||||
tf_nans = np.isnan(tf_hidden_states)
|
||||
pt_nans = np.isnan(pt_hidden_states)
|
||||
|
||||
pt_hidden_states[tf_nans] = 0
|
||||
tf_hidden_states[tf_nans] = 0
|
||||
pt_hidden_states[pt_nans] = 0
|
||||
tf_hidden_states[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
|
||||
self.assertLessEqual(max_diff, 1e-4)
|
||||
|
||||
has_labels = any(
|
||||
x in tf_inputs_dict_maybe_with_labels for x in ["labels", "next_sentence_label", "start_positions"]
|
||||
)
|
||||
if has_labels:
|
||||
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs_dict_maybe_with_labels)
|
||||
tfo = tf_model(tf_inputs_dict_maybe_with_labels, training=False)
|
||||
|
||||
# Some models' output class don't have `loss` attribute despite `labels` is used.
|
||||
tf_loss = getattr(tfo, "loss", None)
|
||||
pt_loss = getattr(pto, "loss", None)
|
||||
|
||||
# Some models require extra condition to return loss. For example, `BertForPreTraining` requires both
|
||||
# `labels` and `next_sentence_label`.
|
||||
# Moreover, some PT models return loss while the corresponding TF/Flax models don't.
|
||||
if tf_loss is not None and pt_loss is not None:
|
||||
|
||||
tf_loss = tf.math.reduce_mean(tf_loss).numpy()
|
||||
pt_loss = pt_loss.numpy()
|
||||
|
||||
tf_nans = np.isnan(tf_loss)
|
||||
pt_nans = np.isnan(pt_loss)
|
||||
# the 2 losses need to be both nan or both not nan
|
||||
# (`TapasForQuestionAnswering` gives nan loss here)
|
||||
self.assertEqual(tf_nans, pt_nans)
|
||||
|
||||
if not tf_nans:
|
||||
max_diff = np.amax(np.abs(tf_loss - pt_loss))
|
||||
# `TFFunnelForTokenClassification` (and potentially other TF token classification models) give
|
||||
# large difference (up to 0.1x). PR #15294 addresses this issue.
|
||||
# There is also an inconsistency between PT/TF `XLNetLMHeadModel`.
|
||||
# Before these issues are fixed & merged, set a higher threshold here to pass the test.
|
||||
self.assertLessEqual(max_diff, 1e-4)
|
||||
|
||||
tf_logits = tfo[1].numpy()
|
||||
pt_logits = pto[1].numpy()
|
||||
|
||||
# check on the shape
|
||||
self.assertEqual(tf_logits.shape, pt_logits.shape)
|
||||
|
||||
tf_nans = np.isnan(tf_logits)
|
||||
pt_nans = np.isnan(pt_logits)
|
||||
|
||||
pt_logits[tf_nans] = 0
|
||||
tf_logits[tf_nans] = 0
|
||||
pt_logits[pt_nans] = 0
|
||||
tf_logits[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tf_logits - pt_logits))
|
||||
self.assertLessEqual(max_diff, 1e-4)
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make LED XLA compliant
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user