mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Update notebook.py to support multi eval datasets (#25796)
* Update notebook.py fix multi eval datasets * Update notebook.py * Update notebook.py using `black` to reformat * Update notebook.py support Validation Loss * Update notebook.py reformat * Update notebook.py
This commit is contained in:
parent
c7b4d0b4e2
commit
ebd21e904f
@ -235,13 +235,25 @@ class NotebookTrainingTracker(NotebookProgressBar):
|
||||
self.inner_table = [list(values.keys()), list(values.values())]
|
||||
else:
|
||||
columns = self.inner_table[0]
|
||||
if len(self.inner_table) == 1:
|
||||
# We give a chance to update the column names at the first iteration
|
||||
for key in values.keys():
|
||||
if key not in columns:
|
||||
columns.append(key)
|
||||
self.inner_table[0] = columns
|
||||
self.inner_table.append([values[c] for c in columns])
|
||||
for key in values.keys():
|
||||
if key not in columns:
|
||||
columns.append(key)
|
||||
self.inner_table[0] = columns
|
||||
if len(self.inner_table) > 1:
|
||||
last_values = self.inner_table[-1]
|
||||
first_column = self.inner_table[0][0]
|
||||
if last_values[0] != values[first_column]:
|
||||
# write new line
|
||||
self.inner_table.append([values[c] if c in values else "No Log" for c in columns])
|
||||
else:
|
||||
# update last line
|
||||
new_values = values
|
||||
for c in columns:
|
||||
if c not in new_values.keys():
|
||||
new_values[c] = last_values[columns.index(c)]
|
||||
self.inner_table[-1] = [new_values[c] for c in columns]
|
||||
else:
|
||||
self.inner_table.append([values[c] for c in columns])
|
||||
|
||||
def add_child(self, total, prefix=None, width=300):
|
||||
"""
|
||||
@ -341,12 +353,12 @@ class NotebookProgressCallback(TrainerCallback):
|
||||
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
|
||||
_ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None)
|
||||
for k, v in metrics.items():
|
||||
if k == f"{metric_key_prefix}_loss":
|
||||
values["Validation Loss"] = v
|
||||
else:
|
||||
splits = k.split("_")
|
||||
name = " ".join([part.capitalize() for part in splits[1:]])
|
||||
values[name] = v
|
||||
splits = k.split("_")
|
||||
name = " ".join([part.capitalize() for part in splits[1:]])
|
||||
if name == "Loss":
|
||||
# Single dataset
|
||||
name = "Validation Loss"
|
||||
values[name] = v
|
||||
self.training_tracker.write_line(values)
|
||||
self.training_tracker.remove_child()
|
||||
self.prediction_bar = None
|
||||
|
Loading…
Reference in New Issue
Block a user