This commit is contained in:
Marcin Zabłocki 2020-09-28 10:09:26 +02:00 committed by GitHub
parent ae3e84f3ba
commit 4083a55ab0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 7 deletions

View File

@ -695,7 +695,7 @@ class Trainer:
# set global_step to global_step of last saved checkpoint from model path
try:
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
self.total_flos = getattr(model.config, "total_flos", 0)
self.total_flos = getattr(self._actual_model(model).config, "total_flos", 0)
epochs_trained = self.global_step // num_update_steps_per_epoch
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
@ -1448,15 +1448,29 @@ class Trainer:
:obj:`int`: The number of floating-point operations.
"""
if isinstance(self.model, torch.nn.DataParallel) or isinstance(
self.model, torch.nn.parallel.DistributedDataParallel
):
model = self.model.module
else:
model = self.model
model = self._actual_model(self.model)
if hasattr(model, "floating_point_ops"):
return model.floating_point_ops(inputs)
else:
return 0
@staticmethod
def _actual_model(
model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
) -> torch.nn.modules.Module:
"""
Args:
model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
Model object used during training
Returns:
:obj:`torch.nn.modules.Module`: unwrapped module
"""
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module
else:
model = model
return model

View File

@ -336,3 +336,16 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer = get_regression_trainer(train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5)
train_output = trainer.train()
self.assertEqual(train_output.global_step, int(self.n_epochs))
def test_flos_extraction(self):
trainer = get_regression_trainer(learning_rate=0.1)
def assert_flos_extraction(trainer, wrapped_model_to_check):
self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check))
self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0)
# with plain model
assert_flos_extraction(trainer, trainer.model)
# with enforced DataParallel
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))