diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8cd8858312d..4119e547a37 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3670,6 +3670,8 @@ class Trainer: total_batch_size = self.args.eval_batch_size * self.args.world_size if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + if f"{metric_key_prefix}_model_preparation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"] output.metrics.update( speed_metrics( metric_key_prefix, @@ -3739,6 +3741,8 @@ class Trainer: total_batch_size = self.args.eval_batch_size * self.args.world_size if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + if f"{metric_key_prefix}_model_preparation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"] output.metrics.update( speed_metrics( metric_key_prefix, @@ -3777,11 +3781,13 @@ class Trainer: model = self._wrap_model(self.model, training=False, dataloader=dataloader) if len(self.accelerator._models) == 0 and model is self.model: + start_time = time.time() model = ( self.accelerator.prepare(model) if self.is_deepspeed_enabled else self.accelerator.prepare_model(model, evaluation_mode=True) ) + self.model_preparation_time = round(time.time() - start_time, 4) if self.is_fsdp_enabled: self.model = model @@ -3954,6 +3960,8 @@ class Trainer: metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() if hasattr(self, "jit_compilation_time"): metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time + if hasattr(self, "model_preparation_time"): + metrics[f"{metric_key_prefix}_model_preparation_time"] = self.model_preparation_time # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()):