diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index e760ab55c17..4068f0c3837 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -358,9 +358,9 @@ class CallbackHandler(TrainerCallback): def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): return self.call_event("on_step_end", args, state, control) - def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics): + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs): control.should_evaluate = False - return self.call_event("on_evaluate", args, state, control, metrics=metrics) + return self.call_event("on_evaluate", args, state, control, metrics=metrics, **kwargs) def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): control.should_save = False