mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-21 21:49:06 +06:00
Modify warnings
in a with
block to avoid flaky tests (#31893)
* fix * [test_all] check before merge --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
ec03d97b27
commit
080e14b24c
@ -218,52 +218,53 @@ class TrainerCallbackTest(unittest.TestCase):
|
||||
import warnings
|
||||
|
||||
# XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested
|
||||
warnings.simplefilter(action="ignore", category=UserWarning)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter(action="ignore", category=UserWarning)
|
||||
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
# Independent log/save/eval
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5)
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
# Independent log/save/eval
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5)
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5)
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5)
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, eval_strategy="steps")
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, eval_strategy="steps")
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_strategy="epoch")
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_strategy="epoch")
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
# A bit of everything
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[MyTestTrainerCallback],
|
||||
logging_steps=3,
|
||||
save_steps=10,
|
||||
eval_steps=5,
|
||||
eval_strategy="steps",
|
||||
)
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
# warning should be emitted for duplicated callbacks
|
||||
with patch("transformers.trainer_callback.logger.warning") as warn_mock:
|
||||
# A bit of everything
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
|
||||
callbacks=[MyTestTrainerCallback],
|
||||
logging_steps=3,
|
||||
save_steps=10,
|
||||
eval_steps=5,
|
||||
eval_strategy="steps",
|
||||
)
|
||||
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
# warning should be emitted for duplicated callbacks
|
||||
with patch("transformers.trainer_callback.logger.warning") as warn_mock:
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
|
||||
)
|
||||
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]
|
||||
|
||||
def test_stateful_callbacks(self):
|
||||
# Use something with non-defaults
|
||||
|
Loading…
Reference in New Issue
Block a user