diff --git a/tests/test_modeling_perceiver.py b/tests/test_modeling_perceiver.py index d6fba44c581..4e6e271448c 100644 --- a/tests/test_modeling_perceiver.py +++ b/tests/test_modeling_perceiver.py @@ -28,7 +28,7 @@ from datasets import load_dataset from transformers import PerceiverConfig from transformers.file_utils import is_torch_available, is_vision_available from transformers.models.auto import get_values -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device from .test_configuration_common import ConfigTester from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask @@ -757,6 +757,31 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase): loss.backward() + @require_torch_multi_gpu + def test_multi_gpu_data_parallel_forward(self): + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_model_class(model_class) + + # some params shouldn't be scattered by nn.DataParallel + # so just remove them if they are present. + blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"] + for k in blacklist_non_batched_params: + inputs_dict.pop(k, None) + + # move input tensors to cuda:O + for k, v in inputs_dict.items(): + if torch.is_tensor(v): + inputs_dict[k] = v.to(0) + + model = model_class(config=config) + model.to(0) + model.eval() + + # Wrap model in nn.DataParallel + model = nn.DataParallel(model) + with torch.no_grad(): + _ = model(**self._prepare_for_class(inputs_dict, model_class)) + @unittest.skip(reason="Perceiver models don't have a typical head like is the case with BERT") def test_save_load_fast_init_from_base(self): pass