mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-22 14:00:33 +06:00
Modify warnings
in a with
block to avoid flaky tests (#31893)
* fix * [test_all] check before merge --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
ec03d97b27
commit
080e14b24c
@ -218,52 +218,53 @@ class TrainerCallbackTest(unittest.TestCase):
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
# XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested
|
# XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested
|
||||||
warnings.simplefilter(action="ignore", category=UserWarning)
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter(action="ignore", category=UserWarning)
|
||||||
|
|
||||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
|
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
|
||||||
trainer.train()
|
trainer.train()
|
||||||
events = trainer.callback_handler.callbacks[-2].events
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
self.assertEqual(events, self.get_expected_events(trainer))
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
# Independent log/save/eval
|
# Independent log/save/eval
|
||||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5)
|
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
events = trainer.callback_handler.callbacks[-2].events
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
self.assertEqual(events, self.get_expected_events(trainer))
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5)
|
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
events = trainer.callback_handler.callbacks[-2].events
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
self.assertEqual(events, self.get_expected_events(trainer))
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, eval_strategy="steps")
|
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, eval_strategy="steps")
|
||||||
trainer.train()
|
trainer.train()
|
||||||
events = trainer.callback_handler.callbacks[-2].events
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
self.assertEqual(events, self.get_expected_events(trainer))
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_strategy="epoch")
|
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_strategy="epoch")
|
||||||
trainer.train()
|
trainer.train()
|
||||||
events = trainer.callback_handler.callbacks[-2].events
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
self.assertEqual(events, self.get_expected_events(trainer))
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
# A bit of everything
|
# A bit of everything
|
||||||
trainer = self.get_trainer(
|
|
||||||
callbacks=[MyTestTrainerCallback],
|
|
||||||
logging_steps=3,
|
|
||||||
save_steps=10,
|
|
||||||
eval_steps=5,
|
|
||||||
eval_strategy="steps",
|
|
||||||
)
|
|
||||||
trainer.train()
|
|
||||||
events = trainer.callback_handler.callbacks[-2].events
|
|
||||||
self.assertEqual(events, self.get_expected_events(trainer))
|
|
||||||
|
|
||||||
# warning should be emitted for duplicated callbacks
|
|
||||||
with patch("transformers.trainer_callback.logger.warning") as warn_mock:
|
|
||||||
trainer = self.get_trainer(
|
trainer = self.get_trainer(
|
||||||
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
|
callbacks=[MyTestTrainerCallback],
|
||||||
|
logging_steps=3,
|
||||||
|
save_steps=10,
|
||||||
|
eval_steps=5,
|
||||||
|
eval_strategy="steps",
|
||||||
)
|
)
|
||||||
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]
|
trainer.train()
|
||||||
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
|
# warning should be emitted for duplicated callbacks
|
||||||
|
with patch("transformers.trainer_callback.logger.warning") as warn_mock:
|
||||||
|
trainer = self.get_trainer(
|
||||||
|
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
|
||||||
|
)
|
||||||
|
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]
|
||||||
|
|
||||||
def test_stateful_callbacks(self):
|
def test_stateful_callbacks(self):
|
||||||
# Use something with non-defaults
|
# Use something with non-defaults
|
||||||
|
Loading…
Reference in New Issue
Block a user