Fix ViT-MAE decoder interpolate (#33330)

* Fix ViT-MAE decoder interpolate

* Add unit test for `interpolate_pos_encoding` w/ custom sizes

* [run_slow] vit_mae
This commit is contained in:
Joshua Lochner 2024-09-30 18:47:13 +02:00 committed by GitHub
parent 1dba608df9
commit 18c5b216f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 16 deletions

View File

@ -841,21 +841,20 @@ class ViTMAEDecoder(nn.Module):
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
This method is a modified version of the interpolation function for ViT-mae model at the decoder, that
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
resolution images.
Source:
Adapted from:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
# -1 removes the class dimension since we later append it without interpolation
embeddings_positions = embeddings.shape[1] - 1
num_positions = self.decoder_pos_embed.shape[1] - 1
# Separation of class token and patch tokens
class_pos_embed = self.decoder_pos_embed[:, 0, :]
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
class_pos_embed = self.decoder_pos_embed[:, :1]
patch_pos_embed = self.decoder_pos_embed[:, 1:]
# To retain the final 3d tensor with the required dimensions
dim = self.decoder_pos_embed.shape[-1]
@ -867,10 +866,10 @@ class ViTMAEDecoder(nn.Module):
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
# Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x).
# 1 keeps the other dimension constant
# we keep the second last dimension constant
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(1, embeddings_positions / num_positions),
size=(patch_pos_embed.shape[-2], embeddings_positions),
mode="bicubic",
align_corners=False,
)
@ -878,7 +877,7 @@ class ViTMAEDecoder(nn.Module):
# Converting back to the original shape
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
# Adding the class token back
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
def initialize_weights(self, num_patches):
# initialize (and freeze) position embeddings by sin-cos embedding

View File

@ -298,12 +298,16 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
def default_image_processor(self):
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
@cached_property
def default_model(self):
return ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
@slow
def test_inference_for_pretraining(self):
# make random mask reproducible across the PT and TF model
np.random.seed(2)
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
model = self.default_model
image_processor = self.default_image_processor
image = prepare_img()
@ -313,11 +317,11 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
# (this way we can ensure that the PT and TF models operate on the same inputs)
vit_mae_config = ViTMAEConfig()
num_patches = int((vit_mae_config.image_size // vit_mae_config.patch_size) ** 2)
noise = np.random.uniform(size=(1, num_patches))
noise = torch.from_numpy(np.random.uniform(size=(1, num_patches))).to(device=torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs, noise=torch.from_numpy(noise).to(device=torch_device))
outputs = model(**inputs, noise=noise)
# verify the logits
expected_shape = torch.Size((1, 196, 768))
@ -339,7 +343,7 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
# make random mask reproducible across the PT and TF model
np.random.seed(2)
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
model = self.default_model
image_processor = self.default_image_processor
image = prepare_img()
@ -349,14 +353,38 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
# (this way we can ensure that the PT and TF models operate on the same inputs)
vit_mae_config = ViTMAEConfig()
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
noise = np.random.uniform(size=(1, num_patches))
noise = torch.from_numpy(np.random.uniform(size=(1, num_patches))).to(device=torch_device)
# forward pass
with torch.no_grad():
outputs = model(
**inputs, noise=torch.from_numpy(noise).to(device=torch_device), interpolate_pos_encoding=True
)
outputs = model(**inputs, noise=noise, interpolate_pos_encoding=True)
# verify the logits
expected_shape = torch.Size((1, 1200, 768))
self.assertEqual(outputs.logits.shape, expected_shape)
@slow
def test_inference_interpolate_pos_encoding_custom_sizes(self):
# Ensure custom sizes are correctly handled when interpolating the position embeddings
# make random mask reproducible across the PT and TF model
np.random.seed(2)
model = self.default_model
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt", size={"height": 256, "width": 256}).to(
torch_device
)
# forward pass
with torch.no_grad():
outputs = model(
**inputs,
interpolate_pos_encoding=True,
)
# verify the logits
expected_shape = torch.Size((1, 256, 768))
self.assertEqual(outputs.logits.shape, expected_shape)