From 34307bb358b568b271362f40cb82ad4ca1ef0e8f Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Sat, 6 Nov 2021 15:08:58 +0100 Subject: [PATCH] Fix tests (#14289) --- tests/test_modeling_beit.py | 8 ++++++-- tests/test_modeling_segformer.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index f0a89031416..c38f956cd8b 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -232,7 +232,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): # this can then be incorporated into _prepare_for_class in test_modeling_common.py elif model_class.__name__ == "BeitForSemanticSegmentation": batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape - inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long() + inputs_dict["labels"] = torch.zeros( + [self.model_tester.batch_size, height, width], device=torch_device + ).long() model = model_class(config) model.to(torch_device) model.train() @@ -259,7 +261,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): # this can then be incorporated into _prepare_for_class in test_modeling_common.py elif model_class.__name__ == "BeitForSemanticSegmentation": batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape - inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long() + inputs_dict["labels"] = torch.zeros( + [self.model_tester.batch_size, height, width], device=torch_device + ).long() model = model_class(config) model.to(torch_device) model.train() diff --git a/tests/test_modeling_segformer.py b/tests/test_modeling_segformer.py index 6934f9b1871..2e84b1f41d1 100644 --- a/tests/test_modeling_segformer.py +++ b/tests/test_modeling_segformer.py @@ -318,7 +318,9 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase): # this can then be incorporated into _prepare_for_class in test_modeling_common.py if model_class.__name__ == "SegformerForSemanticSegmentation": batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape - inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long() + inputs_dict["labels"] = torch.zeros( + [self.model_tester.batch_size, height, width], device=torch_device + ).long() model = model_class(config) model.to(torch_device) model.train()