diff --git a/examples/pytorch/image-pretraining/run_mae.py b/examples/pytorch/image-pretraining/run_mae.py index 4483a654439..159729f5a59 100644 --- a/examples/pytorch/image-pretraining/run_mae.py +++ b/examples/pytorch/image-pretraining/run_mae.py @@ -163,15 +163,6 @@ class CustomTrainingArguments(TrainingArguments): default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."} ) - def __post_init__(self): - # Compute absolute learning rate while args are mutable - super().__post_init__() - if self.base_learning_rate is not None: - total_train_batch_size = self.train_batch_size * self.gradient_accumulation_steps * self.world_size - delattr(self, "_frozen") - self.learning_rate = self.base_learning_rate * total_train_batch_size / 256 - setattr(self, "_frozen", True) - def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) @@ -362,6 +353,13 @@ def main(): # Set the validation transforms ds["validation"].set_transform(preprocess_images) + # Compute absolute learning rate + total_train_batch_size = ( + training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size + ) + if training_args.base_learning_rate is not None: + training_args.learning_rate = training_args.base_learning_rate * total_train_batch_size / 256 + # Initialize our trainer trainer = Trainer( model=model, diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index 8f135557766..f74f754a7c6 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -18,7 +18,6 @@ Fine-tuning the library models for sequence to sequence. """ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. -import dataclasses import logging import os import sys @@ -675,10 +674,14 @@ def main(): return result # Override the decoding parameters of Seq2SeqTrainer - if training_args.generation_max_length is None: - training_args = dataclasses.replace(training_args, generation_max_length=data_args.val_max_target_length) - if training_args.generation_num_beams is None: - training_args = dataclasses.replace(training_args, generation_num_beams=data_args.num_beams) + training_args.generation_max_length = ( + training_args.generation_max_length + if training_args.generation_max_length is not None + else data_args.val_max_target_length + ) + training_args.generation_num_beams = ( + data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams + ) # Initialize our Trainer trainer = Seq2SeqTrainer( diff --git a/examples/research_projects/mlm_wwm/run_mlm_wwm.py b/examples/research_projects/mlm_wwm/run_mlm_wwm.py index 4bb138de832..f14ad5adfef 100644 --- a/examples/research_projects/mlm_wwm/run_mlm_wwm.py +++ b/examples/research_projects/mlm_wwm/run_mlm_wwm.py @@ -21,7 +21,6 @@ https://huggingface.co/models?filter=fill-mask """ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. -import dataclasses import json import logging import math @@ -367,7 +366,7 @@ def main(): # If we have ref files, need to avoid it removed by trainer has_ref = data_args.train_ref_file or data_args.validation_ref_file if has_ref: - training_args = dataclasses.replace(training_args, remove_unused_columns=False) + training_args.remove_unused_columns = False # Data collator # This one will take care of randomly masking the tokens. diff --git a/examples/tensorflow/language-modeling/run_clm.py b/examples/tensorflow/language-modeling/run_clm.py index 033baf59170..1614bbd4b12 100755 --- a/examples/tensorflow/language-modeling/run_clm.py +++ b/examples/tensorflow/language-modeling/run_clm.py @@ -259,6 +259,7 @@ def main(): assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." if training_args.output_dir is not None: + training_args.output_dir = Path(training_args.output_dir) os.makedirs(training_args.output_dir, exist_ok=True) # endregion @@ -266,8 +267,8 @@ def main(): # Detecting last checkpoint. checkpoint = None if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir: - config_path = Path(training_args.output_dir) / CONFIG_NAME - weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME + config_path = training_args.output_dir / CONFIG_NAME + weights_path = training_args.output_dir / TF2_WEIGHTS_NAME if config_path.is_file() and weights_path.is_file(): checkpoint = training_args.output_dir logger.info( diff --git a/examples/tensorflow/language-modeling/run_mlm.py b/examples/tensorflow/language-modeling/run_mlm.py index 7423817f584..671331745de 100755 --- a/examples/tensorflow/language-modeling/run_mlm.py +++ b/examples/tensorflow/language-modeling/run_mlm.py @@ -265,6 +265,7 @@ def main(): assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." if training_args.output_dir is not None: + training_args.output_dir = Path(training_args.output_dir) os.makedirs(training_args.output_dir, exist_ok=True) if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length: @@ -276,8 +277,8 @@ def main(): # Detecting last checkpoint. checkpoint = None if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir: - config_path = Path(training_args.output_dir) / CONFIG_NAME - weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME + config_path = training_args.output_dir / CONFIG_NAME + weights_path = training_args.output_dir / TF2_WEIGHTS_NAME if config_path.is_file() and weights_path.is_file(): checkpoint = training_args.output_dir logger.warning( diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index deb788e1c70..45e858feca1 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1172,8 +1172,6 @@ class Trainer: elif self.hp_search_backend == HPSearchBackend.WANDB: params = trial - # Unfreeze args for hyperparameter search - delattr(self.args, "_frozen") for key, value in params.items(): if not hasattr(self.args, key): logger.warning( @@ -1205,8 +1203,6 @@ class Trainer: self.args.hf_deepspeed_config.trainer_config_process(self.args) self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config) - # Re-freeze them - setattr(self.args, "_frozen", True) self.create_accelerator_and_postprocess() def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 8549739deb1..68458a64b0e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -18,7 +18,7 @@ import json import math import os import warnings -from dataclasses import FrozenInstanceError, asdict, dataclass, field, fields +from dataclasses import asdict, dataclass, field, fields from datetime import timedelta from enum import Enum from pathlib import Path @@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum): PAGED_LION_8BIT = "paged_lion_8bit" +# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903 @dataclass class TrainingArguments: """ @@ -1707,16 +1708,6 @@ class TrainingArguments: FutureWarning, ) - # Finally set the `TrainingArguments` to be immutable - self._frozen = True - - def __setattr__(self, name, value): - # Once fully through the `__post_init__`, `TrainingArguments` are immutable - if not name.startswith("_") and getattr(self, "_frozen", False): - raise FrozenInstanceError(f"cannot assign to field {name}") - else: - super().__setattr__(name, value) - def __str__(self): self_as_dict = asdict(self) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 152fab898cc..db7f9bb20c9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -139,9 +139,9 @@ class RegressionTrainingArguments(TrainingArguments): b: float = 0.0 def __post_init__(self): + super().__post_init__() # save resources not dealing with reporting (also avoids the warning when it's not set) self.report_to = [] - super().__post_init__() class RepeatDataset: @@ -529,8 +529,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): self.check_trained_model(trainer.model) # Re-training should restart from scratch, thus lead the same results and new seed should be used. - args = TrainingArguments("./regression", learning_rate=0.1, seed=314) - trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel()) + trainer.args.seed = 314 trainer.train() self.check_trained_model(trainer.model, alternate_seed=True) diff --git a/tests/trainer/test_trainer_distributed.py b/tests/trainer/test_trainer_distributed.py index f8b59d967c7..5a7734b8ba1 100644 --- a/tests/trainer/test_trainer_distributed.py +++ b/tests/trainer/test_trainer_distributed.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import dataclasses from typing import Dict import numpy as np @@ -206,14 +205,7 @@ if __name__ == "__main__": logger.error(p.metrics) exit(1) - training_args = dataclasses.replace(training_args, eval_accumulation_steps=2) - trainer = Trainer( - model=DummyModel(), - args=training_args, - data_collator=DummyDataCollator(), - eval_dataset=dataset, - compute_metrics=compute_metrics, - ) + trainer.args.eval_accumulation_steps = 2 metrics = trainer.evaluate() logger.info(metrics) @@ -227,22 +219,15 @@ if __name__ == "__main__": logger.error(p.metrics) exit(1) - training_args = dataclasses.replace(training_args, eval_accumulation_steps=None) - trainer = Trainer( - model=DummyModel(), - args=training_args, - data_collator=DummyDataCollator(), - eval_dataset=dataset, - compute_metrics=compute_metrics, - ) + trainer.args.eval_accumulation_steps = None # Check that `dispatch_batches=False` will work on a finite iterable dataset train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1) model = RegressionModel() - training_args = dataclasses.replace( - training_args, per_device_train_batch_size=1, max_steps=1, dispatch_batches=False - ) + training_args.per_device_train_batch_size = 1 + training_args.max_steps = 1 + training_args.dispatch_batches = False trainer = Trainer(model, training_args, train_dataset=train_dataset) trainer.train()