mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix for the issue of device-id getting hardcoded for token_type_ids during Tracing [WIP] (#11252)
* registering a buffer for token_type_ids, to pass the error of device-id getting hardcoded when tracing * sytle format * adding persistent flag to the resgitered buffers that prevent from adding them to the state_dict and addresses the Backward compatibility issue * adding the try catch to the fix as persistent flag is only available from PT >1.6 * adding version check * added the condition to only use the token_type_ids buffer when its autogenerated not passed by user * adding comments and making the conidtion where token_type_ids are None to use the registered buffer * taking out position-embeddding from the if block * adding comments * handling the case if buffer for position_ids was not registered * reverted the changes on position_ids, fix the issue with size of token_type_ids buffer, moved the modification for generated token_type_ids to Bertmodel, instead of Embeddings * reverting the token_type_ids in case of None to the previous version * reverting changes on position_ids adding back the if block * changes added by running make fix-copies * changes added by running make fix-copies and added the import version as it was getting used * changes added by running make fix-copies * changes added by running make fix-copies * fixing the import format * fixing the import format * modified to use temp tensor for trimed and expanded token_type_ids buffer * changes made by fix-copies after temp tensor modifications * changes made by fix-copies after temp tensor modifications * changes made by fix-copies after temp tensor modifications * clean up * clean up * clean up * clean up * Nit * Nit * Nit * modified according to support device conversion on traced models * modified according to support device conversion on traced models * modified according to support device conversion on traced models * modified according to support device conversion on traced models * changes based on latest in master * Adapt templates * Add version import Co-authored-by: Ubuntu <ubuntu@ip-172-31-32-81.us-west-2.compute.internal> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
parent
0d97ba8a98
commit
af6e01c5bc
@ -20,6 +20,7 @@ from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
@ -216,6 +217,12 @@ class AlbertEmbeddings(nn.Module):
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"token_type_ids",
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
||||
def forward(
|
||||
@ -231,8 +238,16 @@ class AlbertEmbeddings(nn.Module):
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
# issue #5664
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
if hasattr(self, "token_type_ids"):
|
||||
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
@ -687,6 +702,7 @@ class AlbertModel(AlbertPreTrainedModel):
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
batch_size, seq_length = input_shape
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
@ -697,7 +713,12 @@ class AlbertModel(AlbertPreTrainedModel):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
|
@ -24,6 +24,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
@ -176,10 +177,15 @@ class BertEmbeddings(nn.Module):
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"token_type_ids",
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||
@ -194,8 +200,16 @@ class BertEmbeddings(nn.Module):
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
# issue #5664
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
if hasattr(self, "token_type_ids"):
|
||||
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
@ -936,8 +950,14 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
|
@ -23,6 +23,7 @@ from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
@ -254,10 +255,15 @@ class BigBirdEmbeddings(nn.Module):
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"token_type_ids",
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||
persistent=False,
|
||||
)
|
||||
# End copy
|
||||
|
||||
self.rescale_embeddings = config.rescale_embeddings
|
||||
@ -276,8 +282,16 @@ class BigBirdEmbeddings(nn.Module):
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
# issue #5664
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
if hasattr(self, "token_type_ids"):
|
||||
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
@ -2025,7 +2039,12 @@ class BigBirdModel(BigBirdPreTrainedModel):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# in order to use block_sparse attention, sequence_length has to be at least
|
||||
# bigger than all global attentions: 2 * block_size
|
||||
|
@ -21,6 +21,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
@ -169,6 +170,12 @@ class ElectraEmbeddings(nn.Module):
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"token_type_ids",
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
||||
def forward(
|
||||
@ -184,8 +191,16 @@ class ElectraEmbeddings(nn.Module):
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
# issue #5664
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
if hasattr(self, "token_type_ids"):
|
||||
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
@ -839,6 +854,7 @@ class ElectraModel(ElectraPreTrainedModel):
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
batch_size, seq_length = input_shape
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
@ -849,7 +865,12 @@ class ElectraModel(ElectraPreTrainedModel):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
@ -19,6 +19,7 @@ import math
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
@ -82,10 +83,15 @@ class RobertaEmbeddings(nn.Module):
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"token_type_ids",
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
# End copy
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -99,9 +105,7 @@ class RobertaEmbeddings(nn.Module):
|
||||
if position_ids is None:
|
||||
if input_ids is not None:
|
||||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||
position_ids = create_position_ids_from_input_ids(
|
||||
input_ids, self.padding_idx, past_key_values_length
|
||||
).to(input_ids.device)
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
|
||||
else:
|
||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||
|
||||
@ -110,8 +114,18 @@ class RobertaEmbeddings(nn.Module):
|
||||
else:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
|
||||
seq_length = input_shape[1]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
# issue #5664
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
if hasattr(self, "token_type_ids"):
|
||||
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
@ -780,8 +794,14 @@ class RobertaModel(RobertaPreTrainedModel):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
|
@ -22,6 +22,7 @@ import os
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
@ -156,6 +157,12 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"token_type_ids",
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||
@ -170,9 +177,17 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
# issue #5664
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
|
||||
if hasattr(self, "token_type_ids"):
|
||||
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
@ -846,8 +861,14 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
|
Loading…
Reference in New Issue
Block a user