From d8ca57d2cea34f949a14992d0f68d8173f30cecd Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 16 Oct 2020 00:19:51 -0700 Subject: [PATCH] fix/hide warnings (#7837) s --- tests/test_trainer_callback.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/test_trainer_callback.py b/tests/test_trainer_callback.py index 0469c077fa7..ad3662adf26 100644 --- a/tests/test_trainer_callback.py +++ b/tests/test_trainer_callback.py @@ -21,7 +21,7 @@ if is_torch_available(): from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel -class TestTrainerCallback(TrainerCallback): +class MyTestTrainerCallback(TrainerCallback): "A callback that registers the events that goes through." def __init__(self): @@ -134,8 +134,8 @@ class TrainerCallbackTest(unittest.TestCase): self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) # Callbacks passed at init are added to the default callbacks - trainer = self.get_trainer(callbacks=[TestTrainerCallback]) - expected_callbacks.append(TestTrainerCallback) + trainer = self.get_trainer(callbacks=[MyTestTrainerCallback]) + expected_callbacks.append(MyTestTrainerCallback) self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) # TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback @@ -179,35 +179,44 @@ class TrainerCallbackTest(unittest.TestCase): self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) def test_event_flow(self): - trainer = self.get_trainer(callbacks=[TestTrainerCallback]) + import warnings + + # XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested + warnings.simplefilter(action="ignore", category=UserWarning) + + trainer = self.get_trainer(callbacks=[MyTestTrainerCallback]) trainer.train() events = trainer.callback_handler.callbacks[-2].events self.assertEqual(events, self.get_expected_events(trainer)) # Independent log/save/eval - trainer = self.get_trainer(callbacks=[TestTrainerCallback], logging_steps=5) + trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5) trainer.train() events = trainer.callback_handler.callbacks[-2].events self.assertEqual(events, self.get_expected_events(trainer)) - trainer = self.get_trainer(callbacks=[TestTrainerCallback], save_steps=5) + trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5) trainer.train() events = trainer.callback_handler.callbacks[-2].events self.assertEqual(events, self.get_expected_events(trainer)) - trainer = self.get_trainer(callbacks=[TestTrainerCallback], eval_steps=5, evaluation_strategy="steps") + trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, evaluation_strategy="steps") trainer.train() events = trainer.callback_handler.callbacks[-2].events self.assertEqual(events, self.get_expected_events(trainer)) - trainer = self.get_trainer(callbacks=[TestTrainerCallback], evaluation_strategy="epoch") + trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], evaluation_strategy="epoch") trainer.train() events = trainer.callback_handler.callbacks[-2].events self.assertEqual(events, self.get_expected_events(trainer)) # A bit of everything trainer = self.get_trainer( - callbacks=[TestTrainerCallback], logging_steps=3, save_steps=10, eval_steps=5, evaluation_strategy="steps" + callbacks=[MyTestTrainerCallback], + logging_steps=3, + save_steps=10, + eval_steps=5, + evaluation_strategy="steps", ) trainer.train() events = trainer.callback_handler.callbacks[-2].events