diff --git a/tests/models/pix2struct/test_modeling_pix2struct.py b/tests/models/pix2struct/test_modeling_pix2struct.py index dc219bbd61a..f56f8f6d3ec 100644 --- a/tests/models/pix2struct/test_modeling_pix2struct.py +++ b/tests/models/pix2struct/test_modeling_pix2struct.py @@ -406,7 +406,7 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase): def test_model(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: - model = model_class(config) + model = model_class(config).to(torch_device) output = model(**input_dict) self.assertEqual(