mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add test (#14810)
This commit is contained in:
parent
c4a0fb5199
commit
cbf036f7ae
@ -28,7 +28,7 @@ from datasets import load_dataset
|
||||
from transformers import PerceiverConfig
|
||||
from transformers.file_utils import is_torch_available, is_vision_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
@ -757,6 +757,31 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
loss.backward()
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_model_class(model_class)
|
||||
|
||||
# some params shouldn't be scattered by nn.DataParallel
|
||||
# so just remove them if they are present.
|
||||
blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"]
|
||||
for k in blacklist_non_batched_params:
|
||||
inputs_dict.pop(k, None)
|
||||
|
||||
# move input tensors to cuda:O
|
||||
for k, v in inputs_dict.items():
|
||||
if torch.is_tensor(v):
|
||||
inputs_dict[k] = v.to(0)
|
||||
|
||||
model = model_class(config=config)
|
||||
model.to(0)
|
||||
model.eval()
|
||||
|
||||
# Wrap model in nn.DataParallel
|
||||
model = nn.DataParallel(model)
|
||||
with torch.no_grad():
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
@unittest.skip(reason="Perceiver models don't have a typical head like is the case with BERT")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user