mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 12:38:23 +06:00
Add resume checkpoint support to ClearML callback #37502
This commit is contained in:
parent
861917173f
commit
82087b5722
@ -1,15 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from clearml import Task
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
TrainingArguments,
|
|
||||||
Trainer,
|
|
||||||
DataCollatorWithPadding,
|
DataCollatorWithPadding,
|
||||||
|
Trainer,
|
||||||
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
from datasets import load_dataset
|
|
||||||
from clearml import Task
|
|
||||||
import torch
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
os.environ["CLEARML_PROJECT"] = "Test Project"
|
os.environ["CLEARML_PROJECT"] = "Test Project"
|
||||||
@ -17,11 +19,7 @@ os.environ["CLEARML_TASK"] = "Test Task"
|
|||||||
os.environ["CLEARML_LOG_MODEL"] = "TRUE"
|
os.environ["CLEARML_LOG_MODEL"] = "TRUE"
|
||||||
|
|
||||||
# Initialize ClearML task
|
# Initialize ClearML task
|
||||||
task = Task.init(
|
task = Task.init(project_name="Test Project", task_name="Test Task", reuse_last_task_id=False)
|
||||||
project_name="Test Project",
|
|
||||||
task_name="Test Task",
|
|
||||||
reuse_last_task_id=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
model_name = "bert-base-uncased"
|
model_name = "bert-base-uncased"
|
||||||
@ -95,5 +93,5 @@ trainer.train()
|
|||||||
initial_params = {name: param.data.clone() for name, param in model.named_parameters()}
|
initial_params = {name: param.data.clone() for name, param in model.named_parameters()}
|
||||||
trainer.train()
|
trainer.train()
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if 'weight' in name:
|
if "weight" in name:
|
||||||
diff = torch.abs(param.data - initial_params[name]).mean().item()
|
diff = torch.abs(param.data - initial_params[name]).mean().item()
|
Loading…
Reference in New Issue
Block a user