Fix BeitForMaskedImageModeling (#13275)

* First pass

* Fix docs of bool_masked_pos

* Add integration script

* Fix docstring

* Add integration test for BeitForMaskedImageModeling

* Remove file

* Fix docs
This commit is contained in:
NielsRogge 2021-08-27 15:09:57 +02:00 committed by GitHub
parent a3f96f366a
commit cc27ac1a87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 1 deletions

View File

@ -667,7 +667,8 @@ class BeitPooler(nn.Module):
@add_start_docstrings(
"Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).", BEIT_START_DOCSTRING
"Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).",
BEIT_START_DOCSTRING,
)
class BeitForMaskedImageModeling(BeitPreTrainedModel):
def __init__(self, config):
@ -695,6 +696,9 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
return_dict=None,
):
r"""
bool_masked_pos (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
@ -722,6 +726,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
outputs = self.beit(
pixel_values,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,

View File

@ -379,6 +379,31 @@ class BeitModelIntegrationTest(unittest.TestCase):
BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224") if is_vision_available() else None
)
@slow
def test_inference_masked_image_modeling_head(self):
model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k").to(torch_device)
feature_extractor = self.default_feature_extractor
image = prepare_img()
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(torch_device)
# prepare bool_masked_pos
bool_masked_pos = torch.ones((1, 196), dtype=torch.bool).to(torch_device)
# forward pass
outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos)
logits = outputs.logits
# verify the logits
expected_shape = torch.Size((1, 196, 8192))
self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor(
[[-3.2437, 0.5072, -13.9174], [-3.2456, 0.4948, -13.9401], [-3.2033, 0.5121, -13.8550]]
).to(torch_device)
self.assertTrue(torch.allclose(logits[bool_masked_pos][:3, :3], expected_slice, atol=1e-2))
@slow
def test_inference_image_classification_head_imagenet_1k(self):
model = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224").to(torch_device)