mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Add SageMakerTrainer for model paralellism (#10122)
* Refactor things out of main train * Store signature * Add SageMakerTrainer * Init + Copyright * Address review comments
This commit is contained in:
parent
b54cb0bd82
commit
31245775e5
20
src/transformers/sagemaker/__init__.py
Normal file
20
src/transformers/sagemaker/__init__.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||||
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .trainer_sm import SageMakerTrainer
|
||||||
|
from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_distributed_available
|
178
src/transformers/sagemaker/trainer_sm.py
Normal file
178
src/transformers/sagemaker/trainer_sm.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
from ..trainer import Trainer
|
||||||
|
from ..trainer_pt_utils import (
|
||||||
|
DistributedLengthGroupedSampler,
|
||||||
|
SequentialDistributedSampler,
|
||||||
|
nested_detach,
|
||||||
|
nested_numpify,
|
||||||
|
)
|
||||||
|
from ..utils import logging
|
||||||
|
from .training_args_sm import is_smdistributed_available
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if is_smdistributed_available():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
@smp.step()
|
||||||
|
def forward_backward(model, inputs):
|
||||||
|
outputs = model(**inputs)
|
||||||
|
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||||
|
model.backward(loss)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@smp.step()
|
||||||
|
def forward_only(model, inputs):
|
||||||
|
return model(**inputs)
|
||||||
|
|
||||||
|
def smp_gather(tensor):
|
||||||
|
if isinstance(tensor, (list, tuple)):
|
||||||
|
return type(tensor)(smp_gather(t) for t in tensor)
|
||||||
|
elif isinstance(tensor, dict):
|
||||||
|
return type(tensor)({k: smp_gather(v) for k, v in tensor.items()})
|
||||||
|
elif not isinstance(tensor, torch.Tensor):
|
||||||
|
raise TypeError(
|
||||||
|
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
|
||||||
|
)
|
||||||
|
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
|
||||||
|
return torch.cat([t.cpu() for t in all_tensors], dim=0)
|
||||||
|
|
||||||
|
def nested_smp_concat(tensor):
|
||||||
|
if isinstance(tensor, (list, tuple)):
|
||||||
|
return type(tensor)(nested_smp_concat(t) for t in tensor)
|
||||||
|
elif isinstance(tensor, dict):
|
||||||
|
return type(tensor)({k: nested_smp_concat(v) for k, v in tensor.items()})
|
||||||
|
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
|
||||||
|
# which is also the name of the decorator so Python is confused.
|
||||||
|
return tensor.concat().detach().cpu()
|
||||||
|
|
||||||
|
|
||||||
|
class SageMakerTrainer(Trainer):
|
||||||
|
def __init__(self, args=None, **kwargs):
|
||||||
|
super().__init__(args=args, **kwargs)
|
||||||
|
self.is_model_parallel_enabled = is_smdistributed_available() and self.args.mp_parameters != ""
|
||||||
|
if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1:
|
||||||
|
raise ValueError("Gradient accumulation is not supported when model parallel is enabled.")
|
||||||
|
|
||||||
|
def _get_train_sampler(self):
|
||||||
|
if self.is_model_parallel_enabled:
|
||||||
|
if self.args.group_by_length:
|
||||||
|
return DistributedLengthGroupedSampler(
|
||||||
|
self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
|
||||||
|
else:
|
||||||
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
|
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
|
if self.is_model_parallel_enabled:
|
||||||
|
return SequentialDistributedSampler(eval_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
|
||||||
|
else:
|
||||||
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
|
def _wrap_model(self, model, training=True):
|
||||||
|
if self.is_model_parallel_enabled:
|
||||||
|
# Wrapping the base model twice in a DistributedModel will raise an error.
|
||||||
|
if isinstance(self.model_wrapped, smp.model.DistributedModel):
|
||||||
|
return self.model_wrapped
|
||||||
|
return smp.DistributedModel(model)
|
||||||
|
else:
|
||||||
|
return super()._wrap_model(model)
|
||||||
|
|
||||||
|
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||||
|
super().create_optimizer_and_scheduler(num_training_steps)
|
||||||
|
if self.is_model_parallel_enabled:
|
||||||
|
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
||||||
|
|
||||||
|
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
||||||
|
if self.is_model_parallel_enabled:
|
||||||
|
model.train()
|
||||||
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
loss_mb = forward_backward(model, inputs)
|
||||||
|
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||||
|
else:
|
||||||
|
return super().training_step(model, inputs)
|
||||||
|
|
||||||
|
def _gather_and_numpify(self, tensors, name):
|
||||||
|
if self.is_model_parallel_enabled:
|
||||||
|
tensors = smp_gather(tensors)
|
||||||
|
return nested_numpify(tensors)
|
||||||
|
else:
|
||||||
|
return super()._gather_and_numpify(tensors, name)
|
||||||
|
|
||||||
|
def prediction_step(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||||
|
prediction_loss_only: bool,
|
||||||
|
ignore_keys: Optional[List[str]] = None,
|
||||||
|
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
if self.is_model_parallel_enabled:
|
||||||
|
has_labels = all(inputs.get(k) is not None for k in self.label_names)
|
||||||
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
|
if ignore_keys is None:
|
||||||
|
if hasattr(self.model, "config"):
|
||||||
|
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
||||||
|
else:
|
||||||
|
ignore_keys = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
raw_outputs = forward_only(model, inputs)
|
||||||
|
if has_labels:
|
||||||
|
if isinstance(raw_outputs, dict):
|
||||||
|
loss_mb = raw_outputs["loss"]
|
||||||
|
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
|
||||||
|
else:
|
||||||
|
loss_mb = raw_outputs[0]
|
||||||
|
logits_mb = raw_outputs[1:]
|
||||||
|
|
||||||
|
loss = loss_mb.reduce_mean().detach().cpu()
|
||||||
|
logits = nested_smp_concat(logits_mb)
|
||||||
|
else:
|
||||||
|
loss = None
|
||||||
|
if isinstance(raw_outputs, dict):
|
||||||
|
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
|
||||||
|
else:
|
||||||
|
logits_mb = raw_outputs
|
||||||
|
logits = nested_smp_concat(logits_mb)
|
||||||
|
|
||||||
|
if prediction_loss_only:
|
||||||
|
return (loss, None, None)
|
||||||
|
|
||||||
|
if len(logits) == 1:
|
||||||
|
logits = logits[0]
|
||||||
|
|
||||||
|
if has_labels:
|
||||||
|
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
||||||
|
if len(labels) == 1:
|
||||||
|
labels = labels[0]
|
||||||
|
else:
|
||||||
|
labels = None
|
||||||
|
|
||||||
|
return (loss, logits, labels)
|
||||||
|
else:
|
||||||
|
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
89
src/transformers/sagemaker/training_args_sm.py
Normal file
89
src/transformers/sagemaker/training_args_sm.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.file_utils import cached_property, is_sagemaker_distributed_available
|
||||||
|
from transformers.training_args import TrainingArguments
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_smdistributed_available():
|
||||||
|
return importlib.util.find_spec("smdistributed") is not None
|
||||||
|
|
||||||
|
|
||||||
|
if is_smdistributed_available():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SageMakerTrainingArguments(TrainingArguments):
|
||||||
|
mp_parameters: str = field(
|
||||||
|
default="", metadata={"help": "Used by the SageMaker launcher to send mp-specific args."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
if is_smdistributed_available() and self.mp_parameters != "":
|
||||||
|
smp.init()
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _setup_devices(self) -> "torch.device":
|
||||||
|
logger.info("PyTorch: setting up devices")
|
||||||
|
if self.no_cuda:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
self._n_gpu = 0
|
||||||
|
elif is_smdistributed_available() and self.mp_parameters != "":
|
||||||
|
local_rank = smp.local_rank()
|
||||||
|
device = torch.device("cuda", local_rank)
|
||||||
|
self._n_gpu = 1
|
||||||
|
elif is_sagemaker_distributed_available():
|
||||||
|
import smdistributed.dataparallel.torch.distributed as dist
|
||||||
|
|
||||||
|
dist.init_process_group()
|
||||||
|
self.local_rank = dist.get_local_rank()
|
||||||
|
device = torch.device("cuda", self.local_rank)
|
||||||
|
self._n_gpu = 1
|
||||||
|
elif self.local_rank == -1:
|
||||||
|
# if n_gpu is > 1 we'll use nn.DataParallel.
|
||||||
|
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||||
|
# Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
|
||||||
|
# trigger an error that a device index is missing. Index 0 takes into account the
|
||||||
|
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
|
||||||
|
# will use the first GPU in that env, i.e. GPU#1
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
|
||||||
|
# the default value.
|
||||||
|
self._n_gpu = torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
# Here, we'll use torch.distributed.
|
||||||
|
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
||||||
|
torch.distributed.init_process_group(backend="nccl")
|
||||||
|
device = torch.device("cuda", self.local_rank)
|
||||||
|
self._n_gpu = 1
|
||||||
|
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def place_model_on_device(self):
|
||||||
|
return not (is_smdistributed_available() and self.mp_parameters != "")
|
@ -272,7 +272,7 @@ class Trainer:
|
|||||||
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
|
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
|
||||||
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
|
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
|
||||||
# and we only use deepspeed for training at the moment
|
# and we only use deepspeed for training at the moment
|
||||||
if not self.is_model_parallel and not (args.deepspeed and args.do_train):
|
if not (self.is_model_parallel or (args.deepspeed and args.do_train)) and self.args.place_model_on_device:
|
||||||
model = model.to(args.device)
|
model = model.to(args.device)
|
||||||
|
|
||||||
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
||||||
@ -319,6 +319,7 @@ class Trainer:
|
|||||||
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
||||||
raise ValueError("eval_dataset must implement __len__")
|
raise ValueError("eval_dataset must implement __len__")
|
||||||
|
|
||||||
|
self._signature_columns = None
|
||||||
if is_datasets_available():
|
if is_datasets_available():
|
||||||
if isinstance(train_dataset, datasets.Dataset):
|
if isinstance(train_dataset, datasets.Dataset):
|
||||||
self._remove_unused_columns(self.train_dataset, description="training")
|
self._remove_unused_columns(self.train_dataset, description="training")
|
||||||
@ -425,16 +426,18 @@ class Trainer:
|
|||||||
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
||||||
if not self.args.remove_unused_columns:
|
if not self.args.remove_unused_columns:
|
||||||
return
|
return
|
||||||
# Inspect model forward signature to keep only the arguments it accepts.
|
if self._signature_columns is None:
|
||||||
signature = inspect.signature(self.model.forward)
|
# Inspect model forward signature to keep only the arguments it accepts.
|
||||||
signature_columns = list(signature.parameters.keys())
|
signature = inspect.signature(self.model.forward)
|
||||||
# Labels may be named label or label_ids, the default data collator handles that.
|
self._signature_columns = list(signature.parameters.keys())
|
||||||
signature_columns += ["label", "label_ids"]
|
# Labels may be named label or label_ids, the default data collator handles that.
|
||||||
columns = [k for k in signature_columns if k in dataset.column_names]
|
self._signature_columns += ["label", "label_ids"]
|
||||||
ignored_columns = list(set(dataset.column_names) - set(signature_columns))
|
columns = [k for k in self._signature_columns if k in dataset.column_names]
|
||||||
|
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
|
||||||
dset_description = "" if description is None else f"in the {description} set "
|
dset_description = "" if description is None else f"in the {description} set "
|
||||||
logger.info(
|
logger.info(
|
||||||
f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
|
f"The following columns {dset_description}don't have a corresponding argument in "
|
||||||
|
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
|
||||||
)
|
)
|
||||||
dataset.set_format(type=dataset.format["type"], columns=columns)
|
dataset.set_format(type=dataset.format["type"], columns=columns)
|
||||||
|
|
||||||
@ -684,6 +687,45 @@ class Trainer:
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def _wrap_model(self, model, training=True):
|
||||||
|
# Mixed precision training with apex (torch < 1.6)
|
||||||
|
if self.use_apex and training:
|
||||||
|
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||||
|
|
||||||
|
# Multi-gpu training (should be after apex fp16 initialization)
|
||||||
|
if self.args.n_gpu > 1:
|
||||||
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
|
# Note: in torch.distributed mode, there's no point in wrapping the model
|
||||||
|
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
||||||
|
if not training:
|
||||||
|
return model
|
||||||
|
|
||||||
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
|
if self.sharded_dpp:
|
||||||
|
model = ShardedDDP(model, self.optimizer)
|
||||||
|
elif is_sagemaker_distributed_available():
|
||||||
|
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
||||||
|
elif self.deepspeed:
|
||||||
|
pass # already initialized its own DDP earlier
|
||||||
|
elif self.args.local_rank != -1:
|
||||||
|
if self.args.ddp_find_unused_parameters is not None:
|
||||||
|
find_unused_parameters = self.args.ddp_find_unused_parameters
|
||||||
|
elif isinstance(model, PreTrainedModel):
|
||||||
|
# find_unused_parameters breaks checkpointing as per
|
||||||
|
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||||
|
find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False)
|
||||||
|
else:
|
||||||
|
find_unused_parameters = True
|
||||||
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
|
model,
|
||||||
|
device_ids=[self.args.local_rank],
|
||||||
|
output_device=self.args.local_rank,
|
||||||
|
find_unused_parameters=find_unused_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
resume_from_checkpoint: Optional[str] = None,
|
resume_from_checkpoint: Optional[str] = None,
|
||||||
@ -736,7 +778,7 @@ class Trainer:
|
|||||||
|
|
||||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||||
if model_reloaded:
|
if model_reloaded:
|
||||||
if not self.is_model_parallel:
|
if not self.is_model_parallel and self.args.place_model_on_device:
|
||||||
self.model = self.model.to(self.args.device)
|
self.model = self.model.to(self.args.device)
|
||||||
self.model_wrapped = self.model
|
self.model_wrapped = self.model
|
||||||
|
|
||||||
@ -783,38 +825,7 @@ class Trainer:
|
|||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
||||||
|
|
||||||
model = self.model_wrapped
|
model = self._wrap_model(self.model_wrapped)
|
||||||
|
|
||||||
# Mixed precision training with apex (torch < 1.6)
|
|
||||||
if self.use_apex:
|
|
||||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
|
||||||
|
|
||||||
# Multi-gpu training (should be after apex fp16 initialization)
|
|
||||||
if self.args.n_gpu > 1:
|
|
||||||
model = torch.nn.DataParallel(model)
|
|
||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
|
||||||
if self.sharded_dpp:
|
|
||||||
model = ShardedDDP(model, self.optimizer)
|
|
||||||
elif is_sagemaker_distributed_available():
|
|
||||||
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
|
||||||
elif self.deepspeed:
|
|
||||||
pass # already initialized its own DDP earlier
|
|
||||||
elif self.args.local_rank != -1:
|
|
||||||
if self.args.ddp_find_unused_parameters is not None:
|
|
||||||
find_unused_parameters = self.args.ddp_find_unused_parameters
|
|
||||||
elif isinstance(model, PreTrainedModel):
|
|
||||||
# find_unused_parameters breaks checkpointing as per
|
|
||||||
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
|
||||||
find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False)
|
|
||||||
else:
|
|
||||||
find_unused_parameters = True
|
|
||||||
model = torch.nn.parallel.DistributedDataParallel(
|
|
||||||
model,
|
|
||||||
device_ids=[self.args.local_rank],
|
|
||||||
output_device=self.args.local_rank,
|
|
||||||
find_unused_parameters=find_unused_parameters,
|
|
||||||
)
|
|
||||||
|
|
||||||
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
||||||
if model is not self.model:
|
if model is not self.model:
|
||||||
@ -1020,7 +1031,7 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
if isinstance(self.model, PreTrainedModel):
|
if isinstance(self.model, PreTrainedModel):
|
||||||
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
|
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
|
||||||
if not self.is_model_parallel:
|
if not self.is_model_parallel and self.args.place_model_on_device:
|
||||||
self.model = self.model.to(self.args.device)
|
self.model = self.model.to(self.args.device)
|
||||||
else:
|
else:
|
||||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||||
@ -1610,13 +1621,7 @@ class Trainer:
|
|||||||
# flagging only for when --do_train wasn't passed as only then it's redundant
|
# flagging only for when --do_train wasn't passed as only then it's redundant
|
||||||
logger.info("Detected the deepspeed argument but it will not be used for evaluation")
|
logger.info("Detected the deepspeed argument but it will not be used for evaluation")
|
||||||
|
|
||||||
model = self.model
|
model = self._wrap_model(self.model, training=False)
|
||||||
|
|
||||||
# multi-gpu eval
|
|
||||||
if self.args.n_gpu > 1:
|
|
||||||
model = torch.nn.DataParallel(model)
|
|
||||||
# Note: in torch.distributed mode, there's no point in wrapping the model
|
|
||||||
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
|
||||||
|
|
||||||
batch_size = dataloader.batch_size
|
batch_size = dataloader.batch_size
|
||||||
num_examples = self.num_examples(dataloader)
|
num_examples = self.num_examples(dataloader)
|
||||||
|
@ -637,6 +637,13 @@ class TrainingArguments:
|
|||||||
else:
|
else:
|
||||||
return ParallelMode.NOT_PARALLEL
|
return ParallelMode.NOT_PARALLEL
|
||||||
|
|
||||||
|
@property
|
||||||
|
def place_model_on_device(self):
|
||||||
|
"""
|
||||||
|
Can be subclassed and overridden for some specific integrations.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
"""
|
"""
|
||||||
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
|
||||||
|
Loading…
Reference in New Issue
Block a user