mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Extend Trainer to enable Ascend NPU to use the fused Adamw optimizer when training (#26194)
This commit is contained in:
parent
fc296f419e
commit
4fdf47cd3c
@ -1068,6 +1068,14 @@ class Trainer:
|
|||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
|
raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
|
||||||
|
elif args.optim == OptimizerNames.ADAMW_TORCH_NPU_FUSED:
|
||||||
|
try:
|
||||||
|
from torch_npu.optim import NpuFusedAdamW
|
||||||
|
|
||||||
|
optimizer_cls = NpuFusedAdamW
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError("Trainer failed to import FusedAdamW from torch_npu.")
|
||||||
elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
|
elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
|
||||||
try:
|
try:
|
||||||
from apex.optimizers import FusedAdam
|
from apex.optimizers import FusedAdam
|
||||||
|
@ -140,6 +140,7 @@ class OptimizerNames(ExplicitEnum):
|
|||||||
ADAMW_TORCH = "adamw_torch"
|
ADAMW_TORCH = "adamw_torch"
|
||||||
ADAMW_TORCH_FUSED = "adamw_torch_fused"
|
ADAMW_TORCH_FUSED = "adamw_torch_fused"
|
||||||
ADAMW_TORCH_XLA = "adamw_torch_xla"
|
ADAMW_TORCH_XLA = "adamw_torch_xla"
|
||||||
|
ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
|
||||||
ADAMW_APEX_FUSED = "adamw_apex_fused"
|
ADAMW_APEX_FUSED = "adamw_apex_fused"
|
||||||
ADAFACTOR = "adafactor"
|
ADAFACTOR = "adafactor"
|
||||||
ADAMW_ANYPRECISION = "adamw_anyprecision"
|
ADAMW_ANYPRECISION = "adamw_anyprecision"
|
||||||
|
Loading…
Reference in New Issue
Block a user