fix/hide warnings (#7837)

s
This commit is contained in:
Stas Bekman 2020-10-16 00:19:51 -07:00 committed by GitHub
parent c6e865ac2b
commit d8ca57d2ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -21,7 +21,7 @@ if is_torch_available():
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
class TestTrainerCallback(TrainerCallback): class MyTestTrainerCallback(TrainerCallback):
"A callback that registers the events that goes through." "A callback that registers the events that goes through."
def __init__(self): def __init__(self):
@ -134,8 +134,8 @@ class TrainerCallbackTest(unittest.TestCase):
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
# Callbacks passed at init are added to the default callbacks # Callbacks passed at init are added to the default callbacks
trainer = self.get_trainer(callbacks=[TestTrainerCallback]) trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
expected_callbacks.append(TestTrainerCallback) expected_callbacks.append(MyTestTrainerCallback)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
# TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback # TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback
@ -179,35 +179,44 @@ class TrainerCallbackTest(unittest.TestCase):
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
def test_event_flow(self): def test_event_flow(self):
trainer = self.get_trainer(callbacks=[TestTrainerCallback]) import warnings
# XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested
warnings.simplefilter(action="ignore", category=UserWarning)
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
trainer.train() trainer.train()
events = trainer.callback_handler.callbacks[-2].events events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer)) self.assertEqual(events, self.get_expected_events(trainer))
# Independent log/save/eval # Independent log/save/eval
trainer = self.get_trainer(callbacks=[TestTrainerCallback], logging_steps=5) trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5)
trainer.train() trainer.train()
events = trainer.callback_handler.callbacks[-2].events events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer)) self.assertEqual(events, self.get_expected_events(trainer))
trainer = self.get_trainer(callbacks=[TestTrainerCallback], save_steps=5) trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5)
trainer.train() trainer.train()
events = trainer.callback_handler.callbacks[-2].events events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer)) self.assertEqual(events, self.get_expected_events(trainer))
trainer = self.get_trainer(callbacks=[TestTrainerCallback], eval_steps=5, evaluation_strategy="steps") trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, evaluation_strategy="steps")
trainer.train() trainer.train()
events = trainer.callback_handler.callbacks[-2].events events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer)) self.assertEqual(events, self.get_expected_events(trainer))
trainer = self.get_trainer(callbacks=[TestTrainerCallback], evaluation_strategy="epoch") trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], evaluation_strategy="epoch")
trainer.train() trainer.train()
events = trainer.callback_handler.callbacks[-2].events events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer)) self.assertEqual(events, self.get_expected_events(trainer))
# A bit of everything # A bit of everything
trainer = self.get_trainer( trainer = self.get_trainer(
callbacks=[TestTrainerCallback], logging_steps=3, save_steps=10, eval_steps=5, evaluation_strategy="steps" callbacks=[MyTestTrainerCallback],
logging_steps=3,
save_steps=10,
eval_steps=5,
evaluation_strategy="steps",
) )
trainer.train() trainer.train()
events = trainer.callback_handler.callbacks[-2].events events = trainer.callback_handler.callbacks[-2].events