mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Flos fix (#7384)
This commit is contained in:
parent
ae3e84f3ba
commit
4083a55ab0
@ -695,7 +695,7 @@ class Trainer:
|
|||||||
# set global_step to global_step of last saved checkpoint from model path
|
# set global_step to global_step of last saved checkpoint from model path
|
||||||
try:
|
try:
|
||||||
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
|
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
|
||||||
self.total_flos = getattr(model.config, "total_flos", 0)
|
self.total_flos = getattr(self._actual_model(model).config, "total_flos", 0)
|
||||||
|
|
||||||
epochs_trained = self.global_step // num_update_steps_per_epoch
|
epochs_trained = self.global_step // num_update_steps_per_epoch
|
||||||
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
|
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
|
||||||
@ -1448,15 +1448,29 @@ class Trainer:
|
|||||||
:obj:`int`: The number of floating-point operations.
|
:obj:`int`: The number of floating-point operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(self.model, torch.nn.DataParallel) or isinstance(
|
model = self._actual_model(self.model)
|
||||||
self.model, torch.nn.parallel.DistributedDataParallel
|
|
||||||
):
|
|
||||||
model = self.model.module
|
|
||||||
else:
|
|
||||||
model = self.model
|
|
||||||
|
|
||||||
if hasattr(model, "floating_point_ops"):
|
if hasattr(model, "floating_point_ops"):
|
||||||
return model.floating_point_ops(inputs)
|
return model.floating_point_ops(inputs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _actual_model(
|
||||||
|
model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
|
||||||
|
) -> torch.nn.modules.Module:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
|
||||||
|
Model object used during training
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`torch.nn.modules.Module`: unwrapped module
|
||||||
|
"""
|
||||||
|
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||||
|
model = model.module
|
||||||
|
else:
|
||||||
|
model = model
|
||||||
|
return model
|
||||||
|
@ -336,3 +336,16 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
trainer = get_regression_trainer(train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5)
|
trainer = get_regression_trainer(train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5)
|
||||||
train_output = trainer.train()
|
train_output = trainer.train()
|
||||||
self.assertEqual(train_output.global_step, int(self.n_epochs))
|
self.assertEqual(train_output.global_step, int(self.n_epochs))
|
||||||
|
|
||||||
|
def test_flos_extraction(self):
|
||||||
|
trainer = get_regression_trainer(learning_rate=0.1)
|
||||||
|
|
||||||
|
def assert_flos_extraction(trainer, wrapped_model_to_check):
|
||||||
|
self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check))
|
||||||
|
self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0)
|
||||||
|
|
||||||
|
# with plain model
|
||||||
|
assert_flos_extraction(trainer, trainer.model)
|
||||||
|
|
||||||
|
# with enforced DataParallel
|
||||||
|
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
|
||||||
|
Loading…
Reference in New Issue
Block a user