mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Revert frozen training arguments (#25903)
* Revert frozen training arguments * TODO
This commit is contained in:
parent
69c5b8f186
commit
be0e189bd3
@ -163,15 +163,6 @@ class CustomTrainingArguments(TrainingArguments):
|
|||||||
default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."}
|
default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Compute absolute learning rate while args are mutable
|
|
||||||
super().__post_init__()
|
|
||||||
if self.base_learning_rate is not None:
|
|
||||||
total_train_batch_size = self.train_batch_size * self.gradient_accumulation_steps * self.world_size
|
|
||||||
delattr(self, "_frozen")
|
|
||||||
self.learning_rate = self.base_learning_rate * total_train_batch_size / 256
|
|
||||||
setattr(self, "_frozen", True)
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||||
@ -362,6 +353,13 @@ def main():
|
|||||||
# Set the validation transforms
|
# Set the validation transforms
|
||||||
ds["validation"].set_transform(preprocess_images)
|
ds["validation"].set_transform(preprocess_images)
|
||||||
|
|
||||||
|
# Compute absolute learning rate
|
||||||
|
total_train_batch_size = (
|
||||||
|
training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||||
|
)
|
||||||
|
if training_args.base_learning_rate is not None:
|
||||||
|
training_args.learning_rate = training_args.base_learning_rate * total_train_batch_size / 256
|
||||||
|
|
||||||
# Initialize our trainer
|
# Initialize our trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -18,7 +18,6 @@ Fine-tuning the library models for sequence to sequence.
|
|||||||
"""
|
"""
|
||||||
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -675,10 +674,14 @@ def main():
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
# Override the decoding parameters of Seq2SeqTrainer
|
# Override the decoding parameters of Seq2SeqTrainer
|
||||||
if training_args.generation_max_length is None:
|
training_args.generation_max_length = (
|
||||||
training_args = dataclasses.replace(training_args, generation_max_length=data_args.val_max_target_length)
|
training_args.generation_max_length
|
||||||
if training_args.generation_num_beams is None:
|
if training_args.generation_max_length is not None
|
||||||
training_args = dataclasses.replace(training_args, generation_num_beams=data_args.num_beams)
|
else data_args.val_max_target_length
|
||||||
|
)
|
||||||
|
training_args.generation_num_beams = (
|
||||||
|
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Seq2SeqTrainer(
|
trainer = Seq2SeqTrainer(
|
||||||
|
@ -21,7 +21,6 @@ https://huggingface.co/models?filter=fill-mask
|
|||||||
"""
|
"""
|
||||||
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@ -367,7 +366,7 @@ def main():
|
|||||||
# If we have ref files, need to avoid it removed by trainer
|
# If we have ref files, need to avoid it removed by trainer
|
||||||
has_ref = data_args.train_ref_file or data_args.validation_ref_file
|
has_ref = data_args.train_ref_file or data_args.validation_ref_file
|
||||||
if has_ref:
|
if has_ref:
|
||||||
training_args = dataclasses.replace(training_args, remove_unused_columns=False)
|
training_args.remove_unused_columns = False
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
# This one will take care of randomly masking the tokens.
|
# This one will take care of randomly masking the tokens.
|
||||||
|
@ -259,6 +259,7 @@ def main():
|
|||||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
||||||
|
|
||||||
if training_args.output_dir is not None:
|
if training_args.output_dir is not None:
|
||||||
|
training_args.output_dir = Path(training_args.output_dir)
|
||||||
os.makedirs(training_args.output_dir, exist_ok=True)
|
os.makedirs(training_args.output_dir, exist_ok=True)
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
@ -266,8 +267,8 @@ def main():
|
|||||||
# Detecting last checkpoint.
|
# Detecting last checkpoint.
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
|
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
|
||||||
config_path = Path(training_args.output_dir) / CONFIG_NAME
|
config_path = training_args.output_dir / CONFIG_NAME
|
||||||
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
|
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
|
||||||
if config_path.is_file() and weights_path.is_file():
|
if config_path.is_file() and weights_path.is_file():
|
||||||
checkpoint = training_args.output_dir
|
checkpoint = training_args.output_dir
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -265,6 +265,7 @@ def main():
|
|||||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
||||||
|
|
||||||
if training_args.output_dir is not None:
|
if training_args.output_dir is not None:
|
||||||
|
training_args.output_dir = Path(training_args.output_dir)
|
||||||
os.makedirs(training_args.output_dir, exist_ok=True)
|
os.makedirs(training_args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length:
|
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length:
|
||||||
@ -276,8 +277,8 @@ def main():
|
|||||||
# Detecting last checkpoint.
|
# Detecting last checkpoint.
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
|
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
|
||||||
config_path = Path(training_args.output_dir) / CONFIG_NAME
|
config_path = training_args.output_dir / CONFIG_NAME
|
||||||
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
|
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
|
||||||
if config_path.is_file() and weights_path.is_file():
|
if config_path.is_file() and weights_path.is_file():
|
||||||
checkpoint = training_args.output_dir
|
checkpoint = training_args.output_dir
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -1172,8 +1172,6 @@ class Trainer:
|
|||||||
elif self.hp_search_backend == HPSearchBackend.WANDB:
|
elif self.hp_search_backend == HPSearchBackend.WANDB:
|
||||||
params = trial
|
params = trial
|
||||||
|
|
||||||
# Unfreeze args for hyperparameter search
|
|
||||||
delattr(self.args, "_frozen")
|
|
||||||
for key, value in params.items():
|
for key, value in params.items():
|
||||||
if not hasattr(self.args, key):
|
if not hasattr(self.args, key):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -1205,8 +1203,6 @@ class Trainer:
|
|||||||
self.args.hf_deepspeed_config.trainer_config_process(self.args)
|
self.args.hf_deepspeed_config.trainer_config_process(self.args)
|
||||||
self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
|
self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
|
||||||
|
|
||||||
# Re-freeze them
|
|
||||||
setattr(self.args, "_frozen", True)
|
|
||||||
self.create_accelerator_and_postprocess()
|
self.create_accelerator_and_postprocess()
|
||||||
|
|
||||||
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
|
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
|
||||||
|
@ -18,7 +18,7 @@ import json
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import FrozenInstanceError, asdict, dataclass, field, fields
|
from dataclasses import asdict, dataclass, field, fields
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum):
|
|||||||
PAGED_LION_8BIT = "paged_lion_8bit"
|
PAGED_LION_8BIT = "paged_lion_8bit"
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments:
|
class TrainingArguments:
|
||||||
"""
|
"""
|
||||||
@ -1707,16 +1708,6 @@ class TrainingArguments:
|
|||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Finally set the `TrainingArguments` to be immutable
|
|
||||||
self._frozen = True
|
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
|
||||||
# Once fully through the `__post_init__`, `TrainingArguments` are immutable
|
|
||||||
if not name.startswith("_") and getattr(self, "_frozen", False):
|
|
||||||
raise FrozenInstanceError(f"cannot assign to field {name}")
|
|
||||||
else:
|
|
||||||
super().__setattr__(name, value)
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
self_as_dict = asdict(self)
|
self_as_dict = asdict(self)
|
||||||
|
|
||||||
|
@ -139,9 +139,9 @@ class RegressionTrainingArguments(TrainingArguments):
|
|||||||
b: float = 0.0
|
b: float = 0.0
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
# save resources not dealing with reporting (also avoids the warning when it's not set)
|
# save resources not dealing with reporting (also avoids the warning when it's not set)
|
||||||
self.report_to = []
|
self.report_to = []
|
||||||
super().__post_init__()
|
|
||||||
|
|
||||||
|
|
||||||
class RepeatDataset:
|
class RepeatDataset:
|
||||||
@ -529,8 +529,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.check_trained_model(trainer.model)
|
self.check_trained_model(trainer.model)
|
||||||
|
|
||||||
# Re-training should restart from scratch, thus lead the same results and new seed should be used.
|
# Re-training should restart from scratch, thus lead the same results and new seed should be used.
|
||||||
args = TrainingArguments("./regression", learning_rate=0.1, seed=314)
|
trainer.args.seed = 314
|
||||||
trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||||
|
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -206,14 +205,7 @@ if __name__ == "__main__":
|
|||||||
logger.error(p.metrics)
|
logger.error(p.metrics)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
training_args = dataclasses.replace(training_args, eval_accumulation_steps=2)
|
trainer.args.eval_accumulation_steps = 2
|
||||||
trainer = Trainer(
|
|
||||||
model=DummyModel(),
|
|
||||||
args=training_args,
|
|
||||||
data_collator=DummyDataCollator(),
|
|
||||||
eval_dataset=dataset,
|
|
||||||
compute_metrics=compute_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
logger.info(metrics)
|
logger.info(metrics)
|
||||||
@ -227,22 +219,15 @@ if __name__ == "__main__":
|
|||||||
logger.error(p.metrics)
|
logger.error(p.metrics)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
training_args = dataclasses.replace(training_args, eval_accumulation_steps=None)
|
trainer.args.eval_accumulation_steps = None
|
||||||
trainer = Trainer(
|
|
||||||
model=DummyModel(),
|
|
||||||
args=training_args,
|
|
||||||
data_collator=DummyDataCollator(),
|
|
||||||
eval_dataset=dataset,
|
|
||||||
compute_metrics=compute_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that `dispatch_batches=False` will work on a finite iterable dataset
|
# Check that `dispatch_batches=False` will work on a finite iterable dataset
|
||||||
|
|
||||||
train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)
|
train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)
|
||||||
|
|
||||||
model = RegressionModel()
|
model = RegressionModel()
|
||||||
training_args = dataclasses.replace(
|
training_args.per_device_train_batch_size = 1
|
||||||
training_args, per_device_train_batch_size=1, max_steps=1, dispatch_batches=False
|
training_args.max_steps = 1
|
||||||
)
|
training_args.dispatch_batches = False
|
||||||
trainer = Trainer(model, training_args, train_dataset=train_dataset)
|
trainer = Trainer(model, training_args, train_dataset=train_dataset)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
Loading…
Reference in New Issue
Block a user