mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
parent
05ad826002
commit
1a25fd2f6d
@ -151,7 +151,6 @@ from .utils import (
|
||||
check_torch_load_is_safe,
|
||||
find_labels,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_apollo_torch_available,
|
||||
is_bitsandbytes_available,
|
||||
is_datasets_available,
|
||||
@ -191,9 +190,6 @@ if is_in_notebook():
|
||||
|
||||
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
|
||||
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
|
||||
if is_datasets_available():
|
||||
import datasets
|
||||
|
||||
@ -761,11 +757,6 @@ class Trainer:
|
||||
self.use_cpu_amp = True
|
||||
self.amp_dtype = torch.bfloat16
|
||||
elif args.half_precision_backend == "apex":
|
||||
if not is_apex_available():
|
||||
raise ImportError(
|
||||
"Using FP16 with APEX but APEX is not installed, please refer to"
|
||||
" https://www.github.com/nvidia/apex."
|
||||
)
|
||||
self.use_apex = True
|
||||
|
||||
# Label smoothing
|
||||
@ -1992,6 +1983,8 @@ class Trainer:
|
||||
|
||||
# Mixed precision training with apex
|
||||
if self.use_apex and training:
|
||||
from apex import amp
|
||||
|
||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||
|
||||
# Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
|
||||
@ -2579,6 +2572,8 @@ class Trainer:
|
||||
if is_sagemaker_mp_enabled() and args.fp16:
|
||||
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
|
||||
elif self.use_apex:
|
||||
from apex import amp
|
||||
|
||||
# Revert to normal clipping otherwise, handling Apex or full precision
|
||||
_grad_norm = nn.utils.clip_grad_norm_(
|
||||
amp.master_params(self.optimizer),
|
||||
@ -3776,6 +3771,8 @@ class Trainer:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
|
||||
if self.use_apex:
|
||||
from apex import amp
|
||||
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
|
@ -39,6 +39,7 @@ from .utils import (
|
||||
ExplicitEnum,
|
||||
cached_property,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_ipex_available,
|
||||
is_safetensors_available,
|
||||
is_sagemaker_dp_enabled,
|
||||
@ -1702,6 +1703,19 @@ class TrainingArguments:
|
||||
if self.half_precision_backend == "apex":
|
||||
raise ValueError(" `--half_precision_backend apex`: GPU bf16 is not supported by apex.")
|
||||
|
||||
if self.half_precision_backend == "apex":
|
||||
if not is_apex_available():
|
||||
raise ImportError(
|
||||
"Using FP16 with APEX but APEX is not installed, please refer to"
|
||||
" https://www.github.com/nvidia/apex."
|
||||
)
|
||||
try:
|
||||
from apex import amp # noqa: F401
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
f"apex.amp is deprecated in the latest version of apex, causing this error {e}. Either revert to an older version or use pytorch amp by setting half_precision_backend='auto' instead. See https://github.com/NVIDIA/apex/pull/1896 "
|
||||
)
|
||||
|
||||
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
|
||||
if self.eval_strategy == IntervalStrategy.NO:
|
||||
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")
|
||||
|
Loading…
Reference in New Issue
Block a user