mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Seq2SeqTrainer: use unwrapped model to retrieve the generation config (#22584)
This commit is contained in:
parent
0aa1153ffb
commit
48706c7178
@ -277,7 +277,7 @@ class Seq2SeqTrainer(Trainer):
|
||||
self.model.generation_config._from_model_config = False
|
||||
|
||||
# Retrieves GenerationConfig from model.generation_config
|
||||
gen_config = model.generation_config
|
||||
gen_config = self.model.generation_config
|
||||
# in case the batch is shorter than max length, the output should be padded
|
||||
if generated_tokens.shape[-1] < gen_config.max_length:
|
||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
|
||||
|
Loading…
Reference in New Issue
Block a user