mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix test_number_of_steps_in_training_with_ipex
(#17889)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
0b0dd97737
commit
f717d47fe0
@ -649,14 +649,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
# Regular training has n_epochs * len(train_dl) steps
|
||||
trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True)
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, self.n_epochs * 64 / self.batch_size)
|
||||
self.assertEqual(train_output.global_step, self.n_epochs * 64 / trainer.args.train_batch_size)
|
||||
|
||||
# Check passing num_train_epochs works (and a float version too):
|
||||
trainer = get_regression_trainer(
|
||||
learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True
|
||||
)
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, int(1.5 * 64 / self.batch_size))
|
||||
self.assertEqual(train_output.global_step, int(1.5 * 64 / trainer.args.train_batch_size))
|
||||
|
||||
# If we pass a max_steps, num_train_epochs is ignored
|
||||
trainer = get_regression_trainer(
|
||||
|
Loading…
Reference in New Issue
Block a user