mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Update notebook.py to support multi eval datasets (#25796)
* Update notebook.py fix multi eval datasets * Update notebook.py * Update notebook.py using `black` to reformat * Update notebook.py support Validation Loss * Update notebook.py reformat * Update notebook.py
This commit is contained in:
parent
c7b4d0b4e2
commit
ebd21e904f
@ -235,13 +235,25 @@ class NotebookTrainingTracker(NotebookProgressBar):
|
|||||||
self.inner_table = [list(values.keys()), list(values.values())]
|
self.inner_table = [list(values.keys()), list(values.values())]
|
||||||
else:
|
else:
|
||||||
columns = self.inner_table[0]
|
columns = self.inner_table[0]
|
||||||
if len(self.inner_table) == 1:
|
for key in values.keys():
|
||||||
# We give a chance to update the column names at the first iteration
|
if key not in columns:
|
||||||
for key in values.keys():
|
columns.append(key)
|
||||||
if key not in columns:
|
self.inner_table[0] = columns
|
||||||
columns.append(key)
|
if len(self.inner_table) > 1:
|
||||||
self.inner_table[0] = columns
|
last_values = self.inner_table[-1]
|
||||||
self.inner_table.append([values[c] for c in columns])
|
first_column = self.inner_table[0][0]
|
||||||
|
if last_values[0] != values[first_column]:
|
||||||
|
# write new line
|
||||||
|
self.inner_table.append([values[c] if c in values else "No Log" for c in columns])
|
||||||
|
else:
|
||||||
|
# update last line
|
||||||
|
new_values = values
|
||||||
|
for c in columns:
|
||||||
|
if c not in new_values.keys():
|
||||||
|
new_values[c] = last_values[columns.index(c)]
|
||||||
|
self.inner_table[-1] = [new_values[c] for c in columns]
|
||||||
|
else:
|
||||||
|
self.inner_table.append([values[c] for c in columns])
|
||||||
|
|
||||||
def add_child(self, total, prefix=None, width=300):
|
def add_child(self, total, prefix=None, width=300):
|
||||||
"""
|
"""
|
||||||
@ -341,12 +353,12 @@ class NotebookProgressCallback(TrainerCallback):
|
|||||||
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
|
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
|
||||||
_ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None)
|
_ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None)
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
if k == f"{metric_key_prefix}_loss":
|
splits = k.split("_")
|
||||||
values["Validation Loss"] = v
|
name = " ".join([part.capitalize() for part in splits[1:]])
|
||||||
else:
|
if name == "Loss":
|
||||||
splits = k.split("_")
|
# Single dataset
|
||||||
name = " ".join([part.capitalize() for part in splits[1:]])
|
name = "Validation Loss"
|
||||||
values[name] = v
|
values[name] = v
|
||||||
self.training_tracker.write_line(values)
|
self.training_tracker.write_line(values)
|
||||||
self.training_tracker.remove_child()
|
self.training_tracker.remove_child()
|
||||||
self.prediction_bar = None
|
self.prediction_bar = None
|
||||||
|
Loading…
Reference in New Issue
Block a user