Enable customized optimizer for DeepSpeed (#32049)

* transformers: enable custom optimizer for DeepSpeed

* transformers: modify error message

---------

Co-authored-by: datakim1201 <roy.kim@maum.ai>
This commit is contained in:
roy 2024-10-07 22:36:54 +09:00 committed by GitHub
parent 7bae833728
commit 55be7c4c48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -599,11 +599,11 @@ class Trainer:
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
)
if (self.is_deepspeed_enabled or self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
self.optimizer is not None or self.lr_scheduler is not None
):
raise RuntimeError(
"Passing `optimizers` is not allowed if Deepspeed or PyTorch FSDP is enabled. "
"Passing `optimizers` is not allowed if PyTorch FSDP is enabled. "
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)