mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
TF MT5 embeddings resize (#15567)
* Fix TF MT5 vocab resize * more assertive testing
This commit is contained in:
parent
8c03df1010
commit
2f40c728c9
@ -1135,6 +1135,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
return model_embeds
|
return model_embeds
|
||||||
|
|
||||||
def _get_word_embedding_weight(model, embedding_layer):
|
def _get_word_embedding_weight(model, embedding_layer):
|
||||||
|
# If the variable holds the weights themselves, return them
|
||||||
|
if isinstance(embedding_layer, tf.Tensor):
|
||||||
|
return embedding_layer
|
||||||
|
# Otherwise, try to get them from the layer's attributes
|
||||||
|
|
||||||
embeds = getattr(embedding_layer, "weight", None)
|
embeds = getattr(embedding_layer, "weight", None)
|
||||||
if embeds is not None:
|
if embeds is not None:
|
||||||
return embeds
|
return embeds
|
||||||
|
@ -22,7 +22,24 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
|
from transformers import AutoTokenizer, T5Tokenizer, TFAutoModelForSeq2SeqLM, TFMT5ForConditionalGeneration
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFMT5ModelTest(unittest.TestCase): # no mixin with common tests -> most cases are already covered in the TF T5
|
||||||
|
@slow
|
||||||
|
def test_resize_embeddings(self):
|
||||||
|
model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
|
||||||
|
original_vocab_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
# the vocab size is defined in the model config
|
||||||
|
self.assertEqual(original_vocab_size, model.config.vocab_size)
|
||||||
|
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||||
|
tokenizer.add_special_tokens({"bos_token": "", "eos_token": ""})
|
||||||
|
model._resize_token_embeddings(len(tokenizer))
|
||||||
|
# the vocab size is now resized to the length of the tokenizer, which is different from the original size
|
||||||
|
self.assertEqual(model.get_input_embeddings().weight.shape[0], len(tokenizer))
|
||||||
|
self.assertNotEqual(model.get_input_embeddings().weight.shape[0], original_vocab_size)
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
|
@ -314,6 +314,20 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO: Fix head-masking according to PyTorch T5 model
|
# TODO: Fix head-masking according to PyTorch T5 model
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_resize_embeddings(self):
|
||||||
|
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||||
|
original_vocab_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
# the vocab size is defined in the model config
|
||||||
|
self.assertEqual(original_vocab_size, model.config.vocab_size)
|
||||||
|
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||||
|
tokenizer.add_special_tokens({"bos_token": "", "eos_token": ""})
|
||||||
|
model._resize_token_embeddings(len(tokenizer))
|
||||||
|
# the vocab size is now resized to the length of the tokenizer, which is different from the original size
|
||||||
|
self.assertEqual(model.get_input_embeddings().weight.shape[0], len(tokenizer))
|
||||||
|
self.assertNotEqual(model.get_input_embeddings().weight.shape[0], original_vocab_size)
|
||||||
|
|
||||||
|
|
||||||
class TFT5EncoderOnlyModelTester:
|
class TFT5EncoderOnlyModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
Loading…
Reference in New Issue
Block a user