diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 81694b7fbeb..eb23e7eb1af 100755 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -195,7 +195,9 @@ class TrainerIntegrationTest(unittest.TestCase): args = TrainingArguments("./regression") dict1, dict2 = args.to_dict(), trainer.args.to_dict() for key in dict1.keys(): - self.assertEqual(dict1[key], dict2[key]) + # Logging dir can be slightly different as they default to something with the time. + if key != "loggin_dir": + self.assertEqual(dict1[key], dict2[key]) def test_reproducible_training(self): # Checks that training worked, model trained and seed made a reproducible training.