Fix test_number_of_steps_in_training_with_ipex (#17889)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-06-28 08:55:02 +02:00 committed by GitHub
parent 0b0dd97737
commit f717d47fe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -649,14 +649,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Regular training has n_epochs * len(train_dl) steps
trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True)
train_output = trainer.train()
self.assertEqual(train_output.global_step, self.n_epochs * 64 / self.batch_size)
self.assertEqual(train_output.global_step, self.n_epochs * 64 / trainer.args.train_batch_size)
# Check passing num_train_epochs works (and a float version too):
trainer = get_regression_trainer(
learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True
)
train_output = trainer.train()
self.assertEqual(train_output.global_step, int(1.5 * 64 / self.batch_size))
self.assertEqual(train_output.global_step, int(1.5 * 64 / trainer.args.train_batch_size))
# If we pass a max_steps, num_train_epochs is ignored
trainer = get_regression_trainer(