[T5] make decoder input ids optional for t5 training (#3521)

* make decoder input ids optional for t5 training

* lm_lables should not be shifted in t5

* add tests

* finish shift right functionality for PT T5

* move shift right to correct class

* cleaner code

* replace -100 values with pad token id

* add assert statement

* remove unnecessary for loop

* make style
This commit is contained in:
Patrick von Platen 2020-03-30 13:45:26 +02:00 committed by GitHub
parent 5b44e0a31b
commit 75ec6c9e3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 79 additions and 18 deletions

View File

@ -502,6 +502,27 @@ class T5PreTrainedModel(PreTrainedModel):
if module.has_relative_attention_bias:
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
assert (
decoder_start_token_id is not None
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in lm_labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `lm_labels` has only positive values and -100"
return shifted_input_ids
class T5Stack(T5PreTrainedModel):
def __init__(self, config, embed_tokens=None):
@ -923,6 +944,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
hidden_states = encoder_outputs[0]
if lm_labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(lm_labels)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
@ -941,10 +966,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
decoder_outputs = (lm_logits,) + decoder_outputs[1:] # Add hidden states and attention if they are here
if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
decoder_outputs = (
loss,
) + decoder_outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

View File

@ -523,8 +523,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
pad_token_id: (`optional`) int
Pad token. Defaults to pad_token_id as defined in the models config.
eos_token_ids: (`optional`) int or list of int
End of sequence token or list of tokens to stop the generation. Default to 0.
eos_token_id: (`optional`) int
EOS token. Defaults to eos_token_id as defined in the models config.
length_penalty: (`optional`) float
Exponential penalty to the length. Default to 1.

View File

@ -721,13 +721,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
Padding token. Default to specicic model pad_token_id or None if it does not exist.
bos_token_id: (`optional`) int
BOS token. Defaults to bos_token_id as defined in the models config.
BOS token. Defaults to `bos_token_id` as defined in the models config.
pad_token_id: (`optional`) int
Pad token. Defaults to pad_token_id as defined in the models config.
eos_token_ids: (`optional`) int or list of int
End of sequence token or list of tokens to stop the generation. Default to eos_token_ids as defined in the models config.
eos_token_id: (`optional`) int
EOS token. Defaults to `eos_token_id` as defined in the models config.
length_penalty: (`optional`) float
Exponential penalty to the length. Default to 1.

View File

@ -468,7 +468,7 @@ class BartModelIntegrationTests(unittest.TestCase):
length_penalty=1.0,
no_repeat_ngram_size=3,
early_stopping=True,
decoder_start_token_id=model.config.eos_token_ids[0],
decoder_start_token_id=model.config.eos_token_id,
)
decoded = [

View File

@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
import torch
from transformers import T5Config, T5Model, T5ForConditionalGeneration
from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP
@ -57,8 +58,9 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_ids=[1],
eos_token_id=1,
pad_token_id=0,
decoder_start_token_id=0,
scope=None,
):
self.parent = parent
@ -78,8 +80,9 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.scope = scope
self.eos_token_ids = eos_token_ids
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
@ -106,9 +109,10 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_ids=self.eos_token_ids,
eos_token_id=self.eos_token_id,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
)
return (
@ -123,6 +127,39 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def check_prepare_lm_labels_via_shift_left(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config)
model.to(torch_device)
model.eval()
# make sure that lm_labels are correctly padded from the right
lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id)
# add casaul pad token mask
triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not()
lm_labels.masked_fill_(triangular_mask, self.pad_token_id)
decoder_input_ids = model._shift_right(lm_labels)
for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)):
# first item
self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id)
if i < decoder_input_ids_slice.shape[-1]:
if i < decoder_input_ids.shape[-1] - 1:
# items before diagonal
self.parent.assertListEqual(
decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist()
)
# pad items after diagonal
if i < decoder_input_ids.shape[-1] - 2:
self.parent.assertListEqual(
decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist()
)
else:
# all items after square
self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
def create_and_check_t5_model(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
@ -197,6 +234,10 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
def test_config(self):
self.config_tester.run_common_tests()
def test_shift_right(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs)
def test_t5_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_model(*config_and_inputs)

View File

@ -52,7 +52,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_ids=[1],
eos_token_id=1,
pad_token_id=0,
scope=None,
):
@ -71,7 +71,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.eos_token_ids = eos_token_ids
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.scope = scope
@ -97,7 +97,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_ids=self.eos_token_ids,
eos_token_id=self.eos_token_id,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
)