mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Fix Trainer with a parallel model (#9578)
* Fix Trainer with a parallel model * More clean up
This commit is contained in:
parent
126fd281bc
commit
5e1bea4f16
@ -16,7 +16,7 @@ import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
||||
from .trainer_utils import EvaluationStrategy, SchedulerType
|
||||
@ -426,7 +426,6 @@ class TrainingArguments:
|
||||
|
||||
if is_torch_available() and self.device.type != "cuda" and self.fp16:
|
||||
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
|
||||
self._n_gpu = torch.cuda.device_count()
|
||||
|
||||
def __repr__(self):
|
||||
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
||||
@ -467,14 +466,14 @@ class TrainingArguments:
|
||||
|
||||
@cached_property
|
||||
@torch_required
|
||||
def _setup_devices(self) -> Tuple["torch.device", int]:
|
||||
def _setup_devices(self) -> "torch.device":
|
||||
logger.info("PyTorch: setting up devices")
|
||||
if self.no_cuda:
|
||||
device = torch.device("cpu")
|
||||
n_gpu = 0
|
||||
self._n_gpu = 0
|
||||
elif is_torch_tpu_available():
|
||||
device = xm.xla_device()
|
||||
n_gpu = 0
|
||||
self._n_gpu = 0
|
||||
elif self.local_rank == -1:
|
||||
# if n_gpu is > 1 we'll use nn.DataParallel.
|
||||
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||
@ -485,9 +484,7 @@ class TrainingArguments:
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
|
||||
# the default value.
|
||||
if self._n_gpu == -1:
|
||||
self._n_gpu = torch.cuda.device_count()
|
||||
n_gpu = self._n_gpu
|
||||
else:
|
||||
# Here, we'll use torch.distributed.
|
||||
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
||||
@ -507,12 +504,12 @@ class TrainingArguments:
|
||||
else:
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
n_gpu = 1
|
||||
self._n_gpu = 1
|
||||
|
||||
if device.type == "cuda":
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
return device, n_gpu
|
||||
return device
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
@ -520,7 +517,7 @@ class TrainingArguments:
|
||||
"""
|
||||
The device used by this process.
|
||||
"""
|
||||
return self._setup_devices[0]
|
||||
return self._setup_devices
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
@ -532,7 +529,9 @@ class TrainingArguments:
|
||||
This will only be greater than one when you have multiple GPUs available but are not using distributed
|
||||
training. For distributed training, it will always be 1.
|
||||
"""
|
||||
return self._setup_devices[1]
|
||||
# Make sure `self._n_gpu` is properly setup.
|
||||
_ = self._setup_devices
|
||||
return self._n_gpu
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
|
@ -381,9 +381,11 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
# Make the Trainer believe it's a parallelized model
|
||||
model.is_parallelizable = True
|
||||
model.model_parallel = True
|
||||
trainer = Trainer(model=model, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
|
||||
args = TrainingArguments("./regression", per_device_train_batch_size=16, per_device_eval_batch_size=16)
|
||||
trainer = Trainer(model, args, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
|
||||
# Check the Trainer was fooled
|
||||
self.assertTrue(trainer.is_model_parallel)
|
||||
self.assertEqual(trainer.args.n_gpu, 1)
|
||||
|
||||
# The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
|
||||
self.assertEqual(trainer.get_train_dataloader().batch_size, 16)
|
||||
|
Loading…
Reference in New Issue
Block a user