Fix #17893, removed dead code (#17917)

* Removed dead position_id code, fix #17893

* Removed unused var

* Now ignores removed (dead) dict key for backward comp
This commit is contained in:
Clémentine Fourrier 2022-06-29 23:54:26 +02:00 committed by GitHub
parent fbc7598bab
commit eb1493b15d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -447,8 +447,6 @@ class LongformerEmbeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.padding_idx = config.pad_token_id
@ -469,13 +467,8 @@ class LongformerEmbeddings(nn.Module):
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
@ -1392,7 +1385,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
config_class = LongformerConfig
base_model_prefix = "longformer"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""