mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 14:58:56 +06:00
Expand inputs in processors for VLMs (#30962)
* let it be * draft * should not have changed * add warnings * fix & add tests * fix tests * ipnuts embeds cannot be passed with pixels * more updates * paligemma ready! * minor typos * update blip-2 * fix tests & raise error * docstring * add blip2 test * tmp * add image seq length to config * update docstring * delete * fix tests * fix blip * fix paligemma * out-of-place scatter * add llava-next-video * Update src/transformers/models/blip_2/modeling_blip_2.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * remove tmp * codestyle * nits * more nits * remove overriding in tests * comprehension when merging video * fix-copies * revert changes for embeds test * fix tests after making comprehension * Update src/transformers/models/blip_2/processing_blip_2.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * Update src/transformers/models/blip_2/processing_blip_2.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * more updates * fix tests --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
parent
2a5a6ad18a
commit
a29eabd0eb
@ -264,6 +264,8 @@ class Blip2Config(PretrainedConfig):
|
||||
num_query_tokens (`int`, *optional*, defaults to 32):
|
||||
The number of query tokens passed through the Transformer.
|
||||
|
||||
image_token_index (`int`, *optional*):
|
||||
Token index of special image token.
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments.
|
||||
|
||||
@ -299,7 +301,15 @@ class Blip2Config(PretrainedConfig):
|
||||
|
||||
model_type = "blip-2"
|
||||
|
||||
def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
qformer_config=None,
|
||||
text_config=None,
|
||||
num_query_tokens=32,
|
||||
image_token_index=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
@ -323,6 +333,7 @@ class Blip2Config(PretrainedConfig):
|
||||
self.is_encoder_decoder = self.text_config.is_encoder_decoder
|
||||
|
||||
self.num_query_tokens = num_query_tokens
|
||||
self.image_token_index = image_token_index
|
||||
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
||||
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
self.initializer_factor = 1.0
|
||||
|
@ -1767,12 +1767,25 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
expected_device = language_model_attention_mask.device
|
||||
attention_mask = torch.cat([language_model_attention_mask, attention_mask.to(expected_device)], dim=1)
|
||||
|
||||
# if the model already has "image_token_index" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concating
|
||||
if getattr(self.config, "image_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
|
||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
attention_mask = torch.cat(
|
||||
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
|
||||
)
|
||||
|
||||
if self.config.use_decoder_only_language_model:
|
||||
outputs = self.language_model(
|
||||
@ -1876,20 +1889,34 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
|
||||
.repeat(batch_size, 1)
|
||||
.to(image_embeds.device)
|
||||
)
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)
|
||||
|
||||
# concatenate query embeddings with prompt embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
# if the model already has "image_token_index" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
|
||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
attention_mask = torch.cat(
|
||||
[language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
|
||||
)
|
||||
|
||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
||||
# -1 is to account for the prepended BOS after `generate.`
|
||||
# TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs
|
||||
if not self.language_model.config.is_encoder_decoder:
|
||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
||||
# -1 is to account for the prepended BOS after `generate.`
|
||||
# TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs
|
||||
if not self.language_model.config.is_encoder_decoder:
|
||||
generate_kwargs["max_length"] = (
|
||||
generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
||||
)
|
||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -20,8 +20,18 @@ from typing import List, Optional, Union
|
||||
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType
|
||||
from ...tokenization_utils_base import (
|
||||
AddedToken,
|
||||
BatchEncoding,
|
||||
PaddingStrategy,
|
||||
PreTokenizedInput,
|
||||
TextInput,
|
||||
TruncationStrategy,
|
||||
)
|
||||
from ...utils import TensorType, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Blip2Processor(ProcessorMixin):
|
||||
@ -36,20 +46,24 @@ class Blip2Processor(ProcessorMixin):
|
||||
An instance of [`BlipImageProcessor`]. The image processor is a required input.
|
||||
tokenizer (`AutoTokenizer`):
|
||||
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
|
||||
num_query_tokens (`int`, *optional*):
|
||||
Number of tokens used by the Qformer as queries, should be same as in model's config.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = []
|
||||
valid_kwargs = ["num_query_tokens"]
|
||||
image_processor_class = "BlipImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
# Copied from transformers.models.blip.processing_blip.BlipProcessor.__init__
|
||||
def __init__(self, image_processor, tokenizer, **kwargs):
|
||||
def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
|
||||
tokenizer.return_token_type_ids = False
|
||||
super().__init__(image_processor, tokenizer)
|
||||
self.current_processor = self.image_processor
|
||||
self.current_processor = image_processor
|
||||
self.image_token = AddedToken("<image>", normalized=False, special=True)
|
||||
tokenizer.add_tokens([self.image_token], special_tokens=True)
|
||||
self.num_query_tokens = num_query_tokens
|
||||
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
# Copied from transformers.models.blip.processing_blip.BlipProcessor.__call__
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
@ -106,7 +120,13 @@ class Blip2Processor(ProcessorMixin):
|
||||
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
|
||||
|
||||
if text is not None:
|
||||
text_encoding = self.tokenizer(
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
text_encoding = {}
|
||||
_text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
@ -121,9 +141,30 @@ class Blip2Processor(ProcessorMixin):
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
return_tensors=None, # hardcode "None" here for prepending image tokens
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
|
||||
# because BLIP expects image tokens to be at the beginning even before BOS token
|
||||
if self.num_query_tokens is not None:
|
||||
image_tokens = self.image_token.content * self.num_query_tokens
|
||||
image_token_encoding = self.tokenizer([image_tokens], add_special_tokens=False, return_tensors=None)
|
||||
for k in _text_encoding:
|
||||
text_encoding[k] = [
|
||||
img_encoding + txt_encoding
|
||||
for img_encoding, txt_encoding in zip(image_token_encoding[k], _text_encoding[k])
|
||||
]
|
||||
else:
|
||||
text_encoding = _text_encoding
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
|
||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
|
||||
# cast to desired return tensors type
|
||||
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
|
||||
else:
|
||||
text_encoding = None
|
||||
|
||||
|
@ -269,6 +269,8 @@ class InstructBlipConfig(PretrainedConfig):
|
||||
num_query_tokens (`int`, *optional*, defaults to 32):
|
||||
The number of query tokens passed through the Transformer.
|
||||
|
||||
image_token_index (`int`, *optional*):
|
||||
Token index of special image token.
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments.
|
||||
|
||||
@ -304,7 +306,15 @@ class InstructBlipConfig(PretrainedConfig):
|
||||
|
||||
model_type = "instructblip"
|
||||
|
||||
def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
qformer_config=None,
|
||||
text_config=None,
|
||||
num_query_tokens=32,
|
||||
image_token_index=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
@ -328,6 +338,7 @@ class InstructBlipConfig(PretrainedConfig):
|
||||
self.is_encoder_decoder = self.text_config.is_encoder_decoder
|
||||
|
||||
self.num_query_tokens = num_query_tokens
|
||||
self.image_token_index = image_token_index
|
||||
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
||||
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
self.initializer_factor = 1.0
|
||||
|
@ -1453,12 +1453,24 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
|
||||
)
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1)
|
||||
|
||||
# if the model already has "image_token_index" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
|
||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
attention_mask = torch.cat(
|
||||
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
|
||||
)
|
||||
|
||||
if self.config.use_decoder_only_language_model:
|
||||
outputs = self.language_model(
|
||||
@ -1580,17 +1592,32 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
|
||||
)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)
|
||||
|
||||
# concatenate query embeddings with prompt embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
|
||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
||||
# -1 is to account for the prepended BOS after `generate.`
|
||||
if not self.language_model.config.is_encoder_decoder:
|
||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
||||
# if the model already has "image_token_index" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
|
||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
attention_mask = torch.cat(
|
||||
[language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
|
||||
)
|
||||
|
||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
||||
# -1 is to account for the prepended BOS after `generate.`
|
||||
if not self.language_model.config.is_encoder_decoder:
|
||||
generate_kwargs["max_length"] = (
|
||||
generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
||||
)
|
||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -22,11 +22,21 @@ from typing import List, Optional, Union
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType
|
||||
from ...tokenization_utils_base import (
|
||||
AddedToken,
|
||||
BatchEncoding,
|
||||
PaddingStrategy,
|
||||
PreTokenizedInput,
|
||||
TextInput,
|
||||
TruncationStrategy,
|
||||
)
|
||||
from ...utils import TensorType, logging
|
||||
from ..auto import AutoTokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class InstructBlipProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs an InstructBLIP processor which wraps a BLIP image processor and a LLaMa/T5 tokenizer into a single
|
||||
@ -42,18 +52,22 @@ class InstructBlipProcessor(ProcessorMixin):
|
||||
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
|
||||
qformer_tokenizer (`AutoTokenizer`, *optional*):
|
||||
An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input.
|
||||
num_query_tokens (`int`, *optional*):"
|
||||
Number of tokens used by the Qformer as queries, should be same as in model's config.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = []
|
||||
valid_kwargs = ["num_query_tokens"]
|
||||
image_processor_class = "BlipImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor, tokenizer, qformer_tokenizer=None, **kwargs):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
def __init__(self, image_processor, tokenizer, qformer_tokenizer=None, num_query_tokens=None, **kwargs):
|
||||
# add QFormer tokenizer
|
||||
self.qformer_tokenizer = qformer_tokenizer
|
||||
self.image_token = AddedToken("<image>", normalized=False, special=True)
|
||||
tokenizer.add_tokens([self.image_token], special_tokens=True)
|
||||
self.num_query_tokens = num_query_tokens
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -87,7 +101,12 @@ class InstructBlipProcessor(ProcessorMixin):
|
||||
encoding = BatchFeature()
|
||||
|
||||
if text is not None:
|
||||
text_encoding = self.tokenizer(
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
_text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
@ -102,9 +121,32 @@ class InstructBlipProcessor(ProcessorMixin):
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
return_tensors=None, # needed to concatenate below
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
|
||||
# because BLIP expects image tokens to be at the beginning even before BOS token
|
||||
if self.num_query_tokens is not None and images is not None:
|
||||
text_encoding = {}
|
||||
image_tokens = self.image_token.content * self.num_query_tokens
|
||||
image_token_encoding = self.tokenizer([image_tokens], add_special_tokens=False, return_tensors=None)
|
||||
for k in _text_encoding:
|
||||
text_encoding[k] = [
|
||||
img_encoding + txt_encoding
|
||||
for img_encoding, txt_encoding in zip(image_token_encoding[k], _text_encoding[k])
|
||||
]
|
||||
else:
|
||||
text_encoding = _text_encoding
|
||||
if images is not None:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
|
||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
|
||||
# cast to desired return tensors type after concatenating
|
||||
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
|
||||
encoding.update(text_encoding)
|
||||
qformer_text_encoding = self.qformer_tokenizer(
|
||||
text=text,
|
||||
|
@ -276,6 +276,8 @@ class InstructBlipVideoConfig(PretrainedConfig):
|
||||
num_query_tokens (`int`, *optional*, defaults to 32):
|
||||
The number of query tokens passed through the Transformer.
|
||||
|
||||
video_token_index (`int`, *optional*):
|
||||
Token index of special video token.
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments.
|
||||
|
||||
@ -311,7 +313,15 @@ class InstructBlipVideoConfig(PretrainedConfig):
|
||||
|
||||
model_type = "instructblipvideo"
|
||||
|
||||
def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
qformer_config=None,
|
||||
text_config=None,
|
||||
num_query_tokens=32,
|
||||
video_token_index=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
@ -335,6 +345,7 @@ class InstructBlipVideoConfig(PretrainedConfig):
|
||||
self.is_encoder_decoder = self.text_config.is_encoder_decoder
|
||||
|
||||
self.num_query_tokens = num_query_tokens
|
||||
self.video_token_index = video_token_index
|
||||
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
||||
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
self.initializer_factor = 1.0
|
||||
|
@ -260,11 +260,24 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
)
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1)
|
||||
|
||||
# if the model already has "video_token_index" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
attention_mask = torch.cat(
|
||||
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
|
||||
)
|
||||
|
||||
if self.config.use_decoder_only_language_model:
|
||||
outputs = self.language_model(
|
||||
@ -394,17 +407,32 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)
|
||||
|
||||
# concatenate query embeddings with prompt embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
|
||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
||||
# -1 is to account for the prepended BOS after `generate.`
|
||||
if not self.language_model.config.is_encoder_decoder:
|
||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
||||
# if the model already has "video_token_index" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
attention_mask = torch.cat(
|
||||
[language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
|
||||
)
|
||||
|
||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
||||
# -1 is to account for the prepended BOS after `generate.`
|
||||
if not self.language_model.config.is_encoder_decoder:
|
||||
generate_kwargs["max_length"] = (
|
||||
generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
||||
)
|
||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -1495,11 +1495,25 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
)
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1)
|
||||
|
||||
# if the model already has "video_token_index" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
attention_mask = torch.cat(
|
||||
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
|
||||
)
|
||||
|
||||
if self.config.use_decoder_only_language_model:
|
||||
outputs = self.language_model(
|
||||
@ -1629,17 +1643,32 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)
|
||||
|
||||
# concatenate query embeddings with prompt embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
|
||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
||||
# -1 is to account for the prepended BOS after `generate.`
|
||||
if not self.language_model.config.is_encoder_decoder:
|
||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
||||
# if the model already has "video_token_index" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
attention_mask = torch.cat(
|
||||
[language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
|
||||
)
|
||||
|
||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
||||
# -1 is to account for the prepended BOS after `generate.`
|
||||
if not self.language_model.config.is_encoder_decoder:
|
||||
generate_kwargs["max_length"] = (
|
||||
generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
||||
)
|
||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -22,11 +22,21 @@ from typing import List, Optional, Union
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import VideoInput
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType
|
||||
from ...tokenization_utils_base import (
|
||||
AddedToken,
|
||||
BatchEncoding,
|
||||
PaddingStrategy,
|
||||
PreTokenizedInput,
|
||||
TextInput,
|
||||
TruncationStrategy,
|
||||
)
|
||||
from ...utils import TensorType, logging
|
||||
from ..auto import AutoTokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class InstructBlipVideoProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs an InstructBLIPVideo processor which wraps a InstructBLIP image processor and a LLaMa/T5 tokenizer into a single
|
||||
@ -42,18 +52,22 @@ class InstructBlipVideoProcessor(ProcessorMixin):
|
||||
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
|
||||
qformer_tokenizer (`AutoTokenizer`, *optional*):
|
||||
An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input.
|
||||
num_query_tokens (`int`, *optional*):
|
||||
Number of tokens used by the Qformer as queries, should be same as in model's config.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = []
|
||||
valid_kwargs = ["num_query_tokens"]
|
||||
image_processor_class = "InstructBlipVideoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor, tokenizer, qformer_tokenizer=None, **kwargs):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
def __init__(self, image_processor, tokenizer, qformer_tokenizer=None, num_query_tokens=None, **kwargs):
|
||||
# add QFormer tokenizer
|
||||
self.qformer_tokenizer = qformer_tokenizer
|
||||
self.video_token = AddedToken("<video>", normalized=False, special=True)
|
||||
tokenizer.add_tokens([self.video_token], special_tokens=True)
|
||||
self.num_query_tokens = num_query_tokens
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -84,7 +98,12 @@ class InstructBlipVideoProcessor(ProcessorMixin):
|
||||
encoding = BatchFeature()
|
||||
|
||||
if text is not None:
|
||||
text_encoding = self.tokenizer(
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
_text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
@ -99,9 +118,34 @@ class InstructBlipVideoProcessor(ProcessorMixin):
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
return_tensors=None, # required to concatenate below
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
|
||||
# because BLIP expects image tokens to be at the beginning even before BOS token
|
||||
if self.num_query_tokens is not None and images is not None:
|
||||
text_encoding = {}
|
||||
video_tokens = (
|
||||
self.video_token.content * self.num_query_tokens * 4
|
||||
) # InstrucBLIP works with 4 frames only
|
||||
video_token_encoding = self.tokenizer([video_tokens], add_special_tokens=False, return_tensors=None)
|
||||
for k in _text_encoding:
|
||||
text_encoding[k] = [
|
||||
img_encoding + txt_encoding
|
||||
for img_encoding, txt_encoding in zip(video_token_encoding[k], _text_encoding[k])
|
||||
]
|
||||
else:
|
||||
text_encoding = _text_encoding
|
||||
if images is not None:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
||||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
|
||||
# cast to desired return tensors type after concatenating
|
||||
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
|
||||
encoding.update(text_encoding)
|
||||
qformer_text_encoding = self.qformer_tokenizer(
|
||||
text=text,
|
||||
|
@ -48,6 +48,8 @@ class LlavaConfig(PretrainedConfig):
|
||||
Can be one of `"default"` or `"full"`.
|
||||
vision_feature_layer (`int`, *optional*, defaults to -2):
|
||||
The index of the layer to select the vision feature.
|
||||
image_seq_length (`int`, *optional*, defaults to 576):
|
||||
Sequence length of one image embedding.
|
||||
|
||||
Example:
|
||||
|
||||
@ -82,11 +84,13 @@ class LlavaConfig(PretrainedConfig):
|
||||
projector_hidden_act="gelu",
|
||||
vision_feature_select_strategy="default",
|
||||
vision_feature_layer=-2,
|
||||
image_seq_length=576,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.image_seq_length = image_seq_length
|
||||
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(
|
||||
|
@ -23,7 +23,6 @@ from torch import nn
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@ -230,6 +229,10 @@ LLAVA_INPUTS_DOCSTRING = r"""
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@ -373,6 +376,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -419,63 +423,90 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
legacy_processing = False
|
||||
if inputs_embeds is None:
|
||||
# 1. Extra the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# 2. Merge text and images
|
||||
if pixel_values is not None and input_ids.shape[1] != 1:
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
|
||||
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||
# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
|
||||
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
|
||||
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
|
||||
legacy_processing = (
|
||||
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
if pixel_values is not None:
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
|
||||
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in LLaVa should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
# prefill stage vs decoding stage (legacy behavior copied)
|
||||
if input_ids.shape[1] != 1:
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
|
||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||
# generation with cache
|
||||
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
@ -486,6 +517,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
@ -519,56 +551,37 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
**kwargs,
|
||||
):
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, Cache):
|
||||
cache_length = past_key_values.get_seq_length()
|
||||
past_length = past_key_values.seen_tokens
|
||||
else:
|
||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
elif self.config.image_token_index in input_ids:
|
||||
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
||||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
||||
# older attention values, as their corresponding values are not part of the input.
|
||||
if cache_length < past_length and attention_mask is not None:
|
||||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
# Trigger the new behavior if we have more than image embeddings seq length tokens for images
|
||||
legacy_processing = (
|
||||
input_ids is not None
|
||||
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||
)
|
||||
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if legacy_processing:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
elif cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
||||
return model_inputs
|
||||
|
||||
def _reorder_cache(self, *args, **kwargs):
|
||||
|
@ -19,10 +19,13 @@ Processor class for Llava.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput
|
||||
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType
|
||||
from ...utils import TensorType, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlavaProcessor(ProcessorMixin):
|
||||
@ -37,16 +40,35 @@ class LlavaProcessor(ProcessorMixin):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
patch_size (`int`, *optional*):
|
||||
Patch size from the vision tower.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Shoudl be same as in model's config
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token"]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
patch_size=None,
|
||||
vision_feature_select_strategy=None,
|
||||
chat_template=None,
|
||||
image_token="<image>", # set the default and let users change if they have peculiar special tokens in rare cases
|
||||
**kwargs,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = image_token
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
@ -107,10 +129,42 @@ class LlavaProcessor(ProcessorMixin):
|
||||
image_inputs = self.image_processor(images, return_tensors=return_tensors)
|
||||
else:
|
||||
image_inputs = {}
|
||||
text_inputs = self.tokenizer(
|
||||
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
|
||||
)
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
# try to expand inputs in processing if we have the necessary parts
|
||||
if image_inputs.get("pixel_values") is not None:
|
||||
if self.patch_size is not None and self.vision_feature_select_strategy is not None:
|
||||
# Replace the image token with the expanded image token sequence
|
||||
pixel_values = image_inputs["pixel_values"]
|
||||
height, width = get_image_size(to_numpy_array(pixel_values[0]))
|
||||
num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
|
||||
prompt_strings = []
|
||||
for sample in text:
|
||||
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
|
||||
prompt_strings.append(sample)
|
||||
else:
|
||||
prompt_strings = text
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in LLaVa should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt_strings,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
)
|
||||
return BatchFeature(data={**text_inputs, **image_inputs})
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||
|
@ -53,6 +53,8 @@ class LlavaNextConfig(PretrainedConfig):
|
||||
of the form `(height, width)`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
image_seq_length (`int`, *optional*, defaults to 576):
|
||||
Sequence length of one image embedding.
|
||||
|
||||
Example:
|
||||
|
||||
@ -89,11 +91,13 @@ class LlavaNextConfig(PretrainedConfig):
|
||||
vision_feature_layer=-2,
|
||||
image_grid_pinpoints=None,
|
||||
tie_word_embeddings=False,
|
||||
image_seq_length=576,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.image_seq_length = image_seq_length
|
||||
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(
|
||||
|
@ -25,7 +25,6 @@ from torch import nn
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...image_processing_utils import select_best_resolution
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...utils import (
|
||||
@ -336,6 +335,10 @@ LLAVA_NEXT_INPUTS_DOCSTRING = r"""
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@ -708,6 +711,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -754,104 +758,118 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
legacy_processing = False
|
||||
if inputs_embeds is None:
|
||||
# 1. Extract the input embeddings
|
||||
# In case image_token_index is not in the embeddings (extra token but embedding don't have it)
|
||||
for_inputs_embeds_ids = input_ids.clone()
|
||||
for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
|
||||
inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# 2. Merge text and images
|
||||
if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in image_sizes
|
||||
# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
|
||||
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
|
||||
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
|
||||
legacy_processing = (
|
||||
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
||||
|
||||
if pixel_values is not None and pixel_values.size(0) > 0:
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in image_sizes
|
||||
]
|
||||
# figure out if pixel_values is concatenated or stacked
|
||||
if pixel_values.dim() == 5:
|
||||
# stacking when input is (batch_size, num_patches, num_channels, height, width)
|
||||
_pixel_values_list = [
|
||||
pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
|
||||
]
|
||||
# figure out if pixel_values is concatenated or stacked
|
||||
if pixel_values.dim() == 5:
|
||||
# stacking when input is (batch_size, num_patches, num_channels, height, width)
|
||||
_pixel_values_list = [
|
||||
pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
|
||||
]
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
elif pixel_values.dim() != 4:
|
||||
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
||||
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
elif pixel_values.dim() != 4:
|
||||
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
||||
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
||||
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
image_newline=self.image_newline,
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
if input_ids.shape[1] != 1:
|
||||
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
||||
inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
|
||||
image_features,
|
||||
feature_lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
||||
inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
|
||||
image_features,
|
||||
feature_lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
|
||||
# pixel_values is not None but is empty ---> text only cases
|
||||
elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0:
|
||||
# there are no images
|
||||
pass
|
||||
|
||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||
# generation with cache
|
||||
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
@ -862,6 +880,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
@ -902,57 +921,32 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
pixel_values=None,
|
||||
image_sizes=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
**kwargs,
|
||||
):
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, Cache):
|
||||
cache_length = past_key_values.get_seq_length()
|
||||
past_length = past_key_values.seen_tokens
|
||||
else:
|
||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
elif self.config.image_token_index in input_ids:
|
||||
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
||||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
||||
# older attention values, as their corresponding values are not part of the input.
|
||||
if cache_length < past_length and attention_mask is not None:
|
||||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"image_sizes": image_sizes,
|
||||
}
|
||||
legacy_processing = (
|
||||
input_ids is not None
|
||||
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||
)
|
||||
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if legacy_processing:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
model_inputs["image_sizes"] = image_sizes
|
||||
elif cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
model_inputs["image_sizes"] = image_sizes
|
||||
|
||||
return model_inputs
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._reorder_cache
|
||||
|
@ -19,10 +19,14 @@ Processor class for LLaVa-NeXT.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput
|
||||
from ...image_processing_utils import select_best_resolution
|
||||
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType
|
||||
from ...utils import TensorType, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlavaNextProcessor(ProcessorMixin):
|
||||
@ -37,16 +41,35 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
patch_size (`int`, *optional*):
|
||||
Patch size from the vision tower.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Shoudl be same as in model's config
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token"]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
patch_size=None,
|
||||
vision_feature_select_strategy=None,
|
||||
chat_template=None,
|
||||
image_token="<image>", # set the default and let users change if they have peculiar special tokens in rare cases
|
||||
**kwargs,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = image_token
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
@ -111,12 +134,89 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
image_inputs = self.image_processor(images, do_pad=do_pad, return_tensors=return_tensors)
|
||||
else:
|
||||
image_inputs = {}
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
if self.patch_size is None or self.vision_feature_select_strategy is None:
|
||||
prompt_strings = text
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
# cannot infer image expansion length if no images are found
|
||||
elif not image_inputs:
|
||||
prompt_strings = text
|
||||
else:
|
||||
image_sizes = image_inputs["image_sizes"]
|
||||
height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
|
||||
prompt_strings = []
|
||||
for image_size, sample in zip(image_sizes, text):
|
||||
# Replace the image token with the expanded image token sequence
|
||||
orig_height, orig_width = image_size
|
||||
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
|
||||
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
|
||||
prompt_strings.append(sample)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
|
||||
prompt_strings,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_inputs})
|
||||
|
||||
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
|
||||
image_grid_pinpoints = self.image_processor.image_grid_pinpoints
|
||||
|
||||
height_best_resolution, width_best_resolution = select_best_resolution(
|
||||
[orig_height, orig_width], image_grid_pinpoints
|
||||
)
|
||||
scale_height, scale_width = height_best_resolution // height, width_best_resolution // width
|
||||
|
||||
patches_height = height // self.patch_size
|
||||
patches_width = width // self.patch_size
|
||||
unpadded_features, newline_features = self._get_unpadded_features(
|
||||
orig_height, orig_width, patches_height, patches_width, scale_height, scale_width
|
||||
)
|
||||
# The base patch covers the entire image (+1 for the CLS)
|
||||
base_features = patches_height * patches_width + 1
|
||||
num_image_tokens = unpadded_features + newline_features + base_features
|
||||
return num_image_tokens
|
||||
|
||||
def _get_unpadded_features(self, height, width, patches_height, patches_width, scale_height, scale_width):
|
||||
"""
|
||||
Get number of features for a given image with height/width. LLaVA-NeXT is different from LLaVA
|
||||
because it divided each image into patches depending on its resolution. Therefore we need to calculate how many
|
||||
patches an image is divided into and get the number of features from that.
|
||||
"""
|
||||
current_width = patches_height * scale_height
|
||||
current_height = patches_width * scale_width
|
||||
|
||||
original_aspect_ratio = width / height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
new_height = (height * current_width) // width
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height -= padding * 2
|
||||
else:
|
||||
new_width = (width * current_height) // height
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width -= padding * 2
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -22,9 +22,15 @@
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ...utils import (
|
||||
logging,
|
||||
)
|
||||
from ..auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlavaNextVideoConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`LlavaNextVideoForConditionalGeneration`]. It is used to instantiate an
|
||||
@ -62,6 +68,10 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
Pooling mode to use for videos. Can be "average", "max" or "conv".
|
||||
spatial_pool_stride (`int`, *optional*, defaults to 2):
|
||||
Stride used in the pooling layer for videos.
|
||||
image_seq_length (`int`, *optional*, defaults to 576):
|
||||
Sequence length of one image embedding.
|
||||
video_seq_length (`int`, *optional*, defaults to 288):
|
||||
Sequence length of one video embedding.
|
||||
|
||||
Example:
|
||||
|
||||
@ -99,11 +109,15 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
video_token_index=32000,
|
||||
spatial_pool_mode="average",
|
||||
spatial_pool_stride=2,
|
||||
image_seq_length=576,
|
||||
video_seq_length=288,
|
||||
**kwargs,
|
||||
):
|
||||
self.video_token_index = video_token_index
|
||||
self.spatial_pool_mode = spatial_pool_mode
|
||||
self.spatial_pool_stride = spatial_pool_stride
|
||||
self.image_seq_length = image_seq_length
|
||||
self.video_seq_length = video_seq_length
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
|
@ -64,6 +64,10 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
Pooling mode to use for videos. Can be "average", "max" or "conv".
|
||||
spatial_pool_stride (`int`, *optional*, defaults to 2):
|
||||
Stride used in the pooling layer for videos.
|
||||
image_seq_length (`int`, *optional*, defaults to 576):
|
||||
Sequence length of one image embedding.
|
||||
video_seq_length (`int`, *optional*, defaults to 288):
|
||||
Sequence length of one video embedding.
|
||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
The activation function used by the multimodal projector.
|
||||
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
||||
@ -114,11 +118,15 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
video_token_index=32000,
|
||||
spatial_pool_mode="average",
|
||||
spatial_pool_stride=2,
|
||||
image_seq_length=576,
|
||||
video_seq_length=288,
|
||||
**kwargs,
|
||||
):
|
||||
self.video_token_index = video_token_index
|
||||
self.spatial_pool_mode = spatial_pool_mode
|
||||
self.spatial_pool_stride = spatial_pool_stride
|
||||
self.image_seq_length = image_seq_length
|
||||
self.video_seq_length = video_seq_length
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
@ -375,90 +383,106 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
legacy_processing = False
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# Merge text and images in prefill stage
|
||||
if past_key_values is None:
|
||||
# First merge image tokens if there are any
|
||||
if pixel_values is not None and pixel_values.size(0) > 0:
|
||||
image_features = self._get_image_features(pixel_values, image_sizes)
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels,
|
||||
input_ids,
|
||||
) = self._merge_input_ids_with_image_features(
|
||||
image_features,
|
||||
feature_lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
image_token_index=self.config.image_token_index,
|
||||
)
|
||||
# Then merge video tokens if there are any
|
||||
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
|
||||
video_features = self._get_video_features(pixel_values_videos)
|
||||
video_features = [feature.flatten(0, 1) for feature in video_features]
|
||||
feature_lens = [feature.size(0) for feature in video_features]
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=video_features.device)
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels,
|
||||
input_ids,
|
||||
) = self._merge_input_ids_with_image_features(
|
||||
video_features,
|
||||
feature_lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
image_token_index=self.config.video_token_index,
|
||||
)
|
||||
|
||||
# pixel_values is not None but is empty ---> text only cases
|
||||
elif (pixel_values is not None and pixel_values.size(0) == 0) or (
|
||||
pixel_values_videos is not None and pixel_values_videos.size(0) == 0
|
||||
):
|
||||
pass
|
||||
|
||||
# generation with cache, decoding stage
|
||||
elif past_key_values is not None and (pixel_values is not None or pixel_values_videos is not None):
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
# if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing
|
||||
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
|
||||
img_token_count = (input_ids == self.config.image_token_index).sum(1).max()
|
||||
video_token_count = (input_ids == self.config.video_token_index).sum(1).max()
|
||||
inputs_expanded = (
|
||||
img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length
|
||||
)
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
pixels_present = input_ids.shape[-1] == 1 and pixel_values is not None and pixel_values_videos is not None
|
||||
legacy_processing = inputs_expanded or pixels_present
|
||||
|
||||
image_features = feature_lens = None
|
||||
if pixel_values is not None and pixel_values.size(0) > 0:
|
||||
image_features = self._get_image_features(pixel_values, image_sizes)
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
|
||||
video_features = video_feature_lens = None
|
||||
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
|
||||
video_features = self._get_video_features(pixel_values_videos)
|
||||
video_features = [feature.flatten(0, 1) for feature in video_features]
|
||||
video_feature_lens = [feature.size(0) for feature in video_features]
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
|
||||
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
if input_ids.shape[1] != 1:
|
||||
iterator = (
|
||||
(image_features, feature_lens, self.config.image_token_index),
|
||||
(video_features, video_feature_lens, self.config.video_token_index),
|
||||
)
|
||||
for features, lens, special_token in zip(iterator):
|
||||
if features is not None:
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels,
|
||||
input_ids,
|
||||
) = self._merge_input_ids_with_image_features(
|
||||
features,
|
||||
lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
image_token_index=special_token,
|
||||
)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
if image_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
if video_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
|
@ -376,6 +376,10 @@ LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r"""
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@ -849,90 +853,106 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
legacy_processing = False
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# Merge text and images in prefill stage
|
||||
if past_key_values is None:
|
||||
# First merge image tokens if there are any
|
||||
if pixel_values is not None and pixel_values.size(0) > 0:
|
||||
image_features = self._get_image_features(pixel_values, image_sizes)
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels,
|
||||
input_ids,
|
||||
) = self._merge_input_ids_with_image_features(
|
||||
image_features,
|
||||
feature_lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
image_token_index=self.config.image_token_index,
|
||||
)
|
||||
# Then merge video tokens if there are any
|
||||
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
|
||||
video_features = self._get_video_features(pixel_values_videos)
|
||||
video_features = [feature.flatten(0, 1) for feature in video_features]
|
||||
feature_lens = [feature.size(0) for feature in video_features]
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=video_features.device)
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels,
|
||||
input_ids,
|
||||
) = self._merge_input_ids_with_image_features(
|
||||
video_features,
|
||||
feature_lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
image_token_index=self.config.video_token_index,
|
||||
)
|
||||
|
||||
# pixel_values is not None but is empty ---> text only cases
|
||||
elif (pixel_values is not None and pixel_values.size(0) == 0) or (
|
||||
pixel_values_videos is not None and pixel_values_videos.size(0) == 0
|
||||
):
|
||||
pass
|
||||
|
||||
# generation with cache, decoding stage
|
||||
elif past_key_values is not None and (pixel_values is not None or pixel_values_videos is not None):
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
# if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing
|
||||
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
|
||||
img_token_count = (input_ids == self.config.image_token_index).sum(1).max()
|
||||
video_token_count = (input_ids == self.config.video_token_index).sum(1).max()
|
||||
inputs_expanded = (
|
||||
img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length
|
||||
)
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
pixels_present = input_ids.shape[-1] == 1 and pixel_values is not None and pixel_values_videos is not None
|
||||
legacy_processing = inputs_expanded or pixels_present
|
||||
|
||||
image_features = feature_lens = None
|
||||
if pixel_values is not None and pixel_values.size(0) > 0:
|
||||
image_features = self._get_image_features(pixel_values, image_sizes)
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
|
||||
video_features = video_feature_lens = None
|
||||
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
|
||||
video_features = self._get_video_features(pixel_values_videos)
|
||||
video_features = [feature.flatten(0, 1) for feature in video_features]
|
||||
video_feature_lens = [feature.size(0) for feature in video_features]
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
|
||||
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
if input_ids.shape[1] != 1:
|
||||
iterator = (
|
||||
(image_features, feature_lens, self.config.image_token_index),
|
||||
(video_features, video_feature_lens, self.config.video_token_index),
|
||||
)
|
||||
for features, lens, special_token in iterator:
|
||||
if features is not None:
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels,
|
||||
input_ids,
|
||||
) = self._merge_input_ids_with_image_features(
|
||||
features,
|
||||
lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
image_token_index=special_token,
|
||||
)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
if image_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
if video_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
|
@ -19,7 +19,7 @@ Processor class for LLaVa-NeXT-Video.
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, VideoInput
|
||||
from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType, logging
|
||||
@ -48,17 +48,41 @@ class LlavaNextVideoProcessor(ProcessorMixin):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*):
|
||||
Jinja chat template that will be used in tokenizer's `apply_chat_template`
|
||||
patch_size (`int`, *optional*):
|
||||
Patch size from the vision tower.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Shoudl be same as in model's config
|
||||
video_token (`str`, *optional*, defaults to `"<video>"`):
|
||||
Special token used to denote video location.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
"""
|
||||
|
||||
# video and image processor share same args, but have different processing logic
|
||||
# only image processor config is saved in the hub
|
||||
attributes = ["video_processor", "image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token", "video_token"]
|
||||
image_processor_class = "LlavaNextImageProcessor"
|
||||
video_processor_class = "LlavaNextVideoImageProcessor"
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
|
||||
def __init__(self, video_processor=None, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
video_processor=None,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
chat_template=None,
|
||||
patch_size=None,
|
||||
vision_feature_select_strategy=None,
|
||||
video_token="<video>",
|
||||
image_token="<image>",
|
||||
**kwargs,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = image_token
|
||||
self.video_token = video_token
|
||||
super().__init__(video_processor, image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
@ -131,9 +155,62 @@ class LlavaNextVideoProcessor(ProcessorMixin):
|
||||
else:
|
||||
videos_inputs = {}
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
print(self.patch_size, self.vision_feature_select_strategy, image_inputs, videos_inputs.keys())
|
||||
|
||||
if self.patch_size is None or self.vision_feature_select_strategy is None:
|
||||
prompt_strings = text
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image/video tokens in LLaVa-NeXT-Video should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
# cannot infer image expansion length if no images/videos are found
|
||||
elif not image_inputs and not videos_inputs:
|
||||
prompt_strings = text
|
||||
else:
|
||||
# images expand taking into account num_of_patches in each image
|
||||
if image_inputs:
|
||||
image_sizes = image_inputs["image_sizes"]
|
||||
height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
|
||||
prompt_strings = []
|
||||
for image_size, sample in zip(image_sizes, text):
|
||||
# Replace the image token with the expanded image token sequence
|
||||
orig_height, orig_width = image_size
|
||||
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
|
||||
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
|
||||
prompt_strings.append(sample)
|
||||
text = prompt_strings
|
||||
|
||||
# videos are easier, simply get frames and multiply
|
||||
if videos_inputs:
|
||||
one_video = to_numpy_array(videos_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(one_video[0])
|
||||
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||
num_image_tokens = (height // self.patch_size) * (width // self.patch_size)
|
||||
num_video_tokens = num_image_tokens // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||
|
||||
prompt_strings = []
|
||||
for sample in text:
|
||||
sample = sample.replace(self.video_token, self.video_token * num_video_tokens)
|
||||
prompt_strings.append(sample)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
|
||||
prompt_strings,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
)
|
||||
print(text_inputs.keys())
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
|
||||
|
||||
|
@ -86,7 +86,7 @@ class PaliGemmaConfig(PretrainedConfig):
|
||||
hidden_size=2048,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
self._ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self._vocab_size = vocab_size
|
||||
self.projection_dim = projection_dim
|
||||
@ -110,14 +110,11 @@ class PaliGemmaConfig(PretrainedConfig):
|
||||
vocab_size=257152,
|
||||
vision_use_head=False,
|
||||
)
|
||||
self.vocab_size = self.vocab_size
|
||||
|
||||
self.text_config = text_config
|
||||
|
||||
if isinstance(self.text_config, dict):
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma"
|
||||
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
self.vocab_size = self.text_config.vocab_size
|
||||
elif text_config is None:
|
||||
self.text_config = CONFIG_MAPPING["gemma"](
|
||||
hidden_size=2048,
|
||||
@ -132,6 +129,18 @@ class PaliGemmaConfig(PretrainedConfig):
|
||||
self.vision_config.projection_dim = projection_dim
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def ignore_index(self):
|
||||
warnings.warn(
|
||||
"The `ignore_index` attribute is deprecated and will be removed in v4.47.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._ignore_index
|
||||
|
||||
@ignore_index.setter
|
||||
def ignore_index(self, value):
|
||||
self._ignore_index = value
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
warnings.warn(
|
||||
@ -147,4 +156,5 @@ class PaliGemmaConfig(PretrainedConfig):
|
||||
def to_dict(self):
|
||||
output = super().to_dict()
|
||||
output.pop("_vocab_size", None)
|
||||
output.pop("_ignore_index", None)
|
||||
return output
|
||||
|
@ -21,7 +21,7 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...cache_utils import Cache, StaticCache
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
@ -126,6 +126,9 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["PaliGemmaMultiModalProjector"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
|
||||
@ -222,6 +225,10 @@ PALIGEMMA_INPUTS_DOCSTRING = r"""
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@ -285,77 +292,52 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
self.vocab_size = model_embeds.num_embeddings
|
||||
return model_embeds
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self, image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
|
||||
def _update_causal_mask(
|
||||
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
|
||||
):
|
||||
_, _, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
dtype, device = inputs_embeds.dtype, inputs_embeds.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
scaled_image_features = image_features / (self.config.hidden_size**0.5)
|
||||
final_embedding = torch.zeros(
|
||||
batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
)
|
||||
|
||||
text_mask = (input_ids != self.config.image_token_index) & (input_ids != self.pad_token_id)
|
||||
image_mask = input_ids == self.config.image_token_index
|
||||
pad_mask = input_ids == self.pad_token_id
|
||||
|
||||
# expand masks to match embedding dimension
|
||||
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device)
|
||||
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device)
|
||||
# insert padding and text token embeddings
|
||||
final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding)
|
||||
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
|
||||
# insert image embeddings - the image mask is always less or equal to the sentence in length
|
||||
final_embedding = final_embedding.masked_scatter(
|
||||
image_mask.unsqueeze(-1).expand_as(final_embedding).to(device=final_embedding.device),
|
||||
scaled_image_features.to(device=final_embedding.device, dtype=final_embedding.dtype),
|
||||
)
|
||||
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
|
||||
if attention_mask is not None:
|
||||
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
|
||||
sequence_length = inputs_embeds.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_length()
|
||||
else:
|
||||
position_ids = None
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else cache_position[0] + sequence_length + 1
|
||||
)
|
||||
|
||||
if token_type_ids is not None and labels is not None:
|
||||
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
||||
target_length = cache_position[-1] + 1
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
# unmask the prefill
|
||||
if is_training:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
else:
|
||||
causal_mask = torch.zeros_like(causal_mask)
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
||||
if is_training:
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
final_labels = torch.full(
|
||||
(batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels)
|
||||
else:
|
||||
causal_mask = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
|
||||
# invert causal mask
|
||||
causal_mask = torch.where(causal_mask == 0, min_dtype, 0)
|
||||
causal_mask = causal_mask.to(dtype).expand(-1, self.config.text_config.num_key_value_heads, -1, -1)
|
||||
final_labels = None
|
||||
|
||||
return final_embedding, causal_mask, final_labels, position_ids
|
||||
return causal_mask
|
||||
|
||||
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
@ -411,66 +393,63 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# the attention mask is turned 4d after, we keep track of the original one
|
||||
input_attention_mask = attention_mask
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
if inputs_embeds is None:
|
||||
# 1. Extra the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# 2. Merge text and images
|
||||
if pixel_values is not None and input_ids.shape[1] != 1:
|
||||
image_outputs = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
|
||||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_outputs = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
|
||||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = image_features / (self.config.hidden_size**0.5)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index)
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
else:
|
||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||
# generation with cache
|
||||
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
# TODO @molbap this will only work for dynamic cache.
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
# mask out pad-token-ids in labels for BC
|
||||
if labels is not None and self.pad_token_id in labels:
|
||||
logger.warning_once(
|
||||
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. ",
|
||||
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
||||
)
|
||||
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training
|
||||
)
|
||||
|
||||
# Get the target length
|
||||
target_seqlen = cache_position[-1] + 1
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1] + 1),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses PaliGemma+ Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
|
||||
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@ -487,9 +466,9 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
if labels is not None:
|
||||
shift_logits = logits[..., :-1, :]
|
||||
shift_labels = labels[..., 1:]
|
||||
if input_attention_mask is not None:
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
shift_attention_mask = input_attention_mask[..., 1:]
|
||||
shift_attention_mask = attention_mask[..., 1:]
|
||||
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
||||
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
||||
else:
|
||||
@ -498,7 +477,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
|
||||
flat_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
||||
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
||||
loss = loss_fct(flat_logits, flat_labels)
|
||||
if not return_dict:
|
||||
@ -526,37 +505,24 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"cache_position": cache_position,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
model_inputs["token_type_ids"] = token_type_ids
|
||||
|
||||
# position_ids in Paligemma are 1-indexed
|
||||
if model_inputs.get("position_ids") is not None:
|
||||
model_inputs["position_ids"] += 1
|
||||
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
||||
return model_inputs
|
||||
|
@ -51,6 +51,10 @@ class VideoLlavaConfig(PretrainedConfig):
|
||||
Can be either "full" to select all features or "default" to select features without `CLS`.
|
||||
vision_feature_layer (`int`, *optional*, defaults to -2):
|
||||
The index of the layer to select the vision feature.
|
||||
image_seq_length (`int`, *optional*, defaults to 256):
|
||||
Sequence length of one image embedding.
|
||||
video_seq_length (`int`, *optional*, defaults to 2056):
|
||||
Sequence length of one video embedding.
|
||||
|
||||
Example:
|
||||
|
||||
@ -86,6 +90,8 @@ class VideoLlavaConfig(PretrainedConfig):
|
||||
projector_hidden_act="gelu",
|
||||
vision_feature_select_strategy="default",
|
||||
vision_feature_layer=-2,
|
||||
image_seq_length=256,
|
||||
video_seq_length=2056,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
@ -94,6 +100,8 @@ class VideoLlavaConfig(PretrainedConfig):
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
self.image_seq_length = image_seq_length
|
||||
self.video_seq_length = video_seq_length
|
||||
|
||||
self.vision_config = vision_config
|
||||
|
||||
|
@ -23,7 +23,6 @@ from torch import nn
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@ -228,6 +227,10 @@ VIDEO_LLAVA_INPUTS_DOCSTRING = r"""
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@ -413,6 +416,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, VideoLlavaCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -503,51 +507,71 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
legacy_processing = False
|
||||
if inputs_embeds is None:
|
||||
# 1. Extra the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# 2. Merge text and images
|
||||
if (pixel_values_images is not None or pixel_values_videos is not None) and input_ids.shape[1] != 1:
|
||||
image_outputs, video_outputs, num_frames = self._get_vision_features(
|
||||
pixel_values_images=pixel_values_images,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
# if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing
|
||||
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
|
||||
img_token_count = (input_ids == self.config.image_token_index).sum(1).max()
|
||||
video_token_count = (input_ids == self.config.video_token_index).sum(1).max()
|
||||
inputs_expanded = (
|
||||
img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length
|
||||
)
|
||||
pixels_present = (
|
||||
input_ids.shape[-1] == 1 and pixel_values_images is not None and pixel_values_videos is not None
|
||||
)
|
||||
legacy_processing = inputs_expanded or pixels_present
|
||||
|
||||
# first add image embeds where possible, then expand again and add video embeds
|
||||
if image_outputs is not None:
|
||||
visual_features = self.multi_modal_projector(image_outputs)
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
labels,
|
||||
position_ids,
|
||||
input_ids,
|
||||
) = self._merge_input_ids_with_visual_features(
|
||||
visual_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
if video_outputs is not None:
|
||||
visual_features = self.multi_modal_projector(video_outputs)
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
labels,
|
||||
position_ids,
|
||||
_,
|
||||
) = self._merge_input_ids_with_visual_features(
|
||||
visual_features,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
labels,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
else:
|
||||
# In case input_ids.shape[1] == 1 & past_key_values != None, we are in the case of
|
||||
# generation with cache
|
||||
if past_key_values is not None and input_ids.shape[1] == 1:
|
||||
if pixel_values_images is not None or pixel_values_videos is not None:
|
||||
image_outputs, video_outputs, num_frames = self._get_vision_features(
|
||||
pixel_values_images=pixel_values_images,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
image_features = video_features = None
|
||||
if image_outputs is not None:
|
||||
image_features = self.multi_modal_projector(image_outputs)
|
||||
if video_outputs is not None:
|
||||
video_features = self.multi_modal_projector(video_outputs)
|
||||
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in Video-LLaVa should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
if input_ids.shape[1] != 1:
|
||||
for features, frames in ((image_features, 1), (video_features, num_frames)):
|
||||
if features is not None:
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
labels,
|
||||
position_ids,
|
||||
input_ids,
|
||||
) = self._merge_input_ids_with_visual_features(
|
||||
features,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
labels,
|
||||
num_frames=frames,
|
||||
)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
@ -577,6 +601,22 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
if image_outputs is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
if video_outputs is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@ -586,6 +626,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
@ -626,60 +667,40 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
||||
pixel_values_images=None,
|
||||
pixel_values_videos=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
**kwargs,
|
||||
):
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, Cache):
|
||||
cache_length = past_key_values.get_seq_length()
|
||||
past_length = past_key_values.seen_tokens
|
||||
else:
|
||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
else:
|
||||
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
||||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
||||
# older attention values, as their corresponding values are not part of the input.
|
||||
if cache_length < past_length and attention_mask is not None:
|
||||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
||||
|
||||
pixel_values_videos = None
|
||||
pixel_values_images = None
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values_videos": pixel_values_videos,
|
||||
"pixel_values_images": pixel_values_images,
|
||||
}
|
||||
# Trigger the new behavior if we have more than image embeddings seq length tokens for images
|
||||
legacy_processing = input_ids is not None and (
|
||||
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||
and (input_ids == self.config.video_token_index).sum(1).max() < self.config.video_seq_length
|
||||
)
|
||||
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if legacy_processing:
|
||||
# legacy specific code copied from prev version, we assume that we always have one more new token (assisted decoding doesn't work for VLMs)
|
||||
# if cache_position[0] != 0:
|
||||
# model_inputs["input_ids"] = model_inputs["input_ids"][:, -1:]
|
||||
# if "position_ids" in model_inputs:
|
||||
# model_inputs["position_ids"] = model_inputs["position_ids"][:, -1:]
|
||||
|
||||
model_inputs["pixel_values_images"] = pixel_values_images
|
||||
model_inputs["pixel_values_videos"] = pixel_values_videos
|
||||
|
||||
elif cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values_images"] = pixel_values_images
|
||||
model_inputs["pixel_values_videos"] = pixel_values_videos
|
||||
|
||||
return model_inputs
|
||||
|
||||
def _reorder_cache(self, *args, **kwargs):
|
||||
|
@ -19,10 +19,13 @@ Processor class for VideoLlava.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput
|
||||
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType
|
||||
from ...utils import TensorType, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VideoLlavaProcessor(ProcessorMixin):
|
||||
@ -37,16 +40,39 @@ class VideoLlavaProcessor(ProcessorMixin):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
patch_size (`int`, *optional*):
|
||||
Patch size from the vision tower.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Shoudl be same as in model's config
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
video_token (`str`, *optional*, defaults to `"<video>"`):
|
||||
Special token used to denote video location.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token", "video_token"]
|
||||
image_processor_class = "VideoLlavaImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
patch_size=None,
|
||||
vision_feature_select_strategy=None,
|
||||
image_token="<image>", # set the default and let users change if they have peculiar special tokens in rare cases
|
||||
video_token="<video>",
|
||||
chat_template=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = image_token
|
||||
self.video_token = video_token
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
@ -114,8 +140,46 @@ class VideoLlavaProcessor(ProcessorMixin):
|
||||
encoded_images = self.image_processor(images=images, videos=videos, return_tensors=return_tensors)
|
||||
data.update(encoded_images)
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
if encoded_images is not None and self.patch_size is None or self.vision_feature_select_strategy is None:
|
||||
prompt_strings = text
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in Video-LLaVa should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44."
|
||||
)
|
||||
elif encoded_images is not None:
|
||||
# Replace the image token with the expanded image token sequence
|
||||
if "pixel_values" in encoded_images:
|
||||
height, width = get_image_size(to_numpy_array(encoded_images.get("pixel_values")[0]))
|
||||
num_frames = 1
|
||||
else:
|
||||
one_video = to_numpy_array(encoded_images.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(one_video[0])
|
||||
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||
|
||||
num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1
|
||||
num_video_tokens = num_image_tokens * num_frames
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
|
||||
prompt_strings = []
|
||||
for sample in text:
|
||||
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
|
||||
sample = sample.replace(self.video_token, self.video_token * num_video_tokens)
|
||||
prompt_strings.append(sample)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
|
||||
prompt_strings,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
)
|
||||
data.update(text_inputs)
|
||||
|
||||
|
@ -47,6 +47,8 @@ class VipLlavaConfig(PretrainedConfig):
|
||||
The layer norm epsilon of the projector layernorm
|
||||
vision_feature_layers (`List[int]`, *optional*, defaults to `[-2, -5, -8, -11, 6]`):
|
||||
The list of layers to select the vision features from.
|
||||
image_seq_length (`int`, *optional*, defaults to 576):
|
||||
Sequence length of one image embedding.
|
||||
|
||||
Example:
|
||||
|
||||
@ -81,6 +83,7 @@ class VipLlavaConfig(PretrainedConfig):
|
||||
projector_hidden_act="gelu",
|
||||
projector_layernorm_eps=1e-5,
|
||||
vision_feature_layers=[-2, -5, -8, -11, 6],
|
||||
image_seq_length=576,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
@ -88,6 +91,7 @@ class VipLlavaConfig(PretrainedConfig):
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.projector_layernorm_eps = projector_layernorm_eps
|
||||
self.vision_feature_layers = vision_feature_layers
|
||||
self.image_seq_length = image_seq_length
|
||||
self.vision_config = vision_config
|
||||
|
||||
if isinstance(self.vision_config, dict):
|
||||
|
@ -23,7 +23,6 @@ from torch import nn
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@ -231,6 +230,10 @@ VIPLLAVA_INPUTS_DOCSTRING = r"""
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@ -375,6 +378,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, VipLlavaCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -419,26 +423,48 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
||||
vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
legacy_processing = False
|
||||
if inputs_embeds is None:
|
||||
# 1. Extra the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# 2. Merge text and images
|
||||
if pixel_values is not None and input_ids.shape[1] != 1:
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# For VIP-llava, the image features are computed this way
|
||||
# We select the features from index 1: for the layers -2, -5, -8, -11 and 6
|
||||
image_features = [image_outputs.hidden_states[index][:, 1:] for index in vision_feature_layers]
|
||||
image_features = torch.cat(image_features, dim=-1)
|
||||
# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
|
||||
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
|
||||
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
|
||||
legacy_processing = (
|
||||
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
||||
|
||||
image_features = self.multi_modal_projector(image_features)
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
if pixel_values is not None:
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
|
||||
# For VIP-llava, the image features are computed this way
|
||||
# We select the features from index 1: for the layers -2, -5, -8, -11 and 6
|
||||
image_features = [image_outputs.hidden_states[index][:, 1:] for index in vision_feature_layers]
|
||||
image_features = torch.cat(image_features, dim=-1)
|
||||
image_features = self.multi_modal_projector(image_features)
|
||||
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in VipLLaVa should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's image processing config. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
else:
|
||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||
# generation with cache
|
||||
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
||||
# prefill stage vs decoding stage (legacy behavior copied)
|
||||
if input_ids.shape[1] != 1:
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
@ -468,6 +494,14 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@ -477,6 +511,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
@ -510,56 +545,37 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
**kwargs,
|
||||
):
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, Cache):
|
||||
cache_length = past_key_values.get_seq_length()
|
||||
past_length = past_key_values.seen_tokens
|
||||
else:
|
||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
elif self.config.image_token_index in input_ids:
|
||||
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
||||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
||||
# older attention values, as their corresponding values are not part of the input.
|
||||
if cache_length < past_length and attention_mask is not None:
|
||||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
# Trigger the new behavior if we have more than image embeddings seq length tokens for images
|
||||
legacy_processing = (
|
||||
input_ids is not None
|
||||
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||
)
|
||||
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if legacy_processing:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
elif cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
||||
return model_inputs
|
||||
|
||||
def _reorder_cache(self, *args, **kwargs):
|
||||
|
@ -1033,3 +1033,33 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
[0, 3, 7, 152, 67, 839, 1],
|
||||
)
|
||||
self.assertEqual(generated_text, "san diego")
|
||||
|
||||
def test_expansion_in_processing(self):
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
image = prepare_img()
|
||||
prompt = "Question: which city is this? Answer:"
|
||||
|
||||
# Make sure we will go the legacy path by setting these args to None
|
||||
processor.num_query_tokens = None
|
||||
model.config.image_token_index = None
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
predictions = model.generate(**inputs, do_sample=False, max_new_tokens=15)
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
# Add args to the config to trigger new logic when inputs are expanded in processing file
|
||||
processor.num_query_tokens = model.config.num_query_tokens
|
||||
processor.tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
|
||||
model.config.image_token_index = len(processor.tokenizer) - 1
|
||||
model.resize_token_embeddings(processor.tokenizer.vocab_size, pad_to_multiple_of=64)
|
||||
|
||||
# Generate again with new inputs
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
predictions_expanded = model.generate(**inputs, do_sample=False, max_new_tokens=15)
|
||||
generated_text_expanded = processor.batch_decode(predictions_expanded, skip_special_tokens=True)[0].strip()
|
||||
|
||||
self.assertTrue(generated_text_expanded == generated_text)
|
||||
|
@ -637,3 +637,35 @@ class InstructBlipModelIntegrationTest(unittest.TestCase):
|
||||
predictions[0].tolist(), [0, 37, 1023, 753, 3, 9, 2335, 3823, 30, 8, 2608, 28, 3, 9, 1782, 5, 1]
|
||||
)
|
||||
self.assertEqual(generated_text, "The image features a woman sitting on the beach with a dog.")
|
||||
|
||||
def test_expansion_in_processing(self):
|
||||
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl")
|
||||
model = InstructBlipForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/instructblip-flan-t5-xl",
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
image = prepare_img()
|
||||
prompt = "What's in the image?"
|
||||
|
||||
# Make sure we will go the legacy path by setting these args to None
|
||||
processor.num_query_tokens = None
|
||||
model.config.image_token_index = None
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
predictions = model.generate(**inputs, do_sample=False, max_new_tokens=15)
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
# Add args to the config to trigger new logic when inputs are expanded in processing file
|
||||
processor.num_query_tokens = model.config.num_query_tokens
|
||||
processor.tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
|
||||
model.config.image_token_index = len(processor.tokenizer) - 1
|
||||
model.resize_token_embeddings(processor.tokenizer.vocab_size, pad_to_multiple_of=64)
|
||||
|
||||
# Generate again with new inputs
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
predictions_expanded = model.generate(**inputs, do_sample=False, max_new_tokens=15)
|
||||
generated_text_expanded = processor.batch_decode(predictions_expanded, skip_special_tokens=True)[0].strip()
|
||||
|
||||
self.assertTrue(generated_text_expanded == generated_text)
|
||||
|
@ -119,7 +119,7 @@ class InstructBlipProcessorTest(unittest.TestCase):
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
|
||||
input_str = "lower newer"
|
||||
input_str = ["lower newer"]
|
||||
|
||||
encoded_processor = processor(text=input_str)
|
||||
|
||||
|
@ -583,3 +583,33 @@ class InstructBlipVideoModelIntegrationTest(unittest.TestCase):
|
||||
generated_text,
|
||||
"a baby girl wearing glasses is reading a book on the bed 1080p",
|
||||
)
|
||||
|
||||
def test_expansion_in_processing(self):
|
||||
processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
|
||||
model = InstructBlipVideoForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/instructblip-vicuna-7b", load_in_8bit=True, low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
clip = prepare_video()
|
||||
prompt = "Explain what is happening in this short video."
|
||||
|
||||
# Make sure we will go the legacy path by setting these args to None
|
||||
processor.num_query_tokens = None
|
||||
model.config.video_token_index = None
|
||||
inputs = processor(images=clip, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
predictions = model.generate(**inputs, do_sample=False, max_new_tokens=15)
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
# Add args to the config to trigger new logic when inputs are expanded in processing file
|
||||
processor.num_query_tokens = model.config.num_query_tokens
|
||||
processor.tokenizer.add_special_tokens({"additional_special_tokens": ["<video>"]})
|
||||
model.config.video_token_index = len(processor.tokenizer) - 1
|
||||
model.resize_token_embeddings(len(processor.tokenizer), pad_to_multiple_of=64)
|
||||
|
||||
# Generate again with new inputs
|
||||
inputs = processor(images=clip, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
predictions_expanded = model.generate(**inputs, do_sample=False, max_new_tokens=15)
|
||||
generated_text_expanded = processor.batch_decode(predictions_expanded, skip_special_tokens=True)[0].strip()
|
||||
|
||||
self.assertTrue(generated_text_expanded == generated_text)
|
||||
|
@ -186,6 +186,49 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase
|
||||
self.model_tester = LlavaVisionText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=LlavaConfig, has_text_modality=False)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
@ -471,3 +514,33 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
# Make sure that `generate` works
|
||||
_ = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_expansion_in_processing(self):
|
||||
model_id = "llava-hf/llava-1.5-7b-hf"
|
||||
model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
prompt = "USER: <image>\nDescribe the image:\nASSISTANT:"
|
||||
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
inputs_expanded = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 593)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 18)
|
||||
|
||||
# generate exactly 20 tokens
|
||||
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
|
||||
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
|
||||
|
||||
# check that both inputs are handled correctly and generate the same output
|
||||
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
||||
|
@ -237,6 +237,49 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
@ -505,3 +548,33 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
output_train = model(**inputs_batched, output_hidden_states=True)
|
||||
self.assertTrue((output_train.hidden_states[0][0, -1414:, ...] == 0).all().item())
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_expansion_in_processing(self):
|
||||
model_id = "llava-hf/llava-v1.6-mistral-7b-hf"
|
||||
model = LlavaNextForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
prompt = "USER: <image>\nDescribe the image:\nASSISTANT:"
|
||||
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
inputs_expanded = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2356)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 17)
|
||||
|
||||
# generate exactly 20 tokens
|
||||
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
|
||||
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
|
||||
|
||||
# check that both inputs are handled correctly and generate the same output
|
||||
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
||||
|
@ -252,8 +252,8 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
# overwrite because llava can't support both inputs_embeds and pixel values at ipnut
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@ -274,6 +274,29 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
del inputs["pixel_values_videos"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
@ -487,3 +510,31 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
output_train = model(**inputs_batched, output_hidden_states=True)
|
||||
self.assertTrue((output_train.hidden_states[0][0, -1482:, ...] == 0).all().item())
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_expansion_in_processing(self):
|
||||
model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"
|
||||
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
inputs_expanded = processor(self.prompt_video, videos=[self.video], return_tensors="pt").to(torch_device)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 1170)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
inputs = processor(self.prompt_video, videos=[self.video], return_tensors="pt").to(torch_device)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 19)
|
||||
|
||||
# generate exactly 20 tokens
|
||||
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
|
||||
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
|
||||
|
||||
# check that both inputs are handled correctly and generate the same output
|
||||
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
||||
|
@ -53,9 +53,9 @@ class PaliGemmaVisionText2TextModelTester:
|
||||
self,
|
||||
parent,
|
||||
ignore_index=-100,
|
||||
image_token_index=98,
|
||||
image_token_index=0,
|
||||
projector_hidden_act="gelu",
|
||||
seq_length=7,
|
||||
seq_length=25,
|
||||
vision_feature_select_strategy="default",
|
||||
vision_feature_layer=-1,
|
||||
projection_dim=32,
|
||||
@ -87,8 +87,8 @@ class PaliGemmaVisionText2TextModelTester:
|
||||
is_training=True,
|
||||
vision_config={
|
||||
"use_labels": True,
|
||||
"image_size": 30,
|
||||
"patch_size": 2,
|
||||
"image_size": 20,
|
||||
"patch_size": 5,
|
||||
"num_image_tokens": 4,
|
||||
"num_channels": 3,
|
||||
"is_training": True,
|
||||
@ -106,6 +106,7 @@ class PaliGemmaVisionText2TextModelTester:
|
||||
):
|
||||
self.parent = parent
|
||||
self.ignore_index = ignore_index
|
||||
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
@ -157,8 +158,10 @@ class PaliGemmaVisionText2TextModelTester:
|
||||
config, pixel_values = config_and_inputs
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
# setting the 4 first tokens to be image
|
||||
input_ids[:, :4] = config.image_token_index
|
||||
# set the 16 first tokens to be image, and ensure that no other tokens are image tokens
|
||||
# do not change this unless you modified image size or patch size
|
||||
input_ids = torch.where(input_ids == config.image_token_index, 2, input_ids)
|
||||
input_ids[:, :16] = config.image_token_index
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": input_ids,
|
||||
@ -185,6 +188,49 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, unittest.Test
|
||||
self.model_tester = PaliGemmaVisionText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=PaliGemmaConfig, has_text_modality=False)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
@ -322,6 +322,51 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
for key in model_batched_output:
|
||||
recursive_check(model_batched_output[key], model_row_output[key], model_name, key)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values_images"]
|
||||
del inputs["pixel_values_videos"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values_images"]
|
||||
del inputs["pixel_values_videos"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
|
||||
@require_torch
|
||||
class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
@ -545,3 +590,35 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
labels=input_ids,
|
||||
).loss
|
||||
loss.backward()
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_expansion_in_processing(self):
|
||||
model_id = "LanguageBind/Video-LLaVA-7B-hf"
|
||||
model = VideoLlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
|
||||
processor = VideoLlavaProcessor.from_pretrained(model_id)
|
||||
|
||||
prompt = "USER: <video>Describe the video in details. ASSISTANT:"
|
||||
video_file = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
|
||||
)
|
||||
video_file = np.load(video_file)
|
||||
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
inputs_expanded = processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2073)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
inputs = processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 18)
|
||||
|
||||
# generate exactly 20 tokens
|
||||
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
|
||||
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
|
||||
|
||||
# check that both inputs are handled correctly and generate the same output
|
||||
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
||||
|
@ -167,6 +167,49 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestC
|
||||
self.model_tester = VipLlavaVisionText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=VipLlavaConfig, has_text_modality=False)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
@ -260,3 +303,33 @@ class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
labels=input_ids,
|
||||
).loss
|
||||
loss.backward()
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_expansion_in_processing(self):
|
||||
model_id = "llava-hf/vip-llava-7b-hf"
|
||||
model = VipLlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
prompt = "USER: <image>\nDescribe the image:\nASSISTANT:"
|
||||
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
inputs_expanded = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 593)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 18)
|
||||
|
||||
# generate exactly 20 tokens
|
||||
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
|
||||
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
|
||||
|
||||
# check that both inputs are handled correctly and generate the same output
|
||||
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
||||
|
Loading…
Reference in New Issue
Block a user