mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Seq2SeqTrainer: use unwrapped model to retrieve the generation config (#22584)
This commit is contained in:
parent
0aa1153ffb
commit
48706c7178
@ -277,7 +277,7 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
self.model.generation_config._from_model_config = False
|
self.model.generation_config._from_model_config = False
|
||||||
|
|
||||||
# Retrieves GenerationConfig from model.generation_config
|
# Retrieves GenerationConfig from model.generation_config
|
||||||
gen_config = model.generation_config
|
gen_config = self.model.generation_config
|
||||||
# in case the batch is shorter than max length, the output should be padded
|
# in case the batch is shorter than max length, the output should be padded
|
||||||
if generated_tokens.shape[-1] < gen_config.max_length:
|
if generated_tokens.shape[-1] < gen_config.max_length:
|
||||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
|
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
|
||||||
|
Loading…
Reference in New Issue
Block a user