mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
T5ForConditionalGeneration: enabling using past_key_values and labels in training (#13805)
* enabling using past_key_values together with labels when training in T5ForConditionalGeneration * test * Enable past_key_values in T5ForconditionalGeneration while training. * delete comments
This commit is contained in:
parent
dac7798144
commit
aea7c5b0c8
@ -1593,15 +1593,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
# get decoder inputs from shifting lm labels to the right
|
||||
decoder_input_ids = self._shift_right(labels)
|
||||
|
||||
# If decoding with past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if past_key_values is not None:
|
||||
assert labels is None, "Decoder should not use cached key value states when training."
|
||||
if decoder_input_ids is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
if decoder_inputs_embeds is not None:
|
||||
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
|
||||
|
||||
# Set device for model parallelism
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(self.decoder.first_device)
|
||||
|
Loading…
Reference in New Issue
Block a user