Extend Trainer to enable Ascend NPU to use the fused Adamw optimizer when training (#26194)

This commit is contained in:
statelesshz 2023-10-04 20:57:11 +08:00 committed by GitHub
parent fc296f419e
commit 4fdf47cd3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 0 deletions

View File

@ -1068,6 +1068,14 @@ class Trainer:
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
except ImportError: except ImportError:
raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.") raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
elif args.optim == OptimizerNames.ADAMW_TORCH_NPU_FUSED:
try:
from torch_npu.optim import NpuFusedAdamW
optimizer_cls = NpuFusedAdamW
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer failed to import FusedAdamW from torch_npu.")
elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
try: try:
from apex.optimizers import FusedAdam from apex.optimizers import FusedAdam

View File

@ -140,6 +140,7 @@ class OptimizerNames(ExplicitEnum):
ADAMW_TORCH = "adamw_torch" ADAMW_TORCH = "adamw_torch"
ADAMW_TORCH_FUSED = "adamw_torch_fused" ADAMW_TORCH_FUSED = "adamw_torch_fused"
ADAMW_TORCH_XLA = "adamw_torch_xla" ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
ADAMW_APEX_FUSED = "adamw_apex_fused" ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor" ADAFACTOR = "adafactor"
ADAMW_ANYPRECISION = "adamw_anyprecision" ADAMW_ANYPRECISION = "adamw_anyprecision"