mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
parent
05ad826002
commit
1a25fd2f6d
@ -151,7 +151,6 @@ from .utils import (
|
|||||||
check_torch_load_is_safe,
|
check_torch_load_is_safe,
|
||||||
find_labels,
|
find_labels,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_apex_available,
|
|
||||||
is_apollo_torch_available,
|
is_apollo_torch_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_datasets_available,
|
is_datasets_available,
|
||||||
@ -191,9 +190,6 @@ if is_in_notebook():
|
|||||||
|
|
||||||
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
|
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
|
||||||
|
|
||||||
if is_apex_available():
|
|
||||||
from apex import amp
|
|
||||||
|
|
||||||
if is_datasets_available():
|
if is_datasets_available():
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
@ -761,11 +757,6 @@ class Trainer:
|
|||||||
self.use_cpu_amp = True
|
self.use_cpu_amp = True
|
||||||
self.amp_dtype = torch.bfloat16
|
self.amp_dtype = torch.bfloat16
|
||||||
elif args.half_precision_backend == "apex":
|
elif args.half_precision_backend == "apex":
|
||||||
if not is_apex_available():
|
|
||||||
raise ImportError(
|
|
||||||
"Using FP16 with APEX but APEX is not installed, please refer to"
|
|
||||||
" https://www.github.com/nvidia/apex."
|
|
||||||
)
|
|
||||||
self.use_apex = True
|
self.use_apex = True
|
||||||
|
|
||||||
# Label smoothing
|
# Label smoothing
|
||||||
@ -1992,6 +1983,8 @@ class Trainer:
|
|||||||
|
|
||||||
# Mixed precision training with apex
|
# Mixed precision training with apex
|
||||||
if self.use_apex and training:
|
if self.use_apex and training:
|
||||||
|
from apex import amp
|
||||||
|
|
||||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||||
|
|
||||||
# Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
|
# Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
|
||||||
@ -2579,6 +2572,8 @@ class Trainer:
|
|||||||
if is_sagemaker_mp_enabled() and args.fp16:
|
if is_sagemaker_mp_enabled() and args.fp16:
|
||||||
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
|
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
|
||||||
elif self.use_apex:
|
elif self.use_apex:
|
||||||
|
from apex import amp
|
||||||
|
|
||||||
# Revert to normal clipping otherwise, handling Apex or full precision
|
# Revert to normal clipping otherwise, handling Apex or full precision
|
||||||
_grad_norm = nn.utils.clip_grad_norm_(
|
_grad_norm = nn.utils.clip_grad_norm_(
|
||||||
amp.master_params(self.optimizer),
|
amp.master_params(self.optimizer),
|
||||||
@ -3776,6 +3771,8 @@ class Trainer:
|
|||||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||||
|
|
||||||
if self.use_apex:
|
if self.use_apex:
|
||||||
|
from apex import amp
|
||||||
|
|
||||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
else:
|
else:
|
||||||
|
@ -39,6 +39,7 @@ from .utils import (
|
|||||||
ExplicitEnum,
|
ExplicitEnum,
|
||||||
cached_property,
|
cached_property,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
|
is_apex_available,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
is_sagemaker_dp_enabled,
|
is_sagemaker_dp_enabled,
|
||||||
@ -1702,6 +1703,19 @@ class TrainingArguments:
|
|||||||
if self.half_precision_backend == "apex":
|
if self.half_precision_backend == "apex":
|
||||||
raise ValueError(" `--half_precision_backend apex`: GPU bf16 is not supported by apex.")
|
raise ValueError(" `--half_precision_backend apex`: GPU bf16 is not supported by apex.")
|
||||||
|
|
||||||
|
if self.half_precision_backend == "apex":
|
||||||
|
if not is_apex_available():
|
||||||
|
raise ImportError(
|
||||||
|
"Using FP16 with APEX but APEX is not installed, please refer to"
|
||||||
|
" https://www.github.com/nvidia/apex."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from apex import amp # noqa: F401
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"apex.amp is deprecated in the latest version of apex, causing this error {e}. Either revert to an older version or use pytorch amp by setting half_precision_backend='auto' instead. See https://github.com/NVIDIA/apex/pull/1896 "
|
||||||
|
)
|
||||||
|
|
||||||
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
|
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
|
||||||
if self.eval_strategy == IntervalStrategy.NO:
|
if self.eval_strategy == IntervalStrategy.NO:
|
||||||
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")
|
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")
|
||||||
|
Loading…
Reference in New Issue
Block a user