Revert frozen training arguments (#25903)

* Revert frozen training arguments

* TODO
This commit is contained in:
Zach Mueller 2023-09-01 11:24:12 -04:00 committed by GitHub
parent 69c5b8f186
commit be0e189bd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 31 additions and 58 deletions

View File

@ -163,15 +163,6 @@ class CustomTrainingArguments(TrainingArguments):
default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."}
)
def __post_init__(self):
# Compute absolute learning rate while args are mutable
super().__post_init__()
if self.base_learning_rate is not None:
total_train_batch_size = self.train_batch_size * self.gradient_accumulation_steps * self.world_size
delattr(self, "_frozen")
self.learning_rate = self.base_learning_rate * total_train_batch_size / 256
setattr(self, "_frozen", True)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
@ -362,6 +353,13 @@ def main():
# Set the validation transforms
ds["validation"].set_transform(preprocess_images)
# Compute absolute learning rate
total_train_batch_size = (
training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
if training_args.base_learning_rate is not None:
training_args.learning_rate = training_args.base_learning_rate * total_train_batch_size / 256
# Initialize our trainer
trainer = Trainer(
model=model,

View File

@ -18,7 +18,6 @@ Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import dataclasses
import logging
import os
import sys
@ -675,10 +674,14 @@ def main():
return result
# Override the decoding parameters of Seq2SeqTrainer
if training_args.generation_max_length is None:
training_args = dataclasses.replace(training_args, generation_max_length=data_args.val_max_target_length)
if training_args.generation_num_beams is None:
training_args = dataclasses.replace(training_args, generation_num_beams=data_args.num_beams)
training_args.generation_max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
training_args.generation_num_beams = (
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
)
# Initialize our Trainer
trainer = Seq2SeqTrainer(

View File

@ -21,7 +21,6 @@ https://huggingface.co/models?filter=fill-mask
"""
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
import dataclasses
import json
import logging
import math
@ -367,7 +366,7 @@ def main():
# If we have ref files, need to avoid it removed by trainer
has_ref = data_args.train_ref_file or data_args.validation_ref_file
if has_ref:
training_args = dataclasses.replace(training_args, remove_unused_columns=False)
training_args.remove_unused_columns = False
# Data collator
# This one will take care of randomly masking the tokens.

View File

@ -259,6 +259,7 @@ def main():
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True)
# endregion
@ -266,8 +267,8 @@ def main():
# Detecting last checkpoint.
checkpoint = None
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
config_path = Path(training_args.output_dir) / CONFIG_NAME
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
config_path = training_args.output_dir / CONFIG_NAME
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
if config_path.is_file() and weights_path.is_file():
checkpoint = training_args.output_dir
logger.info(

View File

@ -265,6 +265,7 @@ def main():
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True)
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length:
@ -276,8 +277,8 @@ def main():
# Detecting last checkpoint.
checkpoint = None
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
config_path = Path(training_args.output_dir) / CONFIG_NAME
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
config_path = training_args.output_dir / CONFIG_NAME
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
if config_path.is_file() and weights_path.is_file():
checkpoint = training_args.output_dir
logger.warning(

View File

@ -1172,8 +1172,6 @@ class Trainer:
elif self.hp_search_backend == HPSearchBackend.WANDB:
params = trial
# Unfreeze args for hyperparameter search
delattr(self.args, "_frozen")
for key, value in params.items():
if not hasattr(self.args, key):
logger.warning(
@ -1205,8 +1203,6 @@ class Trainer:
self.args.hf_deepspeed_config.trainer_config_process(self.args)
self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
# Re-freeze them
setattr(self.args, "_frozen", True)
self.create_accelerator_and_postprocess()
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):

View File

@ -18,7 +18,7 @@ import json
import math
import os
import warnings
from dataclasses import FrozenInstanceError, asdict, dataclass, field, fields
from dataclasses import asdict, dataclass, field, fields
from datetime import timedelta
from enum import Enum
from pathlib import Path
@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum):
PAGED_LION_8BIT = "paged_lion_8bit"
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
@dataclass
class TrainingArguments:
"""
@ -1707,16 +1708,6 @@ class TrainingArguments:
FutureWarning,
)
# Finally set the `TrainingArguments` to be immutable
self._frozen = True
def __setattr__(self, name, value):
# Once fully through the `__post_init__`, `TrainingArguments` are immutable
if not name.startswith("_") and getattr(self, "_frozen", False):
raise FrozenInstanceError(f"cannot assign to field {name}")
else:
super().__setattr__(name, value)
def __str__(self):
self_as_dict = asdict(self)

View File

@ -139,9 +139,9 @@ class RegressionTrainingArguments(TrainingArguments):
b: float = 0.0
def __post_init__(self):
super().__post_init__()
# save resources not dealing with reporting (also avoids the warning when it's not set)
self.report_to = []
super().__post_init__()
class RepeatDataset:
@ -529,8 +529,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.check_trained_model(trainer.model)
# Re-training should restart from scratch, thus lead the same results and new seed should be used.
args = TrainingArguments("./regression", learning_rate=0.1, seed=314)
trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
trainer.args.seed = 314
trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True)

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from typing import Dict
import numpy as np
@ -206,14 +205,7 @@ if __name__ == "__main__":
logger.error(p.metrics)
exit(1)
training_args = dataclasses.replace(training_args, eval_accumulation_steps=2)
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
trainer.args.eval_accumulation_steps = 2
metrics = trainer.evaluate()
logger.info(metrics)
@ -227,22 +219,15 @@ if __name__ == "__main__":
logger.error(p.metrics)
exit(1)
training_args = dataclasses.replace(training_args, eval_accumulation_steps=None)
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
trainer.args.eval_accumulation_steps = None
# Check that `dispatch_batches=False` will work on a finite iterable dataset
train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)
model = RegressionModel()
training_args = dataclasses.replace(
training_args, per_device_train_batch_size=1, max_steps=1, dispatch_batches=False
)
training_args.per_device_train_batch_size = 1
training_args.max_steps = 1
training_args.dispatch_batches = False
trainer = Trainer(model, training_args, train_dataset=train_dataset)
trainer.train()