From 869733ab621495b938d0754176f7f1e360ae7ea9 Mon Sep 17 00:00:00 2001 From: Leo Tronchon Date: Thu, 14 Sep 2023 19:27:40 -0400 Subject: [PATCH] IDEFICS: allow interpolation of vision's pos embeddings (#26029) * add pos embed interpolation for vision encoder * style * update config with interpolate_pos_encoding arg * fix imports formatting * take off copied from on vision embeddings * add test for image embeddings interpolation * add credit for interpolation code * Update src/transformers/models/idefics/configuration_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/idefics/vision.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix condition to check nbr image patches match shape of pos embeddings * use kwargs in the forward methods for interpolation * fix tests * have interpolate_pos_encoding default to False instead of None * Update tests/models/idefics/test_modeling_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/idefics/test_modeling_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/idefics/test_modeling_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/idefics/configuration_idefics.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * take off for loop meant to print k,v * add interpolate_pos_encoding arg in prepare_inputs_for_generation * add test for interpolated generation * fix edge case num_patches == num_positions and height == width * add test for edge case * fix pos_embed in interpolate * allow interpolation in bf16 with upcasting * Update src/transformers/models/idefics/vision.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/idefics/vision.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add multiple images tests for interpolation and generation --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../models/idefics/modeling_idefics.py | 9 +- src/transformers/models/idefics/vision.py | 83 +++++++++++++-- tests/models/idefics/test_modeling_idefics.py | 100 ++++++++++++++---- 3 files changed, 163 insertions(+), 29 deletions(-) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index db5cbb75fe5..b52b7d5f93b 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -236,6 +236,7 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs): image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None) perceiver_embeddings = kwargs.get("perceiver_embeddings", None) image_attention_mask = kwargs.get("image_attention_mask", None) + interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False) return { "input_ids": input_ids, @@ -248,6 +249,7 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs): "image_encoder_embeddings": image_encoder_embeddings, "perceiver_embeddings": perceiver_embeddings, "image_attention_mask": image_attention_mask, + "interpolate_pos_encoding": interpolate_pos_encoding, } @@ -1157,6 +1159,7 @@ class IdeficsModel(IdeficsPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, IdeficsBaseModelOutputWithPast]: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -1212,7 +1215,9 @@ class IdeficsModel(IdeficsPreTrainedModel): pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:]) # Get sequence from the vision encoder - image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state + image_hidden_states = self.vision_model( + pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ).last_hidden_state elif image_encoder_embeddings is not None: batch_size, num_images, image_seq_len, image_hidden_size = image_encoder_embeddings.size() @@ -1468,6 +1473,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, IdeficsCausalLMOutputWithPast]: r""" @@ -1516,6 +1522,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index 614de18c1d8..8b7a14c56a2 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -15,6 +15,7 @@ """ PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object""" +import math from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -24,10 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from ...utils import ( - ModelOutput, - logging, -) +from ...utils import ModelOutput, logging from .configuration_idefics import IdeficsVisionConfig @@ -63,7 +61,7 @@ class IdeficsVisionModelOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None -# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Idefics +# Adapted from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings class IdeficsVisionEmbeddings(nn.Module): def __init__(self, config: IdeficsVisionConfig): super().__init__() @@ -87,15 +85,79 @@ class IdeficsVisionEmbeddings(nn.Module): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + # Heavily inspired from https://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/models/vit/modeling_vit.py#L82 + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + pos_embed = self.position_embedding(self.position_ids) + num_positions = pos_embed.shape[1] - 1 + if num_patches == num_positions and height == width: + return pos_embed + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + + embed_dim = embeddings.shape[-1] + num_h_patches = height // self.config.patch_size + num_w_patches = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1 + sqrt_num_positions = math.sqrt(num_positions) + patch_pos_embed = patch_pos_embed.reshape(1, int(sqrt_num_positions), int(sqrt_num_positions), embed_dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + fp32_upcasting = patch_pos_embed.dtype == torch.bfloat16 + if fp32_upcasting: + logger.warning_once( + "Upcasting patch_pos_embed to fp32 for interpolation since `upsample_bicubic2d_out_frame` in nn.functional.interpolate" + "is not implemented for 'torch.bfloat16' dtype. This will result in a slight overhead" + ) + patch_pos_embed = patch_pos_embed.to(torch.float) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(num_h_patches / sqrt_num_positions, num_w_patches / sqrt_num_positions), + mode="bicubic", + align_corners=False, + ) + if fp32_upcasting: + patch_pos_embed = patch_pos_embed.to(torch.bfloat16) + if int(num_h_patches) != patch_pos_embed.shape[-2] or int(num_w_patches) != patch_pos_embed.shape[-1]: + raise ValueError( + f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the " + f"shape of position embedding ({patch_pos_embed.shape[-2], patch_pos_embed.shape[-1]})" + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if not interpolate_pos_encoding: + if height != self.image_size or width != self.image_size: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size}*{self.image_size}). You should try to set `interpolate_pos_encoding=True`" + ) + target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings @@ -387,12 +449,13 @@ class IdeficsVisionTransformer(nn.Module): self.encoder = IdeficsVisionEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - # copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward + # Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -408,7 +471,7 @@ class IdeficsVisionTransformer(nn.Module): if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 64bda6a037e..c6df84b11fc 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -74,8 +74,6 @@ class IdeficsModelTester: num_labels=3, scope=None, modality_type_vocab_size=2, - add_multiple_images=False, - num_images=-1, vision_embed_dim=32, vision_patch_size=2, vision_image_size=30, @@ -113,8 +111,6 @@ class IdeficsModelTester: self.num_labels = num_labels self.scope = scope self.modality_type_vocab_size = modality_type_vocab_size - self.add_multiple_images = add_multiple_images - self.num_images = num_images self.vision_embed_dim = vision_embed_dim self.vision_patch_size = vision_patch_size @@ -150,14 +146,17 @@ class IdeficsModelTester: # this is equal to the seq length of the text tokens + number of image patches + 1 for the CLS token self.expected_seq_len = self.seq_length + (self.image_size // self.patch_size) ** 2 + 1 - def prepare_config_and_inputs(self): - self.seq_length = 42 - + def prepare_config_and_inputs(self, num_images=1, interpolate_pos_encoding=False, image_expansion=0): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - num_images = 2 if self.add_multiple_images else 1 pixel_values = floats_tensor( - [self.batch_size, num_images, self.num_channels, self.image_size, self.image_size] + [ + self.batch_size, + num_images, + self.num_channels, + self.image_size + image_expansion, + self.image_size + image_expansion, + ] ) input_mask = None if self.use_input_mask: @@ -166,8 +165,7 @@ class IdeficsModelTester: image_attention_mask = random_attention_mask([self.batch_size, self.seq_length, num_images]) config = self.get_config() - - return (config, input_ids, input_mask, pixel_values, image_attention_mask) + return (config, input_ids, input_mask, pixel_values, image_attention_mask, interpolate_pos_encoding) def get_config(self): return IdeficsConfig( @@ -188,7 +186,6 @@ class IdeficsModelTester: initializer_range=self.initializer_range, num_labels=self.num_labels, modality_type_vocab_size=self.modality_type_vocab_size, - num_images=self.num_images, vision_config=self.vision_config, ) @@ -199,17 +196,43 @@ class IdeficsModelTester: input_mask, pixel_values, image_attention_mask, + interpolate_pos_encoding, ): model = IdeficsModel(config=config) model.to(torch_device) model.eval() result = model( - input_ids, attention_mask=input_mask, pixel_values=pixel_values, image_attention_mask=image_attention_mask + input_ids, + attention_mask=input_mask, + pixel_values=pixel_values, + image_attention_mask=image_attention_mask, + interpolate_pos_encoding=interpolate_pos_encoding, ) self.parent.assertEqual( result.last_hidden_state.shape, (self.batch_size, input_ids.shape[1], self.hidden_size) ) + def create_and_check_model_gen( + self, + config, + input_ids, + input_mask, + pixel_values, + image_attention_mask, + interpolate_pos_encoding, + ): + model = IdeficsForVisionText2Text(config) + model.to(torch_device) + model.eval() + model.generate( + input_ids, + attention_mask=input_mask, + pixel_values=pixel_values, + image_attention_mask=image_attention_mask, + interpolate_pos_encoding=interpolate_pos_encoding, + max_length=self.seq_length + 2, + ) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -218,12 +241,14 @@ class IdeficsModelTester: input_mask, pixel_values, image_attention_mask, + interpolate_pos_encoding, ) = config_and_inputs inputs_dict = { "input_ids": input_ids, "attention_mask": input_mask, "pixel_values": pixel_values, "image_attention_mask": image_attention_mask, + "interpolate_pos_encoding": interpolate_pos_encoding, } return config, inputs_dict @@ -268,10 +293,50 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) def test_config(self): self.config_tester.run_common_tests() - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() + def test_model_single_image(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=1, interpolate_pos_encoding=False, image_expansion=0 + ) self.model_tester.create_and_check_model(*config_and_inputs) + def test_model_multiple_images(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=2, interpolate_pos_encoding=False, image_expansion=0 + ) + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_with_image_pos_embeddings_interpolation_single_image(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=1, interpolate_pos_encoding=True, image_expansion=2 + ) + self.model_tester.create_and_check_model(*config_and_inputs) + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=1, interpolate_pos_encoding=True, image_expansion=0 + ) + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_with_image_pos_embeddings_interpolation_multiple_images(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=2, interpolate_pos_encoding=True, image_expansion=2 + ) + self.model_tester.create_and_check_model(*config_and_inputs) + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=2, interpolate_pos_encoding=True, image_expansion=0 + ) + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_generate_with_image_pos_embeddings_interpolation_single_image(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=1, interpolate_pos_encoding=True, image_expansion=2 + ) + self.model_tester.create_and_check_model_gen(*config_and_inputs) + + def test_generate_with_image_pos_embeddings_interpolation_multiple_images(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs( + num_images=2, interpolate_pos_encoding=True, image_expansion=2 + ) + self.model_tester.create_and_check_model_gen(*config_and_inputs) + def test_training(self): if not self.model_tester.is_training: return @@ -289,8 +354,6 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) model.to(torch_device) model.train() inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - for k, v in inputs.items(): - print(k, v.shape) loss = model(**inputs).loss loss.backward() @@ -416,7 +479,8 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, unittest.TestCase): def setUp(self): self.model_tester = IdeficsModelTester( - self, modality_type_vocab_size=3, add_multiple_images=True, num_images=2 + self, + modality_type_vocab_size=3, ) self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)