mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Fix ViT-MAE decoder interpolate (#33330)
* Fix ViT-MAE decoder interpolate * Add unit test for `interpolate_pos_encoding` w/ custom sizes * [run_slow] vit_mae
This commit is contained in:
parent
1dba608df9
commit
18c5b216f1
@ -841,21 +841,20 @@ class ViTMAEDecoder(nn.Module):
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
|
||||
This method is a modified version of the interpolation function for ViT-mae model at the decoder, that
|
||||
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
|
||||
Source:
|
||||
Adapted from:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
|
||||
# -1 removes the class dimension since we later append it without interpolation
|
||||
embeddings_positions = embeddings.shape[1] - 1
|
||||
num_positions = self.decoder_pos_embed.shape[1] - 1
|
||||
|
||||
# Separation of class token and patch tokens
|
||||
class_pos_embed = self.decoder_pos_embed[:, 0, :]
|
||||
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
|
||||
class_pos_embed = self.decoder_pos_embed[:, :1]
|
||||
patch_pos_embed = self.decoder_pos_embed[:, 1:]
|
||||
|
||||
# To retain the final 3d tensor with the required dimensions
|
||||
dim = self.decoder_pos_embed.shape[-1]
|
||||
@ -867,10 +866,10 @@ class ViTMAEDecoder(nn.Module):
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
# Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x).
|
||||
# 1 keeps the other dimension constant
|
||||
# we keep the second last dimension constant
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(1, embeddings_positions / num_positions),
|
||||
size=(patch_pos_embed.shape[-2], embeddings_positions),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
@ -878,7 +877,7 @@ class ViTMAEDecoder(nn.Module):
|
||||
# Converting back to the original shape
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
# Adding the class token back
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def initialize_weights(self, num_patches):
|
||||
# initialize (and freeze) position embeddings by sin-cos embedding
|
||||
|
@ -298,12 +298,16 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
def default_image_processor(self):
|
||||
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
|
||||
|
||||
@cached_property
|
||||
def default_model(self):
|
||||
return ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
|
||||
|
||||
@slow
|
||||
def test_inference_for_pretraining(self):
|
||||
# make random mask reproducible across the PT and TF model
|
||||
np.random.seed(2)
|
||||
|
||||
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
|
||||
model = self.default_model
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img()
|
||||
@ -313,11 +317,11 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
# (this way we can ensure that the PT and TF models operate on the same inputs)
|
||||
vit_mae_config = ViTMAEConfig()
|
||||
num_patches = int((vit_mae_config.image_size // vit_mae_config.patch_size) ** 2)
|
||||
noise = np.random.uniform(size=(1, num_patches))
|
||||
noise = torch.from_numpy(np.random.uniform(size=(1, num_patches))).to(device=torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, noise=torch.from_numpy(noise).to(device=torch_device))
|
||||
outputs = model(**inputs, noise=noise)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 196, 768))
|
||||
@ -339,7 +343,7 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
# make random mask reproducible across the PT and TF model
|
||||
np.random.seed(2)
|
||||
|
||||
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
|
||||
model = self.default_model
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img()
|
||||
@ -349,14 +353,38 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
# (this way we can ensure that the PT and TF models operate on the same inputs)
|
||||
vit_mae_config = ViTMAEConfig()
|
||||
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
|
||||
noise = np.random.uniform(size=(1, num_patches))
|
||||
noise = torch.from_numpy(np.random.uniform(size=(1, num_patches))).to(device=torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
**inputs, noise=torch.from_numpy(noise).to(device=torch_device), interpolate_pos_encoding=True
|
||||
)
|
||||
outputs = model(**inputs, noise=noise, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1200, 768))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding_custom_sizes(self):
|
||||
# Ensure custom sizes are correctly handled when interpolating the position embeddings
|
||||
|
||||
# make random mask reproducible across the PT and TF model
|
||||
np.random.seed(2)
|
||||
|
||||
model = self.default_model
|
||||
image_processor = self.default_image_processor
|
||||
|
||||
image = prepare_img()
|
||||
inputs = image_processor(images=image, return_tensors="pt", size={"height": 256, "width": 256}).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
**inputs,
|
||||
interpolate_pos_encoding=True,
|
||||
)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 256, 768))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
Loading…
Reference in New Issue
Block a user