Revert frozen training arguments (#25903)

* Revert frozen training arguments

* TODO
This commit is contained in:
Zach Mueller 2023-09-01 11:24:12 -04:00 committed by GitHub
parent 69c5b8f186
commit be0e189bd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 31 additions and 58 deletions

View File

@ -163,15 +163,6 @@ class CustomTrainingArguments(TrainingArguments):
default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."} default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."}
) )
def __post_init__(self):
# Compute absolute learning rate while args are mutable
super().__post_init__()
if self.base_learning_rate is not None:
total_train_batch_size = self.train_batch_size * self.gradient_accumulation_steps * self.world_size
delattr(self, "_frozen")
self.learning_rate = self.base_learning_rate * total_train_batch_size / 256
setattr(self, "_frozen", True)
def collate_fn(examples): def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = torch.stack([example["pixel_values"] for example in examples])
@ -362,6 +353,13 @@ def main():
# Set the validation transforms # Set the validation transforms
ds["validation"].set_transform(preprocess_images) ds["validation"].set_transform(preprocess_images)
# Compute absolute learning rate
total_train_batch_size = (
training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
if training_args.base_learning_rate is not None:
training_args.learning_rate = training_args.base_learning_rate * total_train_batch_size / 256
# Initialize our trainer # Initialize our trainer
trainer = Trainer( trainer = Trainer(
model=model, model=model,

View File

@ -18,7 +18,6 @@ Fine-tuning the library models for sequence to sequence.
""" """
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import dataclasses
import logging import logging
import os import os
import sys import sys
@ -675,10 +674,14 @@ def main():
return result return result
# Override the decoding parameters of Seq2SeqTrainer # Override the decoding parameters of Seq2SeqTrainer
if training_args.generation_max_length is None: training_args.generation_max_length = (
training_args = dataclasses.replace(training_args, generation_max_length=data_args.val_max_target_length) training_args.generation_max_length
if training_args.generation_num_beams is None: if training_args.generation_max_length is not None
training_args = dataclasses.replace(training_args, generation_num_beams=data_args.num_beams) else data_args.val_max_target_length
)
training_args.generation_num_beams = (
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
)
# Initialize our Trainer # Initialize our Trainer
trainer = Seq2SeqTrainer( trainer = Seq2SeqTrainer(

View File

@ -21,7 +21,6 @@ https://huggingface.co/models?filter=fill-mask
""" """
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
import dataclasses
import json import json
import logging import logging
import math import math
@ -367,7 +366,7 @@ def main():
# If we have ref files, need to avoid it removed by trainer # If we have ref files, need to avoid it removed by trainer
has_ref = data_args.train_ref_file or data_args.validation_ref_file has_ref = data_args.train_ref_file or data_args.validation_ref_file
if has_ref: if has_ref:
training_args = dataclasses.replace(training_args, remove_unused_columns=False) training_args.remove_unused_columns = False
# Data collator # Data collator
# This one will take care of randomly masking the tokens. # This one will take care of randomly masking the tokens.

View File

@ -259,6 +259,7 @@ def main():
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if training_args.output_dir is not None: if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True) os.makedirs(training_args.output_dir, exist_ok=True)
# endregion # endregion
@ -266,8 +267,8 @@ def main():
# Detecting last checkpoint. # Detecting last checkpoint.
checkpoint = None checkpoint = None
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir: if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
config_path = Path(training_args.output_dir) / CONFIG_NAME config_path = training_args.output_dir / CONFIG_NAME
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
if config_path.is_file() and weights_path.is_file(): if config_path.is_file() and weights_path.is_file():
checkpoint = training_args.output_dir checkpoint = training_args.output_dir
logger.info( logger.info(

View File

@ -265,6 +265,7 @@ def main():
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if training_args.output_dir is not None: if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True) os.makedirs(training_args.output_dir, exist_ok=True)
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length: if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length:
@ -276,8 +277,8 @@ def main():
# Detecting last checkpoint. # Detecting last checkpoint.
checkpoint = None checkpoint = None
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir: if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
config_path = Path(training_args.output_dir) / CONFIG_NAME config_path = training_args.output_dir / CONFIG_NAME
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
if config_path.is_file() and weights_path.is_file(): if config_path.is_file() and weights_path.is_file():
checkpoint = training_args.output_dir checkpoint = training_args.output_dir
logger.warning( logger.warning(

View File

@ -1172,8 +1172,6 @@ class Trainer:
elif self.hp_search_backend == HPSearchBackend.WANDB: elif self.hp_search_backend == HPSearchBackend.WANDB:
params = trial params = trial
# Unfreeze args for hyperparameter search
delattr(self.args, "_frozen")
for key, value in params.items(): for key, value in params.items():
if not hasattr(self.args, key): if not hasattr(self.args, key):
logger.warning( logger.warning(
@ -1205,8 +1203,6 @@ class Trainer:
self.args.hf_deepspeed_config.trainer_config_process(self.args) self.args.hf_deepspeed_config.trainer_config_process(self.args)
self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config) self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
# Re-freeze them
setattr(self.args, "_frozen", True)
self.create_accelerator_and_postprocess() self.create_accelerator_and_postprocess()
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):

View File

@ -18,7 +18,7 @@ import json
import math import math
import os import os
import warnings import warnings
from dataclasses import FrozenInstanceError, asdict, dataclass, field, fields from dataclasses import asdict, dataclass, field, fields
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum):
PAGED_LION_8BIT = "paged_lion_8bit" PAGED_LION_8BIT = "paged_lion_8bit"
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
@dataclass @dataclass
class TrainingArguments: class TrainingArguments:
""" """
@ -1707,16 +1708,6 @@ class TrainingArguments:
FutureWarning, FutureWarning,
) )
# Finally set the `TrainingArguments` to be immutable
self._frozen = True
def __setattr__(self, name, value):
# Once fully through the `__post_init__`, `TrainingArguments` are immutable
if not name.startswith("_") and getattr(self, "_frozen", False):
raise FrozenInstanceError(f"cannot assign to field {name}")
else:
super().__setattr__(name, value)
def __str__(self): def __str__(self):
self_as_dict = asdict(self) self_as_dict = asdict(self)

View File

@ -139,9 +139,9 @@ class RegressionTrainingArguments(TrainingArguments):
b: float = 0.0 b: float = 0.0
def __post_init__(self): def __post_init__(self):
super().__post_init__()
# save resources not dealing with reporting (also avoids the warning when it's not set) # save resources not dealing with reporting (also avoids the warning when it's not set)
self.report_to = [] self.report_to = []
super().__post_init__()
class RepeatDataset: class RepeatDataset:
@ -529,8 +529,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.check_trained_model(trainer.model) self.check_trained_model(trainer.model)
# Re-training should restart from scratch, thus lead the same results and new seed should be used. # Re-training should restart from scratch, thus lead the same results and new seed should be used.
args = TrainingArguments("./regression", learning_rate=0.1, seed=314) trainer.args.seed = 314
trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
trainer.train() trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True) self.check_trained_model(trainer.model, alternate_seed=True)

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import dataclasses
from typing import Dict from typing import Dict
import numpy as np import numpy as np
@ -206,14 +205,7 @@ if __name__ == "__main__":
logger.error(p.metrics) logger.error(p.metrics)
exit(1) exit(1)
training_args = dataclasses.replace(training_args, eval_accumulation_steps=2) trainer.args.eval_accumulation_steps = 2
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
metrics = trainer.evaluate() metrics = trainer.evaluate()
logger.info(metrics) logger.info(metrics)
@ -227,22 +219,15 @@ if __name__ == "__main__":
logger.error(p.metrics) logger.error(p.metrics)
exit(1) exit(1)
training_args = dataclasses.replace(training_args, eval_accumulation_steps=None) trainer.args.eval_accumulation_steps = None
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
# Check that `dispatch_batches=False` will work on a finite iterable dataset # Check that `dispatch_batches=False` will work on a finite iterable dataset
train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1) train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)
model = RegressionModel() model = RegressionModel()
training_args = dataclasses.replace( training_args.per_device_train_batch_size = 1
training_args, per_device_train_batch_size=1, max_steps=1, dispatch_batches=False training_args.max_steps = 1
) training_args.dispatch_batches = False
trainer = Trainer(model, training_args, train_dataset=train_dataset) trainer = Trainer(model, training_args, train_dataset=train_dataset)
trainer.train() trainer.train()