mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Fix BeitForMaskedImageModeling (#13275)
* First pass * Fix docs of bool_masked_pos * Add integration script * Fix docstring * Add integration test for BeitForMaskedImageModeling * Remove file * Fix docs
This commit is contained in:
parent
a3f96f366a
commit
cc27ac1a87
@ -667,7 +667,8 @@ class BeitPooler(nn.Module):
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).", BEIT_START_DOCSTRING
|
||||
"Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).",
|
||||
BEIT_START_DOCSTRING,
|
||||
)
|
||||
class BeitForMaskedImageModeling(BeitPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
@ -695,6 +696,9 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
bool_masked_pos (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, num_patches)`):
|
||||
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||||
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
@ -722,6 +726,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
|
||||
|
||||
outputs = self.beit(
|
||||
pixel_values,
|
||||
bool_masked_pos=bool_masked_pos,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
|
@ -379,6 +379,31 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224") if is_vision_available() else None
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_inference_masked_image_modeling_head(self):
|
||||
model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k").to(torch_device)
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(torch_device)
|
||||
|
||||
# prepare bool_masked_pos
|
||||
bool_masked_pos = torch.ones((1, 196), dtype=torch.bool).to(torch_device)
|
||||
|
||||
# forward pass
|
||||
outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 196, 8192))
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-3.2437, 0.5072, -13.9174], [-3.2456, 0.4948, -13.9401], [-3.2033, 0.5121, -13.8550]]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(logits[bool_masked_pos][:3, :3], expected_slice, atol=1e-2))
|
||||
|
||||
@slow
|
||||
def test_inference_image_classification_head_imagenet_1k(self):
|
||||
model = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224").to(torch_device)
|
||||
|
Loading…
Reference in New Issue
Block a user