From eb1493b15db2019c93e365219b517fb44e313aaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine=20Fourrier?= <22726840+clefourrier@users.noreply.github.com> Date: Wed, 29 Jun 2022 23:54:26 +0200 Subject: [PATCH] Fix #17893, removed dead code (#17917) * Removed dead position_id code, fix #17893 * Removed unused var * Now ignores removed (dead) dict key for backward comp --- .../models/longformer/modeling_longformer.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 797e744bbf7..0fefb93501a 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -447,8 +447,6 @@ class LongformerEmbeddings(nn.Module): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.padding_idx = config.pad_token_id @@ -469,13 +467,8 @@ class LongformerEmbeddings(nn.Module): else: input_shape = inputs_embeds.size()[:-1] - seq_length = input_shape[1] - - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) @@ -1392,7 +1385,7 @@ class LongformerPreTrainedModel(PreTrainedModel): config_class = LongformerConfig base_model_prefix = "longformer" supports_gradient_checkpointing = True - _keys_to_ignore_on_load_missing = [r"position_ids"] + _keys_to_ignore_on_load_unexpected = [r"position_ids"] def _init_weights(self, module): """Initialize the weights"""