remove final_logits_bias (#10606)

This commit is contained in:
Suraj Patil 2021-03-10 09:52:31 +05:30 committed by GitHub
parent 6f52fce673
commit 44f64132a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1153,7 +1153,6 @@ class M2M100Model(M2M100PreTrainedModel):
class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"encoder\.version",
r"decoder\.version",
r"lm_head\.weight",
@ -1168,7 +1167,6 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
def __init__(self, config: M2M100Config):
super().__init__(config)
self.model = M2M100Model(config)
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
self.init_weights()
@ -1181,18 +1179,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
self._resize_final_logits_bias(new_num_tokens)
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
old_num_tokens = self.final_logits_bias.shape[-1]
if new_num_tokens <= old_num_tokens:
new_bias = self.final_logits_bias[:, :new_num_tokens]
else:
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)
def get_output_embeddings(self):
return self.lm_head
@ -1266,7 +1254,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
lm_logits = self.lm_head(outputs[0])
masked_lm_loss = None
if labels is not None: