mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
remove final_logits_bias (#10606)
This commit is contained in:
parent
6f52fce673
commit
44f64132a5
@ -1153,7 +1153,6 @@ class M2M100Model(M2M100PreTrainedModel):
|
||||
class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"final_logits_bias",
|
||||
r"encoder\.version",
|
||||
r"decoder\.version",
|
||||
r"lm_head\.weight",
|
||||
@ -1168,7 +1167,6 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
||||
def __init__(self, config: M2M100Config):
|
||||
super().__init__(config)
|
||||
self.model = M2M100Model(config)
|
||||
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
||||
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
@ -1181,18 +1179,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
self._resize_final_logits_bias(new_num_tokens)
|
||||
return new_embeddings
|
||||
|
||||
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
|
||||
old_num_tokens = self.final_logits_bias.shape[-1]
|
||||
if new_num_tokens <= old_num_tokens:
|
||||
new_bias = self.final_logits_bias[:, :new_num_tokens]
|
||||
else:
|
||||
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
|
||||
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
||||
self.register_buffer("final_logits_bias", new_bias)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
@ -1266,7 +1254,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
||||
lm_logits = self.lm_head(outputs[0])
|
||||
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user