mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[trainer] refactor place_model_on_device logic, add deepspeed (#10243)
* refactor place_model_on_device logic, add deepspeed * doc * style
This commit is contained in:
parent
d1eb88f42d
commit
dee876ceff
@ -214,6 +214,10 @@ class Trainer:
|
||||
inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
|
||||
- **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
|
||||
data parallelism, this means some of the model layers are split on different GPUs).
|
||||
- **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
|
||||
to :obj:`False` if model parallel or deepspeed is used, or if the default
|
||||
``TrainingArguments.place_model_on_device`` is overridden to return :obj:`False` .
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -262,6 +266,11 @@ class Trainer:
|
||||
else:
|
||||
self.is_model_parallel = False
|
||||
|
||||
# one place to sort out whether to place the model on device or not
|
||||
self.place_model_on_device = args.place_model_on_device
|
||||
if self.is_model_parallel or (args.deepspeed and args.do_train):
|
||||
self.place_model_on_device = False
|
||||
|
||||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||
self.data_collator = data_collator if data_collator is not None else default_collator
|
||||
self.train_dataset = train_dataset
|
||||
@ -272,7 +281,7 @@ class Trainer:
|
||||
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
|
||||
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
|
||||
# and we only use deepspeed for training at the moment
|
||||
if not (self.is_model_parallel or (args.deepspeed and args.do_train)) and self.args.place_model_on_device:
|
||||
if self.place_model_on_device:
|
||||
model = model.to(args.device)
|
||||
|
||||
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
||||
@ -780,7 +789,7 @@ class Trainer:
|
||||
|
||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||
if model_reloaded:
|
||||
if not self.is_model_parallel and self.args.place_model_on_device:
|
||||
if self.place_model_on_device:
|
||||
self.model = self.model.to(self.args.device)
|
||||
self.model_wrapped = self.model
|
||||
|
||||
@ -1033,7 +1042,7 @@ class Trainer:
|
||||
)
|
||||
if isinstance(self.model, PreTrainedModel):
|
||||
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
|
||||
if not self.is_model_parallel and self.args.place_model_on_device:
|
||||
if self.place_model_on_device:
|
||||
self.model = self.model.to(self.args.device)
|
||||
else:
|
||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||
|
Loading…
Reference in New Issue
Block a user