diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 0eefac892c9..236551d27cc 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -667,7 +667,8 @@ class BeitPooler(nn.Module): @add_start_docstrings( - "Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).", BEIT_START_DOCSTRING + "Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).", + BEIT_START_DOCSTRING, ) class BeitForMaskedImageModeling(BeitPreTrainedModel): def __init__(self, config): @@ -695,6 +696,9 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel): return_dict=None, ): r""" + bool_masked_pos (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ..., config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), @@ -722,6 +726,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel): outputs = self.beit( pixel_values, + bool_masked_pos=bool_masked_pos, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index d0cc5f6eb49..b4c670356f5 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -379,6 +379,31 @@ class BeitModelIntegrationTest(unittest.TestCase): BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224") if is_vision_available() else None ) + @slow + def test_inference_masked_image_modeling_head(self): + model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k").to(torch_device) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(torch_device) + + # prepare bool_masked_pos + bool_masked_pos = torch.ones((1, 196), dtype=torch.bool).to(torch_device) + + # forward pass + outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos) + logits = outputs.logits + + # verify the logits + expected_shape = torch.Size((1, 196, 8192)) + self.assertEqual(logits.shape, expected_shape) + + expected_slice = torch.tensor( + [[-3.2437, 0.5072, -13.9174], [-3.2456, 0.4948, -13.9401], [-3.2033, 0.5121, -13.8550]] + ).to(torch_device) + + self.assertTrue(torch.allclose(logits[bool_masked_pos][:3, :3], expected_slice, atol=1e-2)) + @slow def test_inference_image_classification_head_imagenet_1k(self): model = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224").to(torch_device)