IDEFICS: allow interpolation of vision's pos embeddings (#26029)

* add pos embed interpolation for vision encoder

* style

* update config with interpolate_pos_encoding arg

* fix imports formatting

* take off copied from on vision embeddings

* add test for image embeddings interpolation

* add credit for interpolation code

* Update src/transformers/models/idefics/configuration_idefics.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/idefics/vision.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix condition to check nbr image patches match shape of pos embeddings

* use kwargs in the forward methods for interpolation

* fix tests

* have interpolate_pos_encoding default to False instead of None

* Update tests/models/idefics/test_modeling_idefics.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/idefics/test_modeling_idefics.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/idefics/test_modeling_idefics.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/idefics/configuration_idefics.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* take off for loop meant to print k,v

* add interpolate_pos_encoding arg in prepare_inputs_for_generation

* add test for interpolated generation

* fix edge case num_patches == num_positions and height == width

* add test for edge case

* fix pos_embed in interpolate

* allow interpolation in bf16 with upcasting

* Update src/transformers/models/idefics/vision.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/idefics/vision.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* add multiple images tests for interpolation and generation

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Leo Tronchon 2023-09-14 19:27:40 -04:00 committed by GitHub
parent 5469c18762
commit 869733ab62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 163 additions and 29 deletions

View File

@ -236,6 +236,7 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None)
perceiver_embeddings = kwargs.get("perceiver_embeddings", None)
image_attention_mask = kwargs.get("image_attention_mask", None)
interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False)
return {
"input_ids": input_ids,
@ -248,6 +249,7 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
"image_encoder_embeddings": image_encoder_embeddings,
"perceiver_embeddings": perceiver_embeddings,
"image_attention_mask": image_attention_mask,
"interpolate_pos_encoding": interpolate_pos_encoding,
}
@ -1157,6 +1159,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, IdeficsBaseModelOutputWithPast]:
device = input_ids.device if input_ids is not None else inputs_embeds.device
@ -1212,7 +1215,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
image_hidden_states = self.vision_model(
pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
).last_hidden_state
elif image_encoder_embeddings is not None:
batch_size, num_images, image_seq_len, image_hidden_size = image_encoder_embeddings.size()
@ -1468,6 +1473,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, IdeficsCausalLMOutputWithPast]:
r"""
@ -1516,6 +1522,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

View File

@ -15,6 +15,7 @@
""" PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@ -24,10 +25,7 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...utils import (
ModelOutput,
logging,
)
from ...utils import ModelOutput, logging
from .configuration_idefics import IdeficsVisionConfig
@ -63,7 +61,7 @@ class IdeficsVisionModelOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Idefics
# Adapted from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings
class IdeficsVisionEmbeddings(nn.Module):
def __init__(self, config: IdeficsVisionConfig):
super().__init__()
@ -87,15 +85,79 @@ class IdeficsVisionEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
# Heavily inspired from https://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/models/vit/modeling_vit.py#L82
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
pos_embed = self.position_embedding(self.position_ids)
num_positions = pos_embed.shape[1] - 1
if num_patches == num_positions and height == width:
return pos_embed
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
embed_dim = embeddings.shape[-1]
num_h_patches = height // self.config.patch_size
num_w_patches = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1
sqrt_num_positions = math.sqrt(num_positions)
patch_pos_embed = patch_pos_embed.reshape(1, int(sqrt_num_positions), int(sqrt_num_positions), embed_dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
fp32_upcasting = patch_pos_embed.dtype == torch.bfloat16
if fp32_upcasting:
logger.warning_once(
"Upcasting patch_pos_embed to fp32 for interpolation since `upsample_bicubic2d_out_frame` in nn.functional.interpolate"
"is not implemented for 'torch.bfloat16' dtype. This will result in a slight overhead"
)
patch_pos_embed = patch_pos_embed.to(torch.float)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(num_h_patches / sqrt_num_positions, num_w_patches / sqrt_num_positions),
mode="bicubic",
align_corners=False,
)
if fp32_upcasting:
patch_pos_embed = patch_pos_embed.to(torch.bfloat16)
if int(num_h_patches) != patch_pos_embed.shape[-2] or int(num_w_patches) != patch_pos_embed.shape[-1]:
raise ValueError(
f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the "
f"shape of position embedding ({patch_pos_embed.shape[-2], patch_pos_embed.shape[-1]})"
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if not interpolate_pos_encoding:
if height != self.image_size or width != self.image_size:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size}*{self.image_size}). You should try to set `interpolate_pos_encoding=True`"
)
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
@ -387,12 +449,13 @@ class IdeficsVisionTransformer(nn.Module):
self.encoder = IdeficsVisionEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
# copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
# Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
@ -408,7 +471,7 @@ class IdeficsVisionTransformer(nn.Module):
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(

View File

@ -74,8 +74,6 @@ class IdeficsModelTester:
num_labels=3,
scope=None,
modality_type_vocab_size=2,
add_multiple_images=False,
num_images=-1,
vision_embed_dim=32,
vision_patch_size=2,
vision_image_size=30,
@ -113,8 +111,6 @@ class IdeficsModelTester:
self.num_labels = num_labels
self.scope = scope
self.modality_type_vocab_size = modality_type_vocab_size
self.add_multiple_images = add_multiple_images
self.num_images = num_images
self.vision_embed_dim = vision_embed_dim
self.vision_patch_size = vision_patch_size
@ -150,14 +146,17 @@ class IdeficsModelTester:
# this is equal to the seq length of the text tokens + number of image patches + 1 for the CLS token
self.expected_seq_len = self.seq_length + (self.image_size // self.patch_size) ** 2 + 1
def prepare_config_and_inputs(self):
self.seq_length = 42
def prepare_config_and_inputs(self, num_images=1, interpolate_pos_encoding=False, image_expansion=0):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
num_images = 2 if self.add_multiple_images else 1
pixel_values = floats_tensor(
[self.batch_size, num_images, self.num_channels, self.image_size, self.image_size]
[
self.batch_size,
num_images,
self.num_channels,
self.image_size + image_expansion,
self.image_size + image_expansion,
]
)
input_mask = None
if self.use_input_mask:
@ -166,8 +165,7 @@ class IdeficsModelTester:
image_attention_mask = random_attention_mask([self.batch_size, self.seq_length, num_images])
config = self.get_config()
return (config, input_ids, input_mask, pixel_values, image_attention_mask)
return (config, input_ids, input_mask, pixel_values, image_attention_mask, interpolate_pos_encoding)
def get_config(self):
return IdeficsConfig(
@ -188,7 +186,6 @@ class IdeficsModelTester:
initializer_range=self.initializer_range,
num_labels=self.num_labels,
modality_type_vocab_size=self.modality_type_vocab_size,
num_images=self.num_images,
vision_config=self.vision_config,
)
@ -199,17 +196,43 @@ class IdeficsModelTester:
input_mask,
pixel_values,
image_attention_mask,
interpolate_pos_encoding,
):
model = IdeficsModel(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids, attention_mask=input_mask, pixel_values=pixel_values, image_attention_mask=image_attention_mask
input_ids,
attention_mask=input_mask,
pixel_values=pixel_values,
image_attention_mask=image_attention_mask,
interpolate_pos_encoding=interpolate_pos_encoding,
)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, input_ids.shape[1], self.hidden_size)
)
def create_and_check_model_gen(
self,
config,
input_ids,
input_mask,
pixel_values,
image_attention_mask,
interpolate_pos_encoding,
):
model = IdeficsForVisionText2Text(config)
model.to(torch_device)
model.eval()
model.generate(
input_ids,
attention_mask=input_mask,
pixel_values=pixel_values,
image_attention_mask=image_attention_mask,
interpolate_pos_encoding=interpolate_pos_encoding,
max_length=self.seq_length + 2,
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@ -218,12 +241,14 @@ class IdeficsModelTester:
input_mask,
pixel_values,
image_attention_mask,
interpolate_pos_encoding,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": input_mask,
"pixel_values": pixel_values,
"image_attention_mask": image_attention_mask,
"interpolate_pos_encoding": interpolate_pos_encoding,
}
return config, inputs_dict
@ -268,10 +293,50 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
def test_model_single_image(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=1, interpolate_pos_encoding=False, image_expansion=0
)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_multiple_images(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=2, interpolate_pos_encoding=False, image_expansion=0
)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_image_pos_embeddings_interpolation_single_image(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=1, interpolate_pos_encoding=True, image_expansion=2
)
self.model_tester.create_and_check_model(*config_and_inputs)
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=1, interpolate_pos_encoding=True, image_expansion=0
)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_image_pos_embeddings_interpolation_multiple_images(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=2, interpolate_pos_encoding=True, image_expansion=2
)
self.model_tester.create_and_check_model(*config_and_inputs)
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=2, interpolate_pos_encoding=True, image_expansion=0
)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_generate_with_image_pos_embeddings_interpolation_single_image(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=1, interpolate_pos_encoding=True, image_expansion=2
)
self.model_tester.create_and_check_model_gen(*config_and_inputs)
def test_generate_with_image_pos_embeddings_interpolation_multiple_images(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=2, interpolate_pos_encoding=True, image_expansion=2
)
self.model_tester.create_and_check_model_gen(*config_and_inputs)
def test_training(self):
if not self.model_tester.is_training:
return
@ -289,8 +354,6 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
for k, v in inputs.items():
print(k, v.shape)
loss = model(**inputs).loss
loss.backward()
@ -416,7 +479,8 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, unittest.TestCase):
def setUp(self):
self.model_tester = IdeficsModelTester(
self, modality_type_vocab_size=3, add_multiple_images=True, num_images=2
self,
modality_type_vocab_size=3,
)
self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)