diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 81ca97ab7be..fdd4c05d60e 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -20,6 +20,7 @@ from dataclasses import dataclass from typing import Optional, Tuple import torch +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -216,6 +217,12 @@ class AlbertEmbeddings(nn.Module): # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), + persistent=False, + ) # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward def forward( @@ -231,8 +238,16 @@ class AlbertEmbeddings(nn.Module): if position_ids is None: position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) @@ -687,6 +702,7 @@ class AlbertModel(AlbertPreTrainedModel): raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() + batch_size, seq_length = input_shape elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: @@ -697,7 +713,12 @@ class AlbertModel(AlbertPreTrainedModel): if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 5c135da7efc..9606af37670 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -24,6 +24,7 @@ from typing import Optional, Tuple import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -176,10 +177,15 @@ class BertEmbeddings(nn.Module): # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), + persistent=False, + ) def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 @@ -194,8 +200,16 @@ class BertEmbeddings(nn.Module): if position_ids is None: position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) @@ -936,8 +950,14 @@ class BertModel(BertPreTrainedModel): if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 8e11594cb1b..429ac39f86e 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -23,6 +23,7 @@ from typing import Optional, Tuple import numpy as np import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -254,10 +255,15 @@ class BigBirdEmbeddings(nn.Module): # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), + persistent=False, + ) # End copy self.rescale_embeddings = config.rescale_embeddings @@ -276,8 +282,16 @@ class BigBirdEmbeddings(nn.Module): if position_ids is None: position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) @@ -2025,7 +2039,12 @@ class BigBirdModel(BigBirdPreTrainedModel): if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # in order to use block_sparse attention, sequence_length has to be at least # bigger than all global attentions: 2 * block_size diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 84084d26b75..aa41b456763 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -21,6 +21,7 @@ from typing import Optional, Tuple import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -169,6 +170,12 @@ class ElectraEmbeddings(nn.Module): # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), + persistent=False, + ) # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward def forward( @@ -184,8 +191,16 @@ class ElectraEmbeddings(nn.Module): if position_ids is None: position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) @@ -839,6 +854,7 @@ class ElectraModel(ElectraPreTrainedModel): raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() + batch_size, seq_length = input_shape elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: @@ -849,7 +865,12 @@ class ElectraModel(ElectraPreTrainedModel): if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index c1a22259ad4..787ae588ed6 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -19,6 +19,7 @@ import math import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -82,10 +83,15 @@ class RobertaEmbeddings(nn.Module): # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), + persistent=False, + ) # End copy self.padding_idx = config.pad_token_id @@ -99,9 +105,7 @@ class RobertaEmbeddings(nn.Module): if position_ids is None: if input_ids is not None: # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = create_position_ids_from_input_ids( - input_ids, self.padding_idx, past_key_values_length - ).to(input_ids.device) + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) else: position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) @@ -110,8 +114,18 @@ class RobertaEmbeddings(nn.Module): else: input_shape = inputs_embeds.size()[:-1] + seq_length = input_shape[1] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) @@ -780,8 +794,14 @@ class RobertaModel(RobertaPreTrainedModel): if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 87a95e6b3b0..2b7bab9d689 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -22,6 +22,7 @@ import os import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import CrossEntropyLoss, MSELoss @@ -156,6 +157,12 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module): # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), + persistent=False, + ) def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 @@ -170,9 +177,17 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module): if position_ids is None: position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) - + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) @@ -846,8 +861,14 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads.