mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Persist embedding type of BART and mBART models after resize (#32242)
* fix: persist embedding type of MBartConditonalGeneration after resize * fix: persist embedding type of BartConditonalGeneration after resize
This commit is contained in:
parent
f5f1e52f6c
commit
baf7e5c927
@ -1431,7 +1431,8 @@ class BartModel(BartPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
||||
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.shared = BartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
|
||||
|
||||
self.encoder = BartEncoder(config, self.shared)
|
||||
self.decoder = BartDecoder(config, self.shared)
|
||||
|
@ -1271,7 +1271,8 @@ class MBartModel(MBartPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
||||
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.shared = MBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
|
||||
|
||||
self.encoder = MBartEncoder(config, self.shared)
|
||||
self.decoder = MBartDecoder(config, self.shared)
|
||||
|
@ -518,6 +518,18 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
def test_load_save_without_tied_weights(self):
|
||||
pass
|
||||
|
||||
def test_resize_embeddings_persists_embeddings_type(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
config.scale_embedding = True
|
||||
model = BartForConditionalGeneration(config)
|
||||
old_type = type(model.model.decoder.embed_tokens)
|
||||
|
||||
model.resize_token_embeddings(new_num_tokens=config.vocab_size)
|
||||
|
||||
new_type = type(model.model.decoder.embed_tokens)
|
||||
self.assertIs(old_type, new_type)
|
||||
|
||||
|
||||
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
||||
|
@ -375,6 +375,18 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
def test_load_save_without_tied_weights(self):
|
||||
pass
|
||||
|
||||
def test_resize_embeddings_persists_embeddings_type(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
config.scale_embedding = True
|
||||
model = MBartForConditionalGeneration(config)
|
||||
old_type = type(model.model.decoder.embed_tokens)
|
||||
|
||||
model.resize_token_embeddings(new_num_tokens=config.vocab_size)
|
||||
|
||||
new_type = type(model.model.decoder.embed_tokens)
|
||||
self.assertIs(old_type, new_type)
|
||||
|
||||
|
||||
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
||||
|
Loading…
Reference in New Issue
Block a user