mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
T5ForConditionalGeneration: enabling using past_key_values and labels in training (#13805)
* enabling using past_key_values together with labels when training in T5ForConditionalGeneration * test * Enable past_key_values in T5ForconditionalGeneration while training. * delete comments
This commit is contained in:
parent
dac7798144
commit
aea7c5b0c8
@ -1593,15 +1593,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
# get decoder inputs from shifting lm labels to the right
|
# get decoder inputs from shifting lm labels to the right
|
||||||
decoder_input_ids = self._shift_right(labels)
|
decoder_input_ids = self._shift_right(labels)
|
||||||
|
|
||||||
# If decoding with past key value states, only the last tokens
|
|
||||||
# should be given as an input
|
|
||||||
if past_key_values is not None:
|
|
||||||
assert labels is None, "Decoder should not use cached key value states when training."
|
|
||||||
if decoder_input_ids is not None:
|
|
||||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
|
||||||
if decoder_inputs_embeds is not None:
|
|
||||||
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
|
|
||||||
|
|
||||||
# Set device for model parallelism
|
# Set device for model parallelism
|
||||||
if self.model_parallel:
|
if self.model_parallel:
|
||||||
torch.cuda.set_device(self.decoder.first_device)
|
torch.cuda.set_device(self.decoder.first_device)
|
||||||
|
Loading…
Reference in New Issue
Block a user