mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[DeepSpeed] simplify init (#10762)
This commit is contained in:
parent
0486ccdd3d
commit
01c7fb04be
@ -22,7 +22,6 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from .utils import logging
|
||||
from .utils.versions import require_version
|
||||
@ -430,16 +429,12 @@ def init_deepspeed(trainer, num_training_steps):
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
# for clarity extract the specific cl args that are being passed to deepspeed
|
||||
ds_args = dict(local_rank=args.local_rank)
|
||||
|
||||
# keep for quick debug:
|
||||
# from pprint import pprint; pprint(config)
|
||||
|
||||
# init that takes part of the config via `args`, and the bulk of it via `config_params`
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
model, optimizer, _, lr_scheduler = deepspeed.initialize(
|
||||
args=SimpleNamespace(**ds_args), # expects an obj
|
||||
model=model,
|
||||
model_parameters=model_parameters,
|
||||
config_params=config,
|
||||
|
Loading…
Reference in New Issue
Block a user