diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 5e37a1818df..428af92b2b3 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1135,6 +1135,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu return model_embeds def _get_word_embedding_weight(model, embedding_layer): + # If the variable holds the weights themselves, return them + if isinstance(embedding_layer, tf.Tensor): + return embedding_layer + # Otherwise, try to get them from the layer's attributes + embeds = getattr(embedding_layer, "weight", None) if embeds is not None: return embeds diff --git a/tests/test_modeling_tf_mt5.py b/tests/test_modeling_tf_mt5.py index 9b23e05f752..1ab1a635b39 100644 --- a/tests/test_modeling_tf_mt5.py +++ b/tests/test_modeling_tf_mt5.py @@ -22,7 +22,24 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir if is_tf_available(): import tensorflow as tf - from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM + from transformers import AutoTokenizer, T5Tokenizer, TFAutoModelForSeq2SeqLM, TFMT5ForConditionalGeneration + + +@require_tf +class TFMT5ModelTest(unittest.TestCase): # no mixin with common tests -> most cases are already covered in the TF T5 + @slow + def test_resize_embeddings(self): + model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small") + original_vocab_size = model.get_input_embeddings().weight.shape[0] + # the vocab size is defined in the model config + self.assertEqual(original_vocab_size, model.config.vocab_size) + + tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") + tokenizer.add_special_tokens({"bos_token": "", "eos_token": ""}) + model._resize_token_embeddings(len(tokenizer)) + # the vocab size is now resized to the length of the tokenizer, which is different from the original size + self.assertEqual(model.get_input_embeddings().weight.shape[0], len(tokenizer)) + self.assertNotEqual(model.get_input_embeddings().weight.shape[0], original_vocab_size) @require_tf diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index 59ee70c53ec..67e780f24c2 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -314,6 +314,20 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): # TODO: Fix head-masking according to PyTorch T5 model pass + @slow + def test_resize_embeddings(self): + model = TFT5ForConditionalGeneration.from_pretrained("t5-small") + original_vocab_size = model.get_input_embeddings().weight.shape[0] + # the vocab size is defined in the model config + self.assertEqual(original_vocab_size, model.config.vocab_size) + + tokenizer = T5Tokenizer.from_pretrained("t5-small") + tokenizer.add_special_tokens({"bos_token": "", "eos_token": ""}) + model._resize_token_embeddings(len(tokenizer)) + # the vocab size is now resized to the length of the tokenizer, which is different from the original size + self.assertEqual(model.get_input_embeddings().weight.shape[0], len(tokenizer)) + self.assertNotEqual(model.get_input_embeddings().weight.shape[0], original_vocab_size) + class TFT5EncoderOnlyModelTester: def __init__(