Fix Trainer with a parallel model (#9578)

* Fix Trainer with a parallel model

* More clean up
This commit is contained in:
Sylvain Gugger 2021-01-14 03:23:41 -05:00 committed by GitHub
parent 126fd281bc
commit 5e1bea4f16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 13 deletions

View File

@ -16,7 +16,7 @@ import json
import os import os
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from .trainer_utils import EvaluationStrategy, SchedulerType from .trainer_utils import EvaluationStrategy, SchedulerType
@ -426,7 +426,6 @@ class TrainingArguments:
if is_torch_available() and self.device.type != "cuda" and self.fp16: if is_torch_available() and self.device.type != "cuda" and self.fp16:
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.") raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
self._n_gpu = torch.cuda.device_count()
def __repr__(self): def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once # We override the default repr to remove deprecated arguments from the repr. This method should be removed once
@ -467,14 +466,14 @@ class TrainingArguments:
@cached_property @cached_property
@torch_required @torch_required
def _setup_devices(self) -> Tuple["torch.device", int]: def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices") logger.info("PyTorch: setting up devices")
if self.no_cuda: if self.no_cuda:
device = torch.device("cpu") device = torch.device("cpu")
n_gpu = 0 self._n_gpu = 0
elif is_torch_tpu_available(): elif is_torch_tpu_available():
device = xm.xla_device() device = xm.xla_device()
n_gpu = 0 self._n_gpu = 0
elif self.local_rank == -1: elif self.local_rank == -1:
# if n_gpu is > 1 we'll use nn.DataParallel. # if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
@ -485,9 +484,7 @@ class TrainingArguments:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
# the default value. # the default value.
if self._n_gpu == -1: self._n_gpu = torch.cuda.device_count()
self._n_gpu = torch.cuda.device_count()
n_gpu = self._n_gpu
else: else:
# Here, we'll use torch.distributed. # Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
@ -507,12 +504,12 @@ class TrainingArguments:
else: else:
torch.distributed.init_process_group(backend="nccl") torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", self.local_rank) device = torch.device("cuda", self.local_rank)
n_gpu = 1 self._n_gpu = 1
if device.type == "cuda": if device.type == "cuda":
torch.cuda.set_device(device) torch.cuda.set_device(device)
return device, n_gpu return device
@property @property
@torch_required @torch_required
@ -520,7 +517,7 @@ class TrainingArguments:
""" """
The device used by this process. The device used by this process.
""" """
return self._setup_devices[0] return self._setup_devices
@property @property
@torch_required @torch_required
@ -532,7 +529,9 @@ class TrainingArguments:
This will only be greater than one when you have multiple GPUs available but are not using distributed This will only be greater than one when you have multiple GPUs available but are not using distributed
training. For distributed training, it will always be 1. training. For distributed training, it will always be 1.
""" """
return self._setup_devices[1] # Make sure `self._n_gpu` is properly setup.
_ = self._setup_devices
return self._n_gpu
@property @property
@torch_required @torch_required

View File

@ -381,9 +381,11 @@ class TrainerIntegrationTest(unittest.TestCase):
# Make the Trainer believe it's a parallelized model # Make the Trainer believe it's a parallelized model
model.is_parallelizable = True model.is_parallelizable = True
model.model_parallel = True model.model_parallel = True
trainer = Trainer(model=model, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset()) args = TrainingArguments("./regression", per_device_train_batch_size=16, per_device_eval_batch_size=16)
trainer = Trainer(model, args, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
# Check the Trainer was fooled # Check the Trainer was fooled
self.assertTrue(trainer.is_model_parallel) self.assertTrue(trainer.is_model_parallel)
self.assertEqual(trainer.args.n_gpu, 1)
# The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu # The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
self.assertEqual(trainer.get_train_dataloader().batch_size, 16) self.assertEqual(trainer.get_train_dataloader().batch_size, 16)