[trainer] refactor place_model_on_device logic, add deepspeed (#10243)

* refactor place_model_on_device logic, add deepspeed

* doc

* style
This commit is contained in:
Stas Bekman 2021-02-17 15:52:36 -08:00 committed by GitHub
parent d1eb88f42d
commit dee876ceff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -214,6 +214,10 @@ class Trainer:
inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
- **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
data parallelism, this means some of the model layers are split on different GPUs).
- **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
to :obj:`False` if model parallel or deepspeed is used, or if the default
``TrainingArguments.place_model_on_device`` is overridden to return :obj:`False` .
"""
def __init__(
@ -262,6 +266,11 @@ class Trainer:
else:
self.is_model_parallel = False
# one place to sort out whether to place the model on device or not
self.place_model_on_device = args.place_model_on_device
if self.is_model_parallel or (args.deepspeed and args.do_train):
self.place_model_on_device = False
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset
@ -272,7 +281,7 @@ class Trainer:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
# and we only use deepspeed for training at the moment
if not (self.is_model_parallel or (args.deepspeed and args.do_train)) and self.args.place_model_on_device:
if self.place_model_on_device:
model = model.to(args.device)
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
@ -780,7 +789,7 @@ class Trainer:
# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
if not self.is_model_parallel and self.args.place_model_on_device:
if self.place_model_on_device:
self.model = self.model.to(self.args.device)
self.model_wrapped = self.model
@ -1033,7 +1042,7 @@ class Trainer:
)
if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
if not self.is_model_parallel and self.args.place_model_on_device:
if self.place_model_on_device:
self.model = self.model.to(self.args.device)
else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))