Fix progress callback deepcopy (#32070)

* Replacing ProgressCallbacks deepcopy with a shallowcopy

* Using items instead of entries

* code cleanup for copy in trainer callback

* Style fix for ProgressCallback
This commit is contained in:
Keith Stevens 2024-07-19 03:56:45 -07:00 committed by GitHub
parent e316c5214f
commit 566b0f1fbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,7 +16,6 @@
Callbacks to use with the Trainer class and customize the training loop. Callbacks to use with the Trainer class and customize the training loop.
""" """
import copy
import dataclasses import dataclasses
import json import json
from dataclasses import dataclass from dataclasses import dataclass
@ -617,13 +616,16 @@ class ProgressCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs): def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_world_process_zero and self.training_bar is not None: if state.is_world_process_zero and self.training_bar is not None:
# avoid modifying the logs object as it is shared between callbacks # make a shallow copy of logs so we can mutate the fields copied
logs = copy.deepcopy(logs) # but avoid doing any value pickling.
_ = logs.pop("total_flos", None) shallow_logs = {}
for k, v in logs.items():
shallow_logs[k] = v
_ = shallow_logs.pop("total_flos", None)
# round numbers so that it looks better in console # round numbers so that it looks better in console
if "epoch" in logs: if "epoch" in shallow_logs:
logs["epoch"] = round(logs["epoch"], 2) shallow_logs["epoch"] = round(shallow_logs["epoch"], 2)
self.training_bar.write(str(logs)) self.training_bar.write(str(shallow_logs))
def on_train_end(self, args, state, control, **kwargs): def on_train_end(self, args, state, control, **kwargs):
if state.is_world_process_zero: if state.is_world_process_zero: