Add timing inside Trainer (#9196)

* Add timing inside Trainer

* Fix tests

* Add n_objs for train

* Sort logs
This commit is contained in:
Sylvain Gugger 2020-12-18 15:10:39 -05:00 committed by GitHub
parent 9a25c5bd3a
commit 1198ba8fba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 76 additions and 49 deletions

View File

@ -16,7 +16,6 @@
import logging
import os
import sys
import time
from dataclasses import dataclass, field
from typing import Optional
@ -120,30 +119,6 @@ class DataTrainingArguments:
)
def speed_metrics(split, start_time, num_samples):
"""
Measure and return speed performance metrics.
This function requires a time snapshot `start_time` before the operation to be measured starts and this
function should be run immediately after the operation to be measured has completed.
Args:
- split: one of train, val, test
- start_time: operation start time
- num_samples: number of samples processed
"""
runtime = time.time() - start_time
result = {}
samples_per_second = 1 / (runtime / num_samples)
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
result[f"{split}_runtime"] = round(runtime, 4)
result[f"{split}_n_ojbs"] = num_samples
return result
def handle_metrics(split, metrics, output_dir):
"""
Log and save metrics
@ -155,8 +130,8 @@ def handle_metrics(split, metrics, output_dir):
"""
logger.info(f"***** {split} metrics *****")
for key, value in metrics.items():
logger.info(f" {key} = {value}")
for key in sorted(metrics.keys()):
logger.info(f" {key} = {metrics[key]}")
save_json(metrics, os.path.join(output_dir, f"{split}_results.json"))
@ -311,11 +286,11 @@ def main():
if training_args.do_train:
logger.info("*** Train ***")
start_time = time.time()
trainer.train(
train_result = trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
)
metrics = speed_metrics("train", start_time, data_args.n_train)
metrics = train_result.metrics
metrics["train_n_objs"] = data_args.n_train
trainer.save_model() # this also saves the tokenizer
@ -334,9 +309,8 @@ def main():
if training_args.do_eval:
logger.info("*** Evaluate ***")
start_time = time.time()
metrics = trainer.evaluate(metric_key_prefix="val")
metrics.update(speed_metrics("val", start_time, data_args.n_val))
metrics["val_n_objs"] = data_args.n_val
metrics["val_loss"] = round(metrics["val_loss"], 4)
if trainer.is_world_process_zero():
@ -347,10 +321,9 @@ def main():
if training_args.do_predict:
logger.info("*** Predict ***")
start_time = time.time()
test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
metrics = test_output.metrics
metrics.update(speed_metrics("test", start_time, data_args.n_test))
metrics["test_n_objs"] = data_args.n_test
if trainer.is_world_process_zero():
metrics["test_loss"] = round(metrics["test_loss"], 4)

View File

@ -97,9 +97,7 @@ class ExamplesTests(TestCasePlus):
with patch.object(sys, "argv", testargs):
result = run_glue.main()
del result["eval_loss"]
for value in result.values():
self.assertGreaterEqual(value, 0.75)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
@require_torch_non_multi_gpu_but_fix_me
def test_run_clm(self):

View File

@ -22,6 +22,7 @@ import math
import os
import re
import shutil
import time
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@ -89,6 +90,7 @@ from .trainer_utils import (
default_compute_objective,
default_hp_space,
set_seed,
speed_metrics,
)
from .training_args import TrainingArguments
from .utils import logging
@ -707,6 +709,7 @@ class Trainer:
logger.info(f" Total optimization steps = {max_steps}")
self.state.epoch = 0
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
@ -870,15 +873,17 @@ class Trainer:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
metrics = speed_metrics("train", start_time, self.state.max_steps)
if self._total_flos is not None:
self.store_flos()
self.log({"total_flos": self.state.total_flos})
metrics["total_flos"] = self.state.total_flos
self.log(metrics)
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step)
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log:
@ -1317,6 +1322,7 @@ class Trainer:
raise ValueError("eval_dataset must implement __len__")
eval_dataloader = self.get_eval_dataloader(eval_dataset)
start_time = time.time()
output = self.prediction_loop(
eval_dataloader,
@ -1328,6 +1334,8 @@ class Trainer:
metric_key_prefix=metric_key_prefix,
)
n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset)
output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples))
self.log(output.metrics)
if self.args.tpu_metrics_debug or self.args.debug:
@ -1374,10 +1382,13 @@ class Trainer:
raise ValueError("test_dataset must implement __len__")
test_dataloader = self.get_test_dataloader(test_dataset)
start_time = time.time()
return self.prediction_loop(
output = self.prediction_loop(
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset)))
return output
def prediction_loop(
self,

View File

@ -18,6 +18,7 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc
import copy
import random
import time
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
import numpy as np
@ -70,6 +71,7 @@ class PredictionOutput(NamedTuple):
class TrainOutput(NamedTuple):
global_step: int
training_loss: float
metrics: Dict[str, float]
PREFIX_CHECKPOINT_DIR = "checkpoint"
@ -179,3 +181,23 @@ def total_processes_number(local_rank):
return torch.distributed.get_world_size()
return 1
def speed_metrics(split, start_time, num_samples=None):
"""
Measure and return speed performance metrics.
This function requires a time snapshot `start_time` before the operation to be measured starts and this function
should be run immediately after the operation to be measured has completed.
Args:
- split: name to prefix metric (like train, eval, test...)
- start_time: operation start time
- num_samples: number of samples processed
"""
runtime = time.time() - start_time
result = {f"{split}_runtime": round(runtime, 4)}
if num_samples is not None:
samples_per_second = 1 / (runtime / num_samples)
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
return result

View File

@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import json
import os
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
@ -411,7 +410,16 @@ class TrainingArguments:
self.run_name = self.output_dir
if is_torch_available() and self.device.type != "cuda" and self.fp16:
raise ValueError("AMP (`--fp16`) can only be used on CUDA devices.")
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
self_as_dict = asdict(self)
del self_as_dict["per_gpu_train_batch_size"]
del self_as_dict["per_gpu_eval_batch_size"]
attrs_as_str = [f"{k}={v}" for k, v in self_as_dict.items()]
return f"{self.__class__.__name__}({', '.join(attrs_as_str)})"
@property
def train_batch_size(self) -> int:
@ -523,7 +531,7 @@ class TrainingArguments:
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
"""
d = dataclasses.asdict(self)
d = asdict(self)
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value

View File

@ -265,6 +265,21 @@ class TrainerIntegrationTest(unittest.TestCase):
metrics = trainer.evaluate()
self.assertEqual(metrics[metric], best_value)
def check_trainer_state_are_the_same(self, trainer_state, trainer_state1):
# We'll pop things so operate on copies.
state = trainer_state.copy()
state1 = trainer_state1.copy()
# Log history main contain different logs for the time metrics (after resuming a training).
log_history = state.pop("log_history", None)
log_history1 = state1.pop("log_history", None)
self.assertEqual(state, state1)
for log, log1 in zip(log_history, log_history1):
_ = log.pop("train_runtime", None)
_ = log1.pop("train_runtime", None)
_ = log.pop("train_samples_per_second", None)
_ = log1.pop("train_samples_per_second", None)
self.assertEqual(log, log1)
def test_trainer_works_with_dict(self):
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
# anything.
@ -552,7 +567,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
self.check_trainer_state_are_the_same(state, state1)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint = os.path.join(tmpdir, "checkpoint-15")
@ -566,7 +581,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
self.check_trainer_state_are_the_same(state, state1)
# With a regular model that is not a PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir:
@ -590,7 +605,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
self.check_trainer_state_are_the_same(state, state1)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint = os.path.join(tmpdir, "checkpoint-15")
@ -606,7 +621,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
self.check_trainer_state_are_the_same(state, state1)
def test_resume_training_with_gradient_accumulation(self):
if torch.cuda.device_count() > 2:
@ -638,7 +653,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
self.check_trainer_state_are_the_same(state, state1)
def test_load_best_model_at_end(self):
total = int(self.n_epochs * 64 / self.batch_size)