Big TF test cleanup (#24282)

* Fix one BLIP arg not being optional, remove misspelled arg

* Remove the lxmert test overrides and just use the base test_saved_model_creation

* saved_model_creation fixes and re-enabling tests across the board

* Remove unnecessary skip

* Stop caching sinusoidal embeddings in speech_to_text

* Fix transfo_xl compilation

* Fix transfo_xl compilation

* Fix the conditionals in xglm

* Set the save spec only when building

* Clarify comment

* Move comment correctly

* Correct embeddings generation for speech2text

* Mark RAG generation tests as @slow

* Remove redundant else:

* Add comment to clarify the save_spec line in build()

* Fix size tests for XGLM at last!

* make fixup

* Remove one band_part operation

* Mark test_keras_fit as @slow
This commit is contained in:
Matt 2023-06-16 15:40:49 +01:00 committed by GitHub
parent 896a58de15
commit 3403712958
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 68 additions and 217 deletions

View File

@ -1157,6 +1157,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
self.built = True
else:
self.built = True
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
# Setting it in build() allows users to override the shape when loading a non-pretrained model from config
self._set_save_spec(self._prune_signature(self.input_signature))
self(self.dummy_inputs, training=False)
def __init__(self, config, *inputs, **kwargs):
@ -1171,8 +1174,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
self.config = config
self.name_or_path = config.name_or_path
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
self._set_save_spec(self._prune_signature(self.input_signature))
def get_config(self):
return self.config.to_dict()

View File

@ -1216,12 +1216,11 @@ class TFBlipForQuestionAnswering(TFBlipPreTrainedModel):
def call(
self,
input_ids: tf.Tensor,
pixel_values: tf.Tensor,
pixel_values: tf.Tensor | None = None,
decoder_input_ids: tf.Tensor | None = None,
decoder_attention_mask: tf.Tensor | None = None,
attention_mask: tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
foutput_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
labels: tf.Tensor | None = None,
return_dict: Optional[bool] = None,

View File

@ -618,7 +618,7 @@ class TFOPTDecoder(tf.keras.layers.Layer):
attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool)
else:
tf.debugging.assert_equal(
attention_mask.shape[1],
tf.shape(attention_mask)[1],
past_key_values_length + input_shape[1],
message=(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "

View File

@ -195,30 +195,19 @@ class TFSpeech2TextSinusoidalPositionalEmbedding(tf.keras.layers.Layer):
emb = tf.concat([emb[:padding_idx, :], tf.zeros((1, tf.shape(emb)[1])), emb[padding_idx + 1 :, :]], axis=0)
return emb
def build(self, input_shape: tf.TensorShape):
"""
Build shared token embedding layer Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
self.embeddings = self.add_weight(
name="weights", # name also used in PT
shape=tf.shape(self.embedding_weights),
trainable=False,
)
self.embeddings.assign(self.embedding_weights)
super().build(input_shape)
def call(self, input_ids: tf.Tensor, past_key_values_length: int = 0) -> tf.Tensor:
bsz, seq_len = shape_list(input_ids)
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
# expand embeddings if needed
max_pos = self.padding_idx + 1 + seq_len
if max_pos > shape_list(self.embeddings)[0]:
self.embedding_weights = self._get_embedding(max_pos + self.offset, self.embedding_dim, self.padding_idx)
self.embeddings.assign(self.embedding_weights)
return tf.reshape(tf.gather(self.embeddings, tf.reshape(position_ids, (-1,)), axis=0), (bsz, seq_len, -1))
# Matt: The PyTorch code does a lot of work to cache the embeddings, setting the cached values as a
# model attribute in the forward pass. This is extremely forbidden in TF, which wants forward calls to be
# idempotent. TF doesn't need that caching anyway, since it can just store constants during compilation,
# so we just remove all of that code.
embeddings = self._get_embedding(
self.padding_idx + 1 + seq_len + self.offset + past_key_values_length, self.embedding_dim, self.padding_idx
)
return tf.reshape(tf.gather(embeddings, tf.reshape(position_ids, (-1,)), axis=0), (bsz, seq_len, -1))
@staticmethod
def create_position_ids_from_input_ids(
@ -562,6 +551,7 @@ class TFSpeech2TextPreTrainedModel(TFPreTrainedModel):
config_class = Speech2TextConfig
base_model_prefix = "model"
main_input_name = "input_features"
_keys_to_ignore_on_load_unexpected = [r"encoder.embed_positions.weights"]
def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor):
"""

View File

@ -588,35 +588,19 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
klen = mlen + qlen
# Compute decoder attention mask
# ::: PyTorch masking code for reference :::
# if self.same_length:
# all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
# mask_len = klen - self.mem_len
# if mask_len > 0:
# mask_shift_len = qlen - mask_len
# else:
# mask_shift_len = qlen
# dec_attn_mask = (torch.triu(all_ones, 1+mlen)
# + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
# else:
# dec_attn_mask = torch.triu(
# word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]
# TensorFlow version
dec_attn_mask = 1 - tf.linalg.band_part(
tf.ones([qlen, klen], dtype=tf.int32), -1, mlen
) # (q, q): diagonal with 1's
all_ones = tf.ones([qlen, klen], dtype=tf.int32)
upper_mask = 1 - tf.linalg.band_part(tf.ones([qlen, klen], dtype=tf.int32), -1, mlen)
if self.same_length:
mask_len = klen - self.mem_len
if mask_len > 0:
mask_shift_len = qlen - mask_len
else:
mask_shift_len = qlen
if mask_shift_len >= 1:
dec_attn_mask += 1 - tf.linalg.band_part(tf.ones([qlen, klen], dtype=tf.int32), mask_shift_len - 1, -1)
else:
dec_attn_mask += tf.linalg.band_part(tf.ones([qlen, klen], dtype=tf.int32), -1, -mask_shift_len)
mask_shift_len = qlen - tf.nn.relu(mask_len) # Lazy clamping of negatives to zero
# Use an indicator variable instead of a conditional to keep the compiler happy
lower_mask = tf.linalg.band_part(all_ones, -1, 0) - (
tf.linalg.band_part(all_ones, mask_shift_len - 1, 0) * tf.cast(mask_shift_len != 0, tf.int32)
)
dec_attn_mask = upper_mask + lower_mask
else:
dec_attn_mask = upper_mask
hids = []
attentions = [] if output_attentions else None

View File

@ -463,19 +463,14 @@ class TFXGLMMainLayer(tf.keras.layers.Layer):
) -> tf.Tensor:
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask: tf.Tensor | None = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length)
if attention_mask is not None:
expand_attention_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1])
combined_attention_mask = (
expand_attention_mask
if combined_attention_mask is None
else expand_attention_mask + combined_attention_mask
)
return combined_attention_mask
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length)
combined_attention_mask = tf.cond(
input_shape[-1] > 1, lambda: combined_attention_mask, lambda: tf.ones_like(combined_attention_mask)
)
if attention_mask is None:
return combined_attention_mask
expand_attention_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1])
return expand_attention_mask + combined_attention_mask
def embed_positions(self, position_ids: np.ndarray | tf.Tensor | None = None) -> tf.Tensor:
position_ids += self.offset
@ -512,10 +507,10 @@ class TFXGLMMainLayer(tf.keras.layers.Layer):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_shape = tf.shape(input_ids)
input_ids = tf.reshape(input_ids, (-1, input_shape[-1]))
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
input_shape = tf.shape(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

View File

@ -2676,7 +2676,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model")
self.model._set_save_spec(self._prune_signature(self.input_signature))
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.bias_layer = BiasLayer(

View File

@ -22,7 +22,7 @@ import unittest
import numpy as np
from transformers import BartConfig, BartTokenizer, is_tf_available
from transformers.testing_utils import require_tf, slow, tooslow
from transformers.testing_utils import require_tf, slow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -225,10 +225,6 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
@tooslow
def test_saved_model_creation(self):
pass
# TODO (Joao): fix me
@unittest.skip("Onnx compliancy broke with TF 2.10")
def test_onnx_compliancy(self):

View File

@ -19,7 +19,7 @@ from __future__ import annotations
import unittest
from transformers import BlenderbotConfig, BlenderbotTokenizer, is_tf_available
from transformers.testing_utils import require_tf, require_tokenizers, slow, tooslow
from transformers.testing_utils import require_tf, require_tokenizers, slow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -207,10 +207,6 @@ class TFBlenderbotModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Te
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
@tooslow
def test_saved_model_creation(self):
pass
@require_tokenizers
@require_tf

View File

@ -19,7 +19,7 @@ from __future__ import annotations
import unittest
from transformers import BlenderbotSmallConfig, BlenderbotSmallTokenizer, is_tf_available
from transformers.testing_utils import require_tf, require_tokenizers, slow, tooslow
from transformers.testing_utils import require_tf, require_tokenizers, slow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -209,10 +209,6 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, PipelineTesterMixin, unitte
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
@tooslow
def test_saved_model_creation(self):
pass
@require_tokenizers
@require_tf

View File

@ -156,6 +156,7 @@ class TFConvNextModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF does not support backprop for grouped convolutions on CPU.",
)
@slow
def test_keras_fit(self):
super().test_keras_fit()

View File

@ -185,6 +185,7 @@ class TFCvtModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF does not support backprop for grouped convolutions on CPU.",
)
@slow
def test_keras_fit(self):
super().test_keras_fit()

View File

@ -347,6 +347,7 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, PipelineTesterMixin, unittes
check_hidden_states_output(inputs_dict, config, model_class)
# Overriding this method since the base method won't be compatible with Data2VecVision.
@slow
def test_keras_fit(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:

View File

@ -19,7 +19,7 @@ from __future__ import annotations
import unittest
from transformers import FunnelConfig, is_tf_available
from transformers.testing_utils import require_tf, tooslow
from transformers.testing_utils import require_tf
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
@ -386,10 +386,6 @@ class TFFunnelModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCa
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
@tooslow
def test_saved_model_creation(self):
pass
@require_tf
class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
@ -417,7 +413,3 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
@tooslow
def test_saved_model_creation(self):
pass

View File

@ -601,6 +601,7 @@ class TFGroupViTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
pass
@require_tensorflow_probability
@slow
def test_keras_fit(self):
super().test_keras_fit()
@ -692,11 +693,6 @@ class TFGroupViTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
def test_saved_model_creation(self):
pass
@unittest.skip(reason="Currently `saved_model` doesn't work with nested outputs.")
@slow
def test_saved_model_creation_extended(self):
pass
@unittest.skip(reason="`saved_model` doesn't work with nested outputs so no preparation happens.")
@slow
def test_prepare_serving_output(self):

View File

@ -19,7 +19,7 @@ from __future__ import annotations
import unittest
from transformers import LEDConfig, is_tf_available
from transformers.testing_utils import require_tf, slow, tooslow
from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
@ -292,11 +292,7 @@ class TFLEDModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
self.assertEqual(model.config.output_hidden_states, True)
check_encoder_attentions_output(outputs)
def test_xla_mode(self):
# TODO JP: Make LED XLA compliant
pass
@tooslow
@unittest.skip("LED keeps using potentially symbolic tensors in conditionals and breaks tracing.")
def test_saved_model_creation(self):
pass

View File

@ -19,7 +19,7 @@ from __future__ import annotations
import unittest
from transformers import is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
@ -356,14 +356,10 @@ class TFLongformerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Te
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
@tooslow
@unittest.skip("Longformer keeps using potentially symbolic tensors in conditionals and breaks tracing.")
def test_saved_model_creation(self):
pass
def test_xla_mode(self):
# TODO JP: Make Longformer XLA compliant
pass
@require_tf
@require_sentencepiece

View File

@ -15,14 +15,13 @@
from __future__ import annotations
import os
import tempfile
import unittest
import numpy as np
from transformers import LxmertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow, tooslow
from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
@ -532,73 +531,6 @@ class TFLxmertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCa
self.assert_outputs_same(after_outputs, outputs)
@tooslow
def test_saved_model_creation(self):
pass
@slow
def test_saved_model_creation_extended(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True
if hasattr(config, "use_cache"):
config.use_cache = True
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
num_out = len(model(class_inputs_dict))
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True)
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
model = tf.keras.models.load_model(saved_model_dir)
outputs = model(class_inputs_dict)
language_hidden_states = outputs["language_hidden_states"]
vision_hidden_states = outputs["vision_hidden_states"]
language_attentions = outputs["language_attentions"]
vision_attentions = outputs["vision_attentions"]
cross_encoder_attentions = outputs["cross_encoder_attentions"]
self.assertEqual(len(outputs), num_out)
self.assertEqual(len(language_hidden_states), self.model_tester.num_hidden_layers["language"] + 1)
self.assertEqual(len(vision_hidden_states), self.model_tester.num_hidden_layers["vision"] + 1)
seq_length = self.model_tester.seq_length
num_visual_features = self.model_tester.num_visual_features
self.assertListEqual(
list(language_hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
self.assertListEqual(
list(vision_hidden_states[0].shape[-2:]),
[num_visual_features, self.model_tester.hidden_size],
)
self.assertEqual(len(language_attentions), self.model_tester.num_hidden_layers["language"])
self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers["vision"])
self.assertEqual(len(cross_encoder_attentions), self.model_tester.num_hidden_layers["cross_encoder"])
attentions = [language_attentions, vision_attentions, cross_encoder_attentions]
attention_shapes = [
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
[
self.model_tester.num_attention_heads,
self.model_tester.num_visual_features,
self.model_tester.num_visual_features,
],
[self.model_tester.num_attention_heads, encoder_key_length, self.model_tester.num_visual_features],
]
for attention, attention_shape in zip(attentions, attention_shapes):
self.assertListEqual(list(attention[0].shape[-3:]), attention_shape)
@require_tf
class TFLxmertModelIntegrationTest(unittest.TestCase):

View File

@ -20,7 +20,7 @@ import unittest
import warnings
from transformers import AutoTokenizer, MarianConfig, MarianTokenizer, TranslationPipeline, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -208,10 +208,6 @@ class TFMarianModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCa
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
@tooslow
def test_saved_model_creation(self):
pass
@require_tf
class AbstractMarianIntegrationTest(unittest.TestCase):

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import unittest
from transformers import AutoTokenizer, MBartConfig, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -195,10 +195,6 @@ class TFMBartModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
@tooslow
def test_saved_model_creation(self):
pass
@require_sentencepiece
@require_tokenizers

View File

@ -20,7 +20,7 @@ import unittest
from transformers import MobileBertConfig, is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_tf, slow, tooslow
from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
@ -311,15 +311,6 @@ class TFMobileBertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Te
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mobilebert_for_token_classification(*config_and_inputs)
@slow
def test_keras_fit(self):
# Override as it is a slow test on this model
super().test_keras_fit()
@tooslow
def test_saved_model_creation(self):
pass
@slow
def test_model_from_pretrained(self):
# for model_name in TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:

View File

@ -20,7 +20,7 @@ import unittest
import numpy as np
from transformers import OPTConfig, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, slow, tooslow
from transformers.testing_utils import require_sentencepiece, require_tf, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
@ -219,10 +219,6 @@ class TFOPTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
models_equal = False
self.assertTrue(models_equal)
@tooslow
def test_saved_model_creation(self):
pass
def _long_tensor(tok_lst):
return tf.constant(tok_lst, dtype=tf.int32)

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import unittest
from transformers import AutoTokenizer, PegasusConfig, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -206,10 +206,6 @@ class TFPegasusModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
@tooslow
def test_saved_model_creation(self):
pass
@require_sentencepiece
@require_tokenizers

View File

@ -490,6 +490,7 @@ class TFRagTestMixin:
inputs_dict = self.config_and_inputs
self.check_model_without_retriever(**inputs_dict)
@slow
def test_model_generate_from_context_input_ids(self):
inputs_dict = self.config_and_inputs
self.check_model_generate_from_context_input_ids(**inputs_dict)
@ -498,6 +499,7 @@ class TFRagTestMixin:
inputs_dict = self.config_and_inputs
self.check_model_with_encoder_outputs(**inputs_dict)
@slow
def test_model_generate(self):
inputs_dict = self.config_and_inputs
self.check_model_generate(**inputs_dict)

View File

@ -148,6 +148,7 @@ class TFRegNetModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCa
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF does not support backprop for grouped convolutions on CPU.",
)
@slow
def test_keras_fit(self):
super().test_keras_fit()

View File

@ -347,6 +347,7 @@ class TFSegformerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Tes
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF does not support backprop for grouped convolutions on CPU.",
)
@slow
def test_keras_fit(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -722,6 +722,10 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
self.assertTrue(models_equal)
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
# Allow missing keys since TF doesn't cache the sinusoidal embeddings in an attribute
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
@require_torch
@require_torchaudio

View File

@ -558,6 +558,10 @@ class TFSpeech2TextModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.T
]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
# Allow missing keys since TF doesn't cache the sinusoidal embeddings in an attribute
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
@require_tf
@require_sentencepiece

View File

@ -23,7 +23,7 @@ import unittest
import numpy as np
from transformers import SwinConfig
from transformers.testing_utils import require_tf, require_vision, slow, to_2tuple, tooslow
from transformers.testing_utils import require_tf, require_vision, slow, to_2tuple
from transformers.utils import cached_property, is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester
@ -232,10 +232,6 @@ class TFSwinModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
def test_inputs_embeds(self):
pass
@tooslow
def test_saved_model_creation(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import unittest
from transformers import T5Config, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -300,10 +300,6 @@ class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs)
@tooslow
def test_saved_model_creation(self):
pass
@slow
def test_model_from_pretrained(self):
model = TFT5Model.from_pretrained("t5-small")

View File

@ -1415,6 +1415,7 @@ class TFModelTesterMixin:
def check_keras_fit_results(self, val_loss1, val_loss2, atol=1e-2, rtol=1e-3):
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol))
@slow
def test_keras_fit(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: