[DeepSpeed] simplify init (#10762)

This commit is contained in:
Stas Bekman 2021-03-17 10:21:03 -07:00 committed by GitHub
parent 0486ccdd3d
commit 01c7fb04be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -22,7 +22,6 @@ import os
import re
import tempfile
from pathlib import Path
from types import SimpleNamespace
from .utils import logging
from .utils.versions import require_version
@ -430,16 +429,12 @@ def init_deepspeed(trainer, num_training_steps):
"enabled": True,
}
# for clarity extract the specific cl args that are being passed to deepspeed
ds_args = dict(local_rank=args.local_rank)
# keep for quick debug:
# from pprint import pprint; pprint(config)
# init that takes part of the config via `args`, and the bulk of it via `config_params`
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
model, optimizer, _, lr_scheduler = deepspeed.initialize(
args=SimpleNamespace(**ds_args), # expects an obj
model=model,
model_parameters=model_parameters,
config_params=config,