diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 714580e13e4..d3d8afef07b 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -437,6 +437,9 @@ class SwinBlock(nn.Module): hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + if self.attn_mask is not None: + self.attn_mask = self.attn_mask.to(hidden_states_windows.device) + self_attention_outputs = self.attention( hidden_states_windows, self.attn_mask, diff --git a/tests/test_modeling_vilt.py b/tests/test_modeling_vilt.py index e9eca63adca..f1bac75a6cd 100644 --- a/tests/test_modeling_vilt.py +++ b/tests/test_modeling_vilt.py @@ -595,8 +595,8 @@ class ViltModelIntegrationTest(unittest.TestCase): # forward pass outputs = model( - input_ids=encoding_1.input_ids, - pixel_values=pixel_values, + input_ids=encoding_1.input_ids.to(torch_device), + pixel_values=pixel_values.to(torch_device), ) # verify the logits diff --git a/tests/test_modeling_vit_mae.py b/tests/test_modeling_vit_mae.py index 02d3a73f1e9..9cf9fa2759d 100644 --- a/tests/test_modeling_vit_mae.py +++ b/tests/test_modeling_vit_mae.py @@ -327,9 +327,6 @@ class ViTMAEModelTest(ModelTesterMixin, unittest.TestCase): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: - - print("Model class:", model_class) - model = model_class(config) model.to(torch_device) model.eval()