Do not log the generation config for each prediction step in TrainerSeq2Seq (#21385)

Do not log the generation config for each iteration
This commit is contained in:
regisss 2023-01-31 15:05:22 +01:00 committed by GitHub
parent 98d40fed3a
commit d31497b196
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -199,6 +199,11 @@ class Seq2SeqTrainer(Trainer):
generation_inputs,
**gen_kwargs,
)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# TODO: remove this hack when the legacy code that initializes generation_config from a model config is
# removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
if self.model.generation_config._from_model_config:
self.model.generation_config._from_model_config = False
# in case the batch is shorter than max length, the output should be padded
if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])