Add cpu distributed fine-tuning support for transformers Trainer API (#13574)

* update trainer with cpu distributed fine-tuning support.

Signed-off-by: Ding, Ke <ke.ding@intel.com>

* Style.

* refinement on cpu dist training check.

Signed-off-by: Ding, Ke <ke.ding@intel.com>

* style.

Signed-off-by: Ding, Ke <ke.ding@intel.com>

* Test over private field not public one.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Co-authored-by: Funtowicz Morgan <mfuntowicz@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
kding1 2021-09-23 09:15:27 -07:00 committed by GitHub
parent 6a3a197fcd
commit 8632a60d33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 6 deletions

View File

@ -1001,8 +1001,8 @@ class Trainer:
find_unused_parameters = True
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
find_unused_parameters=find_unused_parameters,
)
@ -2005,7 +2005,9 @@ class Trainer:
def store_flos(self):
# Storing the number of floating-point operations that went into the model
if self.args.local_rank != -1:
self.state.total_flos += distributed_broadcast_scalars([self.current_flos]).sum().item()
self.state.total_flos += (
distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
)
self.current_flos = 0
else:
self.state.total_flos += self.current_flos

View File

@ -175,10 +175,12 @@ def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int]
def distributed_broadcast_scalars(
scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
scalars: List[Union[int, float]],
num_total_examples: Optional[int] = None,
device: Optional[torch.device] = torch.device("cuda"),
) -> torch.Tensor:
try:
tensorized_scalar = torch.tensor(scalars).cuda()
tensorized_scalar = torch.tensor(scalars).to(device)
output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensorized_scalar)
concat = torch.cat(output_tensors, dim=0)

View File

@ -220,6 +220,8 @@ class TrainingArguments:
can harm metric values.
local_rank (:obj:`int`, `optional`, defaults to -1):
Rank of the process during distributed training.
xpu_backend (:obj:`str`, `optional`):
The backend to use for xpu distributed training. Must be one of :obj:`"mpi"` or :obj:`"ccl"`.
tpu_num_cores (:obj:`int`, `optional`):
When training on TPU, the number of TPU cores (automatically passed by launcher script).
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
@ -526,7 +528,10 @@ class TrainingArguments:
metadata={"help": "Whether to use full 16-bit precision evaluation instead of 32-bit"},
)
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
xpu_backend: str = field(
default=None,
metadata={"help": "The backend to be used for distributed training on Intel XPU.", "choices": ["mpi", "ccl"]},
)
tpu_num_cores: Optional[int] = field(
default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"}
)
@ -894,6 +899,14 @@ class TrainingArguments:
if self.no_cuda:
device = torch.device("cpu")
self._n_gpu = 0
if self.local_rank != -1:
# Initializes distributed backend for cpu
if self.xpu_backend not in ("mpi", "ccl"):
raise ValueError(
"CPU distributed training backend is not properly set. "
"Please set '--xpu_backend' to either 'mpi' or 'ccl'."
)
torch.distributed.init_process_group(backend=self.xpu_backend)
elif is_torch_tpu_available():
device = xm.xla_device()
self._n_gpu = 0