From a29eabd0eb1bec9d8ca71048fdca5f7e0465e15d Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 13 Aug 2024 10:14:39 +0500 Subject: [PATCH] Expand inputs in processors for VLMs (#30962) * let it be * draft * should not have changed * add warnings * fix & add tests * fix tests * ipnuts embeds cannot be passed with pixels * more updates * paligemma ready! * minor typos * update blip-2 * fix tests & raise error * docstring * add blip2 test * tmp * add image seq length to config * update docstring * delete * fix tests * fix blip * fix paligemma * out-of-place scatter * add llava-next-video * Update src/transformers/models/blip_2/modeling_blip_2.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * remove tmp * codestyle * nits * more nits * remove overriding in tests * comprehension when merging video * fix-copies * revert changes for embeds test * fix tests after making comprehension * Update src/transformers/models/blip_2/processing_blip_2.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * Update src/transformers/models/blip_2/processing_blip_2.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * more updates * fix tests --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> --- .../models/blip_2/configuration_blip_2.py | 13 +- .../models/blip_2/modeling_blip_2.py | 55 +++- .../models/blip_2/processing_blip_2.py | 61 +++- .../configuration_instructblip.py | 13 +- .../instructblip/modeling_instructblip.py | 51 +++- .../instructblip/processing_instructblip.py | 58 +++- .../configuration_instructblipvideo.py | 13 +- .../diff_instructblipvideo.py | 50 +++- .../modeling_instructblipvideo.py | 49 +++- .../processing_instructblipvideo.py | 60 +++- .../models/llava/configuration_llava.py | 4 + .../models/llava/modeling_llava.py | 209 +++++++------- .../models/llava/processing_llava.py | 68 ++++- .../llava_next/configuration_llava_next.py | 4 + .../models/llava_next/modeling_llava_next.py | 272 +++++++++--------- .../llava_next/processing_llava_next.py | 110 ++++++- .../configuration_llava_next_video.py | 14 + .../llava_next_video/diff_llava_next_video.py | 184 ++++++------ .../modeling_llava_next_video.py | 180 ++++++------ .../processing_llava_next_video.py | 85 +++++- .../paligemma/configuration_paligemma.py | 18 +- .../models/paligemma/modeling_paligemma.py | 244 +++++++--------- .../video_llava/configuration_video_llava.py | 8 + .../video_llava/modeling_video_llava.py | 207 +++++++------ .../video_llava/processing_video_llava.py | 74 ++++- .../models/vipllava/configuration_vipllava.py | 4 + .../models/vipllava/modeling_vipllava.py | 144 +++++----- tests/models/blip_2/test_modeling_blip_2.py | 30 ++ .../test_modeling_instructblip.py | 32 +++ .../test_processor_instructblip.py | 2 +- .../test_modeling_instructblipvideo.py | 30 ++ tests/models/llava/test_modeling_llava.py | 73 +++++ .../llava_next/test_modeling_llava_next.py | 73 +++++ .../test_modeling_llava_next_video.py | 53 +++- .../paligemma/test_modeling_paligemma.py | 58 +++- .../video_llava/test_modeling_video_llava.py | 77 +++++ .../models/vipllava/test_modeling_vipllava.py | 73 +++++ 37 files changed, 1951 insertions(+), 802 deletions(-) diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index fbbe67764df..86380e89b6d 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -264,6 +264,8 @@ class Blip2Config(PretrainedConfig): num_query_tokens (`int`, *optional*, defaults to 32): The number of query tokens passed through the Transformer. + image_token_index (`int`, *optional*): + Token index of special image token. kwargs (*optional*): Dictionary of keyword arguments. @@ -299,7 +301,15 @@ class Blip2Config(PretrainedConfig): model_type = "blip-2" - def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs): + def __init__( + self, + vision_config=None, + qformer_config=None, + text_config=None, + num_query_tokens=32, + image_token_index=None, + **kwargs, + ): super().__init__(**kwargs) if vision_config is None: @@ -323,6 +333,7 @@ class Blip2Config(PretrainedConfig): self.is_encoder_decoder = self.text_config.is_encoder_decoder self.num_query_tokens = num_query_tokens + self.image_token_index = image_token_index self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES self.initializer_factor = 1.0 diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 7aad5bea66c..e89576c67ec 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1767,12 +1767,25 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel): language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) - if attention_mask is None: attention_mask = torch.ones_like(input_ids) - expected_device = language_model_attention_mask.device - attention_mask = torch.cat([language_model_attention_mask, attention_mask.to(expected_device)], dim=1) + + # if the model already has "image_token_index" then the input is expanded to account for image embeds + # otherwise we expand manually by concating + if getattr(self.config, "image_token_index", None) is not None: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) + else: + logger.warning_once( + "Expanding inputs for image tokens in BLIP-2 should be done in processing. " + "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) + attention_mask = torch.cat( + [language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1 + ) if self.config.use_decoder_only_language_model: outputs = self.language_model( @@ -1876,20 +1889,34 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel): .repeat(batch_size, 1) .to(image_embeds.device) ) + inputs_embeds = self.get_input_embeddings()(input_ids) if attention_mask is None: attention_mask = torch.ones_like(input_ids) - attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1) - # concatenate query embeddings with prompt embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) + # if the model already has "image_token_index" then the input is expanded to account for image embeds + # otherwise we expand manually by concatenating + if getattr(self.config, "image_token_index", None) is not None: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() + else: + logger.warning_once( + "Expanding inputs for image tokens in BLIP-2 should be done in processing. " + "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) + attention_mask = torch.cat( + [language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1 + ) - # add image_embeds length to max_length, so that the final max_length in counted only on token embeds - # -1 is to account for the prepended BOS after `generate.` - # TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs - if not self.language_model.config.is_encoder_decoder: - generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 - generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] + # add image_embeds length to max_length, so that the final max_length in counted only on token embeds + # -1 is to account for the prepended BOS after `generate.` + # TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs + if not self.language_model.config.is_encoder_decoder: + generate_kwargs["max_length"] = ( + generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 + ) + generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] outputs = self.language_model.generate( inputs_embeds=inputs_embeds, diff --git a/src/transformers/models/blip_2/processing_blip_2.py b/src/transformers/models/blip_2/processing_blip_2.py index 2d526a17ba6..e879b41eb15 100644 --- a/src/transformers/models/blip_2/processing_blip_2.py +++ b/src/transformers/models/blip_2/processing_blip_2.py @@ -20,8 +20,18 @@ from typing import List, Optional, Union from ...image_utils import ImageInput from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...tokenization_utils_base import ( + AddedToken, + BatchEncoding, + PaddingStrategy, + PreTokenizedInput, + TextInput, + TruncationStrategy, +) +from ...utils import TensorType, logging + + +logger = logging.get_logger(__name__) class Blip2Processor(ProcessorMixin): @@ -36,20 +46,24 @@ class Blip2Processor(ProcessorMixin): An instance of [`BlipImageProcessor`]. The image processor is a required input. tokenizer (`AutoTokenizer`): An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input. + num_query_tokens (`int`, *optional*): + Number of tokens used by the Qformer as queries, should be same as in model's config. """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = [] + valid_kwargs = ["num_query_tokens"] image_processor_class = "BlipImageProcessor" tokenizer_class = "AutoTokenizer" - # Copied from transformers.models.blip.processing_blip.BlipProcessor.__init__ - def __init__(self, image_processor, tokenizer, **kwargs): + def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs): tokenizer.return_token_type_ids = False - super().__init__(image_processor, tokenizer) - self.current_processor = self.image_processor + self.current_processor = image_processor + self.image_token = AddedToken("", normalized=False, special=True) + tokenizer.add_tokens([self.image_token], special_tokens=True) + self.num_query_tokens = num_query_tokens + + super().__init__(image_processor, tokenizer) - # Copied from transformers.models.blip.processing_blip.BlipProcessor.__call__ def __call__( self, images: ImageInput = None, @@ -106,7 +120,13 @@ class Blip2Processor(ProcessorMixin): encoding_image_processor = self.image_processor(images, return_tensors=return_tensors) if text is not None: - text_encoding = self.tokenizer( + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + text_encoding = {} + _text_encoding = self.tokenizer( text=text, add_special_tokens=add_special_tokens, padding=padding, @@ -121,9 +141,30 @@ class Blip2Processor(ProcessorMixin): return_token_type_ids=return_token_type_ids, return_length=return_length, verbose=verbose, - return_tensors=return_tensors, + return_tensors=None, # hardcode "None" here for prepending image tokens **kwargs, ) + + # if we know how many query tokens, expand text inside processor. We need this hacky manipulation + # because BLIP expects image tokens to be at the beginning even before BOS token + if self.num_query_tokens is not None: + image_tokens = self.image_token.content * self.num_query_tokens + image_token_encoding = self.tokenizer([image_tokens], add_special_tokens=False, return_tensors=None) + for k in _text_encoding: + text_encoding[k] = [ + img_encoding + txt_encoding + for img_encoding, txt_encoding in zip(image_token_encoding[k], _text_encoding[k]) + ] + else: + text_encoding = _text_encoding + logger.warning_once( + "Expanding inputs for image tokens in BLIP-2 should be done in processing. " + "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + + # cast to desired return tensors type + text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors) else: text_encoding = None diff --git a/src/transformers/models/instructblip/configuration_instructblip.py b/src/transformers/models/instructblip/configuration_instructblip.py index 77014e6f466..a274212a945 100644 --- a/src/transformers/models/instructblip/configuration_instructblip.py +++ b/src/transformers/models/instructblip/configuration_instructblip.py @@ -269,6 +269,8 @@ class InstructBlipConfig(PretrainedConfig): num_query_tokens (`int`, *optional*, defaults to 32): The number of query tokens passed through the Transformer. + image_token_index (`int`, *optional*): + Token index of special image token. kwargs (*optional*): Dictionary of keyword arguments. @@ -304,7 +306,15 @@ class InstructBlipConfig(PretrainedConfig): model_type = "instructblip" - def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs): + def __init__( + self, + vision_config=None, + qformer_config=None, + text_config=None, + num_query_tokens=32, + image_token_index=None, + **kwargs, + ): super().__init__(**kwargs) if vision_config is None: @@ -328,6 +338,7 @@ class InstructBlipConfig(PretrainedConfig): self.is_encoder_decoder = self.text_config.is_encoder_decoder self.num_query_tokens = num_query_tokens + self.image_token_index = image_token_index self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES self.initializer_factor = 1.0 diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 8ad47b308fd..f59f72a6699 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -1453,12 +1453,24 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel): ) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) - - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) - if attention_mask is None: attention_mask = torch.ones_like(input_ids) - attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1) + + # if the model already has "image_token_index" then the input is expanded to account for image embeds + # otherwise we expand manually by concatenating + if getattr(self.config, "image_token_index", None) is not None: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() + else: + logger.warning_once( + "Expanding inputs for image tokens in InstructBLIP should be done in processing. " + "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) + attention_mask = torch.cat( + [language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1 + ) if self.config.use_decoder_only_language_model: outputs = self.language_model( @@ -1580,17 +1592,32 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel): ) if attention_mask is None: attention_mask = torch.ones_like(input_ids) - attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1) - # concatenate query embeddings with prompt embeddings inputs_embeds = self.get_input_embeddings()(input_ids) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) - # add image_embeds length to max_length, so that the final max_length in counted only on token embeds - # -1 is to account for the prepended BOS after `generate.` - if not self.language_model.config.is_encoder_decoder: - generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 - generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] + # if the model already has "image_token_index" then the input is expanded to account for image embeds + # otherwise we expand manually by concatenating + if getattr(self.config, "image_token_index", None) is not None: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() + else: + logger.warning_once( + "Expanding inputs for image tokens in InstructBLIP should be done in processing. " + "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) + attention_mask = torch.cat( + [language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1 + ) + + # add image_embeds length to max_length, so that the final max_length in counted only on token embeds + # -1 is to account for the prepended BOS after `generate.` + if not self.language_model.config.is_encoder_decoder: + generate_kwargs["max_length"] = ( + generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 + ) + generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] outputs = self.language_model.generate( inputs_embeds=inputs_embeds, diff --git a/src/transformers/models/instructblip/processing_instructblip.py b/src/transformers/models/instructblip/processing_instructblip.py index adebd22178e..bb0351b6718 100644 --- a/src/transformers/models/instructblip/processing_instructblip.py +++ b/src/transformers/models/instructblip/processing_instructblip.py @@ -22,11 +22,21 @@ from typing import List, Optional, Union from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...tokenization_utils_base import ( + AddedToken, + BatchEncoding, + PaddingStrategy, + PreTokenizedInput, + TextInput, + TruncationStrategy, +) +from ...utils import TensorType, logging from ..auto import AutoTokenizer +logger = logging.get_logger(__name__) + + class InstructBlipProcessor(ProcessorMixin): r""" Constructs an InstructBLIP processor which wraps a BLIP image processor and a LLaMa/T5 tokenizer into a single @@ -42,18 +52,22 @@ class InstructBlipProcessor(ProcessorMixin): An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input. qformer_tokenizer (`AutoTokenizer`, *optional*): An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input. + num_query_tokens (`int`, *optional*):" + Number of tokens used by the Qformer as queries, should be same as in model's config. """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = [] + valid_kwargs = ["num_query_tokens"] image_processor_class = "BlipImageProcessor" tokenizer_class = "AutoTokenizer" - def __init__(self, image_processor, tokenizer, qformer_tokenizer=None, **kwargs): - super().__init__(image_processor, tokenizer) - + def __init__(self, image_processor, tokenizer, qformer_tokenizer=None, num_query_tokens=None, **kwargs): # add QFormer tokenizer self.qformer_tokenizer = qformer_tokenizer + self.image_token = AddedToken("", normalized=False, special=True) + tokenizer.add_tokens([self.image_token], special_tokens=True) + self.num_query_tokens = num_query_tokens + super().__init__(image_processor, tokenizer) def __call__( self, @@ -87,7 +101,12 @@ class InstructBlipProcessor(ProcessorMixin): encoding = BatchFeature() if text is not None: - text_encoding = self.tokenizer( + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + _text_encoding = self.tokenizer( text=text, add_special_tokens=add_special_tokens, padding=padding, @@ -102,9 +121,32 @@ class InstructBlipProcessor(ProcessorMixin): return_token_type_ids=return_token_type_ids, return_length=return_length, verbose=verbose, - return_tensors=return_tensors, + return_tensors=None, # needed to concatenate below **kwargs, ) + + # if we know how many query tokens, expand text inside processor. We need this hacky manipulation + # because BLIP expects image tokens to be at the beginning even before BOS token + if self.num_query_tokens is not None and images is not None: + text_encoding = {} + image_tokens = self.image_token.content * self.num_query_tokens + image_token_encoding = self.tokenizer([image_tokens], add_special_tokens=False, return_tensors=None) + for k in _text_encoding: + text_encoding[k] = [ + img_encoding + txt_encoding + for img_encoding, txt_encoding in zip(image_token_encoding[k], _text_encoding[k]) + ] + else: + text_encoding = _text_encoding + if images is not None: + logger.warning_once( + "Expanding inputs for image tokens in InstructBLIP should be done in processing. " + "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + + # cast to desired return tensors type after concatenating + text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors) encoding.update(text_encoding) qformer_text_encoding = self.qformer_tokenizer( text=text, diff --git a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py index 180372f35d1..051e8e21807 100644 --- a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py @@ -276,6 +276,8 @@ class InstructBlipVideoConfig(PretrainedConfig): num_query_tokens (`int`, *optional*, defaults to 32): The number of query tokens passed through the Transformer. + video_token_index (`int`, *optional*): + Token index of special video token. kwargs (*optional*): Dictionary of keyword arguments. @@ -311,7 +313,15 @@ class InstructBlipVideoConfig(PretrainedConfig): model_type = "instructblipvideo" - def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs): + def __init__( + self, + vision_config=None, + qformer_config=None, + text_config=None, + num_query_tokens=32, + video_token_index=None, + **kwargs, + ): super().__init__(**kwargs) if vision_config is None: @@ -335,6 +345,7 @@ class InstructBlipVideoConfig(PretrainedConfig): self.is_encoder_decoder = self.text_config.is_encoder_decoder self.num_query_tokens = num_query_tokens + self.video_token_index = video_token_index self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES self.initializer_factor = 1.0 diff --git a/src/transformers/models/instructblipvideo/diff_instructblipvideo.py b/src/transformers/models/instructblipvideo/diff_instructblipvideo.py index f400250d932..506da83c532 100644 --- a/src/transformers/models/instructblipvideo/diff_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/diff_instructblipvideo.py @@ -260,11 +260,24 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera ) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) - if attention_mask is None: attention_mask = torch.ones_like(input_ids) - attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1) + + # if the model already has "video_token_index" then the input is expanded to account for image embeds + # otherwise we expand manually by concatenating + if getattr(self.config, "video_token_index", None) is not None: + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() + else: + logger.warning_once( + "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " + "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) + attention_mask = torch.cat( + [language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1 + ) if self.config.use_decoder_only_language_model: outputs = self.language_model( @@ -394,17 +407,32 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera ) if attention_mask is None: attention_mask = torch.ones_like(input_ids) - attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1) - # concatenate query embeddings with prompt embeddings inputs_embeds = self.get_input_embeddings()(input_ids) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) - # add image_embeds length to max_length, so that the final max_length in counted only on token embeds - # -1 is to account for the prepended BOS after `generate.` - if not self.language_model.config.is_encoder_decoder: - generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 - generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] + # if the model already has "video_token_index" then the input is expanded to account for image embeds + # otherwise we expand manually by concatenating + if getattr(self.config, "video_token_index", None) is not None: + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() + else: + logger.warning_once( + "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " + "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) + attention_mask = torch.cat( + [language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1 + ) + + # add image_embeds length to max_length, so that the final max_length in counted only on token embeds + # -1 is to account for the prepended BOS after `generate.` + if not self.language_model.config.is_encoder_decoder: + generate_kwargs["max_length"] = ( + generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 + ) + generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] outputs = self.language_model.generate( inputs_embeds=inputs_embeds, diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index d3b594e9c3f..701402241d4 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -1495,11 +1495,25 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel ) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) if attention_mask is None: attention_mask = torch.ones_like(input_ids) - attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1) + + # if the model already has "video_token_index" then the input is expanded to account for image embeds + # otherwise we expand manually by concatenating + if getattr(self.config, "video_token_index", None) is not None: + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() + else: + logger.warning_once( + "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " + "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) + attention_mask = torch.cat( + [language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1 + ) if self.config.use_decoder_only_language_model: outputs = self.language_model( @@ -1629,17 +1643,32 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel ) if attention_mask is None: attention_mask = torch.ones_like(input_ids) - attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1) - # concatenate query embeddings with prompt embeddings inputs_embeds = self.get_input_embeddings()(input_ids) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) - # add image_embeds length to max_length, so that the final max_length in counted only on token embeds - # -1 is to account for the prepended BOS after `generate.` - if not self.language_model.config.is_encoder_decoder: - generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 - generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] + # if the model already has "video_token_index" then the input is expanded to account for image embeds + # otherwise we expand manually by concatenating + if getattr(self.config, "video_token_index", None) is not None: + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() + else: + logger.warning_once( + "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " + "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) + attention_mask = torch.cat( + [language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1 + ) + + # add image_embeds length to max_length, so that the final max_length in counted only on token embeds + # -1 is to account for the prepended BOS after `generate.` + if not self.language_model.config.is_encoder_decoder: + generate_kwargs["max_length"] = ( + generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 + ) + generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] outputs = self.language_model.generate( inputs_embeds=inputs_embeds, diff --git a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py index 8310b68d736..f56f8186b07 100644 --- a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py @@ -22,11 +22,21 @@ from typing import List, Optional, Union from ...image_processing_utils import BatchFeature from ...image_utils import VideoInput from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...tokenization_utils_base import ( + AddedToken, + BatchEncoding, + PaddingStrategy, + PreTokenizedInput, + TextInput, + TruncationStrategy, +) +from ...utils import TensorType, logging from ..auto import AutoTokenizer +logger = logging.get_logger(__name__) + + class InstructBlipVideoProcessor(ProcessorMixin): r""" Constructs an InstructBLIPVideo processor which wraps a InstructBLIP image processor and a LLaMa/T5 tokenizer into a single @@ -42,18 +52,22 @@ class InstructBlipVideoProcessor(ProcessorMixin): An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input. qformer_tokenizer (`AutoTokenizer`, *optional*): An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input. + num_query_tokens (`int`, *optional*): + Number of tokens used by the Qformer as queries, should be same as in model's config. """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = [] + valid_kwargs = ["num_query_tokens"] image_processor_class = "InstructBlipVideoImageProcessor" tokenizer_class = "AutoTokenizer" - def __init__(self, image_processor, tokenizer, qformer_tokenizer=None, **kwargs): - super().__init__(image_processor, tokenizer) - + def __init__(self, image_processor, tokenizer, qformer_tokenizer=None, num_query_tokens=None, **kwargs): # add QFormer tokenizer self.qformer_tokenizer = qformer_tokenizer + self.video_token = AddedToken("