Seq2SeqTrainer: use unwrapped model to retrieve the generation config (#22584)

This commit is contained in:
Joao Gante 2023-04-06 13:29:58 +01:00 committed by GitHub
parent 0aa1153ffb
commit 48706c7178
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -277,7 +277,7 @@ class Seq2SeqTrainer(Trainer):
self.model.generation_config._from_model_config = False
# Retrieves GenerationConfig from model.generation_config
gen_config = model.generation_config
gen_config = self.model.generation_config
# in case the batch is shorter than max length, the output should be padded
if generated_tokens.shape[-1] < gen_config.max_length:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)