update the train_batch_size in case HPO change batch_size_per_device (#18918)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2022-09-07 20:01:30 +08:00 committed by GitHub
parent 4f299b2446
commit d842f2d5b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1488,6 +1488,7 @@ class Trainer:
raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
# This might change the seed so needs to run first.
self._hp_search_setup(trial)
self._train_batch_size = self.args.train_batch_size
# Model re-init
model_reloaded = False