mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Do not log the generation config for each prediction step in TrainerSeq2Seq (#21385)
Do not log the generation config for each iteration
This commit is contained in:
parent
98d40fed3a
commit
d31497b196
@ -199,6 +199,11 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
generation_inputs,
|
generation_inputs,
|
||||||
**gen_kwargs,
|
**gen_kwargs,
|
||||||
)
|
)
|
||||||
|
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
||||||
|
# TODO: remove this hack when the legacy code that initializes generation_config from a model config is
|
||||||
|
# removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
|
||||||
|
if self.model.generation_config._from_model_config:
|
||||||
|
self.model.generation_config._from_model_config = False
|
||||||
# in case the batch is shorter than max length, the output should be padded
|
# in case the batch is shorter than max length, the output should be padded
|
||||||
if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
|
if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
|
||||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
||||||
|
Loading…
Reference in New Issue
Block a user