improving efficiency of mlflow metric logging (#14232)

Signed-off-by: Walter Martin <wamartin@microsoft.com>
This commit is contained in:
Walter Martin 2021-11-01 13:46:11 -04:00 committed by GitHub
parent ce91bf9a34
commit 8b32578119
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -726,9 +726,10 @@ class MLflowCallback(TrainerCallback):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
metrics = {}
for k, v in logs.items():
if isinstance(v, (int, float)):
self._ml_flow.log_metric(k, v, step=state.global_step)
metrics[k] = v
else:
logger.warning(
f"Trainer is attempting to log a value of "
@ -736,6 +737,7 @@ class MLflowCallback(TrainerCallback):
f"MLflow's log_metric() only accepts float and "
f"int types so we dropped this attribute."
)
self._ml_flow.log_metrics(metrics=metrics, step=state.global_step)
def on_train_end(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero: