mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove deprecated training arguments (#36946)
* Remove deprecated training arguments * More fixes * More fixes * More fixes
This commit is contained in:
parent
44715225e3
commit
2b550c47b2
@ -66,7 +66,7 @@ python run_instance_segmentation.py \
|
||||
--dataloader_persistent_workers \
|
||||
--dataloader_prefetch_factor 4 \
|
||||
--do_eval \
|
||||
--evaluation_strategy epoch \
|
||||
--eval_strategy epoch \
|
||||
--logging_strategy epoch \
|
||||
--save_strategy epoch \
|
||||
--save_total_limit 2 \
|
||||
|
@ -56,7 +56,7 @@ python run_object_detection.py \
|
||||
--greater_is_better true \
|
||||
--load_best_model_at_end true \
|
||||
--logging_strategy epoch \
|
||||
--evaluation_strategy epoch \
|
||||
--eval_strategy epoch \
|
||||
--save_strategy epoch \
|
||||
--save_total_limit 2 \
|
||||
--push_to_hub true \
|
||||
|
@ -667,7 +667,7 @@ class ExamplesTests(TestCasePlus):
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 1
|
||||
--do_eval
|
||||
--evaluation_strategy epoch
|
||||
--eval_strategy epoch
|
||||
--seed 32
|
||||
""".split()
|
||||
|
||||
|
@ -1263,7 +1263,7 @@ class AcceleratorConfig:
|
||||
" in your script multiplied by the number of processes."
|
||||
},
|
||||
)
|
||||
dispatch_batches: bool = field(
|
||||
dispatch_batches: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process"
|
||||
|
@ -768,14 +768,6 @@ class TrainingArguments:
|
||||
Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
|
||||
|
||||
This flag is experimental and subject to change in future releases.
|
||||
split_batches (`bool`, *optional*):
|
||||
Whether or not the accelerator should split the batches yielded by the dataloaders across the devices
|
||||
during distributed training. If
|
||||
|
||||
set to `True`, the actual batch size used will be the same on any kind of distributed processes, but it
|
||||
must be a
|
||||
|
||||
round multiple of the number of processes you are using (such as GPUs).
|
||||
include_tokens_per_second (`bool`, *optional*):
|
||||
Whether or not to compute the number of tokens per second per device for training speed metrics.
|
||||
|
||||
@ -1426,10 +1418,6 @@ class TrainingArguments:
|
||||
"choices": ["auto", "apex", "cpu_amp"],
|
||||
},
|
||||
)
|
||||
evaluation_strategy: Union[IntervalStrategy, str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Deprecated. Use `eval_strategy` instead"},
|
||||
)
|
||||
push_to_hub_model_id: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
|
||||
)
|
||||
@ -1504,16 +1492,6 @@ class TrainingArguments:
|
||||
},
|
||||
)
|
||||
|
||||
dispatch_batches: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Deprecated. Pass {'dispatch_batches':VALUE} to `accelerator_config`."},
|
||||
)
|
||||
|
||||
split_batches: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Deprecated. Pass {'split_batches':True} to `accelerator_config`."},
|
||||
)
|
||||
|
||||
include_tokens_per_second: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."},
|
||||
@ -1606,13 +1584,6 @@ class TrainingArguments:
|
||||
if self.disable_tqdm is None:
|
||||
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
||||
|
||||
if self.evaluation_strategy is not None:
|
||||
warnings.warn(
|
||||
"`evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead",
|
||||
FutureWarning,
|
||||
)
|
||||
self.eval_strategy = self.evaluation_strategy
|
||||
|
||||
if isinstance(self.eval_strategy, EvaluationStrategy):
|
||||
warnings.warn(
|
||||
"using `EvaluationStrategy` for `eval_strategy` is deprecated and will be removed in version 5"
|
||||
@ -1771,7 +1742,7 @@ class TrainingArguments:
|
||||
|
||||
# We need to setup the accelerator config here *before* the first call to `self.device`
|
||||
if is_accelerate_available():
|
||||
if not isinstance(self.accelerator_config, (AcceleratorConfig)):
|
||||
if not isinstance(self.accelerator_config, AcceleratorConfig):
|
||||
if self.accelerator_config is None:
|
||||
self.accelerator_config = AcceleratorConfig()
|
||||
elif isinstance(self.accelerator_config, dict):
|
||||
@ -1786,22 +1757,6 @@ class TrainingArguments:
|
||||
else:
|
||||
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
|
||||
|
||||
if self.dispatch_batches is not None:
|
||||
warnings.warn(
|
||||
"Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
|
||||
" `--accelerator_config {'dispatch_batches':VALUE} instead",
|
||||
FutureWarning,
|
||||
)
|
||||
self.accelerator_config.dispatch_batches = self.dispatch_batches
|
||||
|
||||
if self.split_batches is not None:
|
||||
warnings.warn(
|
||||
"Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
|
||||
" `--accelerator_config {'split_batches':VALUE} instead",
|
||||
FutureWarning,
|
||||
)
|
||||
self.accelerator_config.split_batches = self.split_batches
|
||||
|
||||
# Initialize device before we proceed
|
||||
if self.framework == "pt" and is_torch_available():
|
||||
self.device
|
||||
|
@ -646,7 +646,7 @@ class GPTQConfig(QuantizationConfigMixin):
|
||||
sym: bool = True,
|
||||
true_sequential: bool = True,
|
||||
checkpoint_format: str = "gptq",
|
||||
meta: Optional[Dict[str, any]] = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
backend: Optional[str] = None,
|
||||
use_cuda_fp16: bool = False,
|
||||
model_seqlen: Optional[int] = None,
|
||||
|
@ -28,7 +28,7 @@ import unittest
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import numpy as np
|
||||
@ -2982,7 +2982,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
self.tokenizer.add_tokens(["<NEW_TOKEN1>", "<NEW_TOKEN2>"])
|
||||
|
||||
def __call__(self, features: List[Any], return_tensors="pt") -> Dict[str, Any]:
|
||||
def __call__(self, features: list[Any], return_tensors="pt") -> dict[str, Any]:
|
||||
return default_data_collator(features, return_tensors)
|
||||
|
||||
data_collator = FakeCollator()
|
||||
@ -2999,7 +2999,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmp_dir,
|
||||
save_steps=5,
|
||||
evaluation_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
eval_steps=5,
|
||||
max_steps=9,
|
||||
)
|
||||
@ -3020,7 +3020,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmp_dir,
|
||||
save_steps=5,
|
||||
evaluation_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
eval_steps=5,
|
||||
load_best_model_at_end=True,
|
||||
save_total_limit=2,
|
||||
@ -4260,7 +4260,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
accelerator_config = {
|
||||
accelerator_config: dict[str, Any] = {
|
||||
"split_batches": True,
|
||||
"dispatch_batches": True,
|
||||
"even_batches": False,
|
||||
@ -4370,56 +4370,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, True)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||
|
||||
def test_accelerator_config_from_dict_with_deprecated_args(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
# and maintains the deprecated args if passed in
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
# Leaves all options as something *not* basic
|
||||
with self.assertWarns(FutureWarning) as cm:
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
accelerator_config={
|
||||
"split_batches": True,
|
||||
},
|
||||
dispatch_batches=False,
|
||||
)
|
||||
self.assertIn("dispatch_batches", str(cm.warnings[0].message))
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.dispatch_batches, False)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
with self.assertWarns(FutureWarning) as cm:
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
accelerator_config={
|
||||
"even_batches": False,
|
||||
},
|
||||
split_batches=True,
|
||||
)
|
||||
self.assertIn("split_batches", str(cm.warnings[0].message))
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.dispatch_batches, None)
|
||||
|
||||
def test_accelerator_config_only_deprecated_args(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertWarns(FutureWarning) as cm:
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
split_batches=True,
|
||||
)
|
||||
self.assertIn("split_batches", str(cm.warnings[0].message))
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
|
||||
def test_accelerator_custom_state(self):
|
||||
AcceleratorState._reset_state(reset_partial_state=True)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -5191,7 +5141,7 @@ class TrainerHyperParameterMultiObjectOptunaIntegrationTest(unittest.TestCase):
|
||||
def hp_name(trial):
|
||||
return MyTrialShortNamer.shortname(trial.params)
|
||||
|
||||
def compute_objective(metrics: Dict[str, float]) -> List[float]:
|
||||
def compute_objective(metrics: dict[str, float]) -> list[float]:
|
||||
return metrics["eval_loss"], metrics["eval_accuracy"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
@ -200,6 +200,8 @@ if __name__ == "__main__":
|
||||
model = RegressionModel()
|
||||
training_args.per_device_train_batch_size = 1
|
||||
training_args.max_steps = 1
|
||||
training_args.dispatch_batches = False
|
||||
training_args.accelerator_config = {
|
||||
"dispatch_batches": False,
|
||||
}
|
||||
trainer = Trainer(model, training_args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
|
Loading…
Reference in New Issue
Block a user