Update notebook.py to support multi eval datasets (#25796)

* Update notebook.py

fix multi eval datasets

* Update notebook.py

* Update notebook.py

using `black` to reformat

* Update notebook.py

support Validation Loss

* Update notebook.py

reformat

* Update notebook.py
This commit is contained in:
Matrix 2023-09-15 23:52:18 +08:00 committed by GitHub
parent c7b4d0b4e2
commit ebd21e904f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -235,13 +235,25 @@ class NotebookTrainingTracker(NotebookProgressBar):
self.inner_table = [list(values.keys()), list(values.values())] self.inner_table = [list(values.keys()), list(values.values())]
else: else:
columns = self.inner_table[0] columns = self.inner_table[0]
if len(self.inner_table) == 1: for key in values.keys():
# We give a chance to update the column names at the first iteration if key not in columns:
for key in values.keys(): columns.append(key)
if key not in columns: self.inner_table[0] = columns
columns.append(key) if len(self.inner_table) > 1:
self.inner_table[0] = columns last_values = self.inner_table[-1]
self.inner_table.append([values[c] for c in columns]) first_column = self.inner_table[0][0]
if last_values[0] != values[first_column]:
# write new line
self.inner_table.append([values[c] if c in values else "No Log" for c in columns])
else:
# update last line
new_values = values
for c in columns:
if c not in new_values.keys():
new_values[c] = last_values[columns.index(c)]
self.inner_table[-1] = [new_values[c] for c in columns]
else:
self.inner_table.append([values[c] for c in columns])
def add_child(self, total, prefix=None, width=300): def add_child(self, total, prefix=None, width=300):
""" """
@ -341,12 +353,12 @@ class NotebookProgressCallback(TrainerCallback):
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None) _ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None) _ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None)
for k, v in metrics.items(): for k, v in metrics.items():
if k == f"{metric_key_prefix}_loss": splits = k.split("_")
values["Validation Loss"] = v name = " ".join([part.capitalize() for part in splits[1:]])
else: if name == "Loss":
splits = k.split("_") # Single dataset
name = " ".join([part.capitalize() for part in splits[1:]]) name = "Validation Loss"
values[name] = v values[name] = v
self.training_tracker.write_line(values) self.training_tracker.write_line(values)
self.training_tracker.remove_child() self.training_tracker.remove_child()
self.prediction_bar = None self.prediction_bar = None