diff --git a/src/transformers/deepspeed.py b/src/transformers/deepspeed.py index 24fd01e14ea..7cf9fb07f0a 100644 --- a/src/transformers/deepspeed.py +++ b/src/transformers/deepspeed.py @@ -295,11 +295,13 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None): """ import deepspeed + from deepspeed.utils import logger as ds_logger model = trainer.model + args = trainer.args - hf_deepspeed_config = trainer.args.hf_deepspeed_config - hf_deepspeed_config.trainer_config_finalize(trainer.args, model, num_training_steps) + hf_deepspeed_config = args.hf_deepspeed_config + hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps) # resume config update - some bits like `model` and `num_training_steps` only become available during train config = hf_deepspeed_config.config @@ -319,7 +321,7 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None): optimizer = None if "optimizer" in config: - if trainer.args.adafactor: + if args.adafactor: raise ValueError( "--adafactor was passed, but also found `optimizer` configured in the DeepSpeed config. " "Only one optimizer can be configured." @@ -356,6 +358,9 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None): # keep for quick debug: # from pprint import pprint; pprint(config) + # set the Deepspeed log level consistent with the trainer + ds_logger.setLevel(args.get_process_log_level()) + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) model, optimizer, _, lr_scheduler = deepspeed.initialize(