Fix for the issue of device-id getting hardcoded for position-ids during Tracing for Distillbert (#12290)

* registered buffer for position-ids to address issues similar to issue#5664

* added comment

* added the flag to prevent from adding the buffer into the state_dict
This commit is contained in:
Hamid Shojanazeri 2021-09-01 01:47:25 -07:00 committed by GitHub
parent 5d1a3d135c
commit 5adf5cab2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -22,6 +22,7 @@ import math
import numpy as np
import torch
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -101,6 +102,10 @@ class Embeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
self.dropout = nn.Dropout(config.dropout)
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(self, input_ids):
"""
@ -111,8 +116,15 @@ class Embeddings(nn.Module):
embeddings)
"""
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
# Setting the position-ids to the registered buffer in constructor, it helps
# when tracing the model without passing position-ids, solves
# isues similar to issue #5664
if hasattr(self, "position_ids"):
position_ids = self.position_ids[:, :seq_length]
else:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)