mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-13 01:30:04 +06:00
parent
c6e865ac2b
commit
d8ca57d2ce
@ -21,7 +21,7 @@ if is_torch_available():
|
|||||||
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
|
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
class TestTrainerCallback(TrainerCallback):
|
class MyTestTrainerCallback(TrainerCallback):
|
||||||
"A callback that registers the events that goes through."
|
"A callback that registers the events that goes through."
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -134,8 +134,8 @@ class TrainerCallbackTest(unittest.TestCase):
|
|||||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||||
|
|
||||||
# Callbacks passed at init are added to the default callbacks
|
# Callbacks passed at init are added to the default callbacks
|
||||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback])
|
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
|
||||||
expected_callbacks.append(TestTrainerCallback)
|
expected_callbacks.append(MyTestTrainerCallback)
|
||||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||||
|
|
||||||
# TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback
|
# TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback
|
||||||
@ -179,35 +179,44 @@ class TrainerCallbackTest(unittest.TestCase):
|
|||||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||||
|
|
||||||
def test_event_flow(self):
|
def test_event_flow(self):
|
||||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback])
|
import warnings
|
||||||
|
|
||||||
|
# XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested
|
||||||
|
warnings.simplefilter(action="ignore", category=UserWarning)
|
||||||
|
|
||||||
|
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
|
||||||
trainer.train()
|
trainer.train()
|
||||||
events = trainer.callback_handler.callbacks[-2].events
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
self.assertEqual(events, self.get_expected_events(trainer))
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
# Independent log/save/eval
|
# Independent log/save/eval
|
||||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback], logging_steps=5)
|
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
events = trainer.callback_handler.callbacks[-2].events
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
self.assertEqual(events, self.get_expected_events(trainer))
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback], save_steps=5)
|
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
events = trainer.callback_handler.callbacks[-2].events
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
self.assertEqual(events, self.get_expected_events(trainer))
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback], eval_steps=5, evaluation_strategy="steps")
|
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, evaluation_strategy="steps")
|
||||||
trainer.train()
|
trainer.train()
|
||||||
events = trainer.callback_handler.callbacks[-2].events
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
self.assertEqual(events, self.get_expected_events(trainer))
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback], evaluation_strategy="epoch")
|
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], evaluation_strategy="epoch")
|
||||||
trainer.train()
|
trainer.train()
|
||||||
events = trainer.callback_handler.callbacks[-2].events
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
self.assertEqual(events, self.get_expected_events(trainer))
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
# A bit of everything
|
# A bit of everything
|
||||||
trainer = self.get_trainer(
|
trainer = self.get_trainer(
|
||||||
callbacks=[TestTrainerCallback], logging_steps=3, save_steps=10, eval_steps=5, evaluation_strategy="steps"
|
callbacks=[MyTestTrainerCallback],
|
||||||
|
logging_steps=3,
|
||||||
|
save_steps=10,
|
||||||
|
eval_steps=5,
|
||||||
|
evaluation_strategy="steps",
|
||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
events = trainer.callback_handler.callbacks[-2].events
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
|
Loading…
Reference in New Issue
Block a user