mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Fix Trainer with a parallel model (#9578)
* Fix Trainer with a parallel model * More clean up
This commit is contained in:
parent
126fd281bc
commit
5e1bea4f16
@ -16,7 +16,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
||||||
from .trainer_utils import EvaluationStrategy, SchedulerType
|
from .trainer_utils import EvaluationStrategy, SchedulerType
|
||||||
@ -426,7 +426,6 @@ class TrainingArguments:
|
|||||||
|
|
||||||
if is_torch_available() and self.device.type != "cuda" and self.fp16:
|
if is_torch_available() and self.device.type != "cuda" and self.fp16:
|
||||||
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
|
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
|
||||||
self._n_gpu = torch.cuda.device_count()
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
||||||
@ -467,14 +466,14 @@ class TrainingArguments:
|
|||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@torch_required
|
@torch_required
|
||||||
def _setup_devices(self) -> Tuple["torch.device", int]:
|
def _setup_devices(self) -> "torch.device":
|
||||||
logger.info("PyTorch: setting up devices")
|
logger.info("PyTorch: setting up devices")
|
||||||
if self.no_cuda:
|
if self.no_cuda:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
n_gpu = 0
|
self._n_gpu = 0
|
||||||
elif is_torch_tpu_available():
|
elif is_torch_tpu_available():
|
||||||
device = xm.xla_device()
|
device = xm.xla_device()
|
||||||
n_gpu = 0
|
self._n_gpu = 0
|
||||||
elif self.local_rank == -1:
|
elif self.local_rank == -1:
|
||||||
# if n_gpu is > 1 we'll use nn.DataParallel.
|
# if n_gpu is > 1 we'll use nn.DataParallel.
|
||||||
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||||
@ -485,9 +484,7 @@ class TrainingArguments:
|
|||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
|
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
|
||||||
# the default value.
|
# the default value.
|
||||||
if self._n_gpu == -1:
|
self._n_gpu = torch.cuda.device_count()
|
||||||
self._n_gpu = torch.cuda.device_count()
|
|
||||||
n_gpu = self._n_gpu
|
|
||||||
else:
|
else:
|
||||||
# Here, we'll use torch.distributed.
|
# Here, we'll use torch.distributed.
|
||||||
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
||||||
@ -507,12 +504,12 @@ class TrainingArguments:
|
|||||||
else:
|
else:
|
||||||
torch.distributed.init_process_group(backend="nccl")
|
torch.distributed.init_process_group(backend="nccl")
|
||||||
device = torch.device("cuda", self.local_rank)
|
device = torch.device("cuda", self.local_rank)
|
||||||
n_gpu = 1
|
self._n_gpu = 1
|
||||||
|
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
return device, n_gpu
|
return device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@torch_required
|
@torch_required
|
||||||
@ -520,7 +517,7 @@ class TrainingArguments:
|
|||||||
"""
|
"""
|
||||||
The device used by this process.
|
The device used by this process.
|
||||||
"""
|
"""
|
||||||
return self._setup_devices[0]
|
return self._setup_devices
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@torch_required
|
@torch_required
|
||||||
@ -532,7 +529,9 @@ class TrainingArguments:
|
|||||||
This will only be greater than one when you have multiple GPUs available but are not using distributed
|
This will only be greater than one when you have multiple GPUs available but are not using distributed
|
||||||
training. For distributed training, it will always be 1.
|
training. For distributed training, it will always be 1.
|
||||||
"""
|
"""
|
||||||
return self._setup_devices[1]
|
# Make sure `self._n_gpu` is properly setup.
|
||||||
|
_ = self._setup_devices
|
||||||
|
return self._n_gpu
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@torch_required
|
@torch_required
|
||||||
|
@ -381,9 +381,11 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
# Make the Trainer believe it's a parallelized model
|
# Make the Trainer believe it's a parallelized model
|
||||||
model.is_parallelizable = True
|
model.is_parallelizable = True
|
||||||
model.model_parallel = True
|
model.model_parallel = True
|
||||||
trainer = Trainer(model=model, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
|
args = TrainingArguments("./regression", per_device_train_batch_size=16, per_device_eval_batch_size=16)
|
||||||
|
trainer = Trainer(model, args, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
|
||||||
# Check the Trainer was fooled
|
# Check the Trainer was fooled
|
||||||
self.assertTrue(trainer.is_model_parallel)
|
self.assertTrue(trainer.is_model_parallel)
|
||||||
|
self.assertEqual(trainer.args.n_gpu, 1)
|
||||||
|
|
||||||
# The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
|
# The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
|
||||||
self.assertEqual(trainer.get_train_dataloader().batch_size, 16)
|
self.assertEqual(trainer.get_train_dataloader().batch_size, 16)
|
||||||
|
Loading…
Reference in New Issue
Block a user