mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix for the issue of device-id getting hardcoded for position-ids during Tracing for Distillbert (#12290)
* registered buffer for position-ids to address issues similar to issue#5664 * added comment * added the flag to prevent from adding the buffer into the state_dict
This commit is contained in:
parent
5d1a3d135c
commit
5adf5cab2f
@ -22,6 +22,7 @@ import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
@ -101,6 +102,10 @@ class Embeddings(nn.Module):
|
||||
|
||||
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
||||
)
|
||||
|
||||
def forward(self, input_ids):
|
||||
"""
|
||||
@ -111,8 +116,15 @@ class Embeddings(nn.Module):
|
||||
embeddings)
|
||||
"""
|
||||
seq_length = input_ids.size(1)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
|
||||
|
||||
# Setting the position-ids to the registered buffer in constructor, it helps
|
||||
# when tracing the model without passing position-ids, solves
|
||||
# isues similar to issue #5664
|
||||
if hasattr(self, "position_ids"):
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
else:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
|
||||
|
||||
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
|
||||
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
|
||||
|
Loading…
Reference in New Issue
Block a user