Exclude torch.compile time from metrics computation (#31443)

* exclude compile time from metrics computation

* fix the quality issue
This commit is contained in:
xiangdong 2024-07-05 14:11:55 +08:00 committed by GitHub
parent 2aa2a14481
commit d19b5a90c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3670,6 +3670,8 @@ class Trainer:
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
@ -3739,6 +3741,8 @@ class Trainer:
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
@ -3777,11 +3781,13 @@ class Trainer:
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
if len(self.accelerator._models) == 0 and model is self.model:
start_time = time.time()
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
self.model_preparation_time = round(time.time() - start_time, 4)
if self.is_fsdp_enabled:
self.model = model
@ -3954,6 +3960,8 @@ class Trainer:
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
if hasattr(self, "jit_compilation_time"):
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
if hasattr(self, "model_preparation_time"):
metrics[f"{metric_key_prefix}_model_preparation_time"] = self.model_preparation_time
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):