🔴 [VLM] Add base model without head (#37033)

* i guessreverted all CdGen classes

* style

* llava onevision

* fix copies

* fix some tests

* some more tests

* dump

* skip these

* nevermind, i am dumb

* revert fix not needed

* fixup

* fixup

* another fixup

* more fixup to make ci finally happy

* fixup after rebasing

* fix qwen tests

* add internVL + typos here and there

* image token index -> id

* style

* fix init weights

* revert blip-2 not supported

* address comments

* fix copies

* revert blip2 test file as well

* as discussed internally, revert back CdGen models

* fix some tests

* fix more tests for compile

* CI red

* fix copies

* enumerate explicitly allowed models

* address comments

* fix tests

* fixup

* style again

* add tests for new model class

* another fixup ( x _ x )

* [fixup] unused attributes can be removed post-deprecation
This commit is contained in:
Raushan Turganbay 2025-05-07 17:47:51 +02:00 committed by GitHub
parent 3fa8d9c20e
commit 17742bd9c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
85 changed files with 7590 additions and 2904 deletions

View File

@ -102,6 +102,10 @@ response = processor.decode(output_ids, skip_special_tokens=True)
[[autodoc]] AriaTextModel
## AriaModel
[[autodoc]] AriaModel
## AriaTextForCausalLM
[[autodoc]] AriaTextForCausalLM

View File

@ -237,6 +237,10 @@ for i, output in enumerate(batch_outputs):
[[autodoc]] AyaVisionConfig
## AyaVisionModel
[[autodoc]] AyaVisionModel
## AyaVisionForConditionalGeneration
[[autodoc]] AyaVisionForConditionalGeneration

View File

@ -174,6 +174,10 @@ for i, image in enumerate(images['pixel_values']):
[[autodoc]] Emu3TextModel
- forward
## Emu3Model
[[autodoc]] Emu3Model
## Emu3ForCausalLM
[[autodoc]] Emu3ForCausalLM

View File

@ -103,6 +103,10 @@ The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece.
[[autodoc]] FuyuConfig
## FuyuModel
[[autodoc]] FuyuModel
## FuyuForCausalLM
[[autodoc]] FuyuForCausalLM

View File

@ -254,6 +254,10 @@ visualizer("<img>What is shown in this image?")
[[autodoc]] Gemma3TextModel
- forward
## Gemma3Model
[[autodoc]] Gemma3Model
## Gemma3ForCausalLM
[[autodoc]] Gemma3ForCausalLM

View File

@ -277,6 +277,10 @@ alt="drawing" width="600"/>
[[autodoc]] GotOcr2Processor
## GotOcr2Model
[[autodoc]] GotOcr2Model
## GotOcr2ForConditionalGeneration
[[autodoc]] GotOcr2ForConditionalGeneration

View File

@ -69,6 +69,10 @@ The attributes can be obtained from model config, as `model.config.num_query_tok
[[autodoc]] InstructBlipQFormerModel
- forward
## InstructBlipModel
[[autodoc]] InstructBlipModel
## InstructBlipForConditionalGeneration
[[autodoc]] InstructBlipForConditionalGeneration

View File

@ -73,6 +73,10 @@ The attributes can be obtained from model config, as `model.config.num_query_tok
[[autodoc]] InstructBlipVideoQFormerModel
- forward
## InstructBlipVideoModel
[[autodoc]] InstructBlipVideoModel
- forward
## InstructBlipVideoForConditionalGeneration
[[autodoc]] InstructBlipVideoForConditionalGeneration

View File

@ -340,6 +340,11 @@ This example showcases how to handle a batch of chat conversations with interlea
[[autodoc]] InternVLVisionModel
- forward
## InternVLModel
[[autodoc]] InternVLModel
- forward
## InternVLForConditionalGeneration
[[autodoc]] InternVLForConditionalGeneration

View File

@ -256,6 +256,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] LlavaProcessor
## LlavaModel
[[autodoc]] LlavaModel
## LlavaForConditionalGeneration
[[autodoc]] LlavaForConditionalGeneration

View File

@ -315,6 +315,10 @@ model = AutoModelForImageTextToText.from_pretrained(
[[autodoc]] LlavaNextProcessor
## LlavaNextModel
[[autodoc]] LlavaNextModel
## LlavaNextForConditionalGeneration
[[autodoc]] LlavaNextForConditionalGeneration

View File

@ -262,6 +262,10 @@ model = LlavaNextVideoForConditionalGeneration.from_pretrained(
[[autodoc]] LlavaNextVideoImageProcessor
## LlavaNextVideoModel
[[autodoc]] LlavaNextVideoModel
## LlavaNextVideoForConditionalGeneration
[[autodoc]] LlavaNextVideoForConditionalGeneration

View File

@ -313,6 +313,10 @@ model = LlavaOnevisionForConditionalGeneration.from_pretrained(
[[autodoc]] LlavaOnevisionVideoProcessor
## LlavaOnevisionModel
[[autodoc]] LlavaOnevisionModel
## LlavaOnevisionForConditionalGeneration
[[autodoc]] LlavaOnevisionForConditionalGeneration

View File

@ -227,6 +227,9 @@ This example also how to use `BitsAndBytes` to load the model in 4bit quantizati
[[autodoc]] Mistral3Config
## Mistral3Model
[[autodoc]] Mistral3Model
## Mistral3ForConditionalGeneration

View File

@ -130,6 +130,10 @@ print(processor.decode(output[0], skip_special_tokens=True))
[[autodoc]] MllamaTextModel
- forward
## MllamaModel
[[autodoc]] MllamaModel
## MllamaForCausalLM
[[autodoc]] MllamaForCausalLM

View File

@ -174,6 +174,10 @@ visualizer("<img> What is in this image?")
[[autodoc]] PaliGemmaProcessor
## PaliGemmaModel
[[autodoc]] PaliGemmaModel
## PaliGemmaForConditionalGeneration
[[autodoc]] PaliGemmaForConditionalGeneration

View File

@ -240,6 +240,10 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
[[autodoc]] Qwen2_5_VLProcessor
## Qwen2_5_VLTextModel
[[autodoc]] Qwen2_5_VLTextModel
- forward
## Qwen2_5_VLModel

View File

@ -296,6 +296,11 @@ model = Qwen2VLForConditionalGeneration.from_pretrained(
[[autodoc]] Qwen2VLProcessor
## Qwen2VLTextModel
[[autodoc]] Qwen2VLTextModel
- forward
## Qwen2VLModel
[[autodoc]] Qwen2VLModel

View File

@ -215,6 +215,10 @@ model = VideoLlavaForConditionalGeneration.from_pretrained(
[[autodoc]] VideoLlavaProcessor
## VideoLlavaModel
[[autodoc]] VideoLlavaModel
## VideoLlavaForConditionalGeneration
[[autodoc]] VideoLlavaForConditionalGeneration

View File

@ -101,6 +101,10 @@ A chat between a curious human and an artificial intelligence assistant. The ass
[[autodoc]] VipLlavaConfig
## VipLlavaModel
[[autodoc]] VipLlavaModel
## VipLlavaForConditionalGeneration
[[autodoc]] VipLlavaForConditionalGeneration

View File

@ -216,6 +216,28 @@ TORCH_INIT_FUNCTIONS = {
"kaiming_normal": nn.init.kaiming_normal,
}
# DO NOT MODIFY, KEPT FOR BC ONLY
VLMS = [
"aria",
"aya_vision",
"emu3",
"fuyu",
"got_ocr2",
"gemma3",
"internvl",
"llava",
"llava_next",
"llava_next_video",
"llava_onevision",
"mistral3",
"mllama",
"paligemma",
"qwen2_vl",
"qwem2_5_vl",
"video_llava",
"vipllava",
]
@contextmanager
def no_init_weights():
@ -1778,6 +1800,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
main_input_name = "input_ids"
model_tags = None
_checkpoint_conversion_mapping = {} # used for BC support in VLMs, not meant to be used by new models
_auto_class = None
_no_split_modules = None
_skip_keys_device_placement = None
@ -3484,6 +3508,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
module_map[name + f".{key}"] = module
state_dict = model_to_save.state_dict()
if any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS):
reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
original_state_dict = {}
for key, value in state_dict.items():
for pattern, replacement in reverse_key_mapping.items():
replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
replacement = re.sub(r"\(.*?\)", "", pattern)
key, n_replace = re.subn(pattern, replacement, key)
# Early exit of the loop
if n_replace > 0:
break
original_state_dict[key] = value
state_dict = original_state_dict
# Translate state_dict from smp to hf if saving with smp >= 1.10
if IS_SAGEMAKER_MP_POST_1_10:
for smp_to_hf, _ in smp.state.module_manager.translate_functions:
@ -4071,7 +4110,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
key_mapping = kwargs.pop("key_mapping", None)
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
if any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
key_mapping = kwargs.pop("key_mapping", cls._checkpoint_conversion_mapping)
else:
key_mapping = kwargs.pop("key_mapping", None)
# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("resume_download", None)
_ = kwargs.pop("trust_remote_code", None)

View File

@ -42,7 +42,7 @@ from ...utils import (
replace_return_docstrings,
)
from ...utils.import_utils import is_torch_available
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_aria import AriaConfig, AriaTextConfig
@ -58,7 +58,9 @@ if is_torch_flex_attn_available():
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "AriaTextConfig"
_CONFIG_FOR_DOC = "AriaConfig"
@use_kernel_forward_from_hub("RMSNorm")
@ -659,7 +661,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
"""
config_class = AriaConfig
config_class = AriaTextConfig
base_model_prefix = "model"
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
supports_gradient_checkpointing = True
@ -706,8 +708,8 @@ ARIA_TEXT_START_DOCSTRING = r"""
ARIA_TEXT_START_DOCSTRING,
)
class AriaPreTrainedModel(PreTrainedModel):
config_class = AriaTextConfig
base_model_prefix = "model"
config_class = AriaConfig
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["AriaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
@ -1097,6 +1099,9 @@ class AriaTextModel(AriaTextPreTrainedModel):
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
_CONFIG_FOR_TEXT_DOC = "AriaTextConfig"
class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
"""
Aria model for causal language modeling tasks.
@ -1112,7 +1117,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config_class = AriaTextConfig
def __init__(self, config: AriaTextConfig):
super().__init__(config)
@ -1141,9 +1145,8 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model
@can_return_tuple
@add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_TEXT_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -1255,7 +1258,7 @@ class AriaCausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@ -1267,6 +1270,39 @@ class AriaCausalLMOutputWithPast(ModelOutput):
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class AriaModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for Aria outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
ARIA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor`, *optional*):
@ -1320,30 +1356,131 @@ ARIA_START_DOCSTRING = r"""
@add_start_docstrings(
"""Aria model for conditional generation tasks.
This model combines a vision tower, a multi-modal projector, and a language model
to perform tasks that involve both image and text inputs.""",
"""The Aria model which consists of a vision backbone and a language model, without a language modeling head.""",
ARIA_START_DOCSTRING,
)
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
config_class = AriaConfig
_supports_flash_attn_2 = False
_supports_flex_attn = False
_supports_sdpa = False
_tied_weights_keys = ["language_model.lm_head.weight"]
class AriaModel(AriaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: AriaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = AriaProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_image_features(
self,
pixel_values: torch.FloatTensor,
pixel_mask: torch.FloatTensor = None,
vision_feature_layer: int = -1,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
pixel_mask (`torch.FloatTensor]`, *optional*):
The tensors corresponding to the input image mask.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
image_outputs = self.vision_tower(
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
)
image_attn_mask = None
if patch_attention_mask is not None:
flattened_mask = patch_attention_mask.flatten(1)
image_attn_mask = torch.logical_not(flattened_mask)
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_mask: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, AriaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
if pixel_values is not None and inputs_embeds.shape[1] != 1:
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
else:
image_embeds = input_ids == self.config.image_token_id
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
image_features = self.get_image_features(
pixel_values=pixel_values,
pixel_mask=pixel_mask,
vision_feature_layer=self.config.vision_feature_layer,
)
n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
n_image_features = n_images * n_features_per_image
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
)
output = AriaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
def _create_patch_attention_mask(self, pixel_mask):
if pixel_mask is None:
return None
@ -1360,51 +1497,61 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
@add_start_docstrings(
"""Aria model for conditional generation tasks.
This model combines a vision tower, a multi-modal projector, and a language model
to perform tasks that involve both image and text inputs.""",
ARIA_START_DOCSTRING,
)
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: AriaConfig):
super().__init__(config)
self.model = AriaModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
def get_decoder(self):
return self.language_model.get_decoder()
@property
def vision_tower(self):
return self.model.vision_tower
def get_image_features(
self,
pixel_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
vision_feature_layer: int = -1,
):
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
image_outputs = self.vision_tower(
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
)
image_attn_mask = None
if patch_attention_mask is not None:
flattened_mask = patch_attention_mask.flatten(1)
image_attn_mask = torch.logical_not(flattened_mask)
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_mask: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -1413,10 +1560,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**loss_kwargs,
) -> AriaCausalLMOutputWithPast:
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -1482,37 +1630,12 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
if pixel_values is not None and inputs_embeds.shape[1] != 1:
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
else:
image_embeds = input_ids == self.config.image_token_id
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
image_features = self.get_image_features(
pixel_values=pixel_values,
pixel_mask=pixel_mask,
vision_feature_layer=self.config.vision_feature_layer,
)
n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
n_image_features = n_images * n_features_per_image
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs: CausalLMOutputWithPast = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@ -1520,11 +1643,14 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
logits_to_keep=logits_to_keep,
return_dict=return_dict,
cache_position=cache_position,
)
logits = outputs.logits
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
@ -1552,7 +1678,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
logits_to_keep=None,
**kwargs,
):
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -1570,11 +1696,67 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
return model_inputs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [
"AriaForConditionalGeneration",
"AriaPreTrainedModel",
"AriaTextPreTrainedModel",
"AriaTextModel",
"AriaModel",
"AriaTextForCausalLM",
]

View File

@ -18,7 +18,6 @@ import numpy as np
from ...activations import ACT2FN
from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
from ...image_utils import (
@ -49,7 +48,7 @@ from ...utils import (
replace_return_docstrings,
)
from ...utils.import_utils import is_torch_available
from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
from ..auto import CONFIG_MAPPING, AutoConfig, AutoTokenizer
from ..llama.configuration_llama import LlamaConfig
from ..llama.modeling_llama import (
LlamaDecoderLayer,
@ -59,7 +58,12 @@ from ..llama.modeling_llama import (
LlamaPreTrainedModel,
LlamaRMSNorm,
)
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast
from ..llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaModel,
LlavaModelOutputWithPast,
)
from ..llava_next.image_processing_llava_next import divide_to_patches
@ -70,6 +74,11 @@ if is_torch_available():
from torch import nn
_CONFIG_FOR_DOC = "AriaConfig"
_CONFIG_FOR_TEXT_DOC = "AriaTextConfig"
ARIA_TEXT_INPUTS_DOCSTRING = None
def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
"""
Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
@ -1223,7 +1232,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
"""
config_class = AriaConfig
config_class = AriaTextConfig
base_model_prefix = "model"
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
supports_gradient_checkpointing = True
@ -1249,6 +1258,8 @@ class AriaTextPreTrainedModel(PreTrainedModel):
class AriaPreTrainedModel(LlamaPreTrainedModel):
config_class = AriaConfig
base_model_prefix = ""
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
_supports_attention_backend = False
@ -1292,7 +1303,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
"""
_tied_weights_keys = ["lm_head.weight"]
config_class = AriaTextConfig
def __init__(self, config: AriaTextConfig):
super().__init__(config)
@ -1303,11 +1313,20 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_TEXT_DOC)
def forward(self, **super_kwargs):
super().forward(self, **super_kwargs)
class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
class AriaModelOutputWithPast(LlavaModelOutputWithPast):
pass
ARIA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor`, *optional*):
@ -1360,30 +1379,10 @@ ARIA_START_DOCSTRING = r"""
"""
@add_start_docstrings(
"""Aria model for conditional generation tasks.
This model combines a vision tower, a multi-modal projector, and a language model
to perform tasks that involve both image and text inputs.""",
ARIA_START_DOCSTRING,
)
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
config_class = AriaConfig
_supports_flash_attn_2 = False
_supports_flex_attn = False
_supports_sdpa = False
_tied_weights_keys = ["language_model.lm_head.weight"]
class AriaModel(LlavaModel):
def __init__(self, config: AriaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = AriaProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
self.post_init()
def _create_patch_attention_mask(self, pixel_mask):
if pixel_mask is None:
@ -1401,30 +1400,27 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
pixel_mask: torch.FloatTensor = None,
vision_feature_layer: int = -1,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
pixel_mask (`torch.FloatTensor]`, *optional*):
The tensors corresponding to the input image mask.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
image_outputs = self.vision_tower(
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
@ -1438,14 +1434,94 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features
@can_return_tuple
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_mask: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, AriaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
if pixel_values is not None and inputs_embeds.shape[1] != 1:
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
else:
image_embeds = input_ids == self.config.image_token_id
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
image_features = self.get_image_features(
pixel_values=pixel_values,
pixel_mask=pixel_mask,
vision_feature_layer=self.config.vision_feature_layer,
)
n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
n_image_features = n_images * n_features_per_image
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
)
output = AriaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""Aria model for conditional generation tasks.
This model combines a vision tower, a multi-modal projector, and a language model
to perform tasks that involve both image and text inputs.""",
ARIA_START_DOCSTRING,
)
class AriaForConditionalGeneration(LlavaForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_mask: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -1454,10 +1530,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**loss_kwargs,
) -> AriaCausalLMOutputWithPast:
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -1523,37 +1600,12 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
if pixel_values is not None and inputs_embeds.shape[1] != 1:
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
else:
image_embeds = input_ids == self.config.image_token_id
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
image_features = self.get_image_features(
pixel_values=pixel_values,
pixel_mask=pixel_mask,
vision_feature_layer=self.config.vision_feature_layer,
)
n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
n_image_features = n_images * n_features_per_image
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs: CausalLMOutputWithPast = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@ -1561,11 +1613,14 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
logits_to_keep=logits_to_keep,
return_dict=return_dict,
cache_position=cache_position,
)
logits = outputs.logits
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
@ -1593,7 +1648,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
logits_to_keep=None,
**kwargs,
):
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -1621,5 +1676,6 @@ __all__ = [
"AriaPreTrainedModel",
"AriaTextPreTrainedModel",
"AriaTextModel",
"AriaModel",
"AriaTextForCausalLM",
]

View File

@ -35,10 +35,11 @@ MODEL_MAPPING_NAMES = OrderedDict(
("albert", "AlbertModel"),
("align", "AlignModel"),
("altclip", "AltCLIPModel"),
("aria", "AriaForConditionalGeneration"),
("aria", "AriaModel"),
("aria_text", "AriaTextModel"),
("audio-spectrogram-transformer", "ASTModel"),
("autoformer", "AutoformerModel"),
("aya_vision", "AyaVisionModel"),
("bamba", "BambaModel"),
("bark", "BarkModel"),
("bart", "BartModel"),
@ -108,6 +109,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("efficientformer", "EfficientFormerModel"),
("efficientnet", "EfficientNetModel"),
("electra", "ElectraModel"),
("emu3", "Emu3Model"),
("encodec", "EncodecModel"),
("ernie", "ErnieModel"),
("ernie_m", "ErnieMModel"),
@ -121,14 +123,16 @@ MODEL_MAPPING_NAMES = OrderedDict(
("focalnet", "FocalNetModel"),
("fsmt", "FSMTModel"),
("funnel", ("FunnelModel", "FunnelBaseModel")),
("fuyu", "FuyuModel"),
("gemma", "GemmaModel"),
("gemma2", "Gemma2Model"),
("gemma3", "Gemma3Model"),
("gemma3_text", "Gemma3TextModel"),
("git", "GitModel"),
("glm", "GlmModel"),
("glm4", "Glm4Model"),
("glpn", "GLPNModel"),
("got_ocr2", "GotOcr2ForConditionalGeneration"),
("got_ocr2", "GotOcr2Model"),
("gpt-sw3", "GPT2Model"),
("gpt2", "GPT2Model"),
("gpt_bigcode", "GPTBigCodeModel"),
@ -156,6 +160,9 @@ MODEL_MAPPING_NAMES = OrderedDict(
("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
("instructblip", "InstructBlipModel"),
("instructblipvideo", "InstructBlipVideoModel"),
("internvl", "InternVLModel"),
("internvl_vision", "InternVLVisionModel"),
("jamba", "JambaModel"),
("janus", "JanusModel"),
@ -170,6 +177,10 @@ MODEL_MAPPING_NAMES = OrderedDict(
("lilt", "LiltModel"),
("llama", "LlamaModel"),
("llama4", "Llama4ForConditionalGeneration"),
("llava", "LlavaModel"),
("llava_next", "LlavaNextModel"),
("llava_next_video", "LlavaNextVideoModel"),
("llava_onevision", "LlavaOnevisionModel"),
("longformer", "LongformerModel"),
("longt5", "LongT5Model"),
("luke", "LukeModel"),
@ -189,8 +200,10 @@ MODEL_MAPPING_NAMES = OrderedDict(
("mgp-str", "MgpstrForSceneTextRecognition"),
("mimi", "MimiModel"),
("mistral", "MistralModel"),
("mistral3", "Mistral3Model"),
("mixtral", "MixtralModel"),
("mlcd", "MLCDVisionModel"),
("mllama", "MllamaModel"),
("mobilebert", "MobileBertModel"),
("mobilenet_v1", "MobileNetV1Model"),
("mobilenet_v2", "MobileNetV2Model"),
@ -221,6 +234,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("opt", "OPTModel"),
("owlv2", "Owlv2Model"),
("owlvit", "OwlViTModel"),
("paligemma", "PaliGemmaModel"),
("patchtsmixer", "PatchTSMixerModel"),
("patchtst", "PatchTSTModel"),
("pegasus", "PegasusModel"),
@ -240,11 +254,11 @@ MODEL_MAPPING_NAMES = OrderedDict(
("qdqbert", "QDQBertModel"),
("qwen2", "Qwen2Model"),
("qwen2_5_vl", "Qwen2_5_VLModel"),
("qwen2_5_vl_text", "Qwen2_5_VLModel"),
("qwen2_5_vl_text", "Qwen2_5_VLTextModel"),
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
("qwen2_moe", "Qwen2MoeModel"),
("qwen2_vl", "Qwen2VLModel"),
("qwen2_vl_text", "Qwen2VLModel"),
("qwen2_vl_text", "Qwen2VLTextModel"),
("qwen3", "Qwen3Model"),
("qwen3_moe", "Qwen3MoeModel"),
("recurrent_gemma", "RecurrentGemmaModel"),
@ -306,8 +320,10 @@ MODEL_MAPPING_NAMES = OrderedDict(
("unispeech-sat", "UniSpeechSatModel"),
("univnet", "UnivNetModel"),
("van", "VanModel"),
("video_llava", "VideoLlavaModel"),
("videomae", "VideoMAEModel"),
("vilt", "ViltModel"),
("vipllava", "VipLlavaModel"),
("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
("visual_bert", "VisualBertModel"),
("vit", "ViTModel"),
@ -879,6 +895,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("llama4", "Llama4ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
("mistral3", "Mistral3ForConditionalGeneration"),
("mllama", "MllamaForConditionalGeneration"),

View File

@ -27,15 +27,16 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_aya_vision import AyaVisionConfig
@ -115,9 +116,8 @@ AYA_VISION_START_DOCSTRING = r"""
)
class AyaVisionPreTrainedModel(PreTrainedModel):
config_class = AyaVisionConfig
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["AyaVisionVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@ -169,7 +169,7 @@ class AyaVisionCausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@ -181,7 +181,40 @@ class AyaVisionCausalLMOutputWithPast(ModelOutput):
image_hidden_states: Optional[torch.FloatTensor] = None
AYA_VISION_INPUTS_DOCSTRING = """
@dataclass
class AyaVisionModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for AyaVision outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
AYA_VISION_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@ -193,8 +226,8 @@ AYA_VISION_INPUTS_DOCSTRING = """
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`GotOcr2ImageProcessor.__call__`] for details. [`AyaVisionProcessor`] uses
[`GotOcr2ImageProcessor`] for processing images.
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`AyaVisionProcessor`] uses
[`CLIPImageProcessor`] for processing images).
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
@ -259,23 +292,18 @@ AYA_VISION_INPUTS_DOCSTRING = """
@add_start_docstrings(
"""The AyaVision model which consists of a vision backbone and a language model.""",
"""The AyaVision model which consists of a vision backbone and a language model, without a language modeling head.""",
AYA_VISION_START_DOCSTRING,
)
class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin):
class AyaVisionModel(AyaVisionPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: AyaVisionConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = AyaVisionMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@ -284,18 +312,6 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -307,7 +323,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
@ -342,6 +358,140 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
@add_start_docstrings_to_model_forward(AYA_VISION_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, AyaVisionModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = AyaVisionModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""The AyaVision model which consists of a vision backbone and a language model.""",
AYA_VISION_START_DOCSTRING,
)
class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: AyaVisionConfig):
super().__init__(config)
self.model = AyaVisionModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(AYA_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AyaVisionCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@ -410,7 +560,6 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
>>> gen_tokens = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.3)
>>> processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -425,73 +574,32 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
image_sizes=image_sizes,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return AyaVisionCausalLMOutputWithPast(
loss=loss,
@ -499,7 +607,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -515,7 +623,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -532,15 +640,60 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
return model_inputs
def tie_weights(self):
return self.language_model.tie_weights()
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel"]
__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]

View File

@ -22,6 +22,7 @@ from torch import nn
from transformers.models.llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaModel,
LlavaPreTrainedModel,
)
@ -88,21 +89,8 @@ class AyaVisionMultiModalProjector(nn.Module):
return image_features
AYA_VISION_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`AyaVisionConfig`] or [`AyaVisionVisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
AYA_VISION_START_DOCSTRING = None
AYA_VISION_INPUTS_DOCSTRING = None
@add_start_docstrings(
@ -133,81 +121,8 @@ class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
AYA_VISION_INPUTS_DOCSTRING = """
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`GotOcr2ImageProcessor.__call__`] for details. [`AyaVisionProcessor`] uses
[`GotOcr2ImageProcessor`] for processing images.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
class AyaVisionModel(LlavaModel):
pass
@add_start_docstrings(
@ -215,16 +130,6 @@ AYA_VISION_INPUTS_DOCSTRING = """
AYA_VISION_START_DOCSTRING,
)
class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration):
def tie_weights(self):
return self.language_model.tie_weights()
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -312,4 +217,4 @@ class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration):
)
__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel"]
__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]

View File

@ -175,8 +175,8 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
self.vocab_size = config.vlm_config.text_config.vocab_size
vlm = AutoModelForImageTextToText.from_config(config.vlm_config)
if vlm.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"vlm.language_model.{k}" for k in vlm.language_model._tied_weights_keys]
if vlm._tied_weights_keys is not None:
self._tied_weights_keys = [f"vlm.{k}" for k in vlm._tied_weights_keys]
self.vlm = vlm
self.embedding_dim = self.config.embedding_dim
@ -246,25 +246,25 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
)
def get_input_embeddings(self):
return self.vlm.language_model.get_input_embeddings()
return self.vlm.get_input_embeddings()
def set_input_embeddings(self, value):
self.vlm.language_model.set_input_embeddings(value)
self.vlm.set_input_embeddings(value)
def get_output_embeddings(self):
return self.vlm.language_model.get_output_embeddings()
return self.vlm.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.vlm.language_model.set_output_embeddings(new_embeddings)
self.vlm.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.vlm.language_model.set_decoder(decoder)
self.vlm.set_decoder(decoder)
def get_decoder(self):
return self.vlm.language_model.get_decoder()
return self.vlm.get_decoder()
def tie_weights(self):
return self.vlm.language_model.tie_weights()
return self.vlm.tie_weights()
def resize_token_embeddings(
self,
@ -272,7 +272,7 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
pad_to_multiple_of: Optional[int] = None,
mean_resizing: bool = True,
) -> nn.Embedding:
model_embeds = self.vlm.language_model.resize_token_embeddings(
model_embeds = self.vlm.resize_token_embeddings(
new_num_tokens=new_num_tokens,
pad_to_multiple_of=pad_to_multiple_of,
mean_resizing=mean_resizing,

View File

@ -1778,13 +1778,16 @@ EMU3_INPUTS_DOCSTRING = r"""
"""
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["text_model.lm_head.weight"]
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compilable
class Emu3Model(Emu3PreTrainedModel):
_checkpoint_conversion_mapping = {"text_model.model": "text_model"}
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
def __init__(self, config):
super().__init__(config)
self.text_model = Emu3ForCausalLM._from_config(config.text_config)
self.text_model = Emu3TextModel._from_config(config.text_config)
if self.text_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys]
self.vqmodel = Emu3VQVAE(config.vq_config)
self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
@ -1833,14 +1836,12 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
image = self.vqmodel.decode(image_tokens)
return image
@can_return_tuple
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
@ -1848,10 +1849,99 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
return outputs
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
base_model_prefix = ""
_checkpoint_conversion_mapping = {
"^text_model.model": "model.text_model",
"^vqmodel": "model.vqmodel",
"^text_model.lm_head": "lm_head",
}
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
def __init__(self, config):
super().__init__(config)
self.model = Emu3Model(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
# Make modules available throught conditional class for BC
@property
def text_model(self):
return self.model.text_model
@property
def vqmodel(self):
return self.model.vqmodel
@can_return_tuple
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> CausalLMOutputWithPast:
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -1906,25 +1996,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
return self.text_model(
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@ -1933,8 +2007,25 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
@ -1968,5 +2059,68 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
return model_inputs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["Emu3ForConditionalGeneration", "Emu3ForCausalLM", "Emu3TextModel", "Emu3PreTrainedModel", "Emu3VQVAE"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [
"Emu3ForConditionalGeneration",
"Emu3ForCausalLM",
"Emu3TextModel",
"Emu3PreTrainedModel",
"Emu3VQVAE",
"Emu3Model",
]

View File

@ -1121,13 +1121,16 @@ class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin):
super().forward()
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["text_model.lm_head.weight"]
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compilable
class Emu3Model(Emu3PreTrainedModel):
_checkpoint_conversion_mapping = {"text_model.model": "text_model"}
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
def __init__(self, config):
super().__init__(config)
self.text_model = Emu3ForCausalLM._from_config(config.text_config)
self.text_model = Emu3TextModel._from_config(config.text_config)
if self.text_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys]
self.vqmodel = Emu3VQVAE(config.vq_config)
self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
@ -1176,14 +1179,12 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
image = self.vqmodel.decode(image_tokens)
return image
@can_return_tuple
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
@ -1191,10 +1192,99 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
return outputs
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
base_model_prefix = ""
_checkpoint_conversion_mapping = {
"^text_model.model": "model.text_model",
"^vqmodel": "model.vqmodel",
"^text_model.lm_head": "lm_head",
}
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
def __init__(self, config):
super().__init__(config)
self.model = Emu3Model(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
# Make modules available throught conditional class for BC
@property
def text_model(self):
return self.model.text_model
@property
def vqmodel(self):
return self.model.vqmodel
@can_return_tuple
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> CausalLMOutputWithPast:
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -1249,25 +1339,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
return self.text_model(
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@ -1276,8 +1350,25 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
@ -1311,6 +1402,62 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
return model_inputs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [
"Emu3ForConditionalGeneration",
@ -1318,4 +1465,5 @@ __all__ = [
"Emu3TextModel",
"Emu3PreTrainedModel",
"Emu3VQVAE",
"Emu3Model",
]

View File

@ -21,9 +21,9 @@ import torch.utils.checkpoint
from torch import nn
from ...generation import GenerationMixin
from ...modeling_outputs import CausalLMOutputWithPast
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...models.auto.modeling_auto import AutoModelForCausalLM
from ...models.auto.modeling_auto import AutoModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_fuyu import FuyuConfig
@ -143,18 +143,17 @@ FUYU_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.",
"""The Fuyu model which consists of a vision backbone and a language model, without a language modeling head.""",
FUYU_START_DOCSTRING,
)
class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
class FuyuModel(FuyuPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: FuyuConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.language_model = AutoModel.from_config(config.text_config)
self.vision_embed_tokens = nn.Linear(
config.patch_size * config.patch_size * config.num_channels, config.hidden_size
)
@ -169,18 +168,6 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def gather_continuous_embeddings(
self,
word_embeddings: torch.Tensor,
@ -224,56 +211,21 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
return output_embeddings
@add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
image_patches: Optional[
torch.Tensor
] = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
image_patches_indices: Optional[torch.Tensor] = None,
input_ids: torch.LongTensor = None,
image_patches: torch.Tensor = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
image_patches_indices: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
Returns:
Examples:
```python
>>> from transformers import FuyuProcessor, FuyuForCausalLM
>>> from PIL import Image
>>> import requests
>>> processor = FuyuProcessor.from_pretrained("adept/fuyu-8b")
>>> model = FuyuForCausalLM.from_pretrained("adept/fuyu-8b")
>>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Generate a coco-style caption.\n"
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> outputs = model(**inputs)
>>> generated_ids = model.generate(**inputs, max_new_tokens=7)
>>> generation_text = processor.batch_decode(generated_ids[:, -7:], skip_special_tokens=True)
>>> print(generation_text[0])
A blue bus parked on the side of a road.
```"""
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -327,7 +279,6 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
labels=labels,
use_cache=use_cache,
return_dict=return_dict,
**kwargs,
@ -335,6 +286,139 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
return outputs
@add_start_docstrings(
"Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.",
FUYU_START_DOCSTRING,
)
class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_embed_tokens": "model.vision_embed_tokens",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: FuyuConfig):
super().__init__(config)
self.model = FuyuModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model.get_decoder()
@add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
image_patches: torch.Tensor = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
image_patches_indices: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Optional[int] = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
Returns:
Examples:
```python
>>> from transformers import FuyuProcessor, FuyuForCausalLM
>>> from PIL import Image
>>> import requests
>>> processor = FuyuProcessor.from_pretrained("adept/fuyu-8b")
>>> model = FuyuForCausalLM.from_pretrained("adept/fuyu-8b")
>>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Generate a coco-style caption.\n"
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> outputs = model(**inputs)
>>> generated_ids = model.generate(**inputs, max_new_tokens=7)
>>> generation_text = processor.batch_decode(generated_ids[:, -7:], skip_special_tokens=True)
>>> print(generation_text[0])
A blue bus parked on the side of a road.
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
image_patches=image_patches,
image_patches_indices=image_patches_indices,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
return_dict=return_dict,
# don't pass kwargs because Persimmon-backbone doesn't accept FA2 kwargs yet, TODO: raushan
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
@ -373,4 +457,4 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
return reordered_past
__all__ = ["FuyuForCausalLM", "FuyuPreTrainedModel"]
__all__ = ["FuyuForCausalLM", "FuyuPreTrainedModel", "FuyuModel"]

View File

@ -32,11 +32,12 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -46,7 +47,7 @@ from ...utils import (
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
@ -60,6 +61,39 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Gemma3Config"
@dataclass
class Gemma3ModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for Gemma3 outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class Gemma3CausalLMOutputWithPast(ModelOutput):
"""
@ -88,7 +122,7 @@ class Gemma3CausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
"""
@ -480,7 +514,7 @@ GEMMA3_START_DOCSTRING = r"""
)
class Gemma3PreTrainedModel(PreTrainedModel):
config_class = Gemma3Config
base_model_prefix = "language_model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = [
"Gemma3DecoderLayer",
@ -1066,20 +1100,19 @@ class Gemma3MultiModalProjector(nn.Module):
@add_start_docstrings(
"""The GEMMA3 model which consists of a vision backbone and a language model.""",
"""Base Gemma3 model which consists of a vision backbone and a language model withou language modeling head.""",
GEMMA3_START_DOCSTRING,
)
class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
class Gemma3Model(Gemma3PreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: Gemma3Config):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
language_model = AutoModelForCausalLM.from_config(config=config.text_config)
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
language_model = AutoModel.from_config(config=config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
@ -1091,18 +1124,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def _update_causal_mask(
self,
attention_mask,
@ -1188,11 +1209,10 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
@ -1203,6 +1223,145 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**lm_kwargs,
) -> Union[Tuple, Gemma3ModelOutputWithPast]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
return Gemma3ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
@add_start_docstrings(
"""The Gemma3 model which consists of a vision backbone and a language model.""",
GEMMA3_START_DOCSTRING,
)
class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: Gemma3Config):
super().__init__(config)
self.model = Gemma3Model(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
@ -1260,80 +1419,34 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
```
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id with PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# mask out pad-token-ids in labels for BC
if labels is not None and self.pad_token_id in labels:
logger.warning_once(
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
)
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
outputs: CausalLMOutputWithPast = self.language_model(
attention_mask=causal_mask,
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs.logits
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
@ -1356,13 +1469,17 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -1381,7 +1498,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -1401,15 +1518,73 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self._update_causal_mask(
causal_mask = self.model._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs
def tie_weights(self):
return self.language_model.tie_weights()
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["Gemma3PreTrainedModel", "Gemma3TextModel", "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration"]
__all__ = [
"Gemma3PreTrainedModel",
"Gemma3TextModel",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3Model",
]

View File

@ -26,11 +26,12 @@ import torch.utils.checkpoint
from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
@ -50,7 +51,12 @@ from ..gemma2.modeling_gemma2 import (
apply_rotary_pos_emb,
eager_attention_forward,
)
from ..paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from ..paligemma.modeling_paligemma import (
PaligemmaCausalLMOutputWithPast,
PaliGemmaForConditionalGeneration,
PaliGemmaModel,
PaligemmaModelOutputWithPast,
)
from ..siglip import SiglipVisionConfig
@ -302,43 +308,13 @@ class Gemma3Config(PretrainedConfig):
@dataclass
class Gemma3CausalLMOutputWithPast(ModelOutput):
"""
Base class for Gemma3 causal language model (or autoregressive) outputs.
class Gemma3ModelOutputWithPast(PaligemmaModelOutputWithPast):
pass
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class Gemma3CausalLMOutputWithPast(PaligemmaCausalLMOutputWithPast):
pass
class Gemma3TextScaledWordEmbedding(nn.Embedding):
@ -545,7 +521,7 @@ GEMMA3_START_DOCSTRING = None
class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
base_model_prefix = "language_model"
base_model_prefix = ""
_no_split_modules = [
"Gemma3DecoderLayer",
"SiglipVisionEmbeddings",
@ -755,10 +731,7 @@ class Gemma3MultiModalProjector(nn.Module):
return projected_vision_outputs.type_as(vision_outputs)
class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
def tie_weights(self):
return self.language_model.tie_weights()
class Gemma3Model(PaliGemmaModel):
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Projects the last hidden state from the vision model into language model space.
@ -844,11 +817,10 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
@ -859,6 +831,106 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**lm_kwargs,
) -> Union[Tuple, Gemma3ModelOutputWithPast]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
return Gemma3ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
@add_start_docstrings(
"""The Gemma3 model which consists of a vision backbone and a language model.""",
GEMMA3_START_DOCSTRING,
)
class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
@ -916,80 +988,34 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
```
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id with PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# mask out pad-token-ids in labels for BC
if labels is not None and self.pad_token_id in labels:
logger.warning_once(
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
)
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
outputs: CausalLMOutputWithPast = self.language_model(
attention_mask=causal_mask,
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs.logits
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
@ -1012,13 +1038,17 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -1037,7 +1067,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -1057,7 +1087,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self._update_causal_mask(
causal_mask = self.model._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
@ -1072,4 +1102,5 @@ __all__ = [
"Gemma3TextModel",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3Model",
]

View File

@ -28,11 +28,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import CausalLMOutputWithPast
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
@ -40,7 +38,7 @@ from ...utils import (
can_return_tuple,
replace_return_docstrings,
)
from ..auto import AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig
@ -545,7 +543,7 @@ class GotOcr2CausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@ -557,6 +555,39 @@ class GotOcr2CausalLMOutputWithPast(ModelOutput):
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class GotOcr2ModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for GotOcr2 outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
GOT_OCR2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@ -575,14 +606,13 @@ GOT_OCR2_START_DOCSTRING = r"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
"The bare GotOcr2 Model outputting raw hidden-states without any specific head on top.",
GOT_OCR2_START_DOCSTRING,
)
class GotOcr2PreTrainedModel(PreTrainedModel):
config_class = GotOcr2Config
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["GotOcr2VisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@ -680,23 +710,18 @@ GOT_OCR2_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"""The GOT_OCR2 model which consists of a vision backbone and a language model.""",
"""The GotOcr2 model which consists of a vision backbone and a language model, without a language modeling head.""",
GOT_OCR2_START_DOCSTRING,
)
class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
class GotOcr2Model(GotOcr2PreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: GotOcr2Config):
super().__init__(config)
self.vision_tower = GotOcr2VisionEncoder(config.vision_config)
self.multi_modal_projector = GotOcr2MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = config.pad_token_id
self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@ -705,18 +730,6 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -732,13 +745,124 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
image_outputs = self.vision_tower(pixel_values).last_hidden_state
return self.multi_modal_projector(image_outputs)
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, GotOcr2ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
)
output = GotOcr2ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""The GOT_OCR2 model which consists of a vision backbone and a language model.""",
GOT_OCR2_START_DOCSTRING,
)
class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: GotOcr2Config):
super().__init__(config)
self.model = GotOcr2Model(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -747,9 +871,10 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> GotOcr2CausalLMOutputWithPast:
) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -794,37 +919,15 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
"You should keep in mind what features from the module should be used, especially
when you're planning to sell a template."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs: CausalLMOutputWithPast = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@ -832,29 +935,18 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
)
logits = outputs.logits
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return GotOcr2CausalLMOutputWithPast(
loss=loss,
@ -862,7 +954,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -878,7 +970,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -895,5 +987,60 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
return model_inputs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["GotOcr2PreTrainedModel", "GotOcr2ForConditionalGeneration"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["GotOcr2PreTrainedModel", "GotOcr2Model", "GotOcr2ForConditionalGeneration"]

View File

@ -14,16 +14,17 @@
# limitations under the License.
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaModel,
LlavaModelOutputWithPast,
LlavaPreTrainedModel,
)
from transformers.models.sam.modeling_sam import SamMLPBlock, SamVisionAttention, SamVisionEncoder, SamVisionLayer
@ -36,7 +37,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ..auto import CONFIG_MAPPING, AutoConfig, AutoModelForCausalLM
from ..auto import CONFIG_MAPPING, AutoConfig
if is_vision_available():
@ -278,6 +279,10 @@ class GotOcr2CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
class GotOcr2ModelOutputWithPast(LlavaModelOutputWithPast):
pass
class GotOcr2PreTrainedModel(LlavaPreTrainedModel):
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
@ -368,22 +373,11 @@ GOT_OCR2_INPUTS_DOCSTRING = r"""
"""
class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
class GotOcr2Model(LlavaModel):
def __init__(self, config: GotOcr2Config):
super().__init__(config)
self.vision_tower = GotOcr2VisionEncoder(config.vision_config)
self.multi_modal_projector = GotOcr2MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = config.pad_token_id
self.post_init()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -399,13 +393,81 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
image_outputs = self.vision_tower(pixel_values).last_hidden_state
return self.multi_modal_projector(image_outputs)
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, GotOcr2ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
)
output = GotOcr2ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -414,9 +476,10 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> GotOcr2CausalLMOutputWithPast:
) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -461,37 +524,15 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
"You should keep in mind what features from the module should be used, especially
when you're planning to sell a template."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs: CausalLMOutputWithPast = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@ -499,29 +540,18 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
)
logits = outputs.logits
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return GotOcr2CausalLMOutputWithPast(
loss=loss,
@ -529,7 +559,7 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
@ -537,5 +567,6 @@ __all__ = [
"GotOcr2VisionConfig",
"GotOcr2Config",
"GotOcr2PreTrainedModel",
"GotOcr2Model",
"GotOcr2ForConditionalGeneration",
]

View File

@ -28,7 +28,7 @@ from torch import nn
import transformers.models.jamba.modeling_jamba as modeling_jamba
from transformers.activations import ACT2FN
from ...cache_utils import Cache, StaticCache
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_layers import GradientCheckpointingLayer
@ -1511,10 +1511,10 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
@ -1525,7 +1525,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (

View File

@ -41,7 +41,7 @@ from ...utils import (
replace_return_docstrings,
torch_int,
)
from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from .configuration_instructblip import InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig
@ -315,6 +315,9 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
config_class = InstructBlipConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
_no_split_modules = [
"InstructBlipQFormerEmbeddings",
@ -339,7 +342,7 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
elif isinstance(module, InstructBlipVisionEmbeddings):
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
elif isinstance(module, InstructBlipForConditionalGeneration):
elif isinstance(module, (InstructBlipForConditionalGeneration, InstructBlipModel)):
module.query_tokens.data.zero_()
@ -1274,6 +1277,156 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
)
@add_start_docstrings(
"""
InstructBLIP base Model consisting of language model, qformer and vision encoder.
""",
INSTRUCTBLIP_START_DOCSTRING,
)
class InstructBlipModel(InstructBlipPreTrainedModel):
main_input_name = "pixel_values"
_keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
def __init__(self, config: InstructBlipConfig):
super().__init__(config)
self.vision_model = InstructBlipVisionModel(config.vision_config)
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
self.qformer = InstructBlipQFormerModel(config.qformer_config)
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
self.language_model = AutoModel.from_config(config.text_config)
if self.language_model._no_split_modules is not None:
self._no_split_modules.extend(self.language_model._no_split_modules)
if self.language_model._keep_in_fp32_modules is not None:
self._keep_in_fp32_modules.extend(self.language_model._keep_in_fp32_modules)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
def _preprocess_accelerate(self):
r"""
Some pre-processing hacks to make the model `accelerate` compatible. Check
https://github.com/huggingface/transformers/pull/21707 for more details.
"""
hf_device_map = self.hf_device_map
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
# warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
logger.warning(
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
" Please pass a `device_map` that contains `language_model` to remove this warning."
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
" more details on creating a `device_map` for large models.",
)
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
@add_start_docstrings_to_model_forward(INSTRUCTBLIP_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: torch.FloatTensor,
qformer_input_ids: torch.FloatTensor,
qformer_attention_mask: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# step 1: forward the images through the vision encoder,
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
query_output = query_outputs[0][:, : query_tokens.size(1), :]
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
if self.config.use_decoder_only_language_model:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
else:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
if not return_dict:
return (vision_outputs, query_outputs, outputs)
return InstructBlipForConditionalGenerationModelOutput(
vision_outputs=vision_outputs,
qformer_outputs=query_outputs,
language_model_outputs=outputs,
)
@add_start_docstrings(
"""
InstructBLIP Model for generating text given an image and an optional text prompt. The model consists of a vision
@ -1336,11 +1489,13 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
def get_decoder(self):
return self.language_model.get_decoder()
# Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._tie_weights
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
# Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._preprocess_accelerate
def _preprocess_accelerate(self):
r"""
Some pre-processing hacks to make the model `accelerate` compatible. Check
@ -1645,6 +1800,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
__all__ = [
"InstructBlipQFormerModel",
"InstructBlipPreTrainedModel",
"InstructBlipModel",
"InstructBlipForConditionalGeneration",
"InstructBlipVisionModel",
]

View File

@ -45,7 +45,7 @@ from ...utils import (
replace_return_docstrings,
torch_int,
)
from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from .configuration_instructblipvideo import (
InstructBlipVideoConfig,
InstructBlipVideoQFormerConfig,
@ -945,6 +945,9 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
config_class = InstructBlipVideoConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
_no_split_modules = [
"InstructBlipVideoQFormerEmbeddings",
@ -969,7 +972,7 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
elif isinstance(module, InstructBlipVideoVisionEmbeddings):
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
elif isinstance(module, InstructBlipVideoForConditionalGeneration):
elif isinstance(module, (InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel)):
module.query_tokens.data.zero_()
@ -1269,6 +1272,166 @@ class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
)
@add_start_docstrings(
"""
InstructBlipVideo base Model consisting of language model, qformer and vision encoder.
""",
INSTRUCTBLIPVIDEO_START_DOCSTRING,
)
class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
main_input_name = "pixel_values"
_keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
def __init__(self, config: InstructBlipVideoConfig):
super().__init__(config)
self.vision_model = InstructBlipVideoVisionModel(config.vision_config)
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
self.qformer = InstructBlipVideoQFormerModel(config.qformer_config)
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
self.language_model = AutoModel.from_config(config.text_config)
if self.language_model._no_split_modules is not None:
self._no_split_modules.extend(self.language_model._no_split_modules)
if self.language_model._keep_in_fp32_modules is not None:
self._keep_in_fp32_modules.extend(self.language_model._keep_in_fp32_modules)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
def _preprocess_accelerate(self):
r"""
Some pre-processing hacks to make the model `accelerate` compatible. Check
https://github.com/huggingface/transformers/pull/21707 for more details.
"""
hf_device_map = self.hf_device_map
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
# warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
logger.warning(
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
" Please pass a `device_map` that contains `language_model` to remove this warning."
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
" more details on creating a `device_map` for large models.",
)
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: torch.FloatTensor,
qformer_input_ids: torch.FloatTensor,
qformer_attention_mask: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# step 1: forward the images through the vision encoder,
# we process in a batched way, later unbatch it back (video has frames=4 always)
batch_size, frames, channel, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
query_output = query_outputs[0][:, : query_tokens.size(1), :]
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
if self.config.use_decoder_only_language_model:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
else:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
if not return_dict:
return (vision_outputs, query_outputs, outputs)
return InstructBlipVideoForConditionalGenerationModelOutput(
vision_outputs=vision_outputs,
qformer_outputs=query_outputs,
language_model_outputs=outputs,
)
@add_start_docstrings(
"""
InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision
@ -1682,5 +1845,6 @@ __all__ = [
"InstructBlipVideoVisionModel",
"InstructBlipVideoPreTrainedModel",
"InstructBlipVideoQFormerModel",
"InstructBlipVideoModel",
"InstructBlipVideoForConditionalGeneration",
]

View File

@ -27,6 +27,7 @@ from transformers.models.instructblip.configuration_instructblip import (
from transformers.models.instructblip.modeling_instructblip import (
InstructBlipForConditionalGeneration,
InstructBlipForConditionalGenerationModelOutput,
InstructBlipModel,
InstructBlipPreTrainedModel,
InstructBlipQFormerModel,
InstructBlipVisionModel,
@ -34,7 +35,7 @@ from transformers.models.instructblip.modeling_instructblip import (
from ...configuration_utils import PretrainedConfig
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from ...utils import logging
from ...utils import add_start_docstrings_to_model_forward, logging
from ..auto import CONFIG_MAPPING, AutoConfig
@ -191,6 +192,110 @@ class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForCondit
pass
INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = None
class InstructBlipVideoModel(InstructBlipModel):
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: torch.FloatTensor,
qformer_input_ids: torch.FloatTensor,
qformer_attention_mask: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# step 1: forward the images through the vision encoder,
# we process in a batched way, later unbatch it back (video has frames=4 always)
batch_size, frames, channel, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
query_output = query_outputs[0][:, : query_tokens.size(1), :]
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
if self.config.use_decoder_only_language_model:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
else:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
if not return_dict:
return (vision_outputs, query_outputs, outputs)
return InstructBlipVideoForConditionalGenerationModelOutput(
vision_outputs=vision_outputs,
qformer_outputs=query_outputs,
language_model_outputs=outputs,
)
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
def forward(
self,
@ -508,5 +613,6 @@ __all__ = [
"InstructBlipVideoVisionModel",
"InstructBlipVideoPreTrainedModel",
"InstructBlipVideoQFormerModel",
"InstructBlipVideoModel",
"InstructBlipVideoForConditionalGeneration",
]

View File

@ -31,7 +31,7 @@ from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
@ -45,7 +45,7 @@ from ...utils import (
replace_return_docstrings,
torch_int,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_internvl import InternVLConfig, InternVLVisionConfig
@ -608,14 +608,13 @@ INTERNVL_START_DOCSTRING = r"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
"The bare InternVL Model outputting raw hidden-states without any specific head on top.",
INTERNVL_START_DOCSTRING,
)
class InternVLPreTrainedModel(PreTrainedModel):
config_class = InternVLConfig
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["InternVLVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@ -654,15 +653,13 @@ class InternVLMultiModalProjector(nn.Module):
@dataclass
class InternVLCausalLMOutputWithPast(ModelOutput):
class InternVLModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for InternVL causal language model (or autoregressive) outputs.
Base class for InternVL outputs, with hidden states and attentions.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
@ -681,15 +678,10 @@ class InternVLCausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
@ -771,23 +763,18 @@ INTERNVL_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"""The INTERNVL model which consists of a vision backbone and a language model.""",
"""The InternVL model which consists of a vision backbone and a language model, without a language modeling head.""",
INTERNVL_START_DOCSTRING,
)
class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin):
class InternVLModel(InternVLPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: InternVLConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = InternVLMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@ -796,18 +783,6 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -853,12 +828,221 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
return vision_features
@add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, InternVLModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = InternVLModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
"""Perform pixel shuffle downsampling on vision features.
Args:
vision_features (`torch.Tensor`):
Input tensor of shape (batch_size, width, height, channels).
scale_factor (`float`, *optional*, defaults to `0.5`):
Factor by which to downsample. Default is 0.5, which halves the dimensions.
Returns:
vision_features (`torch.Tensor`):
Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
"""
batch_size, width, height, channels = vision_features.size()
if height % scale_factor != 0 or width % scale_factor != 0:
raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.")
# Reshape to allow downsampling
vision_features = vision_features.view(
batch_size, width, int(height * scale_factor), int(channels / scale_factor)
)
# Permute dimensions to align downsampled axis correctly
vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
# Reshape to achieve final downsampled dimensions
vision_features = vision_features.view(
batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2))
)
# Swap height and width back for proper orientation
vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
return vision_features
@dataclass
class InternVLCausalLMOutputWithPast(ModelOutput):
"""
Base class for InternVL causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
@add_start_docstrings(
"""The INTERNVL model which consists of a vision backbone and a language model.""",
INTERNVL_START_DOCSTRING,
)
class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: InternVLConfig):
super().__init__(config)
self.model = InternVLModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=InternVLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -872,7 +1056,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: Optional[torch.Tensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, InternVLCausalLMOutputWithPast]:
r"""
@ -925,7 +1109,6 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
>>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
The images depict the Statue of Liberty and the Golden Gate Bridge.
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -940,73 +1123,32 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
image_sizes=image_sizes,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return InternVLCausalLMOutputWithPast(
loss=loss,
@ -1014,7 +1156,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -1030,7 +1172,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -1047,45 +1189,66 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
return model_inputs
def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
"""Perform pixel shuffle downsampling on vision features.
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
vision_features (`torch.Tensor`):
Input tensor of shape (batch_size, width, height, channels).
scale_factor (`float`, *optional*, defaults to `0.5`):
Factor by which to downsample. Default is 0.5, which halves the dimensions.
Returns:
vision_features (`torch.Tensor`):
Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
batch_size, width, height, channels = vision_features.size()
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if height % scale_factor != 0 or width % scale_factor != 0:
raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.")
# Reshape to allow downsampling
vision_features = vision_features.view(
batch_size, width, int(height * scale_factor), int(channels / scale_factor)
)
# Permute dimensions to align downsampled axis correctly
vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
# Reshape to achieve final downsampled dimensions
vision_features = vision_features.view(
batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2))
)
# Swap height and width back for proper orientation
vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
return vision_features
return causal_mask
__all__ = [
"InternVLVisionPreTrainedModel",
"InternVLVisionModel",
"InternVLPreTrainedModel",
"InternVLModel",
"InternVLForConditionalGeneration",
]

View File

@ -39,7 +39,12 @@ from ...utils import (
from ..clip.modeling_clip import CLIPMLP
from ..janus.modeling_janus import JanusVisionAttention
from ..llama.modeling_llama import LlamaRMSNorm
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaPreTrainedModel
from ..llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaModel,
LlavaPreTrainedModel,
)
from .configuration_internvl import InternVLConfig, InternVLVisionConfig
@ -573,11 +578,7 @@ class InternVLMultiModalProjector(nn.Module):
return hidden_states
class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
class InternVLForConditionalGeneration(LlavaForConditionalGeneration):
class InternVLModel(LlavaModel):
def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
"""Perform pixel shuffle downsampling on vision features.
@ -658,6 +659,13 @@ class InternVLForConditionalGeneration(LlavaForConditionalGeneration):
return vision_features
class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
class InternVLForConditionalGeneration(LlavaForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=InternVLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(**super_kwargs):
@ -718,5 +726,6 @@ __all__ = [
"InternVLVisionPreTrainedModel",
"InternVLVisionModel",
"InternVLPreTrainedModel",
"InternVLModel",
"InternVLForConditionalGeneration",
]

View File

@ -23,16 +23,17 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_llava import LlavaConfig
@ -44,6 +45,39 @@ _CONFIG_FOR_DOC = "LlavaConfig"
_CHECKPOINT_FOR_DOC = "llava-hf/llava-1.5-7b-hf"
@dataclass
class LlavaModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for Llava outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class LlavaCausalLMOutputWithPast(ModelOutput):
"""
@ -72,7 +106,7 @@ class LlavaCausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@ -124,14 +158,13 @@ LLAVA_START_DOCSTRING = r"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
"The bare Llava Model outputting raw hidden-states without any specific head on top.",
LLAVA_START_DOCSTRING,
)
class LlavaPreTrainedModel(PreTrainedModel):
config_class = LlavaConfig
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@ -149,6 +182,9 @@ class LlavaPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
LLAVA_INPUTS_DOCSTRING = r"""
@ -229,23 +265,18 @@ LLAVA_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"""The LLAVA model which consists of a vision backbone and a language model.""",
"""The Llava model which consists of a vision backbone and a language model, without a language modeling head.""",
LLAVA_START_DOCSTRING,
)
class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
class LlavaModel(LlavaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: LlavaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@ -254,18 +285,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -277,7 +296,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
@ -312,12 +331,146 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, LlavaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = LlavaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""The LLAVA model which consists of a vision backbone and a language model.""",
LLAVA_START_DOCSTRING,
)
class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: LlavaConfig):
super().__init__(config)
self.model = LlavaModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -331,7 +484,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: Optional[torch.Tensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
r"""
@ -371,7 +524,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -386,73 +538,32 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
image_sizes=image_sizes,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return LlavaCausalLMOutputWithPast(
loss=loss,
@ -460,7 +571,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -476,7 +587,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -493,5 +604,61 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
return model_inputs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel", "LlavaModel"]

View File

@ -26,16 +26,17 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
from ...modeling_outputs import ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_llava_next import LlavaNextConfig
@ -151,6 +152,39 @@ def unpad_image(tensor, original_size):
return unpadded_tensor
@dataclass
class LlavaNextModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for Llava outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class LlavaNextCausalLMOutputWithPast(ModelOutput):
"""
@ -237,9 +271,9 @@ LLAVA_NEXT_START_DOCSTRING = r"""
)
class LlavaNextPreTrainedModel(PreTrainedModel):
config_class = LlavaNextConfig
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaNextVisionAttention"]
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@ -254,7 +288,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, LlavaNextForConditionalGeneration):
elif isinstance(module, LlavaNextModel):
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
module.image_newline.data.normal_(mean=0.0, std=embed_std)
@ -340,10 +374,12 @@ LLAVA_NEXT_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
"""The Llava-Next model which consists of a vision backbone and a language model without language modeling head.""",
LLAVA_NEXT_START_DOCSTRING,
)
class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin):
class LlavaNextModel(LlavaNextPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: LlavaNextConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
@ -353,48 +389,16 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.language_model = AutoModel.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.post_init()
@property
def padding_side(self):
return self._padding_side
@padding_side.setter
def padding_side(self, padding_side: str):
if padding_side not in ["left", "right"]:
raise ValueError(f"{padding_side} is not `left` or `right`.")
self._padding_side = padding_side
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
def get_decoder(self):
return self.language_model.get_decoder()
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@ -524,12 +528,152 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
@add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
) -> Union[Tuple, LlavaNextModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = LlavaNextModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
LLAVA_NEXT_START_DOCSTRING,
)
class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^image_newline": "model.image_newline",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: LlavaNextConfig):
super().__init__(config)
self.model = LlavaNextModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -597,45 +741,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
outputs = self.model(
input_ids,
pixel_values=pixel_values,
image_sizes=image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@ -645,33 +756,17 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return LlavaNextCausalLMOutputWithPast(
loss=loss,
@ -679,7 +774,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -696,7 +791,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -714,5 +809,61 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
return model_inputs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["LlavaNextForConditionalGeneration", "LlavaNextPreTrainedModel"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["LlavaNextForConditionalGeneration", "LlavaNextPreTrainedModel", "LlavaNextModel"]

View File

@ -30,24 +30,65 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
from ...modeling_outputs import ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_llava_next_video import LlavaNextVideoConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlavaNextVideoConfig"
@dataclass
class LlavaNextVideoModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for Llava outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
video_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
video_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class LlavaNextVideoCausalLMOutputWithPast(ModelOutput):
"""
@ -167,6 +208,34 @@ LLAVA_NEXT_VIDEO_START_DOCSTRING = r"""
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAVA_NEXT_VIDEO_START_DOCSTRING,
)
class LlavaNextVideoPreTrainedModel(PreTrainedModel):
config_class = LlavaNextVideoConfig
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, LlavaNextVideoModel):
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
module.image_newline.data.normal_(mean=0.0, std=embed_std)
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
@ -355,38 +424,12 @@ LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
"""The Llava-Next model which consists of a vision backbone and a language model without language modeling head.""",
LLAVA_NEXT_VIDEO_START_DOCSTRING,
)
class LlavaNextVideoPreTrainedModel(PreTrainedModel):
config_class = LlavaNextVideoConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaNextVideoVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, LlavaNextVideoForConditionalGeneration):
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
module.image_newline.data.normal_(mean=0.0, std=embed_std)
@add_start_docstrings(
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
LLAVA_NEXT_VIDEO_START_DOCSTRING,
)
class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, GenerationMixin):
def __init__(
self,
config: LlavaNextVideoConfig,
@ -399,43 +442,17 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.language_model = AutoModel.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.vision_resampler = LlavaNextVideoPooler(config)
self.post_init()
@property
def padding_side(self):
return self._padding_side
@padding_side.setter
def padding_side(self, padding_side: str):
if padding_side not in ["left", "right"]:
raise ValueError(f"{padding_side} is not `left` or `right`.")
self._padding_side = padding_side
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@ -564,13 +581,222 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
@add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_values_videos: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
) -> Union[Tuple, LlavaNextVideoModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self.vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
self.vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
self.vision_feature_select_strategy,
image_newline=self.image_newline,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
video_features = self.get_video_features(
pixel_values_videos,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
video_features = [feature.flatten(0, 1) for feature in video_features]
video_feature_lens = [feature.size(0) for feature in video_features]
video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_features.shape[0]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = LlavaNextVideoModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
)
return output if return_dict else output.to_tuple()
def get_video_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
):
"""
Obtains video last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
The tensors corresponding to the input video.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
and are of shape `(num_videos, video_length, embed_dim)`).
"""
batch_size, frames, channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_video_features = video_features.hidden_states[vision_feature_layer]
else:
hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_video_features = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default":
selected_video_features = selected_video_features[:, 1:]
elif vision_feature_select_strategy == "full":
selected_video_features = selected_video_features
# Same as image features except that video has pooling layer
video_features = self.vision_resampler(selected_video_features)
video_features = self.multi_modal_projector(video_features)
video_features = torch.split(video_features, frames, dim=0)
return video_features
@add_start_docstrings(
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
LLAVA_NEXT_VIDEO_START_DOCSTRING,
)
class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^image_newline": "model.image_newline",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: LlavaNextVideoConfig):
super().__init__(config)
self.model = LlavaNextVideoModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_values_videos: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -662,117 +888,47 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"USER: \nWhat's the content of the image? ASSISTANT: The image shows a red stop sign on a pole, with a traditional Chinese archway (...)"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self.vision_feature_layer = (
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
self.vision_feature_select_strategy = (
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
self.vision_feature_select_strategy,
image_newline=self.image_newline,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
video_features = self.get_video_features(
pixel_values_videos,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
video_features = [feature.flatten(0, 1) for feature in video_features]
video_feature_lens = [feature.size(0) for feature in video_features]
video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_features.shape[0]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
outputs = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
image_sizes=image_sizes,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return LlavaNextVideoCausalLMOutputWithPast(
loss=loss,
@ -780,8 +936,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
image_hidden_states=outputs.image_hidden_states,
video_hidden_states=outputs.video_hidden_states,
)
def prepare_inputs_for_generation(
@ -799,7 +955,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
):
# Overwritten -- extra custom processing
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -818,51 +974,60 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
return model_inputs
def get_video_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Obtains video last hidden states from the vision tower and apply multimodal projection.
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
The tensors corresponding to the input video.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
and are of shape `(num_videos, video_length, embed_dim)`).
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
batch_size, frames, channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_video_features = video_features.hidden_states[vision_feature_layer]
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_video_features = torch.cat(hs_pool, dim=-1)
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if vision_feature_select_strategy == "default":
selected_video_features = selected_video_features[:, 1:]
elif vision_feature_select_strategy == "full":
selected_video_features = selected_video_features
# Same as image features except that video has pooling layer
video_features = self.vision_resampler(selected_video_features)
video_features = self.multi_modal_projector(video_features)
video_features = torch.split(video_features, frames, dim=0)
return video_features
return causal_mask
__all__ = ["LlavaNextVideoForConditionalGeneration", "LlavaNextVideoPreTrainedModel"]
__all__ = ["LlavaNextVideoForConditionalGeneration", "LlavaNextVideoModel", "LlavaNextVideoPreTrainedModel"]

View File

@ -24,15 +24,19 @@ from torch import nn
from transformers.models.llava_next.modeling_llava_next import (
LlavaNextCausalLMOutputWithPast,
LlavaNextForConditionalGeneration,
LlavaNextModel,
LlavaNextModelOutputWithPast,
LlavaNextMultiModalProjector,
LlavaNextPreTrainedModel,
image_size_to_num_patches,
)
from ...configuration_utils import PretrainedConfig
from ...utils import (
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from ..auto import CONFIG_MAPPING, AutoConfig
@ -40,6 +44,9 @@ from ..auto import CONFIG_MAPPING, AutoConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlavaNextVideoConfig"
class LlavaNextVideoConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlavaNextVideoForConditionalGeneration`]. It is used to instantiate an
@ -182,6 +189,17 @@ class LlavaNextVideoConfig(PretrainedConfig):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
@dataclass
class LlavaNextVideoModelOutputWithPast(LlavaNextModelOutputWithPast):
"""
video_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
video_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class LlavaNextVideoCausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast):
"""
@ -231,20 +249,7 @@ class LlavaNextVideoMultiModalProjector(LlavaNextMultiModalProjector):
pass
class LlavaNextVideoPreTrainedModel(LlavaNextPreTrainedModel):
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, LlavaNextVideoForConditionalGeneration):
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
module.image_newline.data.normal_(mean=0.0, std=embed_std)
class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
class LlavaNextVideoModel(LlavaNextModel):
def __init__(self, config: LlavaNextVideoConfig, **super_kwargs):
super().__init__(config, **super_kwargs)
self.vision_resampler = LlavaNextVideoPooler(config)
@ -358,9 +363,209 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_values_videos: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
) -> Union[Tuple, LlavaNextVideoModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self.vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
self.vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
self.vision_feature_select_strategy,
image_newline=self.image_newline,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
video_features = self.get_video_features(
pixel_values_videos,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
video_features = [feature.flatten(0, 1) for feature in video_features]
video_feature_lens = [feature.size(0) for feature in video_features]
video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_features.shape[0]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = LlavaNextVideoModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
)
return output if return_dict else output.to_tuple()
LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`LlavaNextVideoImageProcessor.__call__`] for details. [`LlavaProcessor`] uses
[`LlavaNextVideoImageProcessor`] for processing images.
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
The sizes of the images in the batch, being (height, width) for each image.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
If `"full"`, the full vision features are used.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_values_videos: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -452,117 +657,47 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"USER: \nWhat's the content of the image? ASSISTANT: The image shows a red stop sign on a pole, with a traditional Chinese archway (...)"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self.vision_feature_layer = (
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
self.vision_feature_select_strategy = (
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
self.vision_feature_select_strategy,
image_newline=self.image_newline,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
video_features = self.get_video_features(
pixel_values_videos,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
video_features = [feature.flatten(0, 1) for feature in video_features]
video_feature_lens = [feature.size(0) for feature in video_features]
video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_features.shape[0]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
outputs = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
image_sizes=image_sizes,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return LlavaNextVideoCausalLMOutputWithPast(
loss=loss,
@ -570,8 +705,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
image_hidden_states=outputs.image_hidden_states,
video_hidden_states=outputs.video_hidden_states,
)
def prepare_inputs_for_generation(
@ -589,7 +724,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
):
# Overwritten -- extra custom processing
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -609,4 +744,9 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
return model_inputs
__all__ = ["LlavaNextVideoConfig", "LlavaNextVideoForConditionalGeneration", "LlavaNextVideoPreTrainedModel"]
__all__ = [
"LlavaNextVideoConfig",
"LlavaNextVideoForConditionalGeneration",
"LlavaNextVideoModel",
"LlavaNextVideoPreTrainedModel", # noqa: F822
]

View File

@ -4,9 +4,25 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_llava_onevision.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import torch
from ...image_processing_utils import BatchFeature, get_patch_output_size, select_best_resolution
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
@ -28,11 +44,9 @@ from ...image_utils import (
make_flat_list_of_images,
)
from ...processing_utils import Unpack
from ...utils import TensorType, add_start_docstrings, is_torch_available, is_torchvision_v2_available
from ...utils import TensorType, add_start_docstrings, is_torchvision_v2_available
if is_torch_available():
import torch
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:

View File

@ -1,3 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/llava_onevision/modular_llava_onevision.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_llava_onevision.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
@ -12,7 +18,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Llava-Onevision model."""
import math
from dataclasses import dataclass
@ -20,140 +25,66 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
from ...modeling_outputs import ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
can_return_tuple,
is_torchdynamo_compiling,
logging,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_llava_onevision import LlavaOnevisionConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlavaNextConfig"
# Copied from transformers.models.llava_next.modeling_llava_next.get_anyres_image_grid_shape
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
@dataclass
class LlavaOnevisionModelOutputWithPast(BaseModelOutputWithPast):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Base class for Llava outputs, with hidden states and attentions.
Args:
image_size (`tuple`):
The size of the input image in the format (width, height).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Returns:
tuple: The shape of the image patch grid in the format (width, height).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
video_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(
f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
)
image_size = image_size.tolist()
image_hidden_states: Optional[torch.FloatTensor] = None
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
# Copied from transformers.models.llava_next.modeling_llava_next.image_size_to_num_patches
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
"""
Calculate the number of patches after the preprocessing for images of any resolution.
Args:
image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
The size of the input image in the format (height, width). ?
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
int: the number of patches
"""
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
image_size = image_size.tolist()
best_resolution = select_best_resolution(image_size, grid_pinpoints)
height, width = best_resolution
num_patches = 0
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
num_patches += 1
# add the base patch
num_patches += 1
return num_patches
# Copied from transformers.models.llava_next.modeling_llava_next.unpad_image
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
if not isinstance(original_size, (list, tuple)):
if not isinstance(original_size, (torch.Tensor, np.ndarray)):
raise TypeError(
f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
)
original_size = original_size.tolist()
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
else:
scale_factor = current_height / original_height
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
return unpadded_tensor
video_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
# Copied from transformers.models.llava_next_video.modeling_llava_next_video.LlavaNextVideoCausalLMOutputWithPast with LlavaNextVideo->LlavaOnevision
class LlavaOnevisionCausalLMOutputWithPast(ModelOutput):
"""
Base class for LlavaOnevision causal language model (or autoregressive) outputs.
@ -195,10 +126,44 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput):
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
video_hidden_states: Optional[torch.FloatTensor] = None
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaOnevision
class LlavaOnevisionPooler(nn.Module):
def __init__(self, config):
super().__init__()
mode = config.spatial_pool_mode
stride = config.spatial_pool_stride
out_channels = getattr(config, "spatial_pool_out_channels", config.vision_config.hidden_size)
self.image_size = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
if mode == "average":
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride)
elif mode == "max":
self.pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
elif mode == "conv":
self.pool = nn.Conv2d(
in_channels=config.vision_config.hidden_size,
out_channels=out_channels,
kernel_size=stride,
stride=stride,
)
else:
raise ValueError(f"Unknown pooling mode: {mode}. Has to be one of [`average`, `max`, `conv`]")
def forward(self, image_features):
ori_width = int(math.sqrt(image_features.shape[1] * self.image_size // self.image_size))
ori_height = int(ori_width * self.image_size // self.image_size)
batch_size, _, dim = image_features.shape
image_features_spatial = image_features.view(batch_size, ori_height, ori_height, dim).permute(0, 3, 1, 2)
image_features_spatial_pool = self.pool(image_features_spatial)
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
class LlavaOnevisionMultiModalProjector(nn.Module):
def __init__(self, config: LlavaOnevisionConfig):
super().__init__()
@ -231,40 +196,118 @@ LLAVA_ONEVISION_START_DOCSTRING = r"""
and behavior.
Parameters:
config ([`LlavaNextConfig`] or [`LlavaNextVisionConfig`]):
config ([`LlavaOnevisionConfig`] or [`LlavaOnevisionVisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare LLaVA-Onevision Model outputting raw hidden-states without any specific head on top.",
LLAVA_ONEVISION_START_DOCSTRING,
)
class LlavaOnevisionPreTrainedModel(PreTrainedModel):
config_class = LlavaOnevisionConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaOnevisionVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = True
_supports_sdpa = True
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextPreTrainedModel._init_weights with LlavaNext->LlavaOnevision
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
Args:
image_size (`tuple`):
The size of the input image in the format (width, height).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, LlavaOnevisionForConditionalGeneration):
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
module.image_newline.data.normal_(mean=0.0, std=embed_std)
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(
f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
)
image_size = image_size.tolist()
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
"""
Calculate the number of patches after the preprocessing for images of any resolution.
Args:
image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
The size of the input image in the format (height, width). ?
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
int: the number of patches
"""
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
image_size = image_size.tolist()
best_resolution = select_best_resolution(image_size, grid_pinpoints)
height, width = best_resolution
num_patches = 0
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
num_patches += 1
# add the base patch
num_patches += 1
return num_patches
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
if not isinstance(original_size, (list, tuple)):
if not isinstance(original_size, (torch.Tensor, np.ndarray)):
raise TypeError(
f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
)
original_size = original_size.tolist()
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
else:
scale_factor = current_height / original_height
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
return unpadded_tensor
LLAVA_ONEVISION_INPUTS_DOCSTRING = r"""
@ -279,16 +322,10 @@ LLAVA_ONEVISION_INPUTS_DOCSTRING = r"""
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`LlavaNextImageProcessor.__call__`] for details. [`LlavaProcessor`] uses
[`LlavaNextImageProcessor`] for processing images.
[`AutoImageProcessor`]. See [`LlavaOnevisionImageProcessor.__call__`] for details. [`LlavaProcessor`] uses
[`LlavaOnevisionImageProcessor`] for processing images.
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
The sizes of the images in the batch, being (height, width) for each image.
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, frames, num_channels, image_size, image_size)):
The tensors corresponding to the input videos. Pixel values can be obtained using
[`LlavaNextVideoProcessor`]. See [`LlavaNextVideoProcessor.__call__`] for details. [`LlavaProcessor`] uses
[`LlavaNextVideoProcessor`] for processing videos.
image_sizes_videos (`torch.LongTensor` of shape `(batch_size, frames, 2)`, *optional*):
The sizes of the videos in the batch, being (height, width) for each frame in the video.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
@ -335,8 +372,6 @@ LLAVA_ONEVISION_INPUTS_DOCSTRING = r"""
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
If `"full"`, the full vision features are used.
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
Aspect ratio used when processong image features. The default value is "anyres_max_9".
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
@ -356,11 +391,41 @@ LLAVA_ONEVISION_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"""The LLaVA-Onevision model which consists of a vision backbone and a language model.""",
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAVA_ONEVISION_START_DOCSTRING,
)
class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, GenerationMixin):
def __init__(self, config: LlavaOnevisionConfig):
class LlavaOnevisionPreTrainedModel(PreTrainedModel):
config_class = LlavaOnevisionConfig
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, LlavaOnevisionModel):
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
module.image_newline.data.normal_(mean=0.0, std=embed_std)
@add_start_docstrings(
"""The Llava-Next model which consists of a vision backbone and a language model without language modeling head.""",
LLAVA_ONEVISION_START_DOCSTRING,
)
class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
@ -369,36 +434,16 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.language_model = AutoModel.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
def get_decoder(self):
return self.language_model.get_decoder()
def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@ -465,20 +510,6 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
return image_features, feature_lens
def apply_pooling(self, image_features):
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
batch_frames, seq_len, dim = image_features.shape
image_features = image_features.view(batch_frames, height, width, -1)
image_features = image_features.permute(0, 3, 1, 2).contiguous()
height, width = image_features.shape[2:]
scaled_shape = [math.ceil(height / 2), math.ceil(width / 2)]
image_features = nn.functional.interpolate(image_features, size=scaled_shape, mode="bilinear")
image_features = image_features.permute(0, 2, 3, 1)
image_features = image_features.view(batch_frames, -1, dim)
return image_features
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -494,7 +525,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
@ -539,59 +570,13 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
def get_video_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
):
"""
Obtains video last hidden states from the vision tower, apply multimodal projection and pooling.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
The tensors corresponding to the input video.
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
and are of shape `(num_videos, video_length, embed_dim)`).
"""
batch_size, frames, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(batch_size * frames, channels, height, width)
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_video_feature = video_features.hidden_states[vision_feature_layer]
else:
hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_video_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default":
selected_video_feature = selected_video_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_video_feature = selected_video_feature
video_features = self.multi_modal_projector(selected_video_feature)
video_features = self.apply_pooling(video_features)
video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1)
return video_features
@add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
pixel_values_videos: torch.FloatTensor = None,
image_sizes_videos: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -600,62 +585,13 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_aspect_ratio: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
[`~LlavaOnevisionCausalLMOutputWithPast`] (if `return_dict=True`) or a `tuple`.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> import torch
>>> from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration
>>> model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype="float16", device_map="cuda:0")
>>> processor = LlavaOnevisionProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
>>> conversation = [
... {
... "role": "user",
... "content": [
... {"type": "text", "text": "What is shown in this image?"},
... {"type": "image"},
... ],
... },
... ]
>>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
>>> image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> raw_image = Image.open(requests.get(image_file, stream=True).raw)
>>> inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.float16)
>>> output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
>>> processor.batch_decode(output, skip_special_tokens=True)[0]
"user\n\nWhat is shown in this image?\nassistant\ncat"
```"""
) -> Union[Tuple, LlavaOnevisionModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -743,35 +679,247 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = LlavaOnevisionModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
)
return output if return_dict else output.to_tuple()
def get_video_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
):
"""
Obtains video last hidden states from the vision tower, apply multimodal projection and pooling.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
The tensors corresponding to the input video.
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
and are of shape `(num_videos, video_length, embed_dim)`).
"""
batch_size, frames, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(batch_size * frames, channels, height, width)
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_video_feature = video_features.hidden_states[vision_feature_layer]
else:
hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_video_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default":
selected_video_feature = selected_video_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_video_feature = selected_video_feature
video_features = self.multi_modal_projector(selected_video_feature)
video_features = self.apply_pooling(video_features)
video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1)
return video_features
def apply_pooling(self, image_features):
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
batch_frames, seq_len, dim = image_features.shape
image_features = image_features.view(batch_frames, height, width, -1)
image_features = image_features.permute(0, 3, 1, 2).contiguous()
height, width = image_features.shape[2:]
scaled_shape = [math.ceil(height / 2), math.ceil(width / 2)]
image_features = nn.functional.interpolate(image_features, size=scaled_shape, mode="bilinear")
image_features = image_features.permute(0, 2, 3, 1)
image_features = image_features.view(batch_frames, -1, dim)
return image_features
@add_start_docstrings(
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
LLAVA_ONEVISION_START_DOCSTRING,
)
class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^image_newline": "model.image_newline",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: LlavaOnevisionConfig):
super().__init__(config)
self.model = LlavaOnevisionModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
pixel_values_videos: torch.FloatTensor = None,
image_sizes_videos: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_aspect_ratio: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
[`~LlavaOnevisionCausalLMOutputWithPast`] (if `return_dict=True`) or a `tuple`.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> import torch
>>> from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration
>>> model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype="float16", device_map="cuda:0")
>>> processor = LlavaOnevisionProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
>>> conversation = [
... {
... "role": "user",
... "content": [
... {"type": "text", "text": "What is shown in this image?"},
... {"type": "image"},
... ],
... },
... ]
>>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
>>> image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> raw_image = Image.open(requests.get(image_file, stream=True).raw)
>>> inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.float16)
>>> output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
>>> processor.batch_decode(output, skip_special_tokens=True)[0]
"user\n\nWhat is shown in this image?\nassistant\ncat"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
vision_aspect_ratio = (
vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
)
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_sizes=image_sizes,
image_sizes_videos=image_sizes_videos,
vision_aspect_ratio=vision_aspect_ratio,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return LlavaOnevisionCausalLMOutputWithPast(
loss=loss,
@ -779,8 +927,8 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
image_hidden_states=outputs.image_hidden_states,
video_hidden_states=outputs.video_hidden_states,
)
def prepare_inputs_for_generation(
@ -799,7 +947,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -819,5 +967,60 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
return model_inputs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["LlavaOnevisionForConditionalGeneration", "LlavaOnevisionPreTrainedModel"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["LlavaOnevisionModel", "LlavaOnevisionForConditionalGeneration", "LlavaOnevisionPreTrainedModel"]

View File

@ -1,4 +1,34 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from transformers.models.llava_next.image_processing_llava_next_fast import LlavaNextImageProcessorFast
from transformers.models.llava_next_video.modeling_llava_next_video import (
LlavaNextVideoCausalLMOutputWithPast,
LlavaNextVideoForConditionalGeneration,
LlavaNextVideoModel,
LlavaNextVideoModelOutputWithPast,
LlavaNextVideoPreTrainedModel,
get_anyres_image_grid_shape,
unpad_image,
)
from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING
from ...image_utils import (
@ -6,11 +36,13 @@ from ...image_utils import (
OPENAI_CLIP_STD,
PILImageResampling,
)
from ...utils import add_start_docstrings, logging
from ...utils import add_start_docstrings, can_return_tuple, is_torchdynamo_compiling, logging
logger = logging.get_logger(__name__)
LLAVA_ONEVISION_INPUTS_DOCSTRING = None
@add_start_docstrings(
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
@ -42,4 +74,446 @@ class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast):
model_input_names = ["pixel_values_videos"]
__all__ = ["LlavaOnevisionImageProcessorFast"]
class LlavaOnevisionModelOutputWithPast(LlavaNextVideoModelOutputWithPast):
pass
class LlavaOnevisionCausalLMOutputWithPast(LlavaNextVideoCausalLMOutputWithPast):
pass
class LlavaOnevisionPreTrainedModel(LlavaNextVideoPreTrainedModel):
pass
class LlavaOnevisionModel(LlavaNextVideoModel):
def __init__(self, config):
super().__init__(config)
del self.vision_resampler
def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Args:
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
vision_aspect_ratio (`str`, *optional*, "anyres_max_9"):
Aspect ratio used when processong image features. The default value is "anyres_max_9".
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
token length of each image in image_features
"""
new_image_features = []
feature_lens = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
if height * width != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
max_num_patches = int(vision_aspect_ratio.strip("anyres_max_"))
channels, curr_height, curr_width = image_feature.shape
ratio = math.sqrt(curr_height * curr_width / (max_num_patches * height**2))
if ratio > 1.1:
image_feature = image_feature[None]
image_feature = nn.functional.interpolate(
image_feature, [int(curr_height // ratio), int(curr_width // ratio)], mode="bilinear"
)[0]
if image_newline is not None:
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device, image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if image_newline is not None:
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0))
image_features = torch.cat(new_image_features, dim=0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
return image_features, feature_lens
def apply_pooling(self, image_features):
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
batch_frames, seq_len, dim = image_features.shape
image_features = image_features.view(batch_frames, height, width, -1)
image_features = image_features.permute(0, 3, 1, 2).contiguous()
height, width = image_features.shape[2:]
scaled_shape = [math.ceil(height / 2), math.ceil(width / 2)]
image_features = nn.functional.interpolate(image_features, size=scaled_shape, mode="bilinear")
image_features = image_features.permute(0, 2, 3, 1)
image_features = image_features.view(batch_frames, -1, dim)
return image_features
def get_video_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
):
"""
Obtains video last hidden states from the vision tower, apply multimodal projection and pooling.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
The tensors corresponding to the input video.
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
and are of shape `(num_videos, video_length, embed_dim)`).
"""
batch_size, frames, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(batch_size * frames, channels, height, width)
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_video_feature = video_features.hidden_states[vision_feature_layer]
else:
hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_video_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default":
selected_video_feature = selected_video_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_video_feature = selected_video_feature
video_features = self.multi_modal_projector(selected_video_feature)
video_features = self.apply_pooling(video_features)
video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1)
return video_features
@add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
pixel_values_videos: torch.FloatTensor = None,
image_sizes_videos: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_aspect_ratio: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
) -> Union[Tuple, LlavaOnevisionModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
vision_aspect_ratio = (
vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# Images are processed with Anyres
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
image_newline=self.image_newline,
vision_aspect_ratio=vision_aspect_ratio,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# Video are simply embedded and further pooled to decrease seq len
if pixel_values_videos is not None:
video_features = self.get_video_features(
pixel_values_videos,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
image_newline = (
self.image_newline[None, None, :].repeat(video_features.shape[0], 1, 1).to(video_features.device)
)
video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1)
special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum()
n_video_features = video_features.shape[0]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = LlavaOnevisionModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
)
return output if return_dict else output.to_tuple()
class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGeneration):
@can_return_tuple
@add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
pixel_values_videos: torch.FloatTensor = None,
image_sizes_videos: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_aspect_ratio: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
[`~LlavaOnevisionCausalLMOutputWithPast`] (if `return_dict=True`) or a `tuple`.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> import torch
>>> from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration
>>> model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype="float16", device_map="cuda:0")
>>> processor = LlavaOnevisionProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
>>> conversation = [
... {
... "role": "user",
... "content": [
... {"type": "text", "text": "What is shown in this image?"},
... {"type": "image"},
... ],
... },
... ]
>>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
>>> image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> raw_image = Image.open(requests.get(image_file, stream=True).raw)
>>> inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.float16)
>>> output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
>>> processor.batch_decode(output, skip_special_tokens=True)[0]
"user\n\nWhat is shown in this image?\nassistant\ncat"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
vision_aspect_ratio = (
vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
)
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_sizes=image_sizes,
image_sizes_videos=image_sizes_videos,
vision_aspect_ratio=vision_aspect_ratio,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return LlavaOnevisionCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
video_hidden_states=outputs.video_hidden_states,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
image_sizes=None,
pixel_values_videos=None,
image_sizes_videos=None,
attention_mask=None,
cache_position=None,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
if cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
model_inputs["image_sizes"] = image_sizes
model_inputs["pixel_values_videos"] = pixel_values_videos
model_inputs["image_sizes_videos"] = image_sizes_videos
return model_inputs
__all__ = [
"LlavaOnevisionImageProcessorFast",
"LlavaOnevisionModel",
"LlavaOnevisionForConditionalGeneration",
"LlavaOnevisionPreTrainedModel",
]

View File

@ -28,19 +28,20 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_outputs import ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_mistral3 import Mistral3Config
_CONFIG_FOR_DOC = "Mistral3Config"
_CONFIG_FOR_DOC = "Mistra3Config"
@use_kernel_forward_from_hub("RMSNorm")
@ -156,7 +157,7 @@ class Mistral3CausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@ -168,6 +169,39 @@ class Mistral3CausalLMOutputWithPast(ModelOutput):
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class Mistral3ModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for Mistral3 outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
MISTRAL3_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@ -186,14 +220,13 @@ MISTRAL3_START_DOCSTRING = r"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
"The bare Mistral3 Model outputting raw hidden-states without any specific head on top.",
MISTRAL3_START_DOCSTRING,
)
class Mistral3PreTrainedModel(PreTrainedModel):
config_class = Mistral3Config
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["Mistral3VisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@ -202,12 +235,18 @@ class Mistral3PreTrainedModel(PreTrainedModel):
_supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of Mistral3 isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, Mistral3RMSNorm):
module.weight.data.fill_(1.0)
@ -290,23 +329,18 @@ MISTRAL3_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"""The MISTRAL3 model which consists of a vision backbone and a language model.""",
"""The Mistral3 model which consists of a vision backbone and a language model, without a language modeling head.""",
MISTRAL3_START_DOCSTRING,
)
class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin):
class Mistral3Model(Mistral3PreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: Mistral3Config):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = Mistral3MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@ -315,18 +349,6 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -364,64 +386,23 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
return image_features
@add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Mistral3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: Optional[torch.Tensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is the image?The image depicts two cats lying on a pink blanket."
```"""
) -> Union[Tuple, Mistral3ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -468,35 +449,153 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
output = Mistral3ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""The MISTRAL3 model which consists of a vision backbone and a language model.""",
MISTRAL3_START_DOCSTRING,
)
class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: Mistral3Config):
super().__init__(config)
self.model = Mistral3Model(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Mistral3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is the image?The image depicts two cats lying on a pink blanket."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
image_sizes=image_sizes,
**lm_kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return Mistral3CausalLMOutputWithPast(
loss=loss,
@ -504,7 +603,7 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -520,7 +619,7 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -537,5 +636,60 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
return model_inputs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["Mistral3Model", "Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"]

View File

@ -19,14 +19,29 @@ import torch
from torch import nn
from ...activations import ACT2FN
from ...utils import is_torchdynamo_compiling, logging
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaPreTrainedModel
from ...utils import (
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from ..llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaModel,
LlavaModelOutputWithPast,
LlavaPreTrainedModel,
)
from ..mistral.modeling_mistral import MistralRMSNorm
from .configuration_mistral3 import Mistral3Config
logger = logging.get_logger(__name__)
MISTRAL3_INPUTS_DOCSTRING = None
_CONFIG_FOR_DOC = "Mistra3Config"
class Mistral3RMSNorm(MistralRMSNorm):
pass
@ -100,19 +115,29 @@ class Mistral3CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
class Mistral3ModelOutputWithPast(LlavaModelOutputWithPast):
pass
class Mistral3PreTrainedModel(LlavaPreTrainedModel):
def _init_weights(self, module):
# important: this ported version of Mistral3 isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, Mistral3RMSNorm):
module.weight.data.fill_(1.0)
class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
class Mistral3Model(LlavaModel):
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -151,61 +176,21 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: Optional[torch.Tensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is the image?The image depicts two cats lying on a pink blanket."
```"""
) -> Union[Tuple, Mistral3ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -252,35 +237,110 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
output = Mistral3ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Mistral3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is the image?The image depicts two cats lying on a pink blanket."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
image_sizes=image_sizes,
**lm_kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return Mistral3CausalLMOutputWithPast(
loss=loss,
@ -288,11 +348,12 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
__all__ = [
"Mistral3Model",
"Mistral3PreTrainedModel", # noqa
"Mistral3ForConditionalGeneration",
]

View File

@ -32,6 +32,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
@ -1968,10 +1969,11 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
@add_start_docstrings(
"""The Mllama model which consists of a vision encoder and a language model.""",
"""The Mllama model which consists of a vision encoder and a language model without language modeling head.""",
MLLAMA_START_DOCSTRING,
)
class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
class MllamaModel(MllamaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
_supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
def __init__(self, config: MllamaConfig):
@ -1983,10 +1985,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.vision_model = MllamaVisionModel._from_config(config.vision_config)
self.language_model = MllamaForCausalLM._from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.language_model = MllamaTextModel._from_config(config.text_config)
self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
@ -2000,18 +1999,139 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
aspect_ratio_mask: Optional[torch.Tensor] = None,
aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and cross_attention_states is not None:
raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")
if pixel_values is not None:
if aspect_ratio_ids is None:
raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
# get vision tokens from vision model
vision_outputs = self.vision_model(
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
cross_attention_states = vision_outputs[0]
cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
-1, cross_attention_states.shape[-2], self.hidden_size
)
if cross_attention_mask is not None:
cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
cross_attention_mask,
num_vision_tokens=self.vision_model.num_patches,
dtype=self.dtype,
)
else:
full_text_row_masked_out_mask = None
if cross_attention_mask is not None and cache_position is not None:
cross_attention_mask = cross_attention_mask[:, :, cache_position]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=True,
cache_position=cache_position,
)
output = BaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""The Mllama model which consists of a vision encoder and a language model.""",
MLLAMA_START_DOCSTRING,
)
class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_model": "model.vision_model",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: MllamaConfig):
super().__init__(config)
self.model = MllamaModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
def get_decoder(self):
return self.language_model.get_decoder()
@property
def vision_model(self):
return self.model.vision_model
@can_return_tuple
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaConfig")
def forward(
@ -2084,78 +2204,36 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and cross_attention_states is not None:
raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")
if pixel_values is not None:
if aspect_ratio_ids is None:
raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
# get vision tokens from vision model
vision_outputs = self.vision_model(
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
cross_attention_states = vision_outputs[0]
cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
-1, cross_attention_states.shape[-2], self.hidden_size
)
if cross_attention_mask is not None:
cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
cross_attention_mask,
num_vision_tokens=self.vision_model.num_patches,
dtype=self.dtype,
)
else:
full_text_row_masked_out_mask = None
if cross_attention_mask is not None and cache_position is not None:
cross_attention_mask = cross_attention_mask[:, :, cache_position]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
outputs = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
aspect_ratio_mask=aspect_ratio_mask,
aspect_ratio_ids=aspect_ratio_ids,
cross_attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
output_attentions=output_attentions,
return_dict=return_dict,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**loss_kwargs,
)
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
logits = outputs[0]
if labels is not None:
loss = self.loss_function(logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs)
if not return_dict:
return (loss,) + outputs if loss is not None else outputs
loss = self.loss_function(logits, labels, self.config.text_config.vocab_size, **loss_kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=outputs.logits,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
@ -2179,58 +2257,28 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
if past_key_values is not None:
if (
inputs_embeds is not None # Exception 1
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
# TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if logits_to_keep is not None:
model_inputs["logits_to_keep"] = logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"cross_attention_mask": cross_attention_mask,
}
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
cross_attention_mask=cross_attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
# If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios
# to compute image hidden states, otherwise they are cached within each cross attn layer
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
model_inputs["aspect_ratio_ids"] = aspect_ratio_ids
model_inputs["aspect_ratio_mask"] = aspect_ratio_mask
if cache_position[0] != 0:
model_inputs["pixel_values"] = None
model_inputs["aspect_ratio_ids"] = None
model_inputs["aspect_ratio_mask"] = None
return model_inputs
@ -2257,4 +2305,5 @@ __all__ = [
"MllamaTextModel",
"MllamaVisionModel",
"MllamaPreTrainedModel",
"MllamaModel",
]

View File

@ -13,8 +13,6 @@
# limitations under the License.
"""PaliGemmamodel configuration"""
import warnings
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING, AutoConfig
@ -39,8 +37,6 @@ class PaliGemmaConfig(PretrainedConfig):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 256000):
The image token index to encode the image prompt.
vocab_size (`int`, *optional*, defaults to 257152):
@ -83,16 +79,13 @@ class PaliGemmaConfig(PretrainedConfig):
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=256000,
vocab_size=257152,
projection_dim=2048,
hidden_size=2048,
**kwargs,
):
self._ignore_index = ignore_index
self.image_token_index = image_token_index
self._vocab_size = vocab_size
self.projection_dim = projection_dim
self.hidden_size = hidden_size
self.vision_config = vision_config
@ -133,22 +126,5 @@ class PaliGemmaConfig(PretrainedConfig):
self.vision_config.projection_dim = projection_dim
super().__init__(**kwargs)
@property
def ignore_index(self):
warnings.warn(
"The `ignore_index` attribute is deprecated and will be removed in v4.47.",
FutureWarning,
)
return self._ignore_index
@ignore_index.setter
def ignore_index(self, value):
self._ignore_index = value
def to_dict(self):
output = super().to_dict()
output.pop("_ignore_index", None)
return output
__all__ = ["PaliGemmaConfig"]

View File

@ -23,17 +23,18 @@ from torch import nn
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_outputs import CausalLMOutputWithPast
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_paligemma import PaliGemmaConfig
@ -42,78 +43,43 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "PaliGemmaConfig"
# Adapted from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
# But Paligemma has no causal mask on prefix
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
is_training: bool = False,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
):
@dataclass
class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Base class for Paligemma outputs, with hidden states and attentions.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
is_training (`bool`):
Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels`
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
if is_training:
causal_mask = torch.triu(causal_mask, diagonal=1)
else:
causal_mask[:, :sequence_length] = 0.0
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
if is_training:
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
)
return causal_mask
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class PaliGemmaCausalLMOutputWithPast(ModelOutput):
"""
Base class for PaliGemmacausal language model (or autoregressive) outputs.
Base class for PaliGemma causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
@ -184,7 +150,7 @@ PALIGEMMA_START_DOCSTRING = r"""
)
class PaliGemmaPreTrainedModel(PreTrainedModel):
config_class = PaliGemmaConfig
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["PaliGemmaMultiModalProjector"]
_skip_keys_device_placement = "past_key_values"
@ -276,49 +242,32 @@ PALIGEMMA_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"""The PALIGEMMA model which consists of a vision backbone and a language model.""",
"""Base Paligemma model which consists of a vision backbone and a language model withou language modeling head.""",
PALIGEMMA_START_DOCSTRING,
)
class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
class PaliGemmaModel(PaliGemmaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: PaliGemmaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
language_model = AutoModelForCausalLM.from_config(config=config.text_config)
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
language_model = AutoModel.from_config(config=config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings with Llava->PaliGemma
# Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings with Llava->PaliGemma
# Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings with Llava->PaliGemma
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings with Llava->PaliGemma
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder with Llava->PaliGemma
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder with Llava->PaliGemma
def get_decoder(self):
return self.language_model.get_decoder()
def _update_causal_mask(
self,
attention_mask,
@ -326,7 +275,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
past_key_values=None,
cache_position=None,
input_tensor=None,
is_training: Optional[bool] = None,
is_training: bool = None,
):
if self.config.text_config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
@ -403,12 +352,154 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
image_features = image_features / (self.config.text_config.hidden_size**0.5)
return image_features
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**lm_kwargs,
) -> Union[Tuple, PaligemmaModelOutputWithPast]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = PaligemmaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""Base Paligemma model which consists of a vision backbone and a language model withou language modeling head.""",
PALIGEMMA_START_DOCSTRING,
)
class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: PaliGemmaConfig):
super().__init__(config)
self.model = PaliGemmaModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
@ -459,117 +550,46 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Where is the cat standing?\nsnow"
```"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id with PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# mask out pad-token-ids in labels for BC
if labels is not None and self.pad_token_id in labels:
logger.warning_once(
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
)
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
outputs: CausalLMOutputWithPast = self.language_model(
attention_mask=causal_mask,
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
else:
shift_logits = shift_logits.contiguous()
shift_labels = shift_labels.contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
output = PaliGemmaCausalLMOutputWithPast(
return PaliGemmaCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
return output if return_dict else output.to_tuple()
def prepare_inputs_for_generation(
self,
@ -587,7 +607,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -610,12 +630,68 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self._update_causal_mask(
causal_mask = self.model._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"]

View File

@ -2056,7 +2056,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Qwen25OmniThinkerText. Make sure to "
" this may lead to unexpected behaviour for Flash Attention version of Qwen25OmniThinker. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
@ -2156,7 +2156,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`Qwen25OmniThinkerTextConfig`):
config (`Qwen25OmniThinkerConfig`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
@ -2772,7 +2772,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Qwen25OmniTalker. Make sure to "
" this may lead to unexpected behaviour for Flash Attention version of Qwen2_5Omni. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
@ -2872,7 +2872,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`Qwen25OmniTalkerConfig`):
config (`Qwen2_5OmniConfig`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate

View File

@ -33,8 +33,8 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLAttention,
Qwen2_5_VLMLP,
Qwen2_5_VLModel,
Qwen2_5_VLPreTrainedModel,
Qwen2_5_VLTextModel,
Qwen2_5_VLVisionBlock,
Qwen2RMSNorm,
)
@ -2165,7 +2165,7 @@ QWEN2_5OMNI_START_DOCSTRING = r"""
"The bare Qwen2.5OmniThinker Model outputting raw hidden-states without any specific head on top.",
QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniTextConfig"),
)
class Qwen2_5OmniThinkerTextModel(Qwen2_5_VLModel):
class Qwen2_5OmniThinkerTextModel(Qwen2_5_VLTextModel):
config_class = Qwen2_5OmniTextConfig
_no_split_modules = ["Qwen2_5OmniDecoderLayer"]
@ -2591,7 +2591,7 @@ class Qwen2_5OmniTalkerCausalLMOutputWithPast(ModelOutput):
"The bare Qwen2.5OmniTalker Model outputting raw hidden-states without any specific head on top.",
QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniTalkerConfig"),
)
class Qwen2_5OmniTalkerModel(Qwen2_5_VLModel):
class Qwen2_5OmniTalkerModel(Qwen2_5_VLTextModel):
config_class = Qwen2_5OmniTalkerConfig
_no_split_modules = ["Qwen2_5OmniTalkerDecoderLayer"]

View File

@ -31,7 +31,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
@ -44,6 +43,7 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
@ -565,6 +565,42 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return hidden_states
@dataclass
class Qwen2_5_VLModelOutputWithPast(ModelOutput):
"""
Base class for Llava outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
"""
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
rope_deltas: Optional[torch.LongTensor] = None
class Qwen2_5_VLRotaryEmbedding(nn.Module):
def __init__(self, config: Qwen2_5_VLTextConfig, device=None):
super().__init__()
@ -1076,7 +1112,7 @@ class Qwen2_5_VLDecoderLayer(nn.Module):
"The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.",
Qwen2_5_VL_START_DOCSTRING,
)
class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
config_class = Qwen2_5_VLTextConfig
def __init__(self, config: Qwen2_5_VLTextConfig):
@ -1374,45 +1410,6 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
return causal_mask
@dataclass
class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput):
"""
Base class for Qwen2_5_VL causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
rope_deltas: Optional[torch.LongTensor] = None
QWEN2_5_VL_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -1489,41 +1486,25 @@ QWEN2_5_VL_INPUTS_DOCSTRING = r"""
"""
class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
_checkpoint_conversion_mapping = {"^model": "language_model"}
config_class = Qwen2_5_VLConfig
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
def __init__(self, config):
super().__init__(config)
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
text_config = config.get_text_config()
self.model = Qwen2_5_VLModel._from_config(text_config)
self.vocab_size = text_config.vocab_size
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
self.language_model = Qwen2_5_VLTextModel._from_config(config.text_config)
self.rope_deltas = None # cache rope_deltas here
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
self.language_model.set_input_embeddings(value)
def get_rope_index(
self,
@ -1708,15 +1689,13 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
return position_ids, mrope_position_deltas
@add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
@ -1728,46 +1707,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""
) -> Union[Tuple, Qwen2_5_VLModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -1775,7 +1715,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
@ -1846,7 +1786,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.model(
outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
@ -1855,6 +1795,182 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
)
output = Qwen2_5_VLModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)
return output if return_dict else output.to_tuple()
@dataclass
class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput):
"""
Base class for Qwen2_5_VL causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
rope_deltas: Optional[torch.LongTensor] = None
class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^visual": "model.visual",
r"^model(?!\.(language_model|visual))": "model.language_model",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = Qwen2_5_VLModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def visual(self):
return self.model.visual
@can_return_tuple
@add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
@ -1864,18 +1980,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
if not return_dict:
output = (logits,) + outputs[1:]
@ -1887,7 +1992,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
rope_deltas=outputs.rope_deltas,
)
def prepare_inputs_for_generation(
@ -2055,5 +2160,60 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
return input_ids, model_kwargs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLTextModel"]

View File

@ -26,7 +26,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
@ -36,6 +35,7 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLCausalLMOutputWithPast,
Qwen2VLForConditionalGeneration,
Qwen2VLModel,
Qwen2VLModelOutputWithPast,
Qwen2VLPreTrainedModel,
VisionAttention,
VisionRotaryEmbedding,
@ -50,7 +50,12 @@ from ...image_utils import ImageInput, VideoInput
from ...modeling_flash_attention_utils import is_flash_attn_available
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
from ...utils import (
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
)
if is_flash_attn_available():
@ -59,6 +64,8 @@ if is_flash_attn_available():
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Qwen2_5_VLConfig"
def apply_rotary_pos_emb_flashatt(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
@ -406,16 +413,12 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return hidden_states
class Qwen2_5_VLModel(Qwen2VLModel):
pass
@dataclass
class Qwen2_5_VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast):
class Qwen2_5_VLModelOutputWithPast(Qwen2VLModelOutputWithPast):
pass
class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
class Qwen2_5_VLModel(Qwen2VLModel):
config_class = Qwen2_5_VLConfig
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
@ -607,12 +610,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
@ -624,46 +626,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""
) -> Union[Tuple, Qwen2_5_VLModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -671,7 +634,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
@ -742,7 +705,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.model(
outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
@ -751,6 +714,111 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
)
output = Qwen2_5_VLModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)
return output if return_dict else output.to_tuple()
@dataclass
class Qwen2_5_VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast):
pass
QWEN2_5_VL_INPUTS_DOCSTRING = None
class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
@ -760,18 +828,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
if not return_dict:
output = (logits,) + outputs[1:]
@ -783,7 +840,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
rope_deltas=outputs.rope_deltas,
)
def prepare_inputs_for_generation(
@ -985,4 +1042,5 @@ __all__ = [
"Qwen2_5_VLModel",
"Qwen2_5_VLPreTrainedModel",
"Qwen2_5_VLProcessor",
"Qwen2_5_VLTextModel", # noqa: F822
]

View File

@ -799,27 +799,21 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
raise ValueError(f"{padding_side} is not `left` or `right`.")
self._padding_side = padding_side
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
def get_decoder(self):
return self.language_model.get_decoder()

View File

@ -27,7 +27,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn import LayerNorm
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
@ -40,7 +40,9 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torch_flex_attn_available,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -61,6 +63,42 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Qwen2VLConfig"
@dataclass
class Qwen2VLModelOutputWithPast(ModelOutput):
"""
Base class for Llava outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
"""
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
rope_deltas: Optional[torch.LongTensor] = None
@dataclass
class Qwen2VLCausalLMOutputWithPast(ModelOutput):
"""
@ -1027,7 +1065,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
"The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.",
QWEN2VL_START_DOCSTRING,
)
class Qwen2VLModel(Qwen2VLPreTrainedModel):
class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
config_class = Qwen2VLTextConfig
def __init__(self, config: Qwen2VLTextConfig):
@ -1403,39 +1441,23 @@ QWEN2_VL_INPUTS_DOCSTRING = r"""
"""
class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
class Qwen2VLModel(Qwen2VLPreTrainedModel):
_checkpoint_conversion_mapping = {"^model": "language_model"}
def __init__(self, config):
def __init__(self, config: Qwen2VLConfig):
super().__init__(config)
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
text_config = config.get_text_config()
self.model = Qwen2VLModel._from_config(text_config)
self.vocab_size = text_config.vocab_size
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
self.language_model = Qwen2VLTextModel._from_config(config.text_config)
self.rope_deltas = None # cache rope_deltas here
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
self.language_model.set_input_embeddings(value)
def get_rope_index(
self,
@ -1586,11 +1608,166 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
return position_ids, mrope_position_deltas
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Qwen2VLModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_embeds.shape[0]
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum()
n_video_features = video_embeds.shape[0]
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
delta = delta.to(position_ids.device)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
)
output = Qwen2VLModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)
return output if return_dict else output.to_tuple()
class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^visual": "model.visual",
r"^model(?!\.(language_model|visual))": "model.language_model",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = Qwen2VLModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def visual(self):
return self.model.visual
@can_return_tuple
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -1652,70 +1829,12 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
delta = delta.to(position_ids.device)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.model(
input_ids=None,
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
@ -1732,22 +1851,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
return Qwen2VLCausalLMOutputWithPast(
loss=loss,
@ -1755,7 +1859,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
rope_deltas=outputs.rope_deltas,
)
def prepare_inputs_for_generation(
@ -1921,5 +2025,61 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
return input_ids, model_kwargs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel", "Qwen2VLTextModel"]

View File

@ -28,11 +28,12 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_video_llava import VideoLlavaConfig
@ -41,6 +42,47 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VideoLlavaConfig"
@dataclass
class VideoLlavaModelOutputWithPast(ModelOutput):
"""
Base class for VideoLlava base model outputs.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
video_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
video_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class VideoLlavaCausalLMOutputWithPast(ModelOutput):
"""
@ -130,7 +172,7 @@ VIDEO_LLAVA_START_DOCSTRING = r"""
)
class VideoLlavaPreTrainedModel(PreTrainedModel):
config_class = VideoLlavaConfig
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["VideoLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
@ -242,10 +284,12 @@ VIDEO_LLAVA_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"""The VideoLlava model which consists of a vision backbone and a language model.""",
"""The VideoLlava model which consists of a vision backbone and a language model without language modeling head.""",
VIDEO_LLAVA_START_DOCSTRING,
)
class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMixin):
class VideoLlavaModel(VideoLlavaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: VideoLlavaConfig):
super().__init__(config)
self.video_tower = AutoModel.from_config(config.vision_config)
@ -253,10 +297,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
self.multi_modal_projector = VideoLlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.language_model = AutoModel.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
@ -266,18 +307,6 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_image_features(
self,
pixel_values_images: torch.FloatTensor,
@ -358,13 +387,165 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
return video_features, num_frames
@add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values_images: torch.FloatTensor = None,
pixel_values_videos: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
) -> Union[Tuple, VideoLlavaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same "
"time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values_images is not None:
image_features = self.get_image_features(
pixel_values_images,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if pixel_values_videos is not None:
video_features, num_frames = self.get_video_features(
pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer
)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum()
n_video_features = video_features.shape[0] * video_features.shape[1]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = VideoLlavaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values_images is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""The VideoLlava model which consists of a vision backbone and a language model.""",
VIDEO_LLAVA_START_DOCSTRING,
)
class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^image_tower": "model.image_tower",
"^video_tower": "model.video_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: VideoLlavaConfig):
super().__init__(config)
self.model = VideoLlavaModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def video_tower(self):
return self.model.video_tower
@property
def image_tower(self):
return self.model.image_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values_images: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values_images: torch.FloatTensor = None,
pixel_values_videos: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -475,88 +656,32 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same "
"time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values_images is not None:
image_features = self.get_image_features(
pixel_values_images,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if pixel_values_videos is not None:
video_features, num_frames = self.get_video_features(
pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer
)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum()
n_video_features = video_features.shape[0] * video_features.shape[1]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
outputs = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values_images=pixel_values_images,
pixel_values_videos=pixel_values_videos,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return VideoLlavaCausalLMOutputWithPast(
loss=loss,
@ -564,8 +689,8 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values_images is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
image_hidden_states=outputs.image_hidden_states,
video_hidden_states=outputs.video_hidden_states,
)
def prepare_inputs_for_generation(
@ -582,7 +707,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -600,5 +725,61 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
return model_inputs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["VideoLlavaPreTrainedModel", "VideoLlavaForConditionalGeneration"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["VideoLlavaPreTrainedModel", "VideoLlavaModel", "VideoLlavaForConditionalGeneration"]

View File

@ -1,3 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/vipllava/modular_vipllava.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_vipllava.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2023 the HuggingFace Inc. team. All rights reserved.
#
@ -12,37 +18,65 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch VipLlava model."""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_vipllava import VipLlavaConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VipLlavaConfig"
@dataclass
# Copied from transformers.models.llava.modeling_llava.LlavaCausalLMOutputWithPast with Llava->VipLlava
class VipLlavaModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for VipLlava outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class VipLlavaCausalLMOutputWithPast(ModelOutput):
"""
Base class for VipLlava causal language model (or autoregressive) outputs.
@ -70,7 +104,7 @@ class VipLlavaCausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@ -129,9 +163,8 @@ VIPLLAVA_START_DOCSTRING = r"""
)
class VipLlavaPreTrainedModel(PreTrainedModel):
config_class = VipLlavaConfig
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["VipLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@ -140,6 +173,9 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of VipLlava isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/VipLlava/tree/main/vipllava should serve for that purpose
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
@ -147,8 +183,8 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.data.zero_()
VIPLLAVA_INPUTS_DOCSTRING = r"""
@ -163,7 +199,7 @@ VIPLLAVA_INPUTS_DOCSTRING = r"""
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`VipLlavaProcessor`] uses
[`CLIPImageProcessor`] for processing images).
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
@ -203,6 +239,13 @@ VIPLLAVA_INPUTS_DOCSTRING = r"""
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
@ -222,24 +265,18 @@ VIPLLAVA_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"""The VIPLLAVA model which consists of a vision backbone and a language model.""",
"""The VipLlava model which consists of a vision backbone and a language model, without a language modeling head.""",
VIPLLAVA_START_DOCSTRING,
)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration with LLAVA->VIPLLAVA,Llava->VipLlava
class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin):
class VipLlavaModel(VipLlavaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: VipLlavaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = VipLlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@ -248,19 +285,6 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
# Ignore copy
def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: Union[int, List[int]]):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
@ -288,12 +312,132 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
return image_features
@add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
# Ignore copy
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layers: Optional[Union[int, List[int]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
) -> Union[Tuple, VipLlavaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layers = (
vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = VipLlavaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""The VIPLLAVA model which consists of a vision backbone and a language model.""",
VIPLLAVA_START_DOCSTRING,
)
class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: VipLlavaConfig):
super().__init__(config)
self.model = VipLlavaModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -358,68 +502,30 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
vision_feature_layers=vision_feature_layers,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return VipLlavaCausalLMOutputWithPast(
loss=loss,
@ -427,7 +533,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -443,7 +549,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -460,5 +566,60 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
return model_inputs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["VipLlavaForConditionalGeneration", "VipLlavaPreTrainedModel"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["VipLlavaModel", "VipLlavaForConditionalGeneration", "VipLlavaPreTrainedModel"]

View File

@ -0,0 +1,294 @@
# coding=utf-8
# Copyright 2023 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from transformers.models.llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaModel,
LlavaModelOutputWithPast,
LlavaPreTrainedModel,
)
from ...activations import ACT2FN
from ...utils import (
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from .configuration_vipllava import VipLlavaConfig
logger = logging.get_logger(__name__)
VIPLLAVA_INPUTS_DOCSTRING = None
_CONFIG_FOR_DOC = "VipLlavaConfig"
class VipLlavaModelOutputWithPast(LlavaModelOutputWithPast):
pass
class VipLlavaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
class VipLlavaMultiModalProjector(nn.Module):
def __init__(self, config: VipLlavaConfig):
super().__init__()
num_feature_layers = 1 if isinstance(config.vision_feature_layers, int) else len(config.vision_feature_layers)
self.projector_layernorm = nn.LayerNorm(
num_feature_layers * config.vision_config.hidden_size, eps=config.projector_layernorm_eps
)
self.linear_1 = nn.Linear(
num_feature_layers * config.vision_config.hidden_size,
config.text_config.hidden_size,
bias=True,
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
def forward(self, hidden_states):
hidden_states = self.projector_layernorm(hidden_states)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class VipLlavaPreTrainedModel(LlavaPreTrainedModel):
pass
class VipLlavaModel(LlavaModel):
def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: Union[int, List[int]]):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
vision_feature_layers (`Union[int, List[int]]`):
The vision feature layer, or the list of indexes of the layers to select
the vision feature.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# If multiple feature layers are provided (which is usually the case)
# then the image features are concatenated after the CLS is removed.
if isinstance(vision_feature_layers, int):
image_features = image_outputs.hidden_states[vision_feature_layers][:, 1:]
else:
# Usually, we select the features from index 1: the layers -2, -5, -8, -11 and 6
image_features = [image_outputs.hidden_states[index][:, 1:] for index in vision_feature_layers]
image_features = torch.cat(image_features, dim=-1)
image_features = self.multi_modal_projector(image_features)
return image_features
@add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layers: Optional[Union[int, List[int]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
) -> Union[Tuple, VipLlavaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layers = (
vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = VipLlavaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
class VipLlavaForConditionalGeneration(LlavaForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
# Ignore copy
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layers: Optional[Union[int, List[int]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, VipLlavaCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> import torch
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, VipLlavaForConditionalGeneration
>>> model = VipLlavaForConditionalGeneration.from_pretrained("llava-hf/vip-llava-7b-hf", device_map="auto", torch_dtype=torch.float16)
>>> processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf")
>>> prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: <image>\n{}###Assistant:"
>>> question = "Can you please describe this image?"
>>> prompt = prompt.format(question)
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=text, images=image, return_tensors="pt").to(0, torch.float16)
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=20)
>>> processor.decode(generate_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
The image features a brown and white cat sitting on a green surface, with a red ball in its
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layers = (
vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
)
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
vision_feature_layers=vision_feature_layers,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return VipLlavaCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
__all__ = ["VipLlavaModel", "VipLlavaForConditionalGeneration", "VipLlavaPreTrainedModel"]

View File

@ -21,6 +21,7 @@ import requests
from transformers import (
AriaConfig,
AriaForConditionalGeneration,
AriaModel,
AriaTextConfig,
AutoProcessor,
AutoTokenizer,
@ -175,7 +176,7 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
Model tester for `AriaForConditionalGeneration`.
"""
all_model_classes = (AriaForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (AriaModel, AriaForConditionalGeneration) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
_is_composite = True
@ -281,6 +282,18 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading")
def test_cpu_offload(self):
pass
@unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading")
def test_disk_offload_bin(self):
pass
@unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading")
def test_disk_offload_safetensors(self):
pass
@require_torch
class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@ -46,6 +46,7 @@ if is_torch_available():
from transformers import (
AyaVisionForConditionalGeneration,
AyaVisionModel,
)
@ -158,7 +159,14 @@ class AyaVisionVisionText2TextModelTester:
@require_torch
class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (AyaVisionForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
AyaVisionModel,
AyaVisionForConditionalGeneration,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (AyaVisionForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{

View File

@ -46,6 +46,7 @@ if is_torch_available():
from transformers import (
Emu3ForCausalLM,
Emu3ForConditionalGeneration,
Emu3Model,
Emu3Processor,
Emu3TextModel,
)
@ -310,7 +311,14 @@ class Emu3Vision2TextModelTester:
@require_torch
class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Emu3ForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
Emu3Model,
Emu3ForConditionalGeneration,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = {}
test_headmasking = False
test_pruning = False
@ -395,6 +403,10 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
def test_generate_with_static_cache(self):
pass
@unittest.skip("Emu3 doesn't support Flex attn yet!")
def test_flex_attention_with_grads(self):
pass
@require_torch
class Emu3IntegrationTest(unittest.TestCase):

View File

@ -38,7 +38,7 @@ if is_torch_available() and is_vision_available():
if is_torch_available():
from transformers import FuyuForCausalLM
from transformers import FuyuForCausalLM, FuyuModel
class FuyuModelTester:
@ -145,7 +145,14 @@ class FuyuModelTester:
@require_torch
class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (FuyuForCausalLM,) if is_torch_available() else ()
all_model_classes = (
(
FuyuModel,
FuyuForCausalLM,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{"text-generation": FuyuForCausalLM, "image-text-to-text": FuyuForCausalLM} if is_torch_available() else {}
)

View File

@ -50,6 +50,7 @@ if is_torch_available():
from transformers import (
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
Gemma3Model,
Gemma3Processor,
Gemma3TextModel,
)
@ -148,9 +149,9 @@ class Gemma3Vision2TextModelTester:
self,
parent,
mm_tokens_per_image=2,
image_token_index=1,
boi_token_index=2,
eoi_token_index=3,
image_token_index=4,
boi_token_index=5,
eoi_token_index=6,
seq_length=25,
is_training=True,
vision_config={
@ -242,7 +243,14 @@ class Gemma3Vision2TextModelTester:
@require_torch
class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Gemma3ForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
Gemma3Model,
Gemma3ForConditionalGeneration,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (Gemma3ForConditionalGeneration,) if is_torch_available() else ()
test_headmasking = False
test_pruning = False

View File

@ -34,6 +34,7 @@ if is_torch_available():
from transformers import (
GotOcr2ForConditionalGeneration,
GotOcr2Model,
)
@ -140,7 +141,14 @@ class GotOcr2VisionText2TextModelTester:
@require_torch
class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GotOcr2ForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
GotOcr2Model,
GotOcr2ForConditionalGeneration,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{
"image-to-text": GotOcr2ForConditionalGeneration,
@ -228,6 +236,10 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_past_key_values_format(self):
pass
@unittest.skip(reason="Vision backbone doesn't support FLEX yet!")
def test_flex_attention_with_grads(self):
pass
@require_torch
class GotOcr2IntegrationTest(unittest.TestCase):

View File

@ -54,7 +54,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import InstructBlipForConditionalGeneration, InstructBlipVisionModel
from transformers import InstructBlipForConditionalGeneration, InstructBlipModel, InstructBlipVisionModel
if is_vision_available():
@ -460,14 +460,20 @@ class InstructBlipForConditionalGenerationDecoderOnlyModelTester:
"attention_mask": attention_mask,
"qformer_input_ids": qformer_input_ids,
"qformer_attention_mask": qformer_attention_mask,
"labels": input_ids,
}
return config, inputs_dict
@require_torch
class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (InstructBlipForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
InstructBlipModel,
InstructBlipForConditionalGeneration,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = {"image-text-to-text": InstructBlipForConditionalGeneration}
fx_compatible = False
test_head_masking = False

View File

@ -54,7 +54,11 @@ if is_torch_available():
import torch
from torch import nn
from transformers import InstructBlipVideoForConditionalGeneration, InstructBlipVideoVisionModel
from transformers import (
InstructBlipVideoForConditionalGeneration,
InstructBlipVideoModel,
InstructBlipVideoVisionModel,
)
class InstructBlipVideoVisionModelTester:
@ -477,7 +481,6 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyModelTester:
"attention_mask": attention_mask,
"qformer_input_ids": qformer_input_ids,
"qformer_attention_mask": qformer_attention_mask,
"labels": input_ids,
}
return config, inputs_dict
@ -486,7 +489,9 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyModelTester:
class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
):
all_model_classes = (InstructBlipVideoForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel) if is_torch_available() else ()
)
fx_compatible = False
test_head_masking = False
test_pruning = False

View File

@ -47,9 +47,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
InternVLForConditionalGeneration,
)
from transformers import InternVLForConditionalGeneration, InternVLModel
if is_vision_available():
@ -191,7 +189,7 @@ class InternVLVisionText2TextModelTester:
@require_torch
class InternVLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (InternVLForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (InternVLForConditionalGeneration, InternVLModel) if is_torch_available() else ()
all_generative_model_classes = (InternVLForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch Llava model."""
import copy
import unittest
import requests
@ -23,6 +24,7 @@ from transformers import (
AutoTokenizer,
LlavaConfig,
LlavaForConditionalGeneration,
LlavaModel,
is_torch_available,
is_vision_available,
)
@ -166,7 +168,14 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
Model tester for `LlavaForConditionalGeneration`.
"""
all_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
LlavaModel,
LlavaForConditionalGeneration,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{"image-to-text": LlavaForConditionalGeneration, "image-text-to-text": LlavaForConditionalGeneration}
if is_torch_available()
@ -238,16 +247,17 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successful forward with no modifications
curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further
_ = model(**curr_input_dict) # successful forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
_ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
input_ids = curr_input_dict["input_ids"][:1]
pixel_values = curr_input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
@ -281,7 +291,8 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
model = model_class(config).to(torch_device)
# We should have the right number of input features,
# and should be able to run a forward pass without exploding
assert model.multi_modal_projector.linear_1.in_features == expected_features
base_model = getattr(model, "model", model)
assert base_model.multi_modal_projector.linear_1.in_features == expected_features
model(**input_dict)
@unittest.skip(

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch Llava-NeXT model."""
import copy
import unittest
import requests
@ -23,6 +24,7 @@ from transformers import (
AutoProcessor,
LlavaNextConfig,
LlavaNextForConditionalGeneration,
LlavaNextModel,
is_torch_available,
is_vision_available,
)
@ -181,7 +183,14 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
Model tester for `LlavaNextForConditionalGeneration`.
"""
all_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
LlavaNextModel,
LlavaNextForConditionalGeneration,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = {"image-text-to-text": LlavaNextForConditionalGeneration} if is_torch_available() else {}
test_pruning = False
test_head_masking = False
@ -265,18 +274,19 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successful forward with no modifications
curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further
_ = model(**curr_input_dict) # successful forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
curr_input_dict["image_sizes"] = curr_input_dict["image_sizes"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
_ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
image_sizes = input_dict["image_sizes"][:1]
input_ids = curr_input_dict["input_ids"][:1]
pixel_values = curr_input_dict["pixel_values"][:1]
image_sizes = curr_input_dict["image_sizes"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
@ -324,7 +334,8 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
model = model_class(config).to(torch_device)
# We should have the right number of input features,
# and should be able to run a forward pass without exploding
assert model.multi_modal_projector.linear_1.in_features == expected_features
base_model = getattr(model, "model", model)
assert base_model.multi_modal_projector.linear_1.in_features == expected_features
model(**input_dict)
@unittest.skip(

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch Llava-NeXT-Video model."""
import copy
import unittest
import numpy as np
@ -23,6 +24,7 @@ from transformers import (
AutoProcessor,
LlavaNextVideoConfig,
LlavaNextVideoForConditionalGeneration,
LlavaNextVideoModel,
is_torch_available,
is_vision_available,
)
@ -196,7 +198,14 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
Model tester for `LlavaNextVideoForConditionalGeneration`.
"""
all_model_classes = (LlavaNextVideoForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
LlavaNextVideoModel,
LlavaNextVideoForConditionalGeneration,
)
if is_torch_available()
else ()
)
test_pruning = False
test_head_masking = False
_is_composite = True
@ -281,18 +290,19 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successful forward with no modifications
curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further
_ = model(**curr_input_dict) # successful forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
curr_input_dict["image_sizes"] = curr_input_dict["image_sizes"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
_ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
image_sizes = input_dict["image_sizes"][:1]
input_ids = curr_input_dict["input_ids"][:1]
pixel_values = curr_input_dict["pixel_values"][:1]
image_sizes = curr_input_dict["image_sizes"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
@ -340,7 +350,8 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
model = model_class(config).to(torch_device)
# We should have the right number of input features,
# and should be able to run a forward pass without exploding
assert model.multi_modal_projector.linear_1.in_features == expected_features
base_model = getattr(model, "model", model)
assert base_model.multi_modal_projector.linear_1.in_features == expected_features
model(**input_dict)
@unittest.skip(

View File

@ -24,6 +24,7 @@ from transformers import (
AutoProcessor,
LlavaOnevisionConfig,
LlavaOnevisionForConditionalGeneration,
LlavaOnevisionModel,
is_torch_available,
is_vision_available,
)
@ -182,7 +183,14 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
Model tester for `LlavaOnevisionForConditionalGeneration`.
"""
all_model_classes = (LlavaOnevisionForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
LlavaOnevisionModel,
LlavaOnevisionForConditionalGeneration,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{"image-text-to-text": LlavaOnevisionForConditionalGeneration} if is_torch_available() else {}
)
@ -296,7 +304,8 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
model = model_class(config).to(torch_device)
# We should have the right number of input features,
# and should be able to run a forward pass without exploding
assert model.multi_modal_projector.linear_1.in_features == expected_features
base_model = getattr(model, "model", model)
assert base_model.multi_modal_projector.linear_1.in_features == expected_features
model(**input_dict)
@unittest.skip(

View File

@ -42,6 +42,7 @@ if is_torch_available():
from transformers import (
Mistral3ForConditionalGeneration,
Mistral3Model,
)
@ -162,7 +163,14 @@ class Mistral3VisionText2TextModelTester:
@require_torch
class Mistral3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Mistral3ForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
Mistral3Model,
Mistral3ForConditionalGeneration,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (Mistral3ForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
@ -278,6 +286,10 @@ class Mistral3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip("Pixtral does not support attention interfaces.")
def test_flex_attention_with_grads(self):
pass
@slow
@require_torch_gpu

View File

@ -25,6 +25,7 @@ from transformers import (
MllamaConfig,
MllamaForCausalLM,
MllamaForConditionalGeneration,
MllamaModel,
is_torch_available,
is_vision_available,
)
@ -262,7 +263,14 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
Model tester for `MllamaForConditionalGeneration`.
"""
all_model_classes = (MllamaForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
MllamaModel,
MllamaForConditionalGeneration,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = {"image-text-to-text": MllamaForConditionalGeneration} if is_torch_available() else ()
test_pruning = False
test_head_masking = False
@ -325,19 +333,18 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
# resizing embeddings should result in successful loss computation
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model_vocab_size = config.get_text_config().vocab_size
inputs = self._prepare_for_class(inputs, model_class, return_labels=True)
# Resize embeddings and call forward
model.resize_token_embeddings(model_vocab_size + 10)
output = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
labels=inputs["labels"],
return_dict=True,
)
self.assertTrue("loss" in output)
model = MllamaForConditionalGeneration(config).to(torch_device)
model_vocab_size = config.get_text_config().vocab_size
inputs = self._prepare_for_class(inputs, MllamaForConditionalGeneration, return_labels=True)
# Resize embeddings and call forward
model.resize_token_embeddings(model_vocab_size + 10)
output = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
labels=inputs["labels"],
return_dict=True,
)
self.assertTrue("loss" in output)
def _check_attentions_for_generate(
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
@ -409,6 +416,18 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
def test_assisted_decoding_with_num_logits_to_keep(self):
pass
@unittest.skip(reason="Mllama uses self.weights dirrectly causing device mismatch when offloading`")
def test_cpu_offload(self):
pass
@unittest.skip(reason="Mllama uses self.weights dirrectly causing device mismatch when offloading`")
def test_disk_offload_bin(self):
pass
@unittest.skip(reason="Mllama uses self.weights dirrectly causing device mismatch when offloading`")
def test_disk_offload_safetensors(self):
pass
@pytest.mark.generate
# overridden because mllama is not an encoder-decoder model, but has encoder-decoder-like cache
def test_past_key_values_format(self):
@ -501,7 +520,7 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
for model_class in self.all_generative_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch PaliGemma model."""
import copy
import unittest
import requests
@ -20,6 +21,7 @@ import requests
from transformers import (
PaliGemmaConfig,
PaliGemmaForConditionalGeneration,
PaliGemmaModel,
PaliGemmaProcessor,
is_torch_available,
is_vision_available,
@ -177,7 +179,14 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
Model tester for `PaliGemmaForConditionalGeneration`.
"""
all_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
PaliGemmaModel,
PaliGemmaForConditionalGeneration,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = {"image-text-to-text": PaliGemmaForConditionalGeneration}
fx_compatible = False
test_pruning = False
@ -242,16 +251,17 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successful forward with no modifications
curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further
_ = model(**curr_input_dict) # successful forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
_ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
input_ids = curr_input_dict["input_ids"][:1]
pixel_values = curr_input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch PaliGemma model."""
import copy
import unittest
import pytest
@ -239,16 +240,17 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successful forward with no modifications
curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further
_ = model(**curr_input_dict) # successful forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
_ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
input_ids = curr_input_dict["input_ids"][:1]
pixel_values = curr_input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch Qwen2.5-VL model."""
import copy
import gc
import tempfile
import unittest
@ -23,6 +24,7 @@ from transformers import (
AutoProcessor,
Qwen2_5_VLConfig,
Qwen2_5_VLForConditionalGeneration,
Qwen2_5_VLModel,
is_torch_available,
is_vision_available,
)
@ -180,17 +182,11 @@ class Qwen2_5_VLVisionText2TextModelTester:
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
input_ids[:, self.num_image_tokens] = self.image_token_id
input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
labels = torch.zeros(
(self.batch_size, self.seq_length),
dtype=torch.long,
device=torch_device,
)
inputs_dict = {
"pixel_values": pixel_values,
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size),
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
return config, inputs_dict
@ -201,7 +197,14 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
Model tester for `Qwen2_5_VLForConditionalGeneration`.
"""
all_model_classes = (Qwen2_5_VLForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
Qwen2_5_VLModel,
Qwen2_5_VLForConditionalGeneration,
)
if is_torch_available()
else ()
)
test_pruning = False
test_head_masking = False
@ -236,19 +239,20 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successful forward with no modifications
curr_input_dict = copy.deepcopy(input_dict)
# remove one image but leave the image token in text
patch_size = config.vision_config.patch_size
one_img_length = (self.model_tester.image_size**2) // (patch_size**2)
input_dict["pixel_values"] = input_dict["pixel_values"][-one_img_length:, ...]
input_dict["image_grid_thw"] = input_dict["image_grid_thw"][-1:, ...]
curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-one_img_length:, ...]
curr_input_dict["image_grid_thw"] = curr_input_dict["image_grid_thw"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
_ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:one_img_length]
image_grid_thw = input_dict["image_grid_thw"][:1]
input_ids = curr_input_dict["input_ids"][:1]
pixel_values = curr_input_dict["pixel_values"][:one_img_length]
image_grid_thw = curr_input_dict["image_grid_thw"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
@ -375,6 +379,29 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
def test_save_load_fast_init_from_base(self):
pass
# The multimodal base model embeds will not match ids, due to pixel values. We can't change base test
# because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels`
# TODO: @raushan
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@require_torch
class Qwen2_5_VLIntegrationTest(unittest.TestCase):

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch Qwen2-VL model."""
import copy
import gc
import unittest
@ -22,6 +23,7 @@ from transformers import (
AutoProcessor,
Qwen2VLConfig,
Qwen2VLForConditionalGeneration,
Qwen2VLModel,
is_torch_available,
is_vision_available,
)
@ -169,17 +171,12 @@ class Qwen2VLVisionText2TextModelTester:
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
input_ids[:, self.num_image_tokens] = self.image_token_id
input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
labels = torch.zeros(
(self.batch_size, self.seq_length),
dtype=torch.long,
device=torch_device,
)
inputs_dict = {
"pixel_values": pixel_values,
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size),
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
return config, inputs_dict
@ -190,7 +187,14 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
Model tester for `Qwen2VLForConditionalGeneration`.
"""
all_model_classes = (Qwen2VLForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
Qwen2VLModel,
Qwen2VLForConditionalGeneration,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = {"image-text-to-text": Qwen2VLForConditionalGeneration}
test_pruning = False
test_head_masking = False
@ -226,20 +230,21 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successful forward with no modifications
curr_input_dict = copy.deepcopy(input_dict)
_ = model(**curr_input_dict) # successfull forward with no modifications
# remove one image but leave the image token in text
patch_size = config.vision_config.patch_size
one_img_length = (self.model_tester.image_size**2) // (patch_size**2)
input_dict["pixel_values"] = input_dict["pixel_values"][-one_img_length:, ...]
input_dict["image_grid_thw"] = input_dict["image_grid_thw"][-1:, ...]
curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-one_img_length:, ...]
curr_input_dict["image_grid_thw"] = curr_input_dict["image_grid_thw"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
_ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:one_img_length]
image_grid_thw = input_dict["image_grid_thw"][:1]
input_ids = curr_input_dict["input_ids"][:1]
pixel_values = curr_input_dict["pixel_values"][:one_img_length]
image_grid_thw = curr_input_dict["image_grid_thw"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
@ -262,11 +267,11 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
model = model_class(config).to(torch_device)
# Generate and make sure rope_deltas are not `None`
self.assertTrue(model.rope_deltas is None)
self.assertTrue(model.model.rope_deltas is None)
generation_output = model.generate(
**input_dict, max_new_tokens=4, return_dict_in_generate=True, output_logits=True
)
self.assertTrue(model.rope_deltas is not None)
self.assertTrue(model.model.rope_deltas is not None)
# Now if we try to do forward pass, we should get new rope logits, because cache is not passed
forward_output = model(**input_dict)
@ -320,6 +325,29 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
def test_save_load_fast_init_from_base(self):
pass
# The multimodal base model embeds will not match ids, due to pixel values. We can't change base test
# because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels`
# TODO: @raushan
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@require_torch
class Qwen2VLIntegrationTest(unittest.TestCase):

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch VideoLlava model."""
import copy
import unittest
import numpy as np
@ -23,6 +24,7 @@ from parameterized import parameterized
from transformers import (
VideoLlavaConfig,
VideoLlavaForConditionalGeneration,
VideoLlavaModel,
VideoLlavaProcessor,
is_torch_available,
is_vision_available,
@ -190,7 +192,14 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
Model tester for `VideoLlavaForConditionalGeneration`.
"""
all_model_classes = (VideoLlavaForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
VideoLlavaModel,
VideoLlavaForConditionalGeneration,
)
if is_torch_available()
else ()
)
fx_compatible = False
test_pruning = False
test_resize_embeddings = True
@ -235,46 +244,49 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
def test_mixed_input(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
curr_inputs = copy.deepcopy(inputs)
model = model_class(config).to(torch_device).eval()
# test that the forward does not fail
with torch.no_grad():
_ = model(**inputs)
_ = model(**curr_inputs)
# if we remove some images from inputs leaving only one
# image number mismatch error should raise
inputs["pixel_values_images"] = inputs["pixel_values_images"][:1]
curr_inputs["pixel_values_images"] = curr_inputs["pixel_values_images"][:1]
with self.assertRaises(ValueError):
_ = model(**inputs)
_ = model(**curr_inputs)
def test_video_only_input(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
curr_inputs = copy.deepcopy(inputs)
model = model_class(config).to(torch_device).eval()
# replace image token id with dummy id
# Error will be raised as num-image-tokens and num-of-image-embeds mismatch
inputs["input_ids"][:, : self.model_tester.num_image_tokens] = 2
curr_inputs["input_ids"][:, : self.model_tester.num_image_tokens] = 2
with self.assertRaises(ValueError):
_ = model(**inputs)
_ = model(**curr_inputs)
inputs["pixel_values_images"] = None
_ = model(**inputs)
curr_inputs["pixel_values_images"] = None
_ = model(**curr_inputs)
def test_image_only_input(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
curr_inputs = copy.deepcopy(inputs)
model = model_class(config).to(torch_device).eval()
# set dummy id, which is not video token id
# Error will be raised as num-video-tokens and num-of-video-embeds mismatch
inputs["input_ids"][
curr_inputs["input_ids"][
:,
self.model_tester.num_image_tokens : self.model_tester.num_image_tokens
+ self.model_tester.num_video_tokens,
] = 2
with self.assertRaises(ValueError):
_ = model(**inputs)
_ = model(**curr_inputs)
inputs["pixel_values_videos"] = None
_ = model(**inputs)
curr_inputs["pixel_values_videos"] = None
_ = model(**curr_inputs)
def test_batching_equivalence(self):
def recursive_check(batched_object, single_row_object, model_name, key):
@ -386,16 +398,17 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successful forward with no modifications
curr_input_dict = copy.deepcopy(input_dict)
_ = model(**curr_input_dict) # successfull forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values_images"] = input_dict["pixel_values_images"][-1:, ...]
curr_input_dict["pixel_values_images"] = curr_input_dict["pixel_values_images"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
_ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values_images"][:1]
input_ids = curr_input_dict["input_ids"][:1]
pixel_values = curr_input_dict["pixel_values_images"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
@ -429,7 +442,8 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
model = model_class(config).to(torch_device)
# We should have the right number of input features,
# and should be able to run a forward pass without exploding
assert model.multi_modal_projector.linear_1.in_features == expected_features
base_model = getattr(model, "model", model)
assert base_model.multi_modal_projector.linear_1.in_features == expected_features
model(**input_dict)

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch VipLlava model."""
import copy
import unittest
import requests
@ -22,6 +23,7 @@ from transformers import (
AutoProcessor,
VipLlavaConfig,
VipLlavaForConditionalGeneration,
VipLlavaModel,
is_torch_available,
is_vision_available,
)
@ -165,7 +167,14 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
Model tester for `VipLlavaForConditionalGeneration`.
"""
all_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (
(
VipLlavaModel,
VipLlavaForConditionalGeneration,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = {"image-text-to-text": VipLlavaForConditionalGeneration} if is_torch_available() else {}
fx_compatible = False
test_pruning = False
@ -236,16 +245,17 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successful forward with no modifications
curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further
_ = model(**curr_input_dict) # successful forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
_ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
input_ids = curr_input_dict["input_ids"][:1]
pixel_values = curr_input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
@ -284,7 +294,8 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
model = model_class(config).to(torch_device)
# We should have the right number of input features,
# and should be able to run a forward pass without exploding
assert model.multi_modal_projector.linear_1.in_features == expected_features
base_model = getattr(model, "model", model)
assert base_model.multi_modal_projector.linear_1.in_features == expected_features
model(**input_dict)
@unittest.skip(
@ -311,6 +322,10 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("LLaVA vision backbones doesn't support flex attention yet")
def test_flex_attention_with_grads(self):
pass
@require_torch
class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@ -3494,8 +3494,8 @@ class ModelTesterMixin:
vision_model_name = [name for name in vision_model_names if hasattr(model_sdpa, name)][0]
language_model_name = [name for name in language_model_names if hasattr(model_sdpa, name)][0]
vision_model_sdpa = getattr(model, vision_model_name)
language_model_sdpa = getattr(model, language_model_name)
vision_model_sdpa = getattr(model_sdpa, vision_model_name)
language_model_sdpa = getattr(model_sdpa, language_model_name)
text_attn = "sdpa" if language_model_sdpa._supports_sdpa else "eager"
vision_attn = "sdpa" if vision_model_sdpa._supports_sdpa else "eager"
@ -4489,7 +4489,8 @@ class ModelTesterMixin:
@require_torch_gpu
def test_flex_attention_with_grads(self):
for model_class in self.all_model_classes:
if not model_class._supports_flex_attn:
# TODO: raushan, fix for composite models after making VLMs support new attn API
if not model_class._supports_flex_attn or self._is_composite:
self.skipTest(reason="This model does not support flex attention")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config._attn_implementation = "flex_attention"

View File

@ -378,8 +378,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
"rope_theta",
"partial_rotary_factor",
"pretraining_tp",
"boi_token_index",
"eoi_token_index",
"boi_token_id",
"eoi_token_id",
]
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]

View File

@ -156,6 +156,8 @@ IGNORE_NON_TESTED = (
"Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests
"Emu3VQVAE", # Building part of bigger (tested) model
"Emu3TextModel", # Building part of bigger (tested) model
"Qwen2VLTextModel", # Building part of bigger (tested) model
"Qwen2_5_VLTextModel", # Building part of bigger (tested) model
"InternVLVisionModel", # Building part of bigger (tested) model
"JanusVisionModel", # Building part of bigger (tested) model
"TimesFmModel", # Building part of bigger (tested) model