Update notebook.py to support multi eval datasets (#25796)

* Update notebook.py

fix multi eval datasets

* Update notebook.py

* Update notebook.py

using `black` to reformat

* Update notebook.py

support Validation Loss

* Update notebook.py

reformat

* Update notebook.py
This commit is contained in:
Matrix 2023-09-15 23:52:18 +08:00 committed by GitHub
parent c7b4d0b4e2
commit ebd21e904f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -235,13 +235,25 @@ class NotebookTrainingTracker(NotebookProgressBar):
self.inner_table = [list(values.keys()), list(values.values())]
else:
columns = self.inner_table[0]
if len(self.inner_table) == 1:
# We give a chance to update the column names at the first iteration
for key in values.keys():
if key not in columns:
columns.append(key)
self.inner_table[0] = columns
self.inner_table.append([values[c] for c in columns])
for key in values.keys():
if key not in columns:
columns.append(key)
self.inner_table[0] = columns
if len(self.inner_table) > 1:
last_values = self.inner_table[-1]
first_column = self.inner_table[0][0]
if last_values[0] != values[first_column]:
# write new line
self.inner_table.append([values[c] if c in values else "No Log" for c in columns])
else:
# update last line
new_values = values
for c in columns:
if c not in new_values.keys():
new_values[c] = last_values[columns.index(c)]
self.inner_table[-1] = [new_values[c] for c in columns]
else:
self.inner_table.append([values[c] for c in columns])
def add_child(self, total, prefix=None, width=300):
"""
@ -341,12 +353,12 @@ class NotebookProgressCallback(TrainerCallback):
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None)
for k, v in metrics.items():
if k == f"{metric_key_prefix}_loss":
values["Validation Loss"] = v
else:
splits = k.split("_")
name = " ".join([part.capitalize() for part in splits[1:]])
values[name] = v
splits = k.split("_")
name = " ".join([part.capitalize() for part in splits[1:]])
if name == "Loss":
# Single dataset
name = "Validation Loss"
values[name] = v
self.training_tracker.write_line(values)
self.training_tracker.remove_child()
self.prediction_bar = None