mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[DeepSpeed] don't ignore --adafactor (#12257)
This commit is contained in:
parent
eb881674f2
commit
b75b5605c9
@ -318,7 +318,13 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
|
||||
# 4. HF scheduler + DS optimizer: No
|
||||
|
||||
optimizer = None
|
||||
if "optimizer" not in config:
|
||||
if "optimizer" in config:
|
||||
if trainer.args.adafactor:
|
||||
raise ValueError(
|
||||
"--adafactor was passed, but also found `optimizer` configured in the DeepSpeed config. "
|
||||
"Only one optimizer can be configured."
|
||||
)
|
||||
else:
|
||||
if hf_deepspeed_config.is_offload():
|
||||
raise ValueError("ZeRO Offload can only work with DeepSpeed optimizers")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user