mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Flos fix (#7384)
This commit is contained in:
parent
ae3e84f3ba
commit
4083a55ab0
@ -695,7 +695,7 @@ class Trainer:
|
||||
# set global_step to global_step of last saved checkpoint from model path
|
||||
try:
|
||||
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
|
||||
self.total_flos = getattr(model.config, "total_flos", 0)
|
||||
self.total_flos = getattr(self._actual_model(model).config, "total_flos", 0)
|
||||
|
||||
epochs_trained = self.global_step // num_update_steps_per_epoch
|
||||
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
|
||||
@ -1448,15 +1448,29 @@ class Trainer:
|
||||
:obj:`int`: The number of floating-point operations.
|
||||
"""
|
||||
|
||||
if isinstance(self.model, torch.nn.DataParallel) or isinstance(
|
||||
self.model, torch.nn.parallel.DistributedDataParallel
|
||||
):
|
||||
model = self.model.module
|
||||
else:
|
||||
model = self.model
|
||||
model = self._actual_model(self.model)
|
||||
|
||||
if hasattr(model, "floating_point_ops"):
|
||||
return model.floating_point_ops(inputs)
|
||||
|
||||
else:
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _actual_model(
|
||||
model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
|
||||
) -> torch.nn.modules.Module:
|
||||
"""
|
||||
|
||||
Args:
|
||||
model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
|
||||
Model object used during training
|
||||
|
||||
Returns:
|
||||
:obj:`torch.nn.modules.Module`: unwrapped module
|
||||
"""
|
||||
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
model = model.module
|
||||
else:
|
||||
model = model
|
||||
return model
|
||||
|
@ -336,3 +336,16 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
trainer = get_regression_trainer(train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5)
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, int(self.n_epochs))
|
||||
|
||||
def test_flos_extraction(self):
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
|
||||
def assert_flos_extraction(trainer, wrapped_model_to_check):
|
||||
self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check))
|
||||
self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0)
|
||||
|
||||
# with plain model
|
||||
assert_flos_extraction(trainer, trainer.model)
|
||||
|
||||
# with enforced DataParallel
|
||||
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
|
||||
|
Loading…
Reference in New Issue
Block a user