TF MT5 embeddings resize (#15567)

* Fix TF MT5 vocab resize

* more assertive testing
This commit is contained in:
Joao Gante 2022-02-11 17:35:10 +00:00 committed by GitHub
parent 8c03df1010
commit 2f40c728c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 1 deletions

View File

@ -1135,6 +1135,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return model_embeds
def _get_word_embedding_weight(model, embedding_layer):
# If the variable holds the weights themselves, return them
if isinstance(embedding_layer, tf.Tensor):
return embedding_layer
# Otherwise, try to get them from the layer's attributes
embeds = getattr(embedding_layer, "weight", None)
if embeds is not None:
return embeds

View File

@ -22,7 +22,24 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir
if is_tf_available():
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
from transformers import AutoTokenizer, T5Tokenizer, TFAutoModelForSeq2SeqLM, TFMT5ForConditionalGeneration
@require_tf
class TFMT5ModelTest(unittest.TestCase): # no mixin with common tests -> most cases are already covered in the TF T5
@slow
def test_resize_embeddings(self):
model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
original_vocab_size = model.get_input_embeddings().weight.shape[0]
# the vocab size is defined in the model config
self.assertEqual(original_vocab_size, model.config.vocab_size)
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
tokenizer.add_special_tokens({"bos_token": "", "eos_token": ""})
model._resize_token_embeddings(len(tokenizer))
# the vocab size is now resized to the length of the tokenizer, which is different from the original size
self.assertEqual(model.get_input_embeddings().weight.shape[0], len(tokenizer))
self.assertNotEqual(model.get_input_embeddings().weight.shape[0], original_vocab_size)
@require_tf

View File

@ -314,6 +314,20 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO: Fix head-masking according to PyTorch T5 model
pass
@slow
def test_resize_embeddings(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
original_vocab_size = model.get_input_embeddings().weight.shape[0]
# the vocab size is defined in the model config
self.assertEqual(original_vocab_size, model.config.vocab_size)
tokenizer = T5Tokenizer.from_pretrained("t5-small")
tokenizer.add_special_tokens({"bos_token": "", "eos_token": ""})
model._resize_token_embeddings(len(tokenizer))
# the vocab size is now resized to the length of the tokenizer, which is different from the original size
self.assertEqual(model.get_input_embeddings().weight.shape[0], len(tokenizer))
self.assertNotEqual(model.get_input_embeddings().weight.shape[0], original_vocab_size)
class TFT5EncoderOnlyModelTester:
def __init__(