Fix amp deprecation issue (#38100)

apex amp is deprecated
This commit is contained in:
Marc Sun 2025-06-02 16:15:41 +02:00 committed by GitHub
parent 05ad826002
commit 1a25fd2f6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 9 deletions

View File

@ -151,7 +151,6 @@ from .utils import (
check_torch_load_is_safe,
find_labels,
is_accelerate_available,
is_apex_available,
is_apollo_torch_available,
is_bitsandbytes_available,
is_datasets_available,
@ -191,9 +190,6 @@ if is_in_notebook():
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
if is_apex_available():
from apex import amp
if is_datasets_available():
import datasets
@ -761,11 +757,6 @@ class Trainer:
self.use_cpu_amp = True
self.amp_dtype = torch.bfloat16
elif args.half_precision_backend == "apex":
if not is_apex_available():
raise ImportError(
"Using FP16 with APEX but APEX is not installed, please refer to"
" https://www.github.com/nvidia/apex."
)
self.use_apex = True
# Label smoothing
@ -1992,6 +1983,8 @@ class Trainer:
# Mixed precision training with apex
if self.use_apex and training:
from apex import amp
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
@ -2579,6 +2572,8 @@ class Trainer:
if is_sagemaker_mp_enabled() and args.fp16:
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
elif self.use_apex:
from apex import amp
# Revert to normal clipping otherwise, handling Apex or full precision
_grad_norm = nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer),
@ -3776,6 +3771,8 @@ class Trainer:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
from apex import amp
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:

View File

@ -39,6 +39,7 @@ from .utils import (
ExplicitEnum,
cached_property,
is_accelerate_available,
is_apex_available,
is_ipex_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
@ -1702,6 +1703,19 @@ class TrainingArguments:
if self.half_precision_backend == "apex":
raise ValueError(" `--half_precision_backend apex`: GPU bf16 is not supported by apex.")
if self.half_precision_backend == "apex":
if not is_apex_available():
raise ImportError(
"Using FP16 with APEX but APEX is not installed, please refer to"
" https://www.github.com/nvidia/apex."
)
try:
from apex import amp # noqa: F401
except ImportError as e:
raise ImportError(
f"apex.amp is deprecated in the latest version of apex, causing this error {e}. Either revert to an older version or use pytorch amp by setting half_precision_backend='auto' instead. See https://github.com/NVIDIA/apex/pull/1896 "
)
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
if self.eval_strategy == IntervalStrategy.NO:
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")