mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Continue training args and tqdm in notebooks (#3939)
* Continue training args * Continue training args * added explaination * added explaination * added explaination * Fixed tqdm auto * Update src/transformers/training_args.py Co-Authored-By: Julien Chaumond <chaumond@gmail.com> * Update src/transformers/training_args.py * Update src/transformers/training_args.py Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
parent
ab90353f1a
commit
8b5e5ebcf9
@ -15,7 +15,7 @@ from torch.utils.data.dataloader import DataLoader
|
|||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data.dataset import Dataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data.sampler import RandomSampler
|
from torch.utils.data.sampler import RandomSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm.auto import tqdm, trange
|
||||||
|
|
||||||
from .data.data_collator import DataCollator, DefaultDataCollator
|
from .data.data_collator import DataCollator, DefaultDataCollator
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
|
@ -29,20 +29,27 @@ class TrainingArguments:
|
|||||||
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}
|
||||||
)
|
)
|
||||||
overwrite_output_dir: bool = field(
|
overwrite_output_dir: bool = field(
|
||||||
default=False, metadata={"help": "Overwrite the content of the output directory"}
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Overwrite the content of the output directory."
|
||||||
|
"Use this to continue training if output_dir points to a checkpoint directory."
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||||
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
||||||
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
|
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
|
||||||
evaluate_during_training: bool = field(
|
evaluate_during_training: bool = field(
|
||||||
default=False, metadata={"help": "Run evaluation during training at each logging step."}
|
default=False, metadata={"help": "Run evaluation during training at each logging step."},
|
||||||
)
|
)
|
||||||
|
|
||||||
per_gpu_train_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for training."})
|
per_gpu_train_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for training."})
|
||||||
per_gpu_eval_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for evaluation."})
|
per_gpu_eval_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for evaluation."})
|
||||||
gradient_accumulation_steps: int = field(
|
gradient_accumulation_steps: int = field(
|
||||||
default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}
|
default=1,
|
||||||
|
metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."},
|
||||||
)
|
)
|
||||||
|
|
||||||
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."})
|
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."})
|
||||||
@ -64,7 +71,10 @@ class TrainingArguments:
|
|||||||
save_total_limit: Optional[int] = field(
|
save_total_limit: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default"
|
"help": (
|
||||||
|
"Limit the total amount of checkpoints."
|
||||||
|
"Deletes the older checkpoints in the output_dir. Default is unlimited checkpoints"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
no_cuda: bool = field(default=False, metadata={"help": "Avoid using CUDA even if it is available"})
|
no_cuda: bool = field(default=False, metadata={"help": "Avoid using CUDA even if it is available"})
|
||||||
@ -77,8 +87,10 @@ class TrainingArguments:
|
|||||||
fp16_opt_level: str = field(
|
fp16_opt_level: str = field(
|
||||||
default="O1",
|
default="O1",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
"help": (
|
||||||
"See details at https://nvidia.github.io/apex/amp.html"
|
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
|
"See details at https://nvidia.github.io/apex/amp.html"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
|
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
|
||||||
|
Loading…
Reference in New Issue
Block a user