mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Add cpu distributed fine-tuning support for transformers Trainer API (#13574)
* update trainer with cpu distributed fine-tuning support. Signed-off-by: Ding, Ke <ke.ding@intel.com> * Style. * refinement on cpu dist training check. Signed-off-by: Ding, Ke <ke.ding@intel.com> * style. Signed-off-by: Ding, Ke <ke.ding@intel.com> * Test over private field not public one. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com> Co-authored-by: Funtowicz Morgan <mfuntowicz@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
6a3a197fcd
commit
8632a60d33
@ -1001,8 +1001,8 @@ class Trainer:
|
||||
find_unused_parameters = True
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[self.args.local_rank],
|
||||
output_device=self.args.local_rank,
|
||||
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
|
||||
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
)
|
||||
|
||||
@ -2005,7 +2005,9 @@ class Trainer:
|
||||
def store_flos(self):
|
||||
# Storing the number of floating-point operations that went into the model
|
||||
if self.args.local_rank != -1:
|
||||
self.state.total_flos += distributed_broadcast_scalars([self.current_flos]).sum().item()
|
||||
self.state.total_flos += (
|
||||
distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
|
||||
)
|
||||
self.current_flos = 0
|
||||
else:
|
||||
self.state.total_flos += self.current_flos
|
||||
|
@ -175,10 +175,12 @@ def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int]
|
||||
|
||||
|
||||
def distributed_broadcast_scalars(
|
||||
scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
|
||||
scalars: List[Union[int, float]],
|
||||
num_total_examples: Optional[int] = None,
|
||||
device: Optional[torch.device] = torch.device("cuda"),
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
tensorized_scalar = torch.tensor(scalars).cuda()
|
||||
tensorized_scalar = torch.tensor(scalars).to(device)
|
||||
output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output_tensors, tensorized_scalar)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
@ -220,6 +220,8 @@ class TrainingArguments:
|
||||
can harm metric values.
|
||||
local_rank (:obj:`int`, `optional`, defaults to -1):
|
||||
Rank of the process during distributed training.
|
||||
xpu_backend (:obj:`str`, `optional`):
|
||||
The backend to use for xpu distributed training. Must be one of :obj:`"mpi"` or :obj:`"ccl"`.
|
||||
tpu_num_cores (:obj:`int`, `optional`):
|
||||
When training on TPU, the number of TPU cores (automatically passed by launcher script).
|
||||
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
@ -526,7 +528,10 @@ class TrainingArguments:
|
||||
metadata={"help": "Whether to use full 16-bit precision evaluation instead of 32-bit"},
|
||||
)
|
||||
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
|
||||
|
||||
xpu_backend: str = field(
|
||||
default=None,
|
||||
metadata={"help": "The backend to be used for distributed training on Intel XPU.", "choices": ["mpi", "ccl"]},
|
||||
)
|
||||
tpu_num_cores: Optional[int] = field(
|
||||
default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"}
|
||||
)
|
||||
@ -894,6 +899,14 @@ class TrainingArguments:
|
||||
if self.no_cuda:
|
||||
device = torch.device("cpu")
|
||||
self._n_gpu = 0
|
||||
if self.local_rank != -1:
|
||||
# Initializes distributed backend for cpu
|
||||
if self.xpu_backend not in ("mpi", "ccl"):
|
||||
raise ValueError(
|
||||
"CPU distributed training backend is not properly set. "
|
||||
"Please set '--xpu_backend' to either 'mpi' or 'ccl'."
|
||||
)
|
||||
torch.distributed.init_process_group(backend=self.xpu_backend)
|
||||
elif is_torch_tpu_available():
|
||||
device = xm.xla_device()
|
||||
self._n_gpu = 0
|
||||
|
Loading…
Reference in New Issue
Block a user