diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index 0f4d0feef91..38a31e0ca8b 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -227,10 +227,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): query_vectors, key_vectors, self.one_sided_attn_window_size ) + # values to pad for attention probs + remove_from_windowed_attention_mask = attention_mask != 0 + # cast to fp32/fp16 then replace 1's with -inf + float_mask = tf.cast(remove_from_windowed_attention_mask, dtype=query_vectors.dtype) * LARGE_NEGATIVE + # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( tf.ones(shape_list(attention_mask)), - attention_mask, + float_mask, self.one_sided_attn_window_size, ) @@ -1726,7 +1731,9 @@ class TFLEDEncoder(tf.keras.layers.Layer): # merge `global_attention_mask` and `attention_mask` if inputs["global_attention_mask"] is not None: - inputs["attention_mask"] = inputs["global_attention_mask"] + 1 + inputs["attention_mask"] = inputs["attention_mask"] * tf.cast( + (inputs["global_attention_mask"] + 1), dtype=inputs["attention_mask"].dtype + ) ( padding_len, diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index 4995c9dbc7d..cf661a60810 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -50,6 +50,8 @@ _CHECKPOINT_FOR_DOC = "allenai/longformer-base-4096" _CONFIG_FOR_DOC = "LongformerConfig" _TOKENIZER_FOR_DOC = "LongformerTokenizer" +LARGE_NEGATIVE = -1e8 + TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ "allenai/longformer-base-4096", "allenai/longformer-large-4096", @@ -755,10 +757,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): query_vectors, key_vectors, self.one_sided_attn_window_size ) + # values to pad for attention probs + remove_from_windowed_attention_mask = attention_mask != 0 + # cast to fp32/fp16 then replace 1's with -inf + float_mask = tf.cast(remove_from_windowed_attention_mask, dtype=query_vectors.dtype) * LARGE_NEGATIVE + # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( tf.ones(shape_list(attention_mask)), - attention_mask, + float_mask, self.one_sided_attn_window_size, ) diff --git a/tests/test_modeling_tf_led.py b/tests/test_modeling_tf_led.py index 41d132c80b3..b42e8b538cd 100644 --- a/tests/test_modeling_tf_led.py +++ b/tests/test_modeling_tf_led.py @@ -17,13 +17,14 @@ import unittest from transformers import LEDConfig, is_tf_available -from transformers.testing_utils import require_tf, slow +from transformers.testing_utils import is_pt_tf_cross_test, require_tf, slow from .test_configuration_common import ConfigTester from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor if is_tf_available(): + import numpy as np import tensorflow as tf from transformers import TFLEDForConditionalGeneration, TFLEDModel @@ -362,6 +363,128 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): self.assertEqual(model.config.output_hidden_states, True) check_encoder_attentions_output(outputs) + # TODO: Remove this once a more thorough pt/tf equivalence could be implemented in `test_modeling_tf_common.py`. + # (Currently, such a test will fail some other model tests: it requires some time to fix them.) + @is_pt_tf_cross_test + def test_pt_tf_model_equivalence_extra(self): + import torch + + import transformers + + def prepare_pt_inputs_from_tf_inputs(tf_inputs_dict): + + pt_inputs_dict = {} + for name, key in tf_inputs_dict.items(): + if type(key) == bool: + pt_inputs_dict[name] = key + elif name == "input_values": + pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) + elif name == "pixel_values": + pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) + else: + pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) + + return pt_inputs_dict + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + config.output_hidden_states = True + + tf_model = model_class(config) + pt_model = pt_model_class(config) + + tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + tf_inputs_dict_maybe_with_labels = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + + # Check we can load pt model in tf and vice-versa with model => model functions + + tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict) + pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) + + # Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences + pt_model.eval() + + pt_inputs_dict = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict) + pt_inputs_dict_maybe_with_labels = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict_maybe_with_labels) + + # need to rename encoder-decoder "inputs" for PyTorch + if "inputs" in pt_inputs_dict and self.is_encoder_decoder: + pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") + + with torch.no_grad(): + pto = pt_model(**pt_inputs_dict) + tfo = tf_model(tf_inputs_dict, training=False) + + tf_hidden_states = tfo[0].numpy() + pt_hidden_states = pto[0].numpy() + + tf_nans = np.isnan(tf_hidden_states) + pt_nans = np.isnan(pt_hidden_states) + + pt_hidden_states[tf_nans] = 0 + tf_hidden_states[tf_nans] = 0 + pt_hidden_states[pt_nans] = 0 + tf_hidden_states[pt_nans] = 0 + + max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states)) + self.assertLessEqual(max_diff, 1e-4) + + has_labels = any( + x in tf_inputs_dict_maybe_with_labels for x in ["labels", "next_sentence_label", "start_positions"] + ) + if has_labels: + + with torch.no_grad(): + pto = pt_model(**pt_inputs_dict_maybe_with_labels) + tfo = tf_model(tf_inputs_dict_maybe_with_labels, training=False) + + # Some models' output class don't have `loss` attribute despite `labels` is used. + tf_loss = getattr(tfo, "loss", None) + pt_loss = getattr(pto, "loss", None) + + # Some models require extra condition to return loss. For example, `BertForPreTraining` requires both + # `labels` and `next_sentence_label`. + # Moreover, some PT models return loss while the corresponding TF/Flax models don't. + if tf_loss is not None and pt_loss is not None: + + tf_loss = tf.math.reduce_mean(tf_loss).numpy() + pt_loss = pt_loss.numpy() + + tf_nans = np.isnan(tf_loss) + pt_nans = np.isnan(pt_loss) + # the 2 losses need to be both nan or both not nan + # (`TapasForQuestionAnswering` gives nan loss here) + self.assertEqual(tf_nans, pt_nans) + + if not tf_nans: + max_diff = np.amax(np.abs(tf_loss - pt_loss)) + # `TFFunnelForTokenClassification` (and potentially other TF token classification models) give + # large difference (up to 0.1x). PR #15294 addresses this issue. + # There is also an inconsistency between PT/TF `XLNetLMHeadModel`. + # Before these issues are fixed & merged, set a higher threshold here to pass the test. + self.assertLessEqual(max_diff, 1e-4) + + tf_logits = tfo[1].numpy() + pt_logits = pto[1].numpy() + + # check on the shape + self.assertEqual(tf_logits.shape, pt_logits.shape) + + tf_nans = np.isnan(tf_logits) + pt_nans = np.isnan(pt_logits) + + pt_logits[tf_nans] = 0 + tf_logits[tf_nans] = 0 + pt_logits[pt_nans] = 0 + tf_logits[pt_nans] = 0 + + max_diff = np.amax(np.abs(tf_logits - pt_logits)) + self.assertLessEqual(max_diff, 1e-4) + def test_xla_mode(self): # TODO JP: Make LED XLA compliant pass