Bug fix: 1764

This commit is contained in:
Dom Hudson 2019-11-07 19:58:17 +00:00 committed by Julien Chaumond
parent a80778f40e
commit 228f52867c
2 changed files with 74 additions and 12 deletions

View File

@ -51,24 +51,45 @@ class RobertaEmbeddings(BertEmbeddings):
padding_idx=self.padding_idx)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
# Position numbers begin at padding_idx+1. Padding symbols are ignored.
# cf. fairseq's `utils.make_positions`
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids).to(input_ids.device)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
return super(RobertaEmbeddings, self).forward(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds)
def create_position_ids_from_input_ids(self, x):
""" Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param torch.Tensor x:
:return torch.Tensor:
"""
mask = x.ne(self.padding_idx).long()
incremental_indicies = torch.cumsum(mask, dim=1) * mask
return incremental_indicies + self.padding_idx
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
""" We are provided embeddings directly. We cannot infer which are padded so just generate
sequential position ids.
:param torch.Tensor inputs_embeds:
:return torch.Tensor:
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(self.padding_idx+1, sequence_length+self.padding_idx+1, dtype=torch.long,
device=inputs_embeds.device)
return position_ids.unsqueeze(0)
ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in
`RoBERTa: A Robustly Optimized BERT Pretraining Approach`_

View File

@ -25,6 +25,7 @@ if is_torch_available():
import torch
from transformers import (RobertaConfig, RobertaModel, RobertaForMaskedLM,
RobertaForSequenceClassification, RobertaForTokenClassification)
from transformers.modeling_roberta import RobertaEmbeddings
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ids_tensor)
@ -205,6 +206,46 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
def test_create_position_ids_respects_padding_index(self):
""" Ensure that the default position ids only assign a sequential . This is a regression
test for https://github.com/huggingface/transformers/issues/1761
The position ids should be masked with the embedding object's padding index. Therefore, the
first available non-padding position index is RobertaEmbeddings.padding_idx + 1
"""
config = self.model_tester.prepare_config_and_inputs()[0]
model = RobertaEmbeddings(config=config)
input_ids = torch.as_tensor([[12, 31, 13, model.padding_idx]])
expected_positions = torch.as_tensor([[
0 + model.padding_idx + 1,
1 + model.padding_idx + 1,
2 + model.padding_idx + 1,
model.padding_idx
]])
position_ids = model.create_position_ids_from_input_ids(input_ids)
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
def test_create_position_ids_from_inputs_embeds(self):
""" Ensure that the default position ids only assign a sequential . This is a regression
test for https://github.com/huggingface/transformers/issues/1761
The position ids should be masked with the embedding object's padding index. Therefore, the
first available non-padding position index is RobertaEmbeddings.padding_idx + 1
"""
config = self.model_tester.prepare_config_and_inputs()[0]
model = RobertaEmbeddings(config=config)
input_ids = torch.Tensor(1, 4, 30)
expected_positions = torch.as_tensor([[
0 + model.padding_idx + 1,
1 + model.padding_idx + 1,
2 + model.padding_idx + 1,
3 + model.padding_idx + 1,
]])
position_ids = model.create_position_ids_from_inputs_embeds(input_ids)
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
class RobertaModelIntegrationTest(unittest.TestCase):