diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ae1b2458524..f7c3836d4af 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1488,6 +1488,7 @@ class Trainer: raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") # This might change the seed so needs to run first. self._hp_search_setup(trial) + self._train_batch_size = self.args.train_batch_size # Model re-init model_reloaded = False