T5ForConditionalGeneration: enabling using past_key_values and labels in training (#13805)

* enabling using past_key_values together with labels when training in T5ForConditionalGeneration

* test

* Enable past_key_values in T5ForconditionalGeneration while training.

* delete comments
This commit is contained in:
yssjtu 2021-10-06 15:20:41 +08:00 committed by GitHub
parent dac7798144
commit aea7c5b0c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1593,15 +1593,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
# If decoding with past key value states, only the last tokens
# should be given as an input
if past_key_values is not None:
assert labels is None, "Decoder should not use cached key value states when training."
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)