fix pegasus init weights and other copied models (#36844)

* fix pegasus init weights

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix the rest of models

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix test

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix informer init

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* init weight before checking

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix roformer tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix roformer tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng 2025-03-27 21:14:30 +08:00 committed by GitHub
parent 7e813f9cf0
commit 0e56fb69a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 27 additions and 12 deletions

View File

@ -360,7 +360,6 @@ class AutoformerSinusoidalPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter) -> nn.Parameter:
@ -904,7 +903,7 @@ class AutoformerPreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, AutoformerSinusoidalPositionalEmbedding):
pass
module.weight = module._init_weight(module.weight)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:

View File

@ -233,7 +233,6 @@ class InformerSinusoidalPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter) -> nn.Parameter:
@ -887,7 +886,9 @@ class InformerPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding) and not isinstance(module, InformerSinusoidalPositionalEmbedding):
elif isinstance(module, InformerSinusoidalPositionalEmbedding):
module.weight = module._init_weight(module.weight)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

View File

@ -73,7 +73,6 @@ class MarianSinusoidalPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter) -> nn.Parameter:
@ -468,7 +467,7 @@ class MarianPreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, MarianSinusoidalPositionalEmbedding):
pass
module.weight = module._init_weight(module.weight)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:

View File

@ -74,7 +74,6 @@ class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter) -> nn.Parameter:
@ -469,7 +468,7 @@ class PegasusPreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, PegasusSinusoidalPositionalEmbedding):
pass
module.weight = module._init_weight(module.weight)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
@ -665,6 +664,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
self.config.d_model,
self.padding_idx,
)
self.embed_positions.weight = self.embed_positions._init_weight(self.embed_positions.weight)
self.embed_positions.to(self.device)
def get_position_embeddings(self) -> nn.Embedding:
@ -868,6 +868,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
self.config.d_model,
self.padding_idx,
)
self.embed_positions.weight = self.embed_positions._init_weight(self.embed_positions.weight)
self.embed_positions.to(self.device)
def get_position_embeddings(self) -> nn.Embedding:

View File

@ -59,7 +59,6 @@ class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter) -> nn.Parameter:
@ -694,7 +693,7 @@ class RoFormerPreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, RoFormerSinusoidalPositionalEmbedding):
pass
module.weight = module._init_weight(module.weight)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:

View File

@ -233,7 +233,6 @@ class TimeSeriesSinusoidalPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter) -> nn.Parameter:
@ -641,7 +640,7 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, TimeSeriesSinusoidalPositionalEmbedding):
pass
module.weight = module._init_weight(module.weight)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:

View File

@ -171,6 +171,7 @@ class InformerModelTester:
embed_positions = InformerSinusoidalPositionalEmbedding(
config.context_length + config.prediction_length, config.d_model
).to(torch_device)
embed_positions.weight = embed_positions._init_weight(embed_positions.weight)
self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight))
self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight))

View File

@ -348,6 +348,19 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
def model(self):
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint_name).to(torch_device)
@slow
def test_device_map(self):
model_no_device_map = AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint_name).to(torch_device)
model_with_device_map = AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint_name, device_map="auto")
assert torch.equal(
model_no_device_map.model.decoder.embed_positions.weight,
model_with_device_map.model.decoder.embed_positions.weight,
)
assert torch.equal(
model_no_device_map.model.encoder.embed_positions.weight,
model_with_device_map.model.encoder.embed_positions.weight,
)
@slow
@require_torch_fp16
def test_pegasus_xsum_summary(self):

View File

@ -534,6 +534,7 @@ class RoFormerSinusoidalPositionalEmbeddingTest(unittest.TestCase):
def test_basic(self):
input_ids = torch.tensor([[4, 10]], dtype=torch.long, device=torch_device)
emb1 = RoFormerSinusoidalPositionalEmbedding(num_positions=6, embedding_dim=6).to(torch_device)
emb1.weight = emb1._init_weight(emb1.weight)
emb = emb1(input_ids.shape)
desired_weights = torch.tensor(
[[0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 1.0000], [0.8415, 0.0464, 0.0022, 0.5403, 0.9989, 1.0000]]
@ -552,6 +553,7 @@ class RoFormerSinusoidalPositionalEmbeddingTest(unittest.TestCase):
]
).to(torch_device)
emb1 = RoFormerSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512).to(torch_device)
emb1.weight = emb1._init_weight(emb1.weight)
weights = emb1.weight.data[:3, :5].to(torch_device)
self.assertTrue(
@ -573,6 +575,7 @@ class RoFormerSelfAttentionRotaryPositionEmbeddingTest(unittest.TestCase):
-torch.arange(2 * 12 * 16 * 64, dtype=torch.float, device=torch_device).reshape(2, 12, 16, 64) / 100
).to(torch_device)
embed_positions = RoFormerSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=64).to(torch_device)
embed_positions.weight = embed_positions._init_weight(embed_positions.weight)
sinusoidal_pos = embed_positions([2, 16, 768])[None, None, :, :]
query_layer, key_layer = RoFormerSelfAttention.apply_rotary_position_embeddings(