mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-12 17:20:03 +06:00
parent
c6e865ac2b
commit
d8ca57d2ce
@ -21,7 +21,7 @@ if is_torch_available():
|
||||
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
|
||||
|
||||
|
||||
class TestTrainerCallback(TrainerCallback):
|
||||
class MyTestTrainerCallback(TrainerCallback):
|
||||
"A callback that registers the events that goes through."
|
||||
|
||||
def __init__(self):
|
||||
@ -134,8 +134,8 @@ class TrainerCallbackTest(unittest.TestCase):
|
||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||
|
||||
# Callbacks passed at init are added to the default callbacks
|
||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback])
|
||||
expected_callbacks.append(TestTrainerCallback)
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
|
||||
expected_callbacks.append(MyTestTrainerCallback)
|
||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||
|
||||
# TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback
|
||||
@ -179,35 +179,44 @@ class TrainerCallbackTest(unittest.TestCase):
|
||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||
|
||||
def test_event_flow(self):
|
||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback])
|
||||
import warnings
|
||||
|
||||
# XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested
|
||||
warnings.simplefilter(action="ignore", category=UserWarning)
|
||||
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
# Independent log/save/eval
|
||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback], logging_steps=5)
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5)
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback], save_steps=5)
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5)
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback], eval_steps=5, evaluation_strategy="steps")
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, evaluation_strategy="steps")
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback], evaluation_strategy="epoch")
|
||||
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], evaluation_strategy="epoch")
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
# A bit of everything
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[TestTrainerCallback], logging_steps=3, save_steps=10, eval_steps=5, evaluation_strategy="steps"
|
||||
callbacks=[MyTestTrainerCallback],
|
||||
logging_steps=3,
|
||||
save_steps=10,
|
||||
eval_steps=5,
|
||||
evaluation_strategy="steps",
|
||||
)
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
|
Loading…
Reference in New Issue
Block a user