mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Exclude torch.compile time from metrics computation (#31443)
* exclude compile time from metrics computation * fix the quality issue
This commit is contained in:
parent
2aa2a14481
commit
d19b5a90c2
@ -3670,6 +3670,8 @@ class Trainer:
|
||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
|
||||
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
|
||||
if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
|
||||
start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
|
||||
output.metrics.update(
|
||||
speed_metrics(
|
||||
metric_key_prefix,
|
||||
@ -3739,6 +3741,8 @@ class Trainer:
|
||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
|
||||
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
|
||||
if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
|
||||
start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
|
||||
output.metrics.update(
|
||||
speed_metrics(
|
||||
metric_key_prefix,
|
||||
@ -3777,11 +3781,13 @@ class Trainer:
|
||||
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
|
||||
|
||||
if len(self.accelerator._models) == 0 and model is self.model:
|
||||
start_time = time.time()
|
||||
model = (
|
||||
self.accelerator.prepare(model)
|
||||
if self.is_deepspeed_enabled
|
||||
else self.accelerator.prepare_model(model, evaluation_mode=True)
|
||||
)
|
||||
self.model_preparation_time = round(time.time() - start_time, 4)
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
self.model = model
|
||||
@ -3954,6 +3960,8 @@ class Trainer:
|
||||
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
|
||||
if hasattr(self, "jit_compilation_time"):
|
||||
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
|
||||
if hasattr(self, "model_preparation_time"):
|
||||
metrics[f"{metric_key_prefix}_model_preparation_time"] = self.model_preparation_time
|
||||
|
||||
# Prefix all keys with metric_key_prefix + '_'
|
||||
for key in list(metrics.keys()):
|
||||
|
Loading…
Reference in New Issue
Block a user