diff --git a/docs/source/en/model_doc/aria.md b/docs/source/en/model_doc/aria.md
index 7b58f59cab7..89cd26db649 100644
--- a/docs/source/en/model_doc/aria.md
+++ b/docs/source/en/model_doc/aria.md
@@ -102,6 +102,10 @@ response = processor.decode(output_ids, skip_special_tokens=True)
[[autodoc]] AriaTextModel
+## AriaModel
+
+[[autodoc]] AriaModel
+
## AriaTextForCausalLM
[[autodoc]] AriaTextForCausalLM
diff --git a/docs/source/en/model_doc/aya_vision.md b/docs/source/en/model_doc/aya_vision.md
index 17daf494920..f2a82089506 100644
--- a/docs/source/en/model_doc/aya_vision.md
+++ b/docs/source/en/model_doc/aya_vision.md
@@ -237,6 +237,10 @@ for i, output in enumerate(batch_outputs):
[[autodoc]] AyaVisionConfig
+## AyaVisionModel
+
+[[autodoc]] AyaVisionModel
+
## AyaVisionForConditionalGeneration
[[autodoc]] AyaVisionForConditionalGeneration
diff --git a/docs/source/en/model_doc/emu3.md b/docs/source/en/model_doc/emu3.md
index 4ac7d0b0c4f..20b8a5e1cdb 100644
--- a/docs/source/en/model_doc/emu3.md
+++ b/docs/source/en/model_doc/emu3.md
@@ -174,6 +174,10 @@ for i, image in enumerate(images['pixel_values']):
[[autodoc]] Emu3TextModel
- forward
+## Emu3Model
+
+[[autodoc]] Emu3Model
+
## Emu3ForCausalLM
[[autodoc]] Emu3ForCausalLM
diff --git a/docs/source/en/model_doc/fuyu.md b/docs/source/en/model_doc/fuyu.md
index c0ea89ad19f..60ae9efdf3f 100644
--- a/docs/source/en/model_doc/fuyu.md
+++ b/docs/source/en/model_doc/fuyu.md
@@ -103,6 +103,10 @@ The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece.
[[autodoc]] FuyuConfig
+## FuyuModel
+
+[[autodoc]] FuyuModel
+
## FuyuForCausalLM
[[autodoc]] FuyuForCausalLM
diff --git a/docs/source/en/model_doc/gemma3.md b/docs/source/en/model_doc/gemma3.md
index f3864012c84..8372fd9ed15 100644
--- a/docs/source/en/model_doc/gemma3.md
+++ b/docs/source/en/model_doc/gemma3.md
@@ -254,6 +254,10 @@ visualizer("
What is shown in this image?")
[[autodoc]] Gemma3TextModel
- forward
+## Gemma3Model
+
+[[autodoc]] Gemma3Model
+
## Gemma3ForCausalLM
[[autodoc]] Gemma3ForCausalLM
diff --git a/docs/source/en/model_doc/got_ocr2.md b/docs/source/en/model_doc/got_ocr2.md
index bb14cdbcf84..c7a73659e88 100644
--- a/docs/source/en/model_doc/got_ocr2.md
+++ b/docs/source/en/model_doc/got_ocr2.md
@@ -277,6 +277,10 @@ alt="drawing" width="600"/>
[[autodoc]] GotOcr2Processor
+## GotOcr2Model
+
+[[autodoc]] GotOcr2Model
+
## GotOcr2ForConditionalGeneration
[[autodoc]] GotOcr2ForConditionalGeneration
diff --git a/docs/source/en/model_doc/instructblip.md b/docs/source/en/model_doc/instructblip.md
index 4f2feb015f1..944e8888fcf 100644
--- a/docs/source/en/model_doc/instructblip.md
+++ b/docs/source/en/model_doc/instructblip.md
@@ -69,6 +69,10 @@ The attributes can be obtained from model config, as `model.config.num_query_tok
[[autodoc]] InstructBlipQFormerModel
- forward
+## InstructBlipModel
+
+[[autodoc]] InstructBlipModel
+
## InstructBlipForConditionalGeneration
[[autodoc]] InstructBlipForConditionalGeneration
diff --git a/docs/source/en/model_doc/instructblipvideo.md b/docs/source/en/model_doc/instructblipvideo.md
index c26562a8530..c021a4c7afa 100644
--- a/docs/source/en/model_doc/instructblipvideo.md
+++ b/docs/source/en/model_doc/instructblipvideo.md
@@ -73,6 +73,10 @@ The attributes can be obtained from model config, as `model.config.num_query_tok
[[autodoc]] InstructBlipVideoQFormerModel
- forward
+## InstructBlipVideoModel
+[[autodoc]] InstructBlipVideoModel
+ - forward
+
## InstructBlipVideoForConditionalGeneration
[[autodoc]] InstructBlipVideoForConditionalGeneration
diff --git a/docs/source/en/model_doc/internvl.md b/docs/source/en/model_doc/internvl.md
index 4ac56c85377..8a19726e3e0 100644
--- a/docs/source/en/model_doc/internvl.md
+++ b/docs/source/en/model_doc/internvl.md
@@ -340,6 +340,11 @@ This example showcases how to handle a batch of chat conversations with interlea
[[autodoc]] InternVLVisionModel
- forward
+## InternVLModel
+
+[[autodoc]] InternVLModel
+ - forward
+
## InternVLForConditionalGeneration
[[autodoc]] InternVLForConditionalGeneration
diff --git a/docs/source/en/model_doc/llava.md b/docs/source/en/model_doc/llava.md
index 79033ec5a18..dcf2cd2f3f6 100644
--- a/docs/source/en/model_doc/llava.md
+++ b/docs/source/en/model_doc/llava.md
@@ -256,6 +256,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] LlavaProcessor
+## LlavaModel
+
+[[autodoc]] LlavaModel
+
## LlavaForConditionalGeneration
[[autodoc]] LlavaForConditionalGeneration
diff --git a/docs/source/en/model_doc/llava_next.md b/docs/source/en/model_doc/llava_next.md
index 7d85ab8b696..2af882b6118 100644
--- a/docs/source/en/model_doc/llava_next.md
+++ b/docs/source/en/model_doc/llava_next.md
@@ -315,6 +315,10 @@ model = AutoModelForImageTextToText.from_pretrained(
[[autodoc]] LlavaNextProcessor
+## LlavaNextModel
+
+[[autodoc]] LlavaNextModel
+
## LlavaNextForConditionalGeneration
[[autodoc]] LlavaNextForConditionalGeneration
diff --git a/docs/source/en/model_doc/llava_next_video.md b/docs/source/en/model_doc/llava_next_video.md
index e435861cbe2..d3306256846 100644
--- a/docs/source/en/model_doc/llava_next_video.md
+++ b/docs/source/en/model_doc/llava_next_video.md
@@ -262,6 +262,10 @@ model = LlavaNextVideoForConditionalGeneration.from_pretrained(
[[autodoc]] LlavaNextVideoImageProcessor
+## LlavaNextVideoModel
+
+[[autodoc]] LlavaNextVideoModel
+
## LlavaNextVideoForConditionalGeneration
[[autodoc]] LlavaNextVideoForConditionalGeneration
diff --git a/docs/source/en/model_doc/llava_onevision.md b/docs/source/en/model_doc/llava_onevision.md
index 77fe807d46d..a00dd5a0e12 100644
--- a/docs/source/en/model_doc/llava_onevision.md
+++ b/docs/source/en/model_doc/llava_onevision.md
@@ -313,6 +313,10 @@ model = LlavaOnevisionForConditionalGeneration.from_pretrained(
[[autodoc]] LlavaOnevisionVideoProcessor
+## LlavaOnevisionModel
+
+[[autodoc]] LlavaOnevisionModel
+
## LlavaOnevisionForConditionalGeneration
[[autodoc]] LlavaOnevisionForConditionalGeneration
diff --git a/docs/source/en/model_doc/mistral3.md b/docs/source/en/model_doc/mistral3.md
index 4efdb641559..8eedb5de6bd 100644
--- a/docs/source/en/model_doc/mistral3.md
+++ b/docs/source/en/model_doc/mistral3.md
@@ -227,6 +227,9 @@ This example also how to use `BitsAndBytes` to load the model in 4bit quantizati
[[autodoc]] Mistral3Config
+## Mistral3Model
+
+[[autodoc]] Mistral3Model
## Mistral3ForConditionalGeneration
diff --git a/docs/source/en/model_doc/mllama.md b/docs/source/en/model_doc/mllama.md
index 77f5e211f17..cdd4da240af 100644
--- a/docs/source/en/model_doc/mllama.md
+++ b/docs/source/en/model_doc/mllama.md
@@ -130,6 +130,10 @@ print(processor.decode(output[0], skip_special_tokens=True))
[[autodoc]] MllamaTextModel
- forward
+## MllamaModel
+
+[[autodoc]] MllamaModel
+
## MllamaForCausalLM
[[autodoc]] MllamaForCausalLM
diff --git a/docs/source/en/model_doc/paligemma.md b/docs/source/en/model_doc/paligemma.md
index fa119a5f836..a0a0c1b714f 100644
--- a/docs/source/en/model_doc/paligemma.md
+++ b/docs/source/en/model_doc/paligemma.md
@@ -174,6 +174,10 @@ visualizer("
What is in this image?")
[[autodoc]] PaliGemmaProcessor
+## PaliGemmaModel
+
+[[autodoc]] PaliGemmaModel
+
## PaliGemmaForConditionalGeneration
[[autodoc]] PaliGemmaForConditionalGeneration
diff --git a/docs/source/en/model_doc/qwen2_5_vl.md b/docs/source/en/model_doc/qwen2_5_vl.md
index c414b41faac..57b88d1b8da 100644
--- a/docs/source/en/model_doc/qwen2_5_vl.md
+++ b/docs/source/en/model_doc/qwen2_5_vl.md
@@ -240,6 +240,10 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
[[autodoc]] Qwen2_5_VLProcessor
+## Qwen2_5_VLTextModel
+
+[[autodoc]] Qwen2_5_VLTextModel
+ - forward
## Qwen2_5_VLModel
diff --git a/docs/source/en/model_doc/qwen2_vl.md b/docs/source/en/model_doc/qwen2_vl.md
index 3d1845b6015..7fef4e2fdbd 100644
--- a/docs/source/en/model_doc/qwen2_vl.md
+++ b/docs/source/en/model_doc/qwen2_vl.md
@@ -296,6 +296,11 @@ model = Qwen2VLForConditionalGeneration.from_pretrained(
[[autodoc]] Qwen2VLProcessor
+## Qwen2VLTextModel
+
+[[autodoc]] Qwen2VLTextModel
+ - forward
+
## Qwen2VLModel
[[autodoc]] Qwen2VLModel
diff --git a/docs/source/en/model_doc/video_llava.md b/docs/source/en/model_doc/video_llava.md
index f407b4dc5eb..ca1a06d4cdc 100644
--- a/docs/source/en/model_doc/video_llava.md
+++ b/docs/source/en/model_doc/video_llava.md
@@ -215,6 +215,10 @@ model = VideoLlavaForConditionalGeneration.from_pretrained(
[[autodoc]] VideoLlavaProcessor
+## VideoLlavaModel
+
+[[autodoc]] VideoLlavaModel
+
## VideoLlavaForConditionalGeneration
[[autodoc]] VideoLlavaForConditionalGeneration
diff --git a/docs/source/en/model_doc/vipllava.md b/docs/source/en/model_doc/vipllava.md
index 9438893dfb1..8edf1540268 100644
--- a/docs/source/en/model_doc/vipllava.md
+++ b/docs/source/en/model_doc/vipllava.md
@@ -101,6 +101,10 @@ A chat between a curious human and an artificial intelligence assistant. The ass
[[autodoc]] VipLlavaConfig
+## VipLlavaModel
+
+[[autodoc]] VipLlavaModel
+
## VipLlavaForConditionalGeneration
[[autodoc]] VipLlavaForConditionalGeneration
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index be84fdee54b..344990c2c9f 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -216,6 +216,28 @@ TORCH_INIT_FUNCTIONS = {
"kaiming_normal": nn.init.kaiming_normal,
}
+# DO NOT MODIFY, KEPT FOR BC ONLY
+VLMS = [
+ "aria",
+ "aya_vision",
+ "emu3",
+ "fuyu",
+ "got_ocr2",
+ "gemma3",
+ "internvl",
+ "llava",
+ "llava_next",
+ "llava_next_video",
+ "llava_onevision",
+ "mistral3",
+ "mllama",
+ "paligemma",
+ "qwen2_vl",
+ "qwem2_5_vl",
+ "video_llava",
+ "vipllava",
+]
+
@contextmanager
def no_init_weights():
@@ -1778,6 +1800,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
main_input_name = "input_ids"
model_tags = None
+ _checkpoint_conversion_mapping = {} # used for BC support in VLMs, not meant to be used by new models
+
_auto_class = None
_no_split_modules = None
_skip_keys_device_placement = None
@@ -3484,6 +3508,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
module_map[name + f".{key}"] = module
state_dict = model_to_save.state_dict()
+ if any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS):
+ reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
+
+ original_state_dict = {}
+ for key, value in state_dict.items():
+ for pattern, replacement in reverse_key_mapping.items():
+ replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
+ replacement = re.sub(r"\(.*?\)", "", pattern)
+ key, n_replace = re.subn(pattern, replacement, key)
+ # Early exit of the loop
+ if n_replace > 0:
+ break
+ original_state_dict[key] = value
+ state_dict = original_state_dict
+
# Translate state_dict from smp to hf if saving with smp >= 1.10
if IS_SAGEMAKER_MP_POST_1_10:
for smp_to_hf, _ in smp.state.module_manager.translate_functions:
@@ -4071,7 +4110,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
- key_mapping = kwargs.pop("key_mapping", None)
+
+ # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
+ if any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
+ key_mapping = kwargs.pop("key_mapping", cls._checkpoint_conversion_mapping)
+ else:
+ key_mapping = kwargs.pop("key_mapping", None)
+
# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("resume_download", None)
_ = kwargs.pop("trust_remote_code", None)
diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py
index f6a1861d308..31e0980fdeb 100644
--- a/src/transformers/models/aria/modeling_aria.py
+++ b/src/transformers/models/aria/modeling_aria.py
@@ -42,7 +42,7 @@ from ...utils import (
replace_return_docstrings,
)
from ...utils.import_utils import is_torch_available
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_aria import AriaConfig, AriaTextConfig
@@ -58,7 +58,9 @@ if is_torch_flex_attn_available():
logger = logging.get_logger(__name__)
-_CONFIG_FOR_DOC = "AriaTextConfig"
+
+
+_CONFIG_FOR_DOC = "AriaConfig"
@use_kernel_forward_from_hub("RMSNorm")
@@ -659,7 +661,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
"""
- config_class = AriaConfig
+ config_class = AriaTextConfig
base_model_prefix = "model"
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
supports_gradient_checkpointing = True
@@ -706,8 +708,8 @@ ARIA_TEXT_START_DOCSTRING = r"""
ARIA_TEXT_START_DOCSTRING,
)
class AriaPreTrainedModel(PreTrainedModel):
- config_class = AriaTextConfig
- base_model_prefix = "model"
+ config_class = AriaConfig
+ base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["AriaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
@@ -1097,6 +1099,9 @@ class AriaTextModel(AriaTextPreTrainedModel):
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+_CONFIG_FOR_TEXT_DOC = "AriaTextConfig"
+
+
class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
"""
Aria model for causal language modeling tasks.
@@ -1112,7 +1117,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
- config_class = AriaTextConfig
def __init__(self, config: AriaTextConfig):
super().__init__(config)
@@ -1141,9 +1145,8 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model
- @can_return_tuple
@add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_TEXT_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -1255,7 +1258,7 @@ class AriaCausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@@ -1267,6 +1270,39 @@ class AriaCausalLMOutputWithPast(ModelOutput):
image_hidden_states: Optional[torch.FloatTensor] = None
+@dataclass
+class AriaModelOutputWithPast(BaseModelOutputWithPast):
+ """
+ Base class for Aria outputs, with hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
ARIA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor`, *optional*):
@@ -1320,30 +1356,131 @@ ARIA_START_DOCSTRING = r"""
@add_start_docstrings(
- """Aria model for conditional generation tasks.
-
- This model combines a vision tower, a multi-modal projector, and a language model
- to perform tasks that involve both image and text inputs.""",
+ """The Aria model which consists of a vision backbone and a language model, without a language modeling head.""",
ARIA_START_DOCSTRING,
)
-class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
- config_class = AriaConfig
- _supports_flash_attn_2 = False
- _supports_flex_attn = False
- _supports_sdpa = False
- _tied_weights_keys = ["language_model.lm_head.weight"]
+class AriaModel(AriaPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: AriaConfig):
super().__init__(config)
-
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = AriaProjector(config)
- self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
- self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
+ self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_mask: torch.FloatTensor = None,
+ vision_feature_layer: int = -1,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
+ The tensors corresponding to the input images.
+ pixel_mask (`torch.FloatTensor]`, *optional*):
+ The tensors corresponding to the input image mask.
+ vision_feature_layer (`Union[int, List[int]]`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
+ image_outputs = self.vision_tower(
+ pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
+ )
+ image_attn_mask = None
+ if patch_attention_mask is not None:
+ flattened_mask = patch_attention_mask.flatten(1)
+ image_attn_mask = torch.logical_not(flattened_mask)
+
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
+ image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
+ return image_features
+
+ @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ pixel_mask: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, AriaModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # 2. Merge text and images
+ if pixel_values is not None and inputs_embeds.shape[1] != 1:
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
+ else:
+ image_embeds = input_ids == self.config.image_token_id
+ special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ vision_feature_layer=self.config.vision_feature_layer,
+ )
+ n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
+ n_image_features = n_images * n_features_per_image
+ if n_image_tokens != n_image_features:
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ )
+
+ output = AriaModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values if use_cache else None,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
def _create_patch_attention_mask(self, pixel_mask):
if pixel_mask is None:
return None
@@ -1360,51 +1497,61 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
+
+@add_start_docstrings(
+ """Aria model for conditional generation tasks.
+ This model combines a vision tower, a multi-modal projector, and a language model
+ to perform tasks that involve both image and text inputs.""",
+ ARIA_START_DOCSTRING,
+)
+class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: AriaConfig):
+ super().__init__(config)
+ self.model = AriaModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
def get_input_embeddings(self):
- return self.language_model.get_input_embeddings()
+ return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
- self.language_model.set_input_embeddings(value)
+ self.model.set_input_embeddings(value)
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
+ self.lm_head = new_embeddings
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
- def get_decoder(self):
- return self.language_model.get_decoder()
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
- def get_image_features(
- self,
- pixel_values: torch.FloatTensor,
- pixel_mask: Optional[torch.FloatTensor] = None,
- vision_feature_layer: int = -1,
- ):
- patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
- image_outputs = self.vision_tower(
- pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
- )
- image_attn_mask = None
- if patch_attention_mask is not None:
- flattened_mask = patch_attention_mask.flatten(1)
- image_attn_mask = torch.logical_not(flattened_mask)
-
- selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
- image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
- return image_features
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
+ @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
- pixel_mask: Optional[torch.LongTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ pixel_mask: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@@ -1413,10 +1560,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**loss_kwargs,
- ) -> AriaCausalLMOutputWithPast:
+ ) -> Union[Tuple, AriaCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -1482,37 +1630,12 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- # 2. Merge text and images
- if pixel_values is not None and inputs_embeds.shape[1] != 1:
- if input_ids is None:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
- else:
- image_embeds = input_ids == self.config.image_token_id
- special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
- image_features = self.get_image_features(
- pixel_values=pixel_values,
- pixel_mask=pixel_mask,
- vision_feature_layer=self.config.vision_feature_layer,
- )
- n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
- n_image_features = n_images * n_features_per_image
- if n_image_tokens != n_image_features:
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
-
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- outputs: CausalLMOutputWithPast = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@@ -1520,11 +1643,14 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- logits_to_keep=logits_to_keep,
+ return_dict=return_dict,
cache_position=cache_position,
)
- logits = outputs.logits
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
@@ -1552,7 +1678,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
logits_to_keep=None,
**kwargs,
):
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1570,11 +1696,67 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
return model_inputs
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
__all__ = [
"AriaForConditionalGeneration",
"AriaPreTrainedModel",
"AriaTextPreTrainedModel",
"AriaTextModel",
+ "AriaModel",
"AriaTextForCausalLM",
]
diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py
index e4d063f7827..a42c7227772 100644
--- a/src/transformers/models/aria/modular_aria.py
+++ b/src/transformers/models/aria/modular_aria.py
@@ -18,7 +18,6 @@ import numpy as np
from ...activations import ACT2FN
from ...configuration_utils import PretrainedConfig
-from ...generation import GenerationMixin
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
from ...image_utils import (
@@ -49,7 +48,7 @@ from ...utils import (
replace_return_docstrings,
)
from ...utils.import_utils import is_torch_available
-from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
+from ..auto import CONFIG_MAPPING, AutoConfig, AutoTokenizer
from ..llama.configuration_llama import LlamaConfig
from ..llama.modeling_llama import (
LlamaDecoderLayer,
@@ -59,7 +58,12 @@ from ..llama.modeling_llama import (
LlamaPreTrainedModel,
LlamaRMSNorm,
)
-from ..llava.modeling_llava import LlavaCausalLMOutputWithPast
+from ..llava.modeling_llava import (
+ LlavaCausalLMOutputWithPast,
+ LlavaForConditionalGeneration,
+ LlavaModel,
+ LlavaModelOutputWithPast,
+)
from ..llava_next.image_processing_llava_next import divide_to_patches
@@ -70,6 +74,11 @@ if is_torch_available():
from torch import nn
+_CONFIG_FOR_DOC = "AriaConfig"
+_CONFIG_FOR_TEXT_DOC = "AriaTextConfig"
+ARIA_TEXT_INPUTS_DOCSTRING = None
+
+
def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
"""
Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
@@ -1223,7 +1232,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
"""
- config_class = AriaConfig
+ config_class = AriaTextConfig
base_model_prefix = "model"
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
supports_gradient_checkpointing = True
@@ -1249,6 +1258,8 @@ class AriaTextPreTrainedModel(PreTrainedModel):
class AriaPreTrainedModel(LlamaPreTrainedModel):
+ config_class = AriaConfig
+ base_model_prefix = ""
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
_supports_attention_backend = False
@@ -1292,7 +1303,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
"""
_tied_weights_keys = ["lm_head.weight"]
- config_class = AriaTextConfig
def __init__(self, config: AriaTextConfig):
super().__init__(config)
@@ -1303,11 +1313,20 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
# Initialize weights and apply final processing
self.post_init()
+ @add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_TEXT_DOC)
+ def forward(self, **super_kwargs):
+ super().forward(self, **super_kwargs)
+
class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
+class AriaModelOutputWithPast(LlavaModelOutputWithPast):
+ pass
+
+
ARIA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor`, *optional*):
@@ -1360,30 +1379,10 @@ ARIA_START_DOCSTRING = r"""
"""
-@add_start_docstrings(
- """Aria model for conditional generation tasks.
-
- This model combines a vision tower, a multi-modal projector, and a language model
- to perform tasks that involve both image and text inputs.""",
- ARIA_START_DOCSTRING,
-)
-class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
- config_class = AriaConfig
- _supports_flash_attn_2 = False
- _supports_flex_attn = False
- _supports_sdpa = False
- _tied_weights_keys = ["language_model.lm_head.weight"]
-
+class AriaModel(LlavaModel):
def __init__(self, config: AriaConfig):
super().__init__(config)
-
- self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = AriaProjector(config)
- self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
- self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
- self.post_init()
def _create_patch_attention_mask(self, pixel_mask):
if pixel_mask is None:
@@ -1401,30 +1400,27 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
- def get_input_embeddings(self):
- return self.language_model.get_input_embeddings()
-
- def set_input_embeddings(self, value):
- self.language_model.set_input_embeddings(value)
-
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
def get_image_features(
self,
pixel_values: torch.FloatTensor,
- pixel_mask: Optional[torch.FloatTensor] = None,
+ pixel_mask: torch.FloatTensor = None,
vision_feature_layer: int = -1,
):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
+ The tensors corresponding to the input images.
+ pixel_mask (`torch.FloatTensor]`, *optional*):
+ The tensors corresponding to the input image mask.
+ vision_feature_layer (`Union[int, List[int]]`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
image_outputs = self.vision_tower(
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
@@ -1438,14 +1434,94 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features
- @can_return_tuple
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
- pixel_mask: Optional[torch.LongTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ pixel_mask: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, AriaModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # 2. Merge text and images
+ if pixel_values is not None and inputs_embeds.shape[1] != 1:
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
+ else:
+ image_embeds = input_ids == self.config.image_token_id
+ special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ vision_feature_layer=self.config.vision_feature_layer,
+ )
+ n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
+ n_image_features = n_images * n_features_per_image
+ if n_image_tokens != n_image_features:
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ )
+
+ output = AriaModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values if use_cache else None,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+@add_start_docstrings(
+ """Aria model for conditional generation tasks.
+ This model combines a vision tower, a multi-modal projector, and a language model
+ to perform tasks that involve both image and text inputs.""",
+ ARIA_START_DOCSTRING,
+)
+class AriaForConditionalGeneration(LlavaForConditionalGeneration):
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ pixel_mask: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@@ -1454,10 +1530,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**loss_kwargs,
- ) -> AriaCausalLMOutputWithPast:
+ ) -> Union[Tuple, AriaCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -1523,37 +1600,12 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- # 2. Merge text and images
- if pixel_values is not None and inputs_embeds.shape[1] != 1:
- if input_ids is None:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
- else:
- image_embeds = input_ids == self.config.image_token_id
- special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
- image_features = self.get_image_features(
- pixel_values=pixel_values,
- pixel_mask=pixel_mask,
- vision_feature_layer=self.config.vision_feature_layer,
- )
- n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
- n_image_features = n_images * n_features_per_image
- if n_image_tokens != n_image_features:
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
-
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- outputs: CausalLMOutputWithPast = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@@ -1561,11 +1613,14 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- logits_to_keep=logits_to_keep,
+ return_dict=return_dict,
cache_position=cache_position,
)
- logits = outputs.logits
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
@@ -1593,7 +1648,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
logits_to_keep=None,
**kwargs,
):
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1621,5 +1676,6 @@ __all__ = [
"AriaPreTrainedModel",
"AriaTextPreTrainedModel",
"AriaTextModel",
+ "AriaModel",
"AriaTextForCausalLM",
]
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index b196a7718f7..bd7dff185b2 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -35,10 +35,11 @@ MODEL_MAPPING_NAMES = OrderedDict(
("albert", "AlbertModel"),
("align", "AlignModel"),
("altclip", "AltCLIPModel"),
- ("aria", "AriaForConditionalGeneration"),
+ ("aria", "AriaModel"),
("aria_text", "AriaTextModel"),
("audio-spectrogram-transformer", "ASTModel"),
("autoformer", "AutoformerModel"),
+ ("aya_vision", "AyaVisionModel"),
("bamba", "BambaModel"),
("bark", "BarkModel"),
("bart", "BartModel"),
@@ -108,6 +109,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("efficientformer", "EfficientFormerModel"),
("efficientnet", "EfficientNetModel"),
("electra", "ElectraModel"),
+ ("emu3", "Emu3Model"),
("encodec", "EncodecModel"),
("ernie", "ErnieModel"),
("ernie_m", "ErnieMModel"),
@@ -121,14 +123,16 @@ MODEL_MAPPING_NAMES = OrderedDict(
("focalnet", "FocalNetModel"),
("fsmt", "FSMTModel"),
("funnel", ("FunnelModel", "FunnelBaseModel")),
+ ("fuyu", "FuyuModel"),
("gemma", "GemmaModel"),
("gemma2", "Gemma2Model"),
+ ("gemma3", "Gemma3Model"),
("gemma3_text", "Gemma3TextModel"),
("git", "GitModel"),
("glm", "GlmModel"),
("glm4", "Glm4Model"),
("glpn", "GLPNModel"),
- ("got_ocr2", "GotOcr2ForConditionalGeneration"),
+ ("got_ocr2", "GotOcr2Model"),
("gpt-sw3", "GPT2Model"),
("gpt2", "GPT2Model"),
("gpt_bigcode", "GPTBigCodeModel"),
@@ -156,6 +160,9 @@ MODEL_MAPPING_NAMES = OrderedDict(
("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
+ ("instructblip", "InstructBlipModel"),
+ ("instructblipvideo", "InstructBlipVideoModel"),
+ ("internvl", "InternVLModel"),
("internvl_vision", "InternVLVisionModel"),
("jamba", "JambaModel"),
("janus", "JanusModel"),
@@ -170,6 +177,10 @@ MODEL_MAPPING_NAMES = OrderedDict(
("lilt", "LiltModel"),
("llama", "LlamaModel"),
("llama4", "Llama4ForConditionalGeneration"),
+ ("llava", "LlavaModel"),
+ ("llava_next", "LlavaNextModel"),
+ ("llava_next_video", "LlavaNextVideoModel"),
+ ("llava_onevision", "LlavaOnevisionModel"),
("longformer", "LongformerModel"),
("longt5", "LongT5Model"),
("luke", "LukeModel"),
@@ -189,8 +200,10 @@ MODEL_MAPPING_NAMES = OrderedDict(
("mgp-str", "MgpstrForSceneTextRecognition"),
("mimi", "MimiModel"),
("mistral", "MistralModel"),
+ ("mistral3", "Mistral3Model"),
("mixtral", "MixtralModel"),
("mlcd", "MLCDVisionModel"),
+ ("mllama", "MllamaModel"),
("mobilebert", "MobileBertModel"),
("mobilenet_v1", "MobileNetV1Model"),
("mobilenet_v2", "MobileNetV2Model"),
@@ -221,6 +234,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("opt", "OPTModel"),
("owlv2", "Owlv2Model"),
("owlvit", "OwlViTModel"),
+ ("paligemma", "PaliGemmaModel"),
("patchtsmixer", "PatchTSMixerModel"),
("patchtst", "PatchTSTModel"),
("pegasus", "PegasusModel"),
@@ -240,11 +254,11 @@ MODEL_MAPPING_NAMES = OrderedDict(
("qdqbert", "QDQBertModel"),
("qwen2", "Qwen2Model"),
("qwen2_5_vl", "Qwen2_5_VLModel"),
- ("qwen2_5_vl_text", "Qwen2_5_VLModel"),
+ ("qwen2_5_vl_text", "Qwen2_5_VLTextModel"),
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
("qwen2_moe", "Qwen2MoeModel"),
("qwen2_vl", "Qwen2VLModel"),
- ("qwen2_vl_text", "Qwen2VLModel"),
+ ("qwen2_vl_text", "Qwen2VLTextModel"),
("qwen3", "Qwen3Model"),
("qwen3_moe", "Qwen3MoeModel"),
("recurrent_gemma", "RecurrentGemmaModel"),
@@ -306,8 +320,10 @@ MODEL_MAPPING_NAMES = OrderedDict(
("unispeech-sat", "UniSpeechSatModel"),
("univnet", "UnivNetModel"),
("van", "VanModel"),
+ ("video_llava", "VideoLlavaModel"),
("videomae", "VideoMAEModel"),
("vilt", "ViltModel"),
+ ("vipllava", "VipLlavaModel"),
("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
("visual_bert", "VisualBertModel"),
("vit", "ViTModel"),
@@ -879,6 +895,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("llama4", "Llama4ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
+ ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
("mistral3", "Mistral3ForConditionalGeneration"),
("mllama", "MllamaForConditionalGeneration"),
diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py
index 45c2ab66e3e..042ff0c05a7 100644
--- a/src/transformers/models/aya_vision/modeling_aya_vision.py
+++ b/src/transformers/models/aya_vision/modeling_aya_vision.py
@@ -27,15 +27,16 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
-from ...modeling_outputs import ModelOutput
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ can_return_tuple,
is_torchdynamo_compiling,
replace_return_docstrings,
)
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_aya_vision import AyaVisionConfig
@@ -115,9 +116,8 @@ AYA_VISION_START_DOCSTRING = r"""
)
class AyaVisionPreTrainedModel(PreTrainedModel):
config_class = AyaVisionConfig
- base_model_prefix = "model"
+ base_model_prefix = ""
supports_gradient_checkpointing = True
- _no_split_modules = ["AyaVisionVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@@ -169,7 +169,7 @@ class AyaVisionCausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@@ -181,7 +181,40 @@ class AyaVisionCausalLMOutputWithPast(ModelOutput):
image_hidden_states: Optional[torch.FloatTensor] = None
-AYA_VISION_INPUTS_DOCSTRING = """
+@dataclass
+class AyaVisionModelOutputWithPast(BaseModelOutputWithPast):
+ """
+ Base class for AyaVision outputs, with hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+AYA_VISION_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@@ -193,8 +226,8 @@ AYA_VISION_INPUTS_DOCSTRING = """
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
- [`AutoImageProcessor`]. See [`GotOcr2ImageProcessor.__call__`] for details. [`AyaVisionProcessor`] uses
- [`GotOcr2ImageProcessor`] for processing images.
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`AyaVisionProcessor`] uses
+ [`CLIPImageProcessor`] for processing images).
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
@@ -259,23 +292,18 @@ AYA_VISION_INPUTS_DOCSTRING = """
@add_start_docstrings(
- """The AyaVision model which consists of a vision backbone and a language model.""",
+ """The AyaVision model which consists of a vision backbone and a language model, without a language modeling head.""",
AYA_VISION_START_DOCSTRING,
)
-class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin):
+class AyaVisionModel(AyaVisionPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: AyaVisionConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = AyaVisionMultiModalProjector(config)
- self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
-
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
-
+ self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@@ -284,18 +312,6 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@@ -307,7 +323,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
- pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
@@ -342,6 +358,140 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
+ @add_start_docstrings_to_model_forward(AYA_VISION_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ image_sizes: torch.Tensor = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, AyaVisionModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_sizes=image_sizes,
+ )
+
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
+ else:
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ output = AyaVisionModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+@add_start_docstrings(
+ """The AyaVision model which consists of a vision backbone and a language model.""",
+ AYA_VISION_START_DOCSTRING,
+)
+class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: AyaVisionConfig):
+ super().__init__(config)
+ self.model = AyaVisionModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
@add_start_docstrings_to_model_forward(AYA_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AyaVisionCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@@ -410,7 +560,6 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
>>> gen_tokens = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.3)
>>> processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
```"""
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -425,73 +574,32 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
else self.config.vision_feature_select_strategy
)
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if pixel_values is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
- )
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if pixel_values is not None:
- image_features = self.get_image_features(
- pixel_values=pixel_values,
- vision_feature_layer=vision_feature_layer,
- vision_feature_select_strategy=vision_feature_select_strategy,
- image_sizes=image_sizes,
- )
-
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
- n_image_tokens = (input_ids == self.config.image_token_id).sum()
- n_image_features = image_features.shape[0] * image_features.shape[1]
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- outputs = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- return_dict=return_dict,
+ return_dict=True,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
+ image_sizes=image_sizes,
**lm_kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
- )
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return AyaVisionCausalLMOutputWithPast(
loss=loss,
@@ -499,7 +607,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@@ -515,7 +623,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -532,15 +640,60 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
return model_inputs
- def tie_weights(self):
- return self.language_model.tie_weights()
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
- def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
- model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
- # update vocab size
- self.config.text_config.vocab_size = model_embeds.num_embeddings
- self.vocab_size = model_embeds.num_embeddings
- return model_embeds
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
-__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel"]
+__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]
diff --git a/src/transformers/models/aya_vision/modular_aya_vision.py b/src/transformers/models/aya_vision/modular_aya_vision.py
index 96f0de888dd..5d7acecff88 100644
--- a/src/transformers/models/aya_vision/modular_aya_vision.py
+++ b/src/transformers/models/aya_vision/modular_aya_vision.py
@@ -22,6 +22,7 @@ from torch import nn
from transformers.models.llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
+ LlavaModel,
LlavaPreTrainedModel,
)
@@ -88,21 +89,8 @@ class AyaVisionMultiModalProjector(nn.Module):
return image_features
-AYA_VISION_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`AyaVisionConfig`] or [`AyaVisionVisionConfig`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
+AYA_VISION_START_DOCSTRING = None
+AYA_VISION_INPUTS_DOCSTRING = None
@add_start_docstrings(
@@ -133,81 +121,8 @@ class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
-AYA_VISION_INPUTS_DOCSTRING = """
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
- The tensors corresponding to the input images. Pixel values can be obtained using
- [`AutoImageProcessor`]. See [`GotOcr2ImageProcessor.__call__`] for details. [`AyaVisionProcessor`] uses
- [`GotOcr2ImageProcessor`] for processing images.
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
- `past_key_values`).
-
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
-
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
-
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
- The index of the layer to select the vision feature. If multiple indices are provided,
- the vision feature of the corresponding indices will be concatenated to form the
- vision features.
- vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
- The feature selection strategy used to select the vision feature from the vision backbone.
- Can be one of `"default"` or `"full"`.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
- the complete sequence length.
-"""
+class AyaVisionModel(LlavaModel):
+ pass
@add_start_docstrings(
@@ -215,16 +130,6 @@ AYA_VISION_INPUTS_DOCSTRING = """
AYA_VISION_START_DOCSTRING,
)
class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration):
- def tie_weights(self):
- return self.language_model.tie_weights()
-
- def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
- model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
- # update vocab size
- self.config.text_config.vocab_size = model_embeds.num_embeddings
- self.vocab_size = model_embeds.num_embeddings
- return model_embeds
-
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -312,4 +217,4 @@ class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration):
)
-__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel"]
+__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]
diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py
index f1d6b3a0476..c3a669d4c98 100644
--- a/src/transformers/models/colpali/modeling_colpali.py
+++ b/src/transformers/models/colpali/modeling_colpali.py
@@ -175,8 +175,8 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
self.vocab_size = config.vlm_config.text_config.vocab_size
vlm = AutoModelForImageTextToText.from_config(config.vlm_config)
- if vlm.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"vlm.language_model.{k}" for k in vlm.language_model._tied_weights_keys]
+ if vlm._tied_weights_keys is not None:
+ self._tied_weights_keys = [f"vlm.{k}" for k in vlm._tied_weights_keys]
self.vlm = vlm
self.embedding_dim = self.config.embedding_dim
@@ -246,25 +246,25 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
)
def get_input_embeddings(self):
- return self.vlm.language_model.get_input_embeddings()
+ return self.vlm.get_input_embeddings()
def set_input_embeddings(self, value):
- self.vlm.language_model.set_input_embeddings(value)
+ self.vlm.set_input_embeddings(value)
def get_output_embeddings(self):
- return self.vlm.language_model.get_output_embeddings()
+ return self.vlm.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
- self.vlm.language_model.set_output_embeddings(new_embeddings)
+ self.vlm.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
- self.vlm.language_model.set_decoder(decoder)
+ self.vlm.set_decoder(decoder)
def get_decoder(self):
- return self.vlm.language_model.get_decoder()
+ return self.vlm.get_decoder()
def tie_weights(self):
- return self.vlm.language_model.tie_weights()
+ return self.vlm.tie_weights()
def resize_token_embeddings(
self,
@@ -272,7 +272,7 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
pad_to_multiple_of: Optional[int] = None,
mean_resizing: bool = True,
) -> nn.Embedding:
- model_embeds = self.vlm.language_model.resize_token_embeddings(
+ model_embeds = self.vlm.resize_token_embeddings(
new_num_tokens=new_num_tokens,
pad_to_multiple_of=pad_to_multiple_of,
mean_resizing=mean_resizing,
diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py
index 67a643b273e..45031a1a647 100644
--- a/src/transformers/models/emu3/modeling_emu3.py
+++ b/src/transformers/models/emu3/modeling_emu3.py
@@ -1778,13 +1778,16 @@ EMU3_INPUTS_DOCSTRING = r"""
"""
-class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["text_model.lm_head.weight"]
- _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compilable
+class Emu3Model(Emu3PreTrainedModel):
+ _checkpoint_conversion_mapping = {"text_model.model": "text_model"}
+ _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
def __init__(self, config):
super().__init__(config)
- self.text_model = Emu3ForCausalLM._from_config(config.text_config)
+ self.text_model = Emu3TextModel._from_config(config.text_config)
+ if self.text_model._tied_weights_keys is not None:
+ self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys]
+
self.vqmodel = Emu3VQVAE(config.vq_config)
self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
@@ -1833,14 +1836,12 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
image = self.vqmodel.decode(image_tokens)
return image
- @can_return_tuple
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
- image_sizes: Optional[torch.Tensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ image_sizes: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
@@ -1848,10 +1849,99 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None:
+ image_tokens = self.get_image_tokens(pixel_values, image_sizes)
+ special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
+ image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
+ input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ return outputs
+
+
+class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
+ base_model_prefix = ""
+ _checkpoint_conversion_mapping = {
+ "^text_model.model": "model.text_model",
+ "^vqmodel": "model.vqmodel",
+ "^text_model.lm_head": "lm_head",
+ }
+ _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Emu3Model(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ # Make modules available throught conditional class for BC
+ @property
+ def text_model(self):
+ return self.model.text_model
+
+ @property
+ def vqmodel(self):
+ return self.model.vqmodel
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ image_sizes: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
- ) -> CausalLMOutputWithPast:
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -1906,25 +1996,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError(
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
- )
-
- if pixel_values is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
- )
-
- if pixel_values is not None:
- image_tokens = self.get_image_tokens(pixel_values, image_sizes)
- special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
- image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
- input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- return self.text_model(
+ outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -1933,8 +2007,25 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
@@ -1968,5 +2059,68 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
return model_inputs
+ @staticmethod
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-__all__ = ["Emu3ForConditionalGeneration", "Emu3ForCausalLM", "Emu3TextModel", "Emu3PreTrainedModel", "Emu3VQVAE"]
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = [
+ "Emu3ForConditionalGeneration",
+ "Emu3ForCausalLM",
+ "Emu3TextModel",
+ "Emu3PreTrainedModel",
+ "Emu3VQVAE",
+ "Emu3Model",
+]
diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py
index 52d32dbdeea..6075b8d7370 100644
--- a/src/transformers/models/emu3/modular_emu3.py
+++ b/src/transformers/models/emu3/modular_emu3.py
@@ -1121,13 +1121,16 @@ class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin):
super().forward()
-class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["text_model.lm_head.weight"]
- _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compilable
+class Emu3Model(Emu3PreTrainedModel):
+ _checkpoint_conversion_mapping = {"text_model.model": "text_model"}
+ _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
def __init__(self, config):
super().__init__(config)
- self.text_model = Emu3ForCausalLM._from_config(config.text_config)
+ self.text_model = Emu3TextModel._from_config(config.text_config)
+ if self.text_model._tied_weights_keys is not None:
+ self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys]
+
self.vqmodel = Emu3VQVAE(config.vq_config)
self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
@@ -1176,14 +1179,12 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
image = self.vqmodel.decode(image_tokens)
return image
- @can_return_tuple
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
- image_sizes: Optional[torch.Tensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ image_sizes: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
@@ -1191,10 +1192,99 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None:
+ image_tokens = self.get_image_tokens(pixel_values, image_sizes)
+ special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
+ image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
+ input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ return outputs
+
+
+class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
+ base_model_prefix = ""
+ _checkpoint_conversion_mapping = {
+ "^text_model.model": "model.text_model",
+ "^vqmodel": "model.vqmodel",
+ "^text_model.lm_head": "lm_head",
+ }
+ _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Emu3Model(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ # Make modules available throught conditional class for BC
+ @property
+ def text_model(self):
+ return self.model.text_model
+
+ @property
+ def vqmodel(self):
+ return self.model.vqmodel
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ image_sizes: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
- ) -> CausalLMOutputWithPast:
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -1249,25 +1339,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError(
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
- )
-
- if pixel_values is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
- )
-
- if pixel_values is not None:
- image_tokens = self.get_image_tokens(pixel_values, image_sizes)
- special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
- image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
- input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- return self.text_model(
+ outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -1276,8 +1350,25 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
@@ -1311,6 +1402,62 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
return model_inputs
+ @staticmethod
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
__all__ = [
"Emu3ForConditionalGeneration",
@@ -1318,4 +1465,5 @@ __all__ = [
"Emu3TextModel",
"Emu3PreTrainedModel",
"Emu3VQVAE",
+ "Emu3Model",
]
diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py
index fd19ff7b8d4..ec74b9be3d5 100644
--- a/src/transformers/models/fuyu/modeling_fuyu.py
+++ b/src/transformers/models/fuyu/modeling_fuyu.py
@@ -21,9 +21,9 @@ import torch.utils.checkpoint
from torch import nn
from ...generation import GenerationMixin
-from ...modeling_outputs import CausalLMOutputWithPast
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
-from ...models.auto.modeling_auto import AutoModelForCausalLM
+from ...models.auto.modeling_auto import AutoModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_fuyu import FuyuConfig
@@ -143,18 +143,17 @@ FUYU_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
- "Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.",
+ """The Fuyu model which consists of a vision backbone and a language model, without a language modeling head.""",
FUYU_START_DOCSTRING,
)
-class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
+class FuyuModel(FuyuPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: FuyuConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
+ self.language_model = AutoModel.from_config(config.text_config)
self.vision_embed_tokens = nn.Linear(
config.patch_size * config.patch_size * config.num_channels, config.hidden_size
)
@@ -169,18 +168,6 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
def gather_continuous_embeddings(
self,
word_embeddings: torch.Tensor,
@@ -224,56 +211,21 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
return output_embeddings
@add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- image_patches: Optional[
- torch.Tensor
- ] = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
- image_patches_indices: Optional[torch.Tensor] = None,
+ input_ids: torch.LongTensor = None,
+ image_patches: torch.Tensor = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
+ image_patches_indices: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
- labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
-
- Returns:
-
- Examples:
-
- ```python
- >>> from transformers import FuyuProcessor, FuyuForCausalLM
- >>> from PIL import Image
- >>> import requests
-
- >>> processor = FuyuProcessor.from_pretrained("adept/fuyu-8b")
- >>> model = FuyuForCausalLM.from_pretrained("adept/fuyu-8b")
-
- >>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
- >>> image = Image.open(requests.get(url, stream=True).raw)
- >>> prompt = "Generate a coco-style caption.\n"
-
- >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
- >>> outputs = model(**inputs)
-
- >>> generated_ids = model.generate(**inputs, max_new_tokens=7)
- >>> generation_text = processor.batch_decode(generated_ids[:, -7:], skip_special_tokens=True)
- >>> print(generation_text[0])
- A blue bus parked on the side of a road.
- ```"""
-
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -327,7 +279,6 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- labels=labels,
use_cache=use_cache,
return_dict=return_dict,
**kwargs,
@@ -335,6 +286,139 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
return outputs
+
+@add_start_docstrings(
+ "Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.",
+ FUYU_START_DOCSTRING,
+)
+class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_embed_tokens": "model.vision_embed_tokens",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: FuyuConfig):
+ super().__init__(config)
+ self.model = FuyuModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ @add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ image_patches: torch.Tensor = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
+ image_patches_indices: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ logits_to_keep: Optional[int] = 0,
+ **kwargs,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import FuyuProcessor, FuyuForCausalLM
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> processor = FuyuProcessor.from_pretrained("adept/fuyu-8b")
+ >>> model = FuyuForCausalLM.from_pretrained("adept/fuyu-8b")
+
+ >>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> prompt = "Generate a coco-style caption.\n"
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=7)
+ >>> generation_text = processor.batch_decode(generated_ids[:, -7:], skip_special_tokens=True)
+ >>> print(generation_text[0])
+ A blue bus parked on the side of a road.
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ image_patches=image_patches,
+ image_patches_indices=image_patches_indices,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ return_dict=return_dict,
+ # don't pass kwargs because Persimmon-backbone doesn't accept FA2 kwargs yet, TODO: raushan
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
def prepare_inputs_for_generation(
self,
input_ids,
@@ -373,4 +457,4 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
return reordered_past
-__all__ = ["FuyuForCausalLM", "FuyuPreTrainedModel"]
+__all__ = ["FuyuForCausalLM", "FuyuPreTrainedModel", "FuyuModel"]
diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py
index 892b8898b62..74642ea6baa 100644
--- a/src/transformers/models/gemma3/modeling_gemma3.py
+++ b/src/transformers/models/gemma3/modeling_gemma3.py
@@ -32,11 +32,12 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
-from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
+ ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
@@ -46,7 +47,7 @@ from ...utils import (
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
@@ -60,6 +61,39 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Gemma3Config"
+@dataclass
+class Gemma3ModelOutputWithPast(BaseModelOutputWithPast):
+ """
+ Base class for Gemma3 outputs, with hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
@dataclass
class Gemma3CausalLMOutputWithPast(ModelOutput):
"""
@@ -88,7 +122,7 @@ class Gemma3CausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
"""
@@ -480,7 +514,7 @@ GEMMA3_START_DOCSTRING = r"""
)
class Gemma3PreTrainedModel(PreTrainedModel):
config_class = Gemma3Config
- base_model_prefix = "language_model"
+ base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = [
"Gemma3DecoderLayer",
@@ -1066,20 +1100,19 @@ class Gemma3MultiModalProjector(nn.Module):
@add_start_docstrings(
- """The GEMMA3 model which consists of a vision backbone and a language model.""",
+ """Base Gemma3 model which consists of a vision backbone and a language model withou language modeling head.""",
GEMMA3_START_DOCSTRING,
)
-class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
+class Gemma3Model(Gemma3PreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: Gemma3Config):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
- language_model = AutoModelForCausalLM.from_config(config=config.text_config)
-
- if language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
+ language_model = AutoModel.from_config(config=config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
@@ -1091,18 +1124,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
def _update_causal_mask(
self,
attention_mask,
@@ -1188,11 +1209,10 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
@@ -1203,6 +1223,145 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, Gemma3ModelOutputWithPast]:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ is_training = token_type_ids is not None and labels is not None
+
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
+ if input_ids is not None and self.config.image_token_id >= self.vocab_size:
+ special_image_mask = input_ids == self.config.image_token_id
+ llm_input_ids = input_ids.clone()
+ llm_input_ids[special_image_mask] = 0
+ else:
+ llm_input_ids = input_ids
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # Merge text and images
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values)
+
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ else:
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
+ raise ValueError(
+ f"Number of images does not match number of special image tokens in the input text. "
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
+ "tokens from image embeddings."
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
+ )
+ outputs = self.language_model(
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ return Gemma3ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values if use_cache else None,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+@add_start_docstrings(
+ """The Gemma3 model which consists of a vision backbone and a language model.""",
+ GEMMA3_START_DOCSTRING,
+)
+class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: Gemma3Config):
+ super().__init__(config)
+ self.model = Gemma3Model(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
@@ -1260,80 +1419,34 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
```
"""
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- is_training = token_type_ids is not None and labels is not None
-
- # Replace image id with PAD if the image token if OOV, to avoid index-errors
- if input_ids is not None and self.config.image_token_id >= self.vocab_size:
- special_image_mask = input_ids == self.config.image_token_id
- llm_input_ids = input_ids.clone()
- llm_input_ids[special_image_mask] = 0
- else:
- llm_input_ids = input_ids
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(llm_input_ids)
-
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
-
- # Merge text and images
- if pixel_values is not None:
- image_features = self.get_image_features(pixel_values)
-
- if input_ids is None:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- else:
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
-
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
- image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
- raise ValueError(
- f"Number of images does not match number of special image tokens in the input text. "
- f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
- "tokens from image embeddings."
- )
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- # mask out pad-token-ids in labels for BC
- if labels is not None and self.pad_token_id in labels:
- logger.warning_once(
- "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
- "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
- )
- labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
-
- causal_mask = self._update_causal_mask(
- attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
- )
- outputs: CausalLMOutputWithPast = self.language_model(
- attention_mask=causal_mask,
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
+ labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
**lm_kwargs,
)
- logits = outputs.logits
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
@@ -1356,13 +1469,17 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@@ -1381,7 +1498,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1401,15 +1518,73 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
- causal_mask = self._update_causal_mask(
+ causal_mask = self.model._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs
- def tie_weights(self):
- return self.language_model.tie_weights()
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
-__all__ = ["Gemma3PreTrainedModel", "Gemma3TextModel", "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration"]
+__all__ = [
+ "Gemma3PreTrainedModel",
+ "Gemma3TextModel",
+ "Gemma3ForCausalLM",
+ "Gemma3ForConditionalGeneration",
+ "Gemma3Model",
+]
diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py
index ecac4921d2e..24206e6f838 100644
--- a/src/transformers/models/gemma3/modular_gemma3.py
+++ b/src/transformers/models/gemma3/modular_gemma3.py
@@ -26,11 +26,12 @@ import torch.utils.checkpoint
from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
-from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
+from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import (
+ add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
@@ -50,7 +51,12 @@ from ..gemma2.modeling_gemma2 import (
apply_rotary_pos_emb,
eager_attention_forward,
)
-from ..paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
+from ..paligemma.modeling_paligemma import (
+ PaligemmaCausalLMOutputWithPast,
+ PaliGemmaForConditionalGeneration,
+ PaliGemmaModel,
+ PaligemmaModelOutputWithPast,
+)
from ..siglip import SiglipVisionConfig
@@ -302,43 +308,13 @@ class Gemma3Config(PretrainedConfig):
@dataclass
-class Gemma3CausalLMOutputWithPast(ModelOutput):
- """
- Base class for Gemma3 causal language model (or autoregressive) outputs.
+class Gemma3ModelOutputWithPast(PaligemmaModelOutputWithPast):
+ pass
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
- """
-
- loss: Optional[torch.FloatTensor] = None
- logits: Optional[torch.FloatTensor] = None
- past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
- image_hidden_states: Optional[torch.FloatTensor] = None
+@dataclass
+class Gemma3CausalLMOutputWithPast(PaligemmaCausalLMOutputWithPast):
+ pass
class Gemma3TextScaledWordEmbedding(nn.Embedding):
@@ -545,7 +521,7 @@ GEMMA3_START_DOCSTRING = None
class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
- base_model_prefix = "language_model"
+ base_model_prefix = ""
_no_split_modules = [
"Gemma3DecoderLayer",
"SiglipVisionEmbeddings",
@@ -755,10 +731,7 @@ class Gemma3MultiModalProjector(nn.Module):
return projected_vision_outputs.type_as(vision_outputs)
-class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
- def tie_weights(self):
- return self.language_model.tie_weights()
-
+class Gemma3Model(PaliGemmaModel):
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Projects the last hidden state from the vision model into language model space.
@@ -844,11 +817,10 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
@@ -859,6 +831,106 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, Gemma3ModelOutputWithPast]:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ is_training = token_type_ids is not None and labels is not None
+
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
+ if input_ids is not None and self.config.image_token_id >= self.vocab_size:
+ special_image_mask = input_ids == self.config.image_token_id
+ llm_input_ids = input_ids.clone()
+ llm_input_ids[special_image_mask] = 0
+ else:
+ llm_input_ids = input_ids
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # Merge text and images
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values)
+
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ else:
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
+ raise ValueError(
+ f"Number of images does not match number of special image tokens in the input text. "
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
+ "tokens from image embeddings."
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
+ )
+ outputs = self.language_model(
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ return Gemma3ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values if use_cache else None,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+@add_start_docstrings(
+ """The Gemma3 model which consists of a vision backbone and a language model.""",
+ GEMMA3_START_DOCSTRING,
+)
+class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
+ @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
@@ -916,80 +988,34 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
```
"""
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- is_training = token_type_ids is not None and labels is not None
-
- # Replace image id with PAD if the image token if OOV, to avoid index-errors
- if input_ids is not None and self.config.image_token_id >= self.vocab_size:
- special_image_mask = input_ids == self.config.image_token_id
- llm_input_ids = input_ids.clone()
- llm_input_ids[special_image_mask] = 0
- else:
- llm_input_ids = input_ids
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(llm_input_ids)
-
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
-
- # Merge text and images
- if pixel_values is not None:
- image_features = self.get_image_features(pixel_values)
-
- if input_ids is None:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- else:
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
-
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
- image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
- raise ValueError(
- f"Number of images does not match number of special image tokens in the input text. "
- f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
- "tokens from image embeddings."
- )
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- # mask out pad-token-ids in labels for BC
- if labels is not None and self.pad_token_id in labels:
- logger.warning_once(
- "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
- "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
- )
- labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
-
- causal_mask = self._update_causal_mask(
- attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
- )
- outputs: CausalLMOutputWithPast = self.language_model(
- attention_mask=causal_mask,
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
+ labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
**lm_kwargs,
)
- logits = outputs.logits
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
@@ -1012,13 +1038,17 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@@ -1037,7 +1067,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1057,7 +1087,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
- causal_mask = self._update_causal_mask(
+ causal_mask = self.model._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
@@ -1072,4 +1102,5 @@ __all__ = [
"Gemma3TextModel",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
+ "Gemma3Model",
]
diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py
index d3a8a637ede..15013677c23 100644
--- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py
+++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py
@@ -28,11 +28,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
-from transformers.modeling_outputs import CausalLMOutputWithPast
-
from ...activations import ACT2FN
from ...generation import GenerationMixin
-from ...modeling_outputs import ModelOutput
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
@@ -40,7 +38,7 @@ from ...utils import (
can_return_tuple,
replace_return_docstrings,
)
-from ..auto import AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig
@@ -545,7 +543,7 @@ class GotOcr2CausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@@ -557,6 +555,39 @@ class GotOcr2CausalLMOutputWithPast(ModelOutput):
image_hidden_states: Optional[torch.FloatTensor] = None
+@dataclass
+class GotOcr2ModelOutputWithPast(BaseModelOutputWithPast):
+ """
+ Base class for GotOcr2 outputs, with hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
GOT_OCR2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -575,14 +606,13 @@ GOT_OCR2_START_DOCSTRING = r"""
@add_start_docstrings(
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ "The bare GotOcr2 Model outputting raw hidden-states without any specific head on top.",
GOT_OCR2_START_DOCSTRING,
)
class GotOcr2PreTrainedModel(PreTrainedModel):
config_class = GotOcr2Config
- base_model_prefix = "model"
+ base_model_prefix = ""
supports_gradient_checkpointing = True
- _no_split_modules = ["GotOcr2VisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@@ -680,23 +710,18 @@ GOT_OCR2_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
- """The GOT_OCR2 model which consists of a vision backbone and a language model.""",
+ """The GotOcr2 model which consists of a vision backbone and a language model, without a language modeling head.""",
GOT_OCR2_START_DOCSTRING,
)
-class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
+class GotOcr2Model(GotOcr2PreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: GotOcr2Config):
super().__init__(config)
self.vision_tower = GotOcr2VisionEncoder(config.vision_config)
self.multi_modal_projector = GotOcr2MultiModalProjector(config)
- self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
-
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
- self.pad_token_id = config.pad_token_id
-
+ self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@@ -705,18 +730,6 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@@ -732,13 +745,124 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
image_outputs = self.vision_tower(pixel_values).last_hidden_state
return self.multi_modal_projector(image_outputs)
+ @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, GotOcr2ModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ if n_image_tokens != n_image_features:
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ )
+
+ output = GotOcr2ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+@add_start_docstrings(
+ """The GOT_OCR2 model which consists of a vision backbone and a language model.""",
+ GOT_OCR2_START_DOCSTRING,
+)
+class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: GotOcr2Config):
+ super().__init__(config)
+ self.model = GotOcr2Model(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
@can_return_tuple
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@@ -747,9 +871,10 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
- ) -> GotOcr2CausalLMOutputWithPast:
+ ) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -794,37 +919,15 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
"You should keep in mind what features from the module should be used, especially
when you're planning to sell a template."
```"""
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if pixel_values is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
- )
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if pixel_values is not None:
- image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
- n_image_tokens = (input_ids == self.config.image_token_id).sum()
- n_image_features = image_features.shape[0] * image_features.shape[1]
- if n_image_tokens != n_image_features:
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- outputs: CausalLMOutputWithPast = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@@ -832,29 +935,18 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
+ return_dict=True,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
)
- logits = outputs.logits
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
- )
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return GotOcr2CausalLMOutputWithPast(
loss=loss,
@@ -862,7 +954,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@@ -878,7 +970,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -895,5 +987,60 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
return model_inputs
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-__all__ = ["GotOcr2PreTrainedModel", "GotOcr2ForConditionalGeneration"]
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = ["GotOcr2PreTrainedModel", "GotOcr2Model", "GotOcr2ForConditionalGeneration"]
diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py
index aec8c5e1749..f485146f2e0 100644
--- a/src/transformers/models/got_ocr2/modular_got_ocr2.py
+++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py
@@ -14,16 +14,17 @@
# limitations under the License.
-from typing import List, Optional, Union
+from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
-from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
+ LlavaModel,
+ LlavaModelOutputWithPast,
LlavaPreTrainedModel,
)
from transformers.models.sam.modeling_sam import SamMLPBlock, SamVisionAttention, SamVisionEncoder, SamVisionLayer
@@ -36,7 +37,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
-from ..auto import CONFIG_MAPPING, AutoConfig, AutoModelForCausalLM
+from ..auto import CONFIG_MAPPING, AutoConfig
if is_vision_available():
@@ -278,6 +279,10 @@ class GotOcr2CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
+class GotOcr2ModelOutputWithPast(LlavaModelOutputWithPast):
+ pass
+
+
class GotOcr2PreTrainedModel(LlavaPreTrainedModel):
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
@@ -368,22 +373,11 @@ GOT_OCR2_INPUTS_DOCSTRING = r"""
"""
-class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
+class GotOcr2Model(LlavaModel):
def __init__(self, config: GotOcr2Config):
super().__init__(config)
self.vision_tower = GotOcr2VisionEncoder(config.vision_config)
- self.multi_modal_projector = GotOcr2MultiModalProjector(config)
- self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
-
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
- self.pad_token_id = config.pad_token_id
-
- self.post_init()
-
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@@ -399,13 +393,81 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
image_outputs = self.vision_tower(pixel_values).last_hidden_state
return self.multi_modal_projector(image_outputs)
+ @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, GotOcr2ModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ if n_image_tokens != n_image_features:
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ )
+
+ output = GotOcr2ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@@ -414,9 +476,10 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
- ) -> GotOcr2CausalLMOutputWithPast:
+ ) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -461,37 +524,15 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
"You should keep in mind what features from the module should be used, especially
when you're planning to sell a template."
```"""
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if pixel_values is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
- )
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if pixel_values is not None:
- image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
- n_image_tokens = (input_ids == self.config.image_token_id).sum()
- n_image_features = image_features.shape[0] * image_features.shape[1]
- if n_image_tokens != n_image_features:
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- outputs: CausalLMOutputWithPast = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@@ -499,29 +540,18 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
+ return_dict=True,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
)
- logits = outputs.logits
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
- )
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return GotOcr2CausalLMOutputWithPast(
loss=loss,
@@ -529,7 +559,7 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
@@ -537,5 +567,6 @@ __all__ = [
"GotOcr2VisionConfig",
"GotOcr2Config",
"GotOcr2PreTrainedModel",
+ "GotOcr2Model",
"GotOcr2ForConditionalGeneration",
]
diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py
index 016b2af2fa0..932c57ba0a6 100644
--- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py
+++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py
@@ -28,7 +28,7 @@ from torch import nn
import transformers.models.jamba.modeling_jamba as modeling_jamba
from transformers.activations import ACT2FN
-from ...cache_utils import Cache, StaticCache
+from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_layers import GradientCheckpointingLayer
@@ -1511,10 +1511,10 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- using_static_cache = isinstance(past_key_values, StaticCache)
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
@@ -1525,7 +1525,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
- if using_static_cache:
+ if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py
index cdfb59b5804..6abed48c86a 100644
--- a/src/transformers/models/instructblip/modeling_instructblip.py
+++ b/src/transformers/models/instructblip/modeling_instructblip.py
@@ -41,7 +41,7 @@ from ...utils import (
replace_return_docstrings,
torch_int,
)
-from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
+from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from .configuration_instructblip import InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig
@@ -315,6 +315,9 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
config_class = InstructBlipConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
+ _supports_cache_class = True
+ _supports_static_cache = True
+ _supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
_no_split_modules = [
"InstructBlipQFormerEmbeddings",
@@ -339,7 +342,7 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
elif isinstance(module, InstructBlipVisionEmbeddings):
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
- elif isinstance(module, InstructBlipForConditionalGeneration):
+ elif isinstance(module, (InstructBlipForConditionalGeneration, InstructBlipModel)):
module.query_tokens.data.zero_()
@@ -1274,6 +1277,156 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
)
+@add_start_docstrings(
+ """
+ InstructBLIP base Model consisting of language model, qformer and vision encoder.
+ """,
+ INSTRUCTBLIP_START_DOCSTRING,
+)
+class InstructBlipModel(InstructBlipPreTrainedModel):
+ main_input_name = "pixel_values"
+ _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
+
+ def __init__(self, config: InstructBlipConfig):
+ super().__init__(config)
+
+ self.vision_model = InstructBlipVisionModel(config.vision_config)
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
+ self.qformer = InstructBlipQFormerModel(config.qformer_config)
+
+ self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
+ self.language_model = AutoModel.from_config(config.text_config)
+
+ if self.language_model._no_split_modules is not None:
+ self._no_split_modules.extend(self.language_model._no_split_modules)
+
+ if self.language_model._keep_in_fp32_modules is not None:
+ self._keep_in_fp32_modules.extend(self.language_model._keep_in_fp32_modules)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def _tie_weights(self):
+ if not self.config.use_decoder_only_language_model:
+ self.language_model.encoder.embed_tokens = self.language_model.shared
+ self.language_model.decoder.embed_tokens = self.language_model.shared
+
+ def _preprocess_accelerate(self):
+ r"""
+ Some pre-processing hacks to make the model `accelerate` compatible. Check
+ https://github.com/huggingface/transformers/pull/21707 for more details.
+ """
+ hf_device_map = self.hf_device_map
+
+ if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
+ # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
+ logger.warning(
+ "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
+ " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
+ " Please pass a `device_map` that contains `language_model` to remove this warning."
+ " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
+ " more details on creating a `device_map` for large models.",
+ )
+
+ if hasattr(self.language_model, "_hf_hook"):
+ self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
+
+ @add_start_docstrings_to_model_forward(INSTRUCTBLIP_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ qformer_input_ids: torch.FloatTensor,
+ qformer_attention_mask: Optional[torch.LongTensor] = None,
+ input_ids: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ use_cache: Optional[bool] = None,
+ ) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # step 1: forward the images through the vision encoder,
+ # to get image embeddings of shape (batch_size, seq_len, hidden_size)
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+ image_embeds = vision_outputs[0]
+
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
+
+ # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+ query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
+ if qformer_attention_mask is None:
+ qformer_attention_mask = torch.ones_like(qformer_input_ids)
+ qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
+ query_outputs = self.qformer(
+ input_ids=qformer_input_ids,
+ attention_mask=qformer_attention_mask,
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ query_output = query_outputs[0][:, : query_tokens.size(1), :]
+
+ # step 3: use the language model, conditioned on the query outputs and the prompt
+ language_model_inputs = self.language_projection(query_output)
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
+ inputs_embeds[special_image_mask] = language_model_inputs.flatten()
+
+ if self.config.use_decoder_only_language_model:
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=use_cache,
+ )
+ else:
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=use_cache,
+ )
+
+ if not return_dict:
+ return (vision_outputs, query_outputs, outputs)
+
+ return InstructBlipForConditionalGenerationModelOutput(
+ vision_outputs=vision_outputs,
+ qformer_outputs=query_outputs,
+ language_model_outputs=outputs,
+ )
+
+
@add_start_docstrings(
"""
InstructBLIP Model for generating text given an image and an optional text prompt. The model consists of a vision
@@ -1336,11 +1489,13 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
def get_decoder(self):
return self.language_model.get_decoder()
+ # Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._tie_weights
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
+ # Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._preprocess_accelerate
def _preprocess_accelerate(self):
r"""
Some pre-processing hacks to make the model `accelerate` compatible. Check
@@ -1645,6 +1800,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
__all__ = [
"InstructBlipQFormerModel",
"InstructBlipPreTrainedModel",
+ "InstructBlipModel",
"InstructBlipForConditionalGeneration",
"InstructBlipVisionModel",
]
diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py
index e9d9e4938a9..0ce752bed8b 100644
--- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py
+++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py
@@ -45,7 +45,7 @@ from ...utils import (
replace_return_docstrings,
torch_int,
)
-from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
+from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from .configuration_instructblipvideo import (
InstructBlipVideoConfig,
InstructBlipVideoQFormerConfig,
@@ -945,6 +945,9 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
config_class = InstructBlipVideoConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
+ _supports_cache_class = True
+ _supports_static_cache = True
+ _supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
_no_split_modules = [
"InstructBlipVideoQFormerEmbeddings",
@@ -969,7 +972,7 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
elif isinstance(module, InstructBlipVideoVisionEmbeddings):
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
- elif isinstance(module, InstructBlipVideoForConditionalGeneration):
+ elif isinstance(module, (InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel)):
module.query_tokens.data.zero_()
@@ -1269,6 +1272,166 @@ class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
)
+@add_start_docstrings(
+ """
+ InstructBlipVideo base Model consisting of language model, qformer and vision encoder.
+ """,
+ INSTRUCTBLIPVIDEO_START_DOCSTRING,
+)
+class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
+ main_input_name = "pixel_values"
+ _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
+
+ def __init__(self, config: InstructBlipVideoConfig):
+ super().__init__(config)
+
+ self.vision_model = InstructBlipVideoVisionModel(config.vision_config)
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
+ self.qformer = InstructBlipVideoQFormerModel(config.qformer_config)
+
+ self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
+ self.language_model = AutoModel.from_config(config.text_config)
+
+ if self.language_model._no_split_modules is not None:
+ self._no_split_modules.extend(self.language_model._no_split_modules)
+
+ if self.language_model._keep_in_fp32_modules is not None:
+ self._keep_in_fp32_modules.extend(self.language_model._keep_in_fp32_modules)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def _tie_weights(self):
+ if not self.config.use_decoder_only_language_model:
+ self.language_model.encoder.embed_tokens = self.language_model.shared
+ self.language_model.decoder.embed_tokens = self.language_model.shared
+
+ def _preprocess_accelerate(self):
+ r"""
+ Some pre-processing hacks to make the model `accelerate` compatible. Check
+ https://github.com/huggingface/transformers/pull/21707 for more details.
+ """
+ hf_device_map = self.hf_device_map
+
+ if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
+ # warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
+ logger.warning(
+ "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
+ " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
+ " Please pass a `device_map` that contains `language_model` to remove this warning."
+ " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
+ " more details on creating a `device_map` for large models.",
+ )
+
+ if hasattr(self.language_model, "_hf_hook"):
+ self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
+
+ @add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ qformer_input_ids: torch.FloatTensor,
+ qformer_attention_mask: Optional[torch.LongTensor] = None,
+ input_ids: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ use_cache: Optional[bool] = None,
+ ) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # step 1: forward the images through the vision encoder,
+ # we process in a batched way, later unbatch it back (video has frames=4 always)
+ batch_size, frames, channel, height, width = pixel_values.shape
+ pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+ image_embeds = vision_outputs[0]
+
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
+
+ # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+ query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
+
+ if qformer_attention_mask is None:
+ qformer_attention_mask = torch.ones_like(qformer_input_ids)
+
+ qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
+ qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
+ qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
+ query_outputs = self.qformer(
+ input_ids=qformer_input_ids,
+ attention_mask=qformer_attention_mask,
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ query_output = query_outputs[0][:, : query_tokens.size(1), :]
+
+ # step 3: use the language model, conditioned on the query outputs and the prompt
+ language_model_inputs = self.language_projection(query_output)
+
+ # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
+ language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+
+ special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
+ inputs_embeds[special_image_mask] = language_model_inputs.flatten()
+
+ if self.config.use_decoder_only_language_model:
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=use_cache,
+ )
+ else:
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=use_cache,
+ )
+
+ if not return_dict:
+ return (vision_outputs, query_outputs, outputs)
+
+ return InstructBlipVideoForConditionalGenerationModelOutput(
+ vision_outputs=vision_outputs,
+ qformer_outputs=query_outputs,
+ language_model_outputs=outputs,
+ )
+
+
@add_start_docstrings(
"""
InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision
@@ -1682,5 +1845,6 @@ __all__ = [
"InstructBlipVideoVisionModel",
"InstructBlipVideoPreTrainedModel",
"InstructBlipVideoQFormerModel",
+ "InstructBlipVideoModel",
"InstructBlipVideoForConditionalGeneration",
]
diff --git a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py
index 212050877a4..d28485545f7 100644
--- a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py
+++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py
@@ -27,6 +27,7 @@ from transformers.models.instructblip.configuration_instructblip import (
from transformers.models.instructblip.modeling_instructblip import (
InstructBlipForConditionalGeneration,
InstructBlipForConditionalGenerationModelOutput,
+ InstructBlipModel,
InstructBlipPreTrainedModel,
InstructBlipQFormerModel,
InstructBlipVisionModel,
@@ -34,7 +35,7 @@ from transformers.models.instructblip.modeling_instructblip import (
from ...configuration_utils import PretrainedConfig
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
-from ...utils import logging
+from ...utils import add_start_docstrings_to_model_forward, logging
from ..auto import CONFIG_MAPPING, AutoConfig
@@ -191,6 +192,110 @@ class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForCondit
pass
+INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = None
+
+
+class InstructBlipVideoModel(InstructBlipModel):
+ @add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ qformer_input_ids: torch.FloatTensor,
+ qformer_attention_mask: Optional[torch.LongTensor] = None,
+ input_ids: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ use_cache: Optional[bool] = None,
+ ) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # step 1: forward the images through the vision encoder,
+ # we process in a batched way, later unbatch it back (video has frames=4 always)
+ batch_size, frames, channel, height, width = pixel_values.shape
+ pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+ image_embeds = vision_outputs[0]
+
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
+
+ # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+ query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
+
+ if qformer_attention_mask is None:
+ qformer_attention_mask = torch.ones_like(qformer_input_ids)
+
+ qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
+ qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
+ qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
+ query_outputs = self.qformer(
+ input_ids=qformer_input_ids,
+ attention_mask=qformer_attention_mask,
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ query_output = query_outputs[0][:, : query_tokens.size(1), :]
+
+ # step 3: use the language model, conditioned on the query outputs and the prompt
+ language_model_inputs = self.language_projection(query_output)
+
+ # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
+ language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+
+ special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
+ inputs_embeds[special_image_mask] = language_model_inputs.flatten()
+
+ if self.config.use_decoder_only_language_model:
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=use_cache,
+ )
+ else:
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=use_cache,
+ )
+
+ if not return_dict:
+ return (vision_outputs, query_outputs, outputs)
+
+ return InstructBlipVideoForConditionalGenerationModelOutput(
+ vision_outputs=vision_outputs,
+ qformer_outputs=query_outputs,
+ language_model_outputs=outputs,
+ )
+
+
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
def forward(
self,
@@ -508,5 +613,6 @@ __all__ = [
"InstructBlipVideoVisionModel",
"InstructBlipVideoPreTrainedModel",
"InstructBlipVideoQFormerModel",
+ "InstructBlipVideoModel",
"InstructBlipVideoForConditionalGeneration",
]
diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py
index c181976e0ec..6c59d06d2ee 100644
--- a/src/transformers/models/internvl/modeling_internvl.py
+++ b/src/transformers/models/internvl/modeling_internvl.py
@@ -31,7 +31,7 @@ from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_flash_attention_utils import FlashAttentionKwargs
-from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
@@ -45,7 +45,7 @@ from ...utils import (
replace_return_docstrings,
torch_int,
)
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_internvl import InternVLConfig, InternVLVisionConfig
@@ -608,14 +608,13 @@ INTERNVL_START_DOCSTRING = r"""
@add_start_docstrings(
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ "The bare InternVL Model outputting raw hidden-states without any specific head on top.",
INTERNVL_START_DOCSTRING,
)
class InternVLPreTrainedModel(PreTrainedModel):
config_class = InternVLConfig
- base_model_prefix = "model"
+ base_model_prefix = ""
supports_gradient_checkpointing = True
- _no_split_modules = ["InternVLVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@@ -654,15 +653,13 @@ class InternVLMultiModalProjector(nn.Module):
@dataclass
-class InternVLCausalLMOutputWithPast(ModelOutput):
+class InternVLModelOutputWithPast(BaseModelOutputWithPast):
"""
- Base class for InternVL causal language model (or autoregressive) outputs.
+ Base class for InternVL outputs, with hidden states and attentions.
Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
@@ -681,15 +678,10 @@ class InternVLCausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
- loss: Optional[torch.FloatTensor] = None
- logits: Optional[torch.FloatTensor] = None
- past_key_values: Optional[List[torch.FloatTensor]] = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
@@ -771,23 +763,18 @@ INTERNVL_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
- """The INTERNVL model which consists of a vision backbone and a language model.""",
+ """The InternVL model which consists of a vision backbone and a language model, without a language modeling head.""",
INTERNVL_START_DOCSTRING,
)
-class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin):
+class InternVLModel(InternVLPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: InternVLConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = InternVLMultiModalProjector(config)
- self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
-
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
-
+ self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@@ -796,18 +783,6 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@@ -853,12 +828,221 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
return vision_features
+ @add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ image_sizes: torch.Tensor = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, InternVLModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_sizes=image_sizes,
+ )
+
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
+ else:
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ output = InternVLModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+ def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
+ """Perform pixel shuffle downsampling on vision features.
+
+ Args:
+ vision_features (`torch.Tensor`):
+ Input tensor of shape (batch_size, width, height, channels).
+ scale_factor (`float`, *optional*, defaults to `0.5`):
+ Factor by which to downsample. Default is 0.5, which halves the dimensions.
+
+ Returns:
+ vision_features (`torch.Tensor`):
+ Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
+ """
+ batch_size, width, height, channels = vision_features.size()
+
+ if height % scale_factor != 0 or width % scale_factor != 0:
+ raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.")
+
+ # Reshape to allow downsampling
+ vision_features = vision_features.view(
+ batch_size, width, int(height * scale_factor), int(channels / scale_factor)
+ )
+ # Permute dimensions to align downsampled axis correctly
+ vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
+
+ # Reshape to achieve final downsampled dimensions
+ vision_features = vision_features.view(
+ batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2))
+ )
+
+ # Swap height and width back for proper orientation
+ vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
+
+ return vision_features
+
+
+@dataclass
+class InternVLCausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for InternVL causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[List[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@add_start_docstrings(
+ """The INTERNVL model which consists of a vision backbone and a language model.""",
+ INTERNVL_START_DOCSTRING,
+)
+class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: InternVLConfig):
+ super().__init__(config)
+ self.model = InternVLModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
@add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=InternVLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@@ -872,7 +1056,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
- image_sizes: Optional[torch.Tensor] = None,
+ image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, InternVLCausalLMOutputWithPast]:
r"""
@@ -925,7 +1109,6 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
>>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
The images depict the Statue of Liberty and the Golden Gate Bridge.
```"""
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -940,73 +1123,32 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
else self.config.vision_feature_select_strategy
)
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if pixel_values is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
- )
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if pixel_values is not None:
- image_features = self.get_image_features(
- pixel_values=pixel_values,
- vision_feature_layer=vision_feature_layer,
- vision_feature_select_strategy=vision_feature_select_strategy,
- image_sizes=image_sizes,
- )
-
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
- n_image_tokens = (input_ids == self.config.image_token_id).sum()
- n_image_features = image_features.shape[0] * image_features.shape[1]
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- outputs = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- return_dict=return_dict,
+ return_dict=True,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
+ image_sizes=image_sizes,
**lm_kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
- )
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return InternVLCausalLMOutputWithPast(
loss=loss,
@@ -1014,7 +1156,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@@ -1030,7 +1172,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1047,45 +1189,66 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
return model_inputs
- def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
- """Perform pixel shuffle downsampling on vision features.
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
- vision_features (`torch.Tensor`):
- Input tensor of shape (batch_size, width, height, channels).
- scale_factor (`float`, *optional*, defaults to `0.5`):
- Factor by which to downsample. Default is 0.5, which halves the dimensions.
-
- Returns:
- vision_features (`torch.Tensor`):
- Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
"""
- batch_size, width, height, channels = vision_features.size()
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
- if height % scale_factor != 0 or width % scale_factor != 0:
- raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.")
-
- # Reshape to allow downsampling
- vision_features = vision_features.view(
- batch_size, width, int(height * scale_factor), int(channels / scale_factor)
- )
- # Permute dimensions to align downsampled axis correctly
- vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
-
- # Reshape to achieve final downsampled dimensions
- vision_features = vision_features.view(
- batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2))
- )
-
- # Swap height and width back for proper orientation
- vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
-
- return vision_features
+ return causal_mask
__all__ = [
"InternVLVisionPreTrainedModel",
"InternVLVisionModel",
"InternVLPreTrainedModel",
+ "InternVLModel",
"InternVLForConditionalGeneration",
]
diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py
index 00b516cf041..bc3c3bfc6da 100644
--- a/src/transformers/models/internvl/modular_internvl.py
+++ b/src/transformers/models/internvl/modular_internvl.py
@@ -39,7 +39,12 @@ from ...utils import (
from ..clip.modeling_clip import CLIPMLP
from ..janus.modeling_janus import JanusVisionAttention
from ..llama.modeling_llama import LlamaRMSNorm
-from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaPreTrainedModel
+from ..llava.modeling_llava import (
+ LlavaCausalLMOutputWithPast,
+ LlavaForConditionalGeneration,
+ LlavaModel,
+ LlavaPreTrainedModel,
+)
from .configuration_internvl import InternVLConfig, InternVLVisionConfig
@@ -573,11 +578,7 @@ class InternVLMultiModalProjector(nn.Module):
return hidden_states
-class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
- pass
-
-
-class InternVLForConditionalGeneration(LlavaForConditionalGeneration):
+class InternVLModel(LlavaModel):
def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
"""Perform pixel shuffle downsampling on vision features.
@@ -658,6 +659,13 @@ class InternVLForConditionalGeneration(LlavaForConditionalGeneration):
return vision_features
+
+class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
+ pass
+
+
+class InternVLForConditionalGeneration(LlavaForConditionalGeneration):
+ @can_return_tuple
@add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=InternVLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(**super_kwargs):
@@ -718,5 +726,6 @@ __all__ = [
"InternVLVisionPreTrainedModel",
"InternVLVisionModel",
"InternVLPreTrainedModel",
+ "InternVLModel",
"InternVLForConditionalGeneration",
]
diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py
index bc78b571d95..3273d595a5a 100644
--- a/src/transformers/models/llava/modeling_llava.py
+++ b/src/transformers/models/llava/modeling_llava.py
@@ -23,16 +23,17 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
-from ...modeling_outputs import ModelOutput
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_llava import LlavaConfig
@@ -44,6 +45,39 @@ _CONFIG_FOR_DOC = "LlavaConfig"
_CHECKPOINT_FOR_DOC = "llava-hf/llava-1.5-7b-hf"
+@dataclass
+class LlavaModelOutputWithPast(BaseModelOutputWithPast):
+ """
+ Base class for Llava outputs, with hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
@dataclass
class LlavaCausalLMOutputWithPast(ModelOutput):
"""
@@ -72,7 +106,7 @@ class LlavaCausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@@ -124,14 +158,13 @@ LLAVA_START_DOCSTRING = r"""
@add_start_docstrings(
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ "The bare Llava Model outputting raw hidden-states without any specific head on top.",
LLAVA_START_DOCSTRING,
)
class LlavaPreTrainedModel(PreTrainedModel):
config_class = LlavaConfig
- base_model_prefix = "model"
+ base_model_prefix = ""
supports_gradient_checkpointing = True
- _no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@@ -149,6 +182,9 @@ class LlavaPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
LLAVA_INPUTS_DOCSTRING = r"""
@@ -229,23 +265,18 @@ LLAVA_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
- """The LLAVA model which consists of a vision backbone and a language model.""",
+ """The Llava model which consists of a vision backbone and a language model, without a language modeling head.""",
LLAVA_START_DOCSTRING,
)
-class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
+class LlavaModel(LlavaPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: LlavaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(config)
- self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
-
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
-
+ self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@@ -254,18 +285,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@@ -277,7 +296,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
- pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
@@ -312,12 +331,146 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
+ @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ image_sizes: torch.Tensor = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, LlavaModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_sizes=image_sizes,
+ )
+
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
+ else:
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ output = LlavaModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+@add_start_docstrings(
+ """The LLAVA model which consists of a vision backbone and a language model.""",
+ LLAVA_START_DOCSTRING,
+)
+class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: LlavaConfig):
+ super().__init__(config)
+ self.model = LlavaModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@@ -331,7 +484,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
- image_sizes: Optional[torch.Tensor] = None,
+ image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
r"""
@@ -371,7 +524,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
```"""
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -386,73 +538,32 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
else self.config.vision_feature_select_strategy
)
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if pixel_values is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
- )
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if pixel_values is not None:
- image_features = self.get_image_features(
- pixel_values=pixel_values,
- vision_feature_layer=vision_feature_layer,
- vision_feature_select_strategy=vision_feature_select_strategy,
- image_sizes=image_sizes,
- )
-
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
- n_image_tokens = (input_ids == self.config.image_token_id).sum()
- n_image_features = image_features.shape[0] * image_features.shape[1]
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- outputs = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- return_dict=return_dict,
+ return_dict=True,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
+ image_sizes=image_sizes,
**lm_kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
- )
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return LlavaCausalLMOutputWithPast(
loss=loss,
@@ -460,7 +571,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@@ -476,7 +587,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -493,5 +604,61 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
return model_inputs
+ @staticmethod
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel"]
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel", "LlavaModel"]
diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py
index fa7c04bdf40..c17eb9622c8 100644
--- a/src/transformers/models/llava_next/modeling_llava_next.py
+++ b/src/transformers/models/llava_next/modeling_llava_next.py
@@ -26,16 +26,17 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
-from ...modeling_outputs import ModelOutput
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_llava_next import LlavaNextConfig
@@ -151,6 +152,39 @@ def unpad_image(tensor, original_size):
return unpadded_tensor
+@dataclass
+class LlavaNextModelOutputWithPast(BaseModelOutputWithPast):
+ """
+ Base class for Llava outputs, with hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
@dataclass
class LlavaNextCausalLMOutputWithPast(ModelOutput):
"""
@@ -237,9 +271,9 @@ LLAVA_NEXT_START_DOCSTRING = r"""
)
class LlavaNextPreTrainedModel(PreTrainedModel):
config_class = LlavaNextConfig
- base_model_prefix = "model"
+ base_model_prefix = ""
supports_gradient_checkpointing = True
- _no_split_modules = ["LlavaNextVisionAttention"]
+ _no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@@ -254,7 +288,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
- elif isinstance(module, LlavaNextForConditionalGeneration):
+ elif isinstance(module, LlavaNextModel):
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
module.image_newline.data.normal_(mean=0.0, std=embed_std)
@@ -340,10 +374,12 @@ LLAVA_NEXT_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
- """The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
+ """The Llava-Next model which consists of a vision backbone and a language model without language modeling head.""",
LLAVA_NEXT_START_DOCSTRING,
)
-class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin):
+class LlavaNextModel(LlavaNextPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: LlavaNextConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
@@ -353,48 +389,16 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
+ self.language_model = AutoModel.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
- self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.post_init()
- @property
- def padding_side(self):
- return self._padding_side
-
- @padding_side.setter
- def padding_side(self, padding_side: str):
- if padding_side not in ["left", "right"]:
- raise ValueError(f"{padding_side} is not `left` or `right`.")
- self._padding_side = padding_side
-
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
- def get_decoder(self):
- return self.language_model.get_decoder()
-
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@@ -524,12 +528,152 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
+ @add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, LlavaNextModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None and pixel_values.size(0) > 0:
+ image_features = self.get_image_features(
+ pixel_values,
+ image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+
+ # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
+ image_features, feature_lens = self.pack_image_features(
+ image_features,
+ image_sizes,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_newline=self.image_newline,
+ )
+
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+ n_image_features = image_features.shape[0]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ output = LlavaNextModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+@add_start_docstrings(
+ """The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
+ LLAVA_NEXT_START_DOCSTRING,
+)
+class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^image_newline": "model.image_newline",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: LlavaNextConfig):
+ super().__init__(config)
+ self.model = LlavaNextModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
@add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@@ -597,45 +741,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
else self.config.vision_feature_select_strategy
)
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if pixel_values is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
- )
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if pixel_values is not None and pixel_values.size(0) > 0:
- image_features = self.get_image_features(
- pixel_values,
- image_sizes,
- vision_feature_layer=vision_feature_layer,
- vision_feature_select_strategy=vision_feature_select_strategy,
- )
-
- # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
- image_features, feature_lens = self.pack_image_features(
- image_features,
- image_sizes,
- vision_feature_select_strategy=vision_feature_select_strategy,
- image_newline=self.image_newline,
- )
-
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
- n_image_tokens = (input_ids == self.config.image_token_id).sum()
- n_image_features = image_features.shape[0]
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- outputs = self.language_model(
+ outputs = self.model(
+ input_ids,
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@@ -645,33 +756,17 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
**lm_kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
- )
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return LlavaNextCausalLMOutputWithPast(
loss=loss,
@@ -679,7 +774,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@@ -696,7 +791,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -714,5 +809,61 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
return model_inputs
+ @staticmethod
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-__all__ = ["LlavaNextForConditionalGeneration", "LlavaNextPreTrainedModel"]
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = ["LlavaNextForConditionalGeneration", "LlavaNextPreTrainedModel", "LlavaNextModel"]
diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py
index 113441fd2aa..6b6f14bc342 100644
--- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py
@@ -30,24 +30,65 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
-from ...modeling_outputs import ModelOutput
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_llava_next_video import LlavaNextVideoConfig
logger = logging.get_logger(__name__)
+
_CONFIG_FOR_DOC = "LlavaNextVideoConfig"
+@dataclass
+class LlavaNextVideoModelOutputWithPast(BaseModelOutputWithPast):
+ """
+ Base class for Llava outputs, with hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+
+ video_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
+ video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+ video_hidden_states: Optional[torch.FloatTensor] = None
+
+
@dataclass
class LlavaNextVideoCausalLMOutputWithPast(ModelOutput):
"""
@@ -167,6 +208,34 @@ LLAVA_NEXT_VIDEO_START_DOCSTRING = r"""
"""
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAVA_NEXT_VIDEO_START_DOCSTRING,
+)
+class LlavaNextVideoPreTrainedModel(PreTrainedModel):
+ config_class = LlavaNextVideoConfig
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_cache_class = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+
+ def _init_weights(self, module):
+ std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
+
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, LlavaNextVideoModel):
+ embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
+ module.image_newline.data.normal_(mean=0.0, std=embed_std)
+
+
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
@@ -355,38 +424,12 @@ LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ """The Llava-Next model which consists of a vision backbone and a language model without language modeling head.""",
LLAVA_NEXT_VIDEO_START_DOCSTRING,
)
-class LlavaNextVideoPreTrainedModel(PreTrainedModel):
- config_class = LlavaNextVideoConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["LlavaNextVideoVisionAttention"]
- _skip_keys_device_placement = "past_key_values"
- _supports_cache_class = True
- _supports_flash_attn_2 = True
- _supports_sdpa = True
- _supports_quantized_cache = True
- _supports_static_cache = True
+class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
- def _init_weights(self, module):
- std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
-
- if isinstance(module, (nn.Linear, nn.Conv2d)):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, LlavaNextVideoForConditionalGeneration):
- embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
- module.image_newline.data.normal_(mean=0.0, std=embed_std)
-
-
-@add_start_docstrings(
- """The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
- LLAVA_NEXT_VIDEO_START_DOCSTRING,
-)
-class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, GenerationMixin):
def __init__(
self,
config: LlavaNextVideoConfig,
@@ -399,43 +442,17 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
+ self.language_model = AutoModel.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
- self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.vision_resampler = LlavaNextVideoPooler(config)
self.post_init()
- @property
- def padding_side(self):
- return self._padding_side
-
- @padding_side.setter
- def padding_side(self, padding_side: str):
- if padding_side not in ["left", "right"]:
- raise ValueError(f"{padding_side} is not `left` or `right`.")
- self._padding_side = padding_side
-
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@@ -564,13 +581,222 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
+ @add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ pixel_values_videos: torch.FloatTensor = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, LlavaNextVideoModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ self.vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ self.vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
+ "and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None and pixel_values.size(0) > 0:
+ image_features = self.get_image_features(
+ pixel_values,
+ image_sizes,
+ vision_feature_layer=self.vision_feature_layer,
+ vision_feature_select_strategy=self.vision_feature_select_strategy,
+ )
+ image_features, feature_lens = self.pack_image_features(
+ image_features,
+ image_sizes,
+ self.vision_feature_select_strategy,
+ image_newline=self.image_newline,
+ )
+
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+ n_image_features = image_features.shape[0]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
+ video_features = self.get_video_features(
+ pixel_values_videos,
+ vision_feature_layer=self.vision_feature_layer,
+ vision_feature_select_strategy=self.vision_feature_select_strategy,
+ )
+ video_features = [feature.flatten(0, 1) for feature in video_features]
+ video_feature_lens = [feature.size(0) for feature in video_features]
+ video_features = torch.cat(video_features, dim=0)
+ video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
+
+ special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
+ n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
+ n_video_features = video_features.shape[0]
+ raise ValueError(
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
+ )
+ video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ output = LlavaNextVideoModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ video_hidden_states=video_features if pixel_values_videos is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+ def get_video_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Union[int, List[int]],
+ vision_feature_select_strategy: str,
+ ):
+ """
+ Obtains video last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
+ The tensors corresponding to the input video.
+ vision_feature_layer (`Union[int, List[int]]`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ Returns:
+ video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
+ and are of shape `(num_videos, video_length, embed_dim)`).
+ """
+ batch_size, frames, channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
+ video_features = self.vision_tower(pixel_values, output_hidden_states=True)
+
+ # If we have one vision feature layer, return the corresponding hidden states,
+ # otherwise, select the hidden states of each feature layer and concatenate them
+ if isinstance(vision_feature_layer, int):
+ selected_video_features = video_features.hidden_states[vision_feature_layer]
+ else:
+ hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
+ selected_video_features = torch.cat(hs_pool, dim=-1)
+
+ if vision_feature_select_strategy == "default":
+ selected_video_features = selected_video_features[:, 1:]
+ elif vision_feature_select_strategy == "full":
+ selected_video_features = selected_video_features
+
+ # Same as image features except that video has pooling layer
+ video_features = self.vision_resampler(selected_video_features)
+ video_features = self.multi_modal_projector(video_features)
+ video_features = torch.split(video_features, frames, dim=0)
+ return video_features
+
+
+@add_start_docstrings(
+ """The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
+ LLAVA_NEXT_VIDEO_START_DOCSTRING,
+)
+class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^image_newline": "model.image_newline",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: LlavaNextVideoConfig):
+ super().__init__(config)
+ self.model = LlavaNextVideoModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
@add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
- pixel_values_videos: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ pixel_values_videos: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@@ -662,117 +888,47 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"USER: \nWhat's the content of the image? ASSISTANT: The image shows a red stop sign on a pole, with a traditional Chinese archway (...)"
```"""
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- self.vision_feature_layer = (
+ vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
- self.vision_feature_select_strategy = (
+ vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
- "and must specify either one"
- )
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if pixel_values is not None and pixel_values.size(0) > 0:
- image_features = self.get_image_features(
- pixel_values,
- image_sizes,
- vision_feature_layer=self.vision_feature_layer,
- vision_feature_select_strategy=self.vision_feature_select_strategy,
- )
- image_features, feature_lens = self.pack_image_features(
- image_features,
- image_sizes,
- self.vision_feature_select_strategy,
- image_newline=self.image_newline,
- )
-
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
- n_image_tokens = (input_ids == self.config.image_token_id).sum()
- n_image_features = image_features.shape[0]
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
- video_features = self.get_video_features(
- pixel_values_videos,
- vision_feature_layer=self.vision_feature_layer,
- vision_feature_select_strategy=self.vision_feature_select_strategy,
- )
- video_features = [feature.flatten(0, 1) for feature in video_features]
- video_feature_lens = [feature.size(0) for feature in video_features]
- video_features = torch.cat(video_features, dim=0)
- video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
-
- special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
- n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
- n_video_features = video_features.shape[0]
- raise ValueError(
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
- )
- video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
-
- outputs = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
+ image_sizes=image_sizes,
**lm_kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
- )
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return LlavaNextVideoCausalLMOutputWithPast(
loss=loss,
@@ -780,8 +936,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
- video_hidden_states=video_features if pixel_values_videos is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
+ video_hidden_states=outputs.video_hidden_states,
)
def prepare_inputs_for_generation(
@@ -799,7 +955,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
):
# Overwritten -- extra custom processing
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -818,51 +974,60 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
return model_inputs
- def get_video_features(
- self,
- pixel_values: torch.FloatTensor,
- vision_feature_layer: Union[int, List[int]],
- vision_feature_select_strategy: str,
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
):
"""
- Obtains video last hidden states from the vision tower and apply multimodal projection.
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
- pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
- The tensors corresponding to the input video.
- vision_feature_layer (`Union[int, List[int]]`):
- The index of the layer to select the vision feature. If multiple indices are provided,
- the vision feature of the corresponding indices will be concatenated to form the
- vision features.
- vision_feature_select_strategy (`str`):
- The feature selection strategy used to select the vision feature from the vision backbone.
- Can be one of `"default"` or `"full"`
- Returns:
- video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
- and are of shape `(num_videos, video_length, embed_dim)`).
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
"""
- batch_size, frames, channels, height, width = pixel_values.shape
- pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
- video_features = self.vision_tower(pixel_values, output_hidden_states=True)
-
- # If we have one vision feature layer, return the corresponding hidden states,
- # otherwise, select the hidden states of each feature layer and concatenate them
- if isinstance(vision_feature_layer, int):
- selected_video_features = video_features.hidden_states[vision_feature_layer]
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
else:
- hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
- selected_video_features = torch.cat(hs_pool, dim=-1)
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
- if vision_feature_select_strategy == "default":
- selected_video_features = selected_video_features[:, 1:]
- elif vision_feature_select_strategy == "full":
- selected_video_features = selected_video_features
-
- # Same as image features except that video has pooling layer
- video_features = self.vision_resampler(selected_video_features)
- video_features = self.multi_modal_projector(video_features)
- video_features = torch.split(video_features, frames, dim=0)
- return video_features
+ return causal_mask
-__all__ = ["LlavaNextVideoForConditionalGeneration", "LlavaNextVideoPreTrainedModel"]
+__all__ = ["LlavaNextVideoForConditionalGeneration", "LlavaNextVideoModel", "LlavaNextVideoPreTrainedModel"]
diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py
index b0b744c5b32..985a69a68ec 100644
--- a/src/transformers/models/llava_next_video/modular_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py
@@ -24,15 +24,19 @@ from torch import nn
from transformers.models.llava_next.modeling_llava_next import (
LlavaNextCausalLMOutputWithPast,
LlavaNextForConditionalGeneration,
+ LlavaNextModel,
+ LlavaNextModelOutputWithPast,
LlavaNextMultiModalProjector,
- LlavaNextPreTrainedModel,
image_size_to_num_patches,
)
from ...configuration_utils import PretrainedConfig
from ...utils import (
+ add_start_docstrings_to_model_forward,
+ can_return_tuple,
is_torchdynamo_compiling,
logging,
+ replace_return_docstrings,
)
from ..auto import CONFIG_MAPPING, AutoConfig
@@ -40,6 +44,9 @@ from ..auto import CONFIG_MAPPING, AutoConfig
logger = logging.get_logger(__name__)
+_CONFIG_FOR_DOC = "LlavaNextVideoConfig"
+
+
class LlavaNextVideoConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlavaNextVideoForConditionalGeneration`]. It is used to instantiate an
@@ -182,6 +189,17 @@ class LlavaNextVideoConfig(PretrainedConfig):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+@dataclass
+class LlavaNextVideoModelOutputWithPast(LlavaNextModelOutputWithPast):
+ """
+ video_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
+ video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ video_hidden_states: Optional[torch.FloatTensor] = None
+
+
@dataclass
class LlavaNextVideoCausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast):
"""
@@ -231,20 +249,7 @@ class LlavaNextVideoMultiModalProjector(LlavaNextMultiModalProjector):
pass
-class LlavaNextVideoPreTrainedModel(LlavaNextPreTrainedModel):
- def _init_weights(self, module):
- std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
-
- if isinstance(module, (nn.Linear, nn.Conv2d)):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, LlavaNextVideoForConditionalGeneration):
- embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
- module.image_newline.data.normal_(mean=0.0, std=embed_std)
-
-
-class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
+class LlavaNextVideoModel(LlavaNextModel):
def __init__(self, config: LlavaNextVideoConfig, **super_kwargs):
super().__init__(config, **super_kwargs)
self.vision_resampler = LlavaNextVideoPooler(config)
@@ -358,9 +363,209 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
- pixel_values_videos: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ pixel_values_videos: torch.FloatTensor = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, LlavaNextVideoModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ self.vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ self.vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
+ "and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None and pixel_values.size(0) > 0:
+ image_features = self.get_image_features(
+ pixel_values,
+ image_sizes,
+ vision_feature_layer=self.vision_feature_layer,
+ vision_feature_select_strategy=self.vision_feature_select_strategy,
+ )
+ image_features, feature_lens = self.pack_image_features(
+ image_features,
+ image_sizes,
+ self.vision_feature_select_strategy,
+ image_newline=self.image_newline,
+ )
+
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+ n_image_features = image_features.shape[0]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
+ video_features = self.get_video_features(
+ pixel_values_videos,
+ vision_feature_layer=self.vision_feature_layer,
+ vision_feature_select_strategy=self.vision_feature_select_strategy,
+ )
+ video_features = [feature.flatten(0, 1) for feature in video_features]
+ video_feature_lens = [feature.size(0) for feature in video_features]
+ video_features = torch.cat(video_features, dim=0)
+ video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
+
+ special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
+ n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
+ n_video_features = video_features.shape[0]
+ raise ValueError(
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
+ )
+ video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ output = LlavaNextVideoModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ video_hidden_states=video_features if pixel_values_videos is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
+ The tensors corresponding to the input images. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`LlavaNextVideoImageProcessor.__call__`] for details. [`LlavaProcessor`] uses
+ [`LlavaNextVideoImageProcessor`] for processing images.
+ image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
+ The sizes of the images in the batch, being (height, width) for each image.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
+ If `"full"`, the full vision features are used.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ pixel_values_videos: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@@ -452,117 +657,47 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"USER: \nWhat's the content of the image? ASSISTANT: The image shows a red stop sign on a pole, with a traditional Chinese archway (...)"
```"""
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- self.vision_feature_layer = (
+ vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
- self.vision_feature_select_strategy = (
+ vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
- "and must specify either one"
- )
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if pixel_values is not None and pixel_values.size(0) > 0:
- image_features = self.get_image_features(
- pixel_values,
- image_sizes,
- vision_feature_layer=self.vision_feature_layer,
- vision_feature_select_strategy=self.vision_feature_select_strategy,
- )
- image_features, feature_lens = self.pack_image_features(
- image_features,
- image_sizes,
- self.vision_feature_select_strategy,
- image_newline=self.image_newline,
- )
-
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
- n_image_tokens = (input_ids == self.config.image_token_id).sum()
- n_image_features = image_features.shape[0]
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
- video_features = self.get_video_features(
- pixel_values_videos,
- vision_feature_layer=self.vision_feature_layer,
- vision_feature_select_strategy=self.vision_feature_select_strategy,
- )
- video_features = [feature.flatten(0, 1) for feature in video_features]
- video_feature_lens = [feature.size(0) for feature in video_features]
- video_features = torch.cat(video_features, dim=0)
- video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
-
- special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
- n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
- n_video_features = video_features.shape[0]
- raise ValueError(
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
- )
- video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
-
- outputs = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
+ image_sizes=image_sizes,
**lm_kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
- )
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return LlavaNextVideoCausalLMOutputWithPast(
loss=loss,
@@ -570,8 +705,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
- video_hidden_states=video_features if pixel_values_videos is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
+ video_hidden_states=outputs.video_hidden_states,
)
def prepare_inputs_for_generation(
@@ -589,7 +724,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
):
# Overwritten -- extra custom processing
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -609,4 +744,9 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
return model_inputs
-__all__ = ["LlavaNextVideoConfig", "LlavaNextVideoForConditionalGeneration", "LlavaNextVideoPreTrainedModel"]
+__all__ = [
+ "LlavaNextVideoConfig",
+ "LlavaNextVideoForConditionalGeneration",
+ "LlavaNextVideoModel",
+ "LlavaNextVideoPreTrainedModel", # noqa: F822
+]
diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
index f5e2da2cd9e..df49004eb24 100644
--- a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
+++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
@@ -4,9 +4,25 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_llava_onevision.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import List, Optional, Union
+import torch
+
from ...image_processing_utils import BatchFeature, get_patch_output_size, select_best_resolution
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
@@ -28,11 +44,9 @@ from ...image_utils import (
make_flat_list_of_images,
)
from ...processing_utils import Unpack
-from ...utils import TensorType, add_start_docstrings, is_torch_available, is_torchvision_v2_available
+from ...utils import TensorType, add_start_docstrings, is_torchvision_v2_available
-if is_torch_available():
- import torch
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py
index be67df9b3af..5ef23387e9f 100644
--- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py
+++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py
@@ -1,3 +1,9 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/llava_onevision/modular_llava_onevision.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_llava_onevision.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
@@ -12,7 +18,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""PyTorch Llava-Onevision model."""
import math
from dataclasses import dataclass
@@ -20,140 +25,66 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
-import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
-from ...modeling_outputs import ModelOutput
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
+ can_return_tuple,
is_torchdynamo_compiling,
logging,
)
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_llava_onevision import LlavaOnevisionConfig
logger = logging.get_logger(__name__)
-_CONFIG_FOR_DOC = "LlavaNextConfig"
-
-# Copied from transformers.models.llava_next.modeling_llava_next.get_anyres_image_grid_shape
-def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+@dataclass
+class LlavaOnevisionModelOutputWithPast(BaseModelOutputWithPast):
"""
- Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
+ Base class for Llava outputs, with hidden states and attentions.
Args:
- image_size (`tuple`):
- The size of the input image in the format (width, height).
- grid_pinpoints (`List`):
- A list containing possible resolutions. Each item in the list should be a tuple or list
- of the form `(height, width)`.
- patch_size (`int`):
- The size of each image patch.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
- Returns:
- tuple: The shape of the image patch grid in the format (width, height).
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+
+ video_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
+ video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
- if not isinstance(grid_pinpoints, list):
- raise TypeError("grid_pinpoints should be a list of tuples or lists")
- # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
- if not isinstance(image_size, (list, tuple)):
- if not isinstance(image_size, (torch.Tensor, np.ndarray)):
- raise TypeError(
- f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
- )
- image_size = image_size.tolist()
+ image_hidden_states: Optional[torch.FloatTensor] = None
- height, width = select_best_resolution(image_size, grid_pinpoints)
- return height // patch_size, width // patch_size
-
-
-# Copied from transformers.models.llava_next.modeling_llava_next.image_size_to_num_patches
-def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
- """
- Calculate the number of patches after the preprocessing for images of any resolution.
-
- Args:
- image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
- The size of the input image in the format (height, width). ?
- grid_pinpoints (`List`):
- A list containing possible resolutions. Each item in the list should be a tuple or list
- of the form `(height, width)`.
- patch_size (`int`):
- The size of each image patch.
-
- Returns:
- int: the number of patches
- """
- if not isinstance(grid_pinpoints, list):
- raise TypeError("grid_pinpoints should be a list of tuples or lists")
-
- # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
- if not isinstance(image_size, (list, tuple)):
- if not isinstance(image_size, (torch.Tensor, np.ndarray)):
- raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
- image_size = image_size.tolist()
-
- best_resolution = select_best_resolution(image_size, grid_pinpoints)
- height, width = best_resolution
- num_patches = 0
- # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
- for i in range(0, height, patch_size):
- for j in range(0, width, patch_size):
- num_patches += 1
- # add the base patch
- num_patches += 1
- return num_patches
-
-
-# Copied from transformers.models.llava_next.modeling_llava_next.unpad_image
-def unpad_image(tensor, original_size):
- """
- Unpads a PyTorch tensor of a padded and resized image.
-
- Args:
- tensor (`torch.Tensor`):
- The image tensor, assumed to be of shape (num_channels, height, width).
- original_size (`tuple`):
- The original size of the image (height, width).
-
- Returns:
- `torch.Tensor`: The unpadded image tensor.
- """
- if not isinstance(original_size, (list, tuple)):
- if not isinstance(original_size, (torch.Tensor, np.ndarray)):
- raise TypeError(
- f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
- )
- original_size = original_size.tolist()
- original_height, original_width = original_size
- current_height, current_width = tensor.shape[1:]
-
- original_aspect_ratio = original_width / original_height
- current_aspect_ratio = current_width / current_height
-
- if original_aspect_ratio > current_aspect_ratio:
- scale_factor = current_width / original_width
- new_height = min(math.ceil(original_height * scale_factor), current_height)
- padding, r = divmod(current_height - new_height, 2)
- unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
- else:
- scale_factor = current_height / original_height
- new_width = min(math.ceil(original_width * scale_factor), current_width)
- padding, r = divmod(current_width - new_width, 2)
- unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
-
- return unpadded_tensor
+ video_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
-# Copied from transformers.models.llava_next_video.modeling_llava_next_video.LlavaNextVideoCausalLMOutputWithPast with LlavaNextVideo->LlavaOnevision
class LlavaOnevisionCausalLMOutputWithPast(ModelOutput):
"""
Base class for LlavaOnevision causal language model (or autoregressive) outputs.
@@ -195,10 +126,44 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput):
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
+
video_hidden_states: Optional[torch.FloatTensor] = None
-# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaOnevision
+class LlavaOnevisionPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ mode = config.spatial_pool_mode
+ stride = config.spatial_pool_stride
+ out_channels = getattr(config, "spatial_pool_out_channels", config.vision_config.hidden_size)
+ self.image_size = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
+
+ if mode == "average":
+ self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride)
+ elif mode == "max":
+ self.pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
+ elif mode == "conv":
+ self.pool = nn.Conv2d(
+ in_channels=config.vision_config.hidden_size,
+ out_channels=out_channels,
+ kernel_size=stride,
+ stride=stride,
+ )
+ else:
+ raise ValueError(f"Unknown pooling mode: {mode}. Has to be one of [`average`, `max`, `conv`]")
+
+ def forward(self, image_features):
+ ori_width = int(math.sqrt(image_features.shape[1] * self.image_size // self.image_size))
+ ori_height = int(ori_width * self.image_size // self.image_size)
+
+ batch_size, _, dim = image_features.shape
+ image_features_spatial = image_features.view(batch_size, ori_height, ori_height, dim).permute(0, 3, 1, 2)
+ image_features_spatial_pool = self.pool(image_features_spatial)
+
+ return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
+
+
class LlavaOnevisionMultiModalProjector(nn.Module):
def __init__(self, config: LlavaOnevisionConfig):
super().__init__()
@@ -231,40 +196,118 @@ LLAVA_ONEVISION_START_DOCSTRING = r"""
and behavior.
Parameters:
- config ([`LlavaNextConfig`] or [`LlavaNextVisionConfig`]):
+ config ([`LlavaOnevisionConfig`] or [`LlavaOnevisionVisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
-@add_start_docstrings(
- "The bare LLaVA-Onevision Model outputting raw hidden-states without any specific head on top.",
- LLAVA_ONEVISION_START_DOCSTRING,
-)
-class LlavaOnevisionPreTrainedModel(PreTrainedModel):
- config_class = LlavaOnevisionConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["LlavaOnevisionVisionAttention"]
- _skip_keys_device_placement = "past_key_values"
- _supports_flash_attn_2 = True
- _supports_cache_class = True
- _supports_static_cache = True
- _supports_quantized_cache = True
- _supports_sdpa = True
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+ """
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
- # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextPreTrainedModel._init_weights with LlavaNext->LlavaOnevision
- def _init_weights(self, module):
- std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
+ Args:
+ image_size (`tuple`):
+ The size of the input image in the format (width, height).
+ grid_pinpoints (`List`):
+ A list containing possible resolutions. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ patch_size (`int`):
+ The size of each image patch.
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, LlavaOnevisionForConditionalGeneration):
- embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
- module.image_newline.data.normal_(mean=0.0, std=embed_std)
+ Returns:
+ tuple: The shape of the image patch grid in the format (width, height).
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints should be a list of tuples or lists")
+
+ # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
+ if not isinstance(image_size, (list, tuple)):
+ if not isinstance(image_size, (torch.Tensor, np.ndarray)):
+ raise TypeError(
+ f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
+ )
+ image_size = image_size.tolist()
+
+ height, width = select_best_resolution(image_size, grid_pinpoints)
+ return height // patch_size, width // patch_size
+
+
+def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
+ """
+ Calculate the number of patches after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
+ The size of the input image in the format (height, width). ?
+ grid_pinpoints (`List`):
+ A list containing possible resolutions. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ patch_size (`int`):
+ The size of each image patch.
+
+ Returns:
+ int: the number of patches
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints should be a list of tuples or lists")
+
+ # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
+ if not isinstance(image_size, (list, tuple)):
+ if not isinstance(image_size, (torch.Tensor, np.ndarray)):
+ raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
+ image_size = image_size.tolist()
+
+ best_resolution = select_best_resolution(image_size, grid_pinpoints)
+ height, width = best_resolution
+ num_patches = 0
+ # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
+ for i in range(0, height, patch_size):
+ for j in range(0, width, patch_size):
+ num_patches += 1
+ # add the base patch
+ num_patches += 1
+ return num_patches
+
+
+def unpad_image(tensor, original_size):
+ """
+ Unpads a PyTorch tensor of a padded and resized image.
+
+ Args:
+ tensor (`torch.Tensor`):
+ The image tensor, assumed to be of shape (num_channels, height, width).
+ original_size (`tuple`):
+ The original size of the image (height, width).
+
+ Returns:
+ `torch.Tensor`: The unpadded image tensor.
+ """
+ if not isinstance(original_size, (list, tuple)):
+ if not isinstance(original_size, (torch.Tensor, np.ndarray)):
+ raise TypeError(
+ f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
+ )
+ original_size = original_size.tolist()
+ original_height, original_width = original_size
+ current_height, current_width = tensor.shape[1:]
+
+ original_aspect_ratio = original_width / original_height
+ current_aspect_ratio = current_width / current_height
+
+ if original_aspect_ratio > current_aspect_ratio:
+ scale_factor = current_width / original_width
+ new_height = min(math.ceil(original_height * scale_factor), current_height)
+ padding, r = divmod(current_height - new_height, 2)
+ unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
+ else:
+ scale_factor = current_height / original_height
+ new_width = min(math.ceil(original_width * scale_factor), current_width)
+ padding, r = divmod(current_width - new_width, 2)
+ unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
+
+ return unpadded_tensor
LLAVA_ONEVISION_INPUTS_DOCSTRING = r"""
@@ -279,16 +322,10 @@ LLAVA_ONEVISION_INPUTS_DOCSTRING = r"""
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
- [`AutoImageProcessor`]. See [`LlavaNextImageProcessor.__call__`] for details. [`LlavaProcessor`] uses
- [`LlavaNextImageProcessor`] for processing images.
+ [`AutoImageProcessor`]. See [`LlavaOnevisionImageProcessor.__call__`] for details. [`LlavaProcessor`] uses
+ [`LlavaOnevisionImageProcessor`] for processing images.
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
The sizes of the images in the batch, being (height, width) for each image.
- pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, frames, num_channels, image_size, image_size)):
- The tensors corresponding to the input videos. Pixel values can be obtained using
- [`LlavaNextVideoProcessor`]. See [`LlavaNextVideoProcessor.__call__`] for details. [`LlavaProcessor`] uses
- [`LlavaNextVideoProcessor`] for processing videos.
- image_sizes_videos (`torch.LongTensor` of shape `(batch_size, frames, 2)`, *optional*):
- The sizes of the videos in the batch, being (height, width) for each frame in the video.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
@@ -335,8 +372,6 @@ LLAVA_ONEVISION_INPUTS_DOCSTRING = r"""
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
If `"full"`, the full vision features are used.
- vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
- Aspect ratio used when processong image features. The default value is "anyres_max_9".
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
@@ -356,11 +391,41 @@ LLAVA_ONEVISION_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
- """The LLaVA-Onevision model which consists of a vision backbone and a language model.""",
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAVA_ONEVISION_START_DOCSTRING,
)
-class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, GenerationMixin):
- def __init__(self, config: LlavaOnevisionConfig):
+class LlavaOnevisionPreTrainedModel(PreTrainedModel):
+ config_class = LlavaOnevisionConfig
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_cache_class = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+
+ def _init_weights(self, module):
+ std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
+
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, LlavaOnevisionModel):
+ embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
+ module.image_newline.data.normal_(mean=0.0, std=embed_std)
+
+
+@add_start_docstrings(
+ """The Llava-Next model which consists of a vision backbone and a language model without language modeling head.""",
+ LLAVA_ONEVISION_START_DOCSTRING,
+)
+class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
+ def __init__(self, config):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
@@ -369,36 +434,16 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
+ self.language_model = AutoModel.from_config(config.text_config)
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
- # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
- # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
- # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
- def get_decoder(self):
- return self.language_model.get_decoder()
-
def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@@ -465,20 +510,6 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
return image_features, feature_lens
- def apply_pooling(self, image_features):
- height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
- batch_frames, seq_len, dim = image_features.shape
- image_features = image_features.view(batch_frames, height, width, -1)
- image_features = image_features.permute(0, 3, 1, 2).contiguous()
-
- height, width = image_features.shape[2:]
- scaled_shape = [math.ceil(height / 2), math.ceil(width / 2)]
- image_features = nn.functional.interpolate(image_features, size=scaled_shape, mode="bilinear")
-
- image_features = image_features.permute(0, 2, 3, 1)
- image_features = image_features.view(batch_frames, -1, dim)
- return image_features
-
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@@ -494,7 +525,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
- vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
+ vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
@@ -539,59 +570,13 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
- def get_video_features(
- self,
- pixel_values: torch.FloatTensor,
- vision_feature_layer: Union[int, List[int]],
- vision_feature_select_strategy: str,
- ):
- """
- Obtains video last hidden states from the vision tower, apply multimodal projection and pooling.
-
- Args:
- pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
- The tensors corresponding to the input video.
- vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
- The index of the layer to select the vision feature. If multiple indices are provided,
- the vision feature of the corresponding indices will be concatenated to form the
- vision features.
- vision_feature_select_strategy (`str`):
- The feature selection strategy used to select the vision feature from the vision backbone.
- Can be one of `"default"` or `"full"`
- Returns:
- video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
- and are of shape `(num_videos, video_length, embed_dim)`).
- """
- batch_size, frames, channels, height, width = pixel_values.shape
- pixel_values = pixel_values.view(batch_size * frames, channels, height, width)
- video_features = self.vision_tower(pixel_values, output_hidden_states=True)
-
- # If we have one vision feature layer, return the corresponding hidden states,
- # otherwise, select the hidden states of each feature layer and concatenate them
- if isinstance(vision_feature_layer, int):
- selected_video_feature = video_features.hidden_states[vision_feature_layer]
- else:
- hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
- selected_video_feature = torch.cat(hs_pool, dim=-1)
-
- if vision_feature_select_strategy == "default":
- selected_video_feature = selected_video_feature[:, 1:]
- elif vision_feature_select_strategy == "full":
- selected_video_feature = selected_video_feature
- video_features = self.multi_modal_projector(selected_video_feature)
-
- video_features = self.apply_pooling(video_features)
- video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1)
-
- return video_features
-
@add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
- pixel_values_videos: Optional[torch.FloatTensor] = None,
+ pixel_values_videos: torch.FloatTensor = None,
image_sizes_videos: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@@ -600,62 +585,13 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_aspect_ratio: Optional[str] = None,
- labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
- ) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- logits_to_keep (`int` or `torch.Tensor`, *optional*):
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
- This is useful when using packed tensor format (single dimension for batch and sequence length).
-
-
- Returns:
- [`~LlavaOnevisionCausalLMOutputWithPast`] (if `return_dict=True`) or a `tuple`.
-
- Example:
-
- ```python
- >>> from PIL import Image
- >>> import requests
- >>> import torch
- >>> from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration
-
- >>> model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype="float16", device_map="cuda:0")
- >>> processor = LlavaOnevisionProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
-
- >>> conversation = [
- ... {
- ... "role": "user",
- ... "content": [
- ... {"type": "text", "text": "What is shown in this image?"},
- ... {"type": "image"},
- ... ],
- ... },
- ... ]
- >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
-
- >>> image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> raw_image = Image.open(requests.get(image_file, stream=True).raw)
- >>> inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.float16)
-
- >>> output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
- >>> processor.batch_decode(output, skip_special_tokens=True)[0]
- "user\n\nWhat is shown in this image?\nassistant\ncat"
- ```"""
+ ) -> Union[Tuple, LlavaOnevisionModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -743,35 +679,247 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- return_dict=return_dict,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ output = LlavaOnevisionModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ video_hidden_states=video_features if pixel_values_videos is not None else None,
+ )
+
+ return output if return_dict else output.to_tuple()
+
+ def get_video_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Union[int, List[int]],
+ vision_feature_select_strategy: str,
+ ):
+ """
+ Obtains video last hidden states from the vision tower, apply multimodal projection and pooling.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
+ The tensors corresponding to the input video.
+ vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ Returns:
+ video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
+ and are of shape `(num_videos, video_length, embed_dim)`).
+ """
+ batch_size, frames, channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.view(batch_size * frames, channels, height, width)
+ video_features = self.vision_tower(pixel_values, output_hidden_states=True)
+
+ # If we have one vision feature layer, return the corresponding hidden states,
+ # otherwise, select the hidden states of each feature layer and concatenate them
+ if isinstance(vision_feature_layer, int):
+ selected_video_feature = video_features.hidden_states[vision_feature_layer]
+ else:
+ hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
+ selected_video_feature = torch.cat(hs_pool, dim=-1)
+
+ if vision_feature_select_strategy == "default":
+ selected_video_feature = selected_video_feature[:, 1:]
+ elif vision_feature_select_strategy == "full":
+ selected_video_feature = selected_video_feature
+ video_features = self.multi_modal_projector(selected_video_feature)
+
+ video_features = self.apply_pooling(video_features)
+ video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1)
+
+ return video_features
+
+ def apply_pooling(self, image_features):
+ height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
+ batch_frames, seq_len, dim = image_features.shape
+ image_features = image_features.view(batch_frames, height, width, -1)
+ image_features = image_features.permute(0, 3, 1, 2).contiguous()
+
+ height, width = image_features.shape[2:]
+ scaled_shape = [math.ceil(height / 2), math.ceil(width / 2)]
+ image_features = nn.functional.interpolate(image_features, size=scaled_shape, mode="bilinear")
+
+ image_features = image_features.permute(0, 2, 3, 1)
+ image_features = image_features.view(batch_frames, -1, dim)
+ return image_features
+
+
+@add_start_docstrings(
+ """The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
+ LLAVA_ONEVISION_START_DOCSTRING,
+)
+class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^image_newline": "model.image_newline",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: LlavaOnevisionConfig):
+ super().__init__(config)
+ self.model = LlavaOnevisionModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
+ @add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ pixel_values_videos: torch.FloatTensor = None,
+ image_sizes_videos: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ vision_aspect_ratio: Optional[str] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **lm_kwargs,
+ ) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+
+ Returns:
+ [`~LlavaOnevisionCausalLMOutputWithPast`] (if `return_dict=True`) or a `tuple`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> import torch
+ >>> from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration
+
+ >>> model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype="float16", device_map="cuda:0")
+ >>> processor = LlavaOnevisionProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
+
+ >>> conversation = [
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {"type": "text", "text": "What is shown in this image?"},
+ ... {"type": "image"},
+ ... ],
+ ... },
+ ... ]
+ >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
+
+ >>> image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> raw_image = Image.open(requests.get(image_file, stream=True).raw)
+ >>> inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.float16)
+
+ >>> output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ >>> processor.batch_decode(output, skip_special_tokens=True)[0]
+ "user\n\nWhat is shown in this image?\nassistant\ncat"
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+ vision_aspect_ratio = (
+ vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_sizes=image_sizes,
+ image_sizes_videos=image_sizes_videos,
+ vision_aspect_ratio=vision_aspect_ratio,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
- )
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return LlavaOnevisionCausalLMOutputWithPast(
loss=loss,
@@ -779,8 +927,8 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
- video_hidden_states=video_features if pixel_values_videos is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
+ video_hidden_states=outputs.video_hidden_states,
)
def prepare_inputs_for_generation(
@@ -799,7 +947,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -819,5 +967,60 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
return model_inputs
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-__all__ = ["LlavaOnevisionForConditionalGeneration", "LlavaOnevisionPreTrainedModel"]
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = ["LlavaOnevisionModel", "LlavaOnevisionForConditionalGeneration", "LlavaOnevisionPreTrainedModel"]
diff --git a/src/transformers/models/llava_onevision/modular_llava_onevision.py b/src/transformers/models/llava_onevision/modular_llava_onevision.py
index 5a25124e58c..bc692c10a64 100644
--- a/src/transformers/models/llava_onevision/modular_llava_onevision.py
+++ b/src/transformers/models/llava_onevision/modular_llava_onevision.py
@@ -1,4 +1,34 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
from transformers.models.llava_next.image_processing_llava_next_fast import LlavaNextImageProcessorFast
+from transformers.models.llava_next_video.modeling_llava_next_video import (
+ LlavaNextVideoCausalLMOutputWithPast,
+ LlavaNextVideoForConditionalGeneration,
+ LlavaNextVideoModel,
+ LlavaNextVideoModelOutputWithPast,
+ LlavaNextVideoPreTrainedModel,
+ get_anyres_image_grid_shape,
+ unpad_image,
+)
from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING
from ...image_utils import (
@@ -6,11 +36,13 @@ from ...image_utils import (
OPENAI_CLIP_STD,
PILImageResampling,
)
-from ...utils import add_start_docstrings, logging
+from ...utils import add_start_docstrings, can_return_tuple, is_torchdynamo_compiling, logging
logger = logging.get_logger(__name__)
+LLAVA_ONEVISION_INPUTS_DOCSTRING = None
+
@add_start_docstrings(
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
@@ -42,4 +74,446 @@ class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast):
model_input_names = ["pixel_values_videos"]
-__all__ = ["LlavaOnevisionImageProcessorFast"]
+class LlavaOnevisionModelOutputWithPast(LlavaNextVideoModelOutputWithPast):
+ pass
+
+
+class LlavaOnevisionCausalLMOutputWithPast(LlavaNextVideoCausalLMOutputWithPast):
+ pass
+
+
+class LlavaOnevisionPreTrainedModel(LlavaNextVideoPreTrainedModel):
+ pass
+
+
+class LlavaOnevisionModel(LlavaNextVideoModel):
+ def __init__(self, config):
+ super().__init__(config)
+ del self.vision_resampler
+
+ def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"):
+ """
+ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
+
+ Args:
+ image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
+ List of image feature tensor, each contains all the visual feature of all patches.
+ image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
+ Actual image size of each images (H, W).
+ image_newline (`torch.Tensor` of shape `(embed_dim)`)
+ New line embedding vector.
+ vision_aspect_ratio (`str`, *optional*, "anyres_max_9"):
+ Aspect ratio used when processong image features. The default value is "anyres_max_9".
+ Returns:
+ image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
+ feature_lens (`List[int]`)
+ token length of each image in image_features
+ """
+ new_image_features = []
+ feature_lens = []
+ for image_idx, image_feature in enumerate(image_features):
+ if image_feature.shape[0] > 1:
+ base_image_feature = image_feature[0]
+ image_feature = image_feature[1:]
+ height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
+ if height * width != base_image_feature.shape[0]:
+ raise ValueError("The number of patches is not consistent with the image size.")
+ num_patch_height, num_patch_width = get_anyres_image_grid_shape(
+ image_sizes[image_idx],
+ self.config.image_grid_pinpoints,
+ self.config.vision_config.image_size,
+ )
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
+ max_num_patches = int(vision_aspect_ratio.strip("anyres_max_"))
+ channels, curr_height, curr_width = image_feature.shape
+ ratio = math.sqrt(curr_height * curr_width / (max_num_patches * height**2))
+ if ratio > 1.1:
+ image_feature = image_feature[None]
+ image_feature = nn.functional.interpolate(
+ image_feature, [int(curr_height // ratio), int(curr_width // ratio)], mode="bilinear"
+ )[0]
+ if image_newline is not None:
+ image_feature = torch.cat(
+ (
+ image_feature,
+ image_newline[:, None, None]
+ .expand(*image_feature.shape[:-1], 1)
+ .to(image_feature.device, image_feature.dtype),
+ ),
+ dim=-1,
+ )
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
+ else:
+ image_feature = image_feature[0]
+ if image_newline is not None:
+ image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
+ new_image_features.append(image_feature)
+ feature_lens.append(image_feature.size(0))
+ image_features = torch.cat(new_image_features, dim=0)
+ feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
+ return image_features, feature_lens
+
+ def apply_pooling(self, image_features):
+ height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
+ batch_frames, seq_len, dim = image_features.shape
+ image_features = image_features.view(batch_frames, height, width, -1)
+ image_features = image_features.permute(0, 3, 1, 2).contiguous()
+
+ height, width = image_features.shape[2:]
+ scaled_shape = [math.ceil(height / 2), math.ceil(width / 2)]
+ image_features = nn.functional.interpolate(image_features, size=scaled_shape, mode="bilinear")
+
+ image_features = image_features.permute(0, 2, 3, 1)
+ image_features = image_features.view(batch_frames, -1, dim)
+ return image_features
+
+ def get_video_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Union[int, List[int]],
+ vision_feature_select_strategy: str,
+ ):
+ """
+ Obtains video last hidden states from the vision tower, apply multimodal projection and pooling.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
+ The tensors corresponding to the input video.
+ vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ Returns:
+ video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
+ and are of shape `(num_videos, video_length, embed_dim)`).
+ """
+ batch_size, frames, channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.view(batch_size * frames, channels, height, width)
+ video_features = self.vision_tower(pixel_values, output_hidden_states=True)
+
+ # If we have one vision feature layer, return the corresponding hidden states,
+ # otherwise, select the hidden states of each feature layer and concatenate them
+ if isinstance(vision_feature_layer, int):
+ selected_video_feature = video_features.hidden_states[vision_feature_layer]
+ else:
+ hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
+ selected_video_feature = torch.cat(hs_pool, dim=-1)
+
+ if vision_feature_select_strategy == "default":
+ selected_video_feature = selected_video_feature[:, 1:]
+ elif vision_feature_select_strategy == "full":
+ selected_video_feature = selected_video_feature
+ video_features = self.multi_modal_projector(selected_video_feature)
+
+ video_features = self.apply_pooling(video_features)
+ video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1)
+
+ return video_features
+
+ @add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ pixel_values_videos: torch.FloatTensor = None,
+ image_sizes_videos: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ vision_aspect_ratio: Optional[str] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, LlavaOnevisionModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+ vision_aspect_ratio = (
+ vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
+ "and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # Images are processed with Anyres
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values,
+ image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+ image_features, feature_lens = self.pack_image_features(
+ image_features,
+ image_sizes,
+ image_newline=self.image_newline,
+ vision_aspect_ratio=vision_aspect_ratio,
+ )
+
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+ n_image_features = image_features.shape[0]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ # Video are simply embedded and further pooled to decrease seq len
+ if pixel_values_videos is not None:
+ video_features = self.get_video_features(
+ pixel_values_videos,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+ image_newline = (
+ self.image_newline[None, None, :].repeat(video_features.shape[0], 1, 1).to(video_features.device)
+ )
+ video_features = torch.cat((video_features, image_newline), dim=1)
+ video_features = video_features.flatten(0, 1)
+
+ special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
+ special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel():
+ n_video_tokens = (input_ids == self.config.video_token_id).sum()
+ n_video_features = video_features.shape[0]
+ raise ValueError(
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
+ )
+ video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ output = LlavaOnevisionModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ video_hidden_states=video_features if pixel_values_videos is not None else None,
+ )
+
+ return output if return_dict else output.to_tuple()
+
+
+class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGeneration):
+ @can_return_tuple
+ @add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ pixel_values_videos: torch.FloatTensor = None,
+ image_sizes_videos: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ vision_aspect_ratio: Optional[str] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **lm_kwargs,
+ ) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+
+ Returns:
+ [`~LlavaOnevisionCausalLMOutputWithPast`] (if `return_dict=True`) or a `tuple`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> import torch
+ >>> from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration
+
+ >>> model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype="float16", device_map="cuda:0")
+ >>> processor = LlavaOnevisionProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
+
+ >>> conversation = [
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {"type": "text", "text": "What is shown in this image?"},
+ ... {"type": "image"},
+ ... ],
+ ... },
+ ... ]
+ >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
+
+ >>> image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> raw_image = Image.open(requests.get(image_file, stream=True).raw)
+ >>> inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.float16)
+
+ >>> output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ >>> processor.batch_decode(output, skip_special_tokens=True)[0]
+ "user\n\nWhat is shown in this image?\nassistant\ncat"
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+ vision_aspect_ratio = (
+ vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_sizes=image_sizes,
+ image_sizes_videos=image_sizes_videos,
+ vision_aspect_ratio=vision_aspect_ratio,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **lm_kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return LlavaOnevisionCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ video_hidden_states=outputs.video_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ image_sizes=None,
+ pixel_values_videos=None,
+ image_sizes_videos=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["image_sizes"] = image_sizes
+ model_inputs["pixel_values_videos"] = pixel_values_videos
+ model_inputs["image_sizes_videos"] = image_sizes_videos
+
+ return model_inputs
+
+
+__all__ = [
+ "LlavaOnevisionImageProcessorFast",
+ "LlavaOnevisionModel",
+ "LlavaOnevisionForConditionalGeneration",
+ "LlavaOnevisionPreTrainedModel",
+]
diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py
index 5ce7763dd7c..7078631552f 100644
--- a/src/transformers/models/mistral3/modeling_mistral3.py
+++ b/src/transformers/models/mistral3/modeling_mistral3.py
@@ -28,19 +28,20 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
-from ...modeling_outputs import ModelOutput
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ can_return_tuple,
is_torchdynamo_compiling,
replace_return_docstrings,
)
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_mistral3 import Mistral3Config
-_CONFIG_FOR_DOC = "Mistral3Config"
+_CONFIG_FOR_DOC = "Mistra3Config"
@use_kernel_forward_from_hub("RMSNorm")
@@ -156,7 +157,7 @@ class Mistral3CausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@@ -168,6 +169,39 @@ class Mistral3CausalLMOutputWithPast(ModelOutput):
image_hidden_states: Optional[torch.FloatTensor] = None
+@dataclass
+class Mistral3ModelOutputWithPast(BaseModelOutputWithPast):
+ """
+ Base class for Mistral3 outputs, with hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
MISTRAL3_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -186,14 +220,13 @@ MISTRAL3_START_DOCSTRING = r"""
@add_start_docstrings(
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ "The bare Mistral3 Model outputting raw hidden-states without any specific head on top.",
MISTRAL3_START_DOCSTRING,
)
class Mistral3PreTrainedModel(PreTrainedModel):
config_class = Mistral3Config
- base_model_prefix = "model"
+ base_model_prefix = ""
supports_gradient_checkpointing = True
- _no_split_modules = ["Mistral3VisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@@ -202,12 +235,18 @@ class Mistral3PreTrainedModel(PreTrainedModel):
_supports_static_cache = True
def _init_weights(self, module):
+ # important: this ported version of Mistral3 isn't meant for training from scratch - only
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
+ # https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
elif isinstance(module, Mistral3RMSNorm):
module.weight.data.fill_(1.0)
@@ -290,23 +329,18 @@ MISTRAL3_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
- """The MISTRAL3 model which consists of a vision backbone and a language model.""",
+ """The Mistral3 model which consists of a vision backbone and a language model, without a language modeling head.""",
MISTRAL3_START_DOCSTRING,
)
-class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin):
+class Mistral3Model(Mistral3PreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: Mistral3Config):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = Mistral3MultiModalProjector(config)
- self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
-
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
-
+ self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@@ -315,18 +349,6 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@@ -364,64 +386,23 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
return image_features
@add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=Mistral3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
- labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- logits_to_keep: Union[int, torch.Tensor] = 0,
- image_sizes: Optional[torch.Tensor] = None,
+ image_sizes: torch.Tensor = None,
**lm_kwargs,
- ) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- logits_to_keep (`int` or `torch.Tensor`, *optional*):
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
- This is useful when using packed tensor format (single dimension for batch and sequence length).
-
-
- Returns:
-
- Example:
-
- ```python
- >>> from PIL import Image
- >>> import requests
- >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
-
- >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
- >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
-
- >>> prompt = "[INST][IMG]What is the image?[/INST]"
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> image = Image.open(requests.get(url, stream=True).raw)
-
- >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
- >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "What is the image?The image depicts two cats lying on a pink blanket."
- ```"""
-
+ ) -> Union[Tuple, Mistral3ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -468,35 +449,153 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- return_dict=return_dict,
+ return_dict=True,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
**lm_kwargs,
)
- logits = outputs[0]
+ output = Mistral3ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+@add_start_docstrings(
+ """The MISTRAL3 model which consists of a vision backbone and a language model.""",
+ MISTRAL3_START_DOCSTRING,
+)
+class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: Mistral3Config):
+ super().__init__(config)
+ self.model = Mistral3Model(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Mistral3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ image_sizes: torch.Tensor = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
+
+ >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
+ >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
+
+ >>> prompt = "[INST][IMG]What is the image?[/INST]"
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "What is the image?The image depicts two cats lying on a pink blanket."
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ image_sizes=image_sizes,
+ **lm_kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
- )
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return Mistral3CausalLMOutputWithPast(
loss=loss,
@@ -504,7 +603,7 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@@ -520,7 +619,7 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -537,5 +636,60 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
return model_inputs
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-__all__ = ["Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"]
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = ["Mistral3Model", "Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"]
diff --git a/src/transformers/models/mistral3/modular_mistral3.py b/src/transformers/models/mistral3/modular_mistral3.py
index 36fd4526838..5ef6663bde0 100644
--- a/src/transformers/models/mistral3/modular_mistral3.py
+++ b/src/transformers/models/mistral3/modular_mistral3.py
@@ -19,14 +19,29 @@ import torch
from torch import nn
from ...activations import ACT2FN
-from ...utils import is_torchdynamo_compiling, logging
-from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaPreTrainedModel
+from ...utils import (
+ add_start_docstrings_to_model_forward,
+ can_return_tuple,
+ is_torchdynamo_compiling,
+ logging,
+ replace_return_docstrings,
+)
+from ..llava.modeling_llava import (
+ LlavaCausalLMOutputWithPast,
+ LlavaForConditionalGeneration,
+ LlavaModel,
+ LlavaModelOutputWithPast,
+ LlavaPreTrainedModel,
+)
from ..mistral.modeling_mistral import MistralRMSNorm
from .configuration_mistral3 import Mistral3Config
logger = logging.get_logger(__name__)
+MISTRAL3_INPUTS_DOCSTRING = None
+_CONFIG_FOR_DOC = "Mistra3Config"
+
class Mistral3RMSNorm(MistralRMSNorm):
pass
@@ -100,19 +115,29 @@ class Mistral3CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
+class Mistral3ModelOutputWithPast(LlavaModelOutputWithPast):
+ pass
+
+
class Mistral3PreTrainedModel(LlavaPreTrainedModel):
def _init_weights(self, module):
+ # important: this ported version of Mistral3 isn't meant for training from scratch - only
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
+ # https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
elif isinstance(module, Mistral3RMSNorm):
module.weight.data.fill_(1.0)
-class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
+class Mistral3Model(LlavaModel):
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@@ -151,61 +176,21 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
- labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- logits_to_keep: Union[int, torch.Tensor] = 0,
- image_sizes: Optional[torch.Tensor] = None,
+ image_sizes: torch.Tensor = None,
**lm_kwargs,
- ) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- logits_to_keep (`int` or `torch.Tensor`, *optional*):
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
- This is useful when using packed tensor format (single dimension for batch and sequence length).
-
-
- Returns:
-
- Example:
-
- ```python
- >>> from PIL import Image
- >>> import requests
- >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
-
- >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
- >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
-
- >>> prompt = "[INST][IMG]What is the image?[/INST]"
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> image = Image.open(requests.get(url, stream=True).raw)
-
- >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
- >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "What is the image?The image depicts two cats lying on a pink blanket."
- ```"""
-
+ ) -> Union[Tuple, Mistral3ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -252,35 +237,110 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- return_dict=return_dict,
+ return_dict=True,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
**lm_kwargs,
)
- logits = outputs[0]
+ output = Mistral3ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Mistral3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ image_sizes: torch.Tensor = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
+
+ >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
+ >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
+
+ >>> prompt = "[INST][IMG]What is the image?[/INST]"
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "What is the image?The image depicts two cats lying on a pink blanket."
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ image_sizes=image_sizes,
+ **lm_kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
- )
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return Mistral3CausalLMOutputWithPast(
loss=loss,
@@ -288,11 +348,12 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
__all__ = [
+ "Mistral3Model",
"Mistral3PreTrainedModel", # noqa
"Mistral3ForConditionalGeneration",
]
diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py
index d842bd7c131..79278c9892e 100644
--- a/src/transformers/models/mllama/modeling_mllama.py
+++ b/src/transformers/models/mllama/modeling_mllama.py
@@ -32,6 +32,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ can_return_tuple,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
@@ -1968,10 +1969,11 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
@add_start_docstrings(
- """The Mllama model which consists of a vision encoder and a language model.""",
+ """The Mllama model which consists of a vision encoder and a language model without language modeling head.""",
MLLAMA_START_DOCSTRING,
)
-class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
+class MllamaModel(MllamaPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
_supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
def __init__(self, config: MllamaConfig):
@@ -1983,10 +1985,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.vision_model = MllamaVisionModel._from_config(config.vision_config)
- self.language_model = MllamaForCausalLM._from_config(config.text_config)
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
+ self.language_model = MllamaTextModel._from_config(config.text_config)
self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
@@ -2000,18 +1999,139 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
+ @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ aspect_ratio_mask: Optional[torch.Tensor] = None,
+ aspect_ratio_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None and cross_attention_states is not None:
+ raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")
+
+ if pixel_values is not None:
+ if aspect_ratio_ids is None:
+ raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
+ # get vision tokens from vision model
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ aspect_ratio_ids=aspect_ratio_ids,
+ aspect_ratio_mask=aspect_ratio_mask,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+ cross_attention_states = vision_outputs[0]
+ cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
+ -1, cross_attention_states.shape[-2], self.hidden_size
+ )
+
+ if cross_attention_mask is not None:
+ cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
+ cross_attention_mask,
+ num_vision_tokens=self.vision_model.num_patches,
+ dtype=self.dtype,
+ )
+ else:
+ full_text_row_masked_out_mask = None
+
+ if cross_attention_mask is not None and cache_position is not None:
+ cross_attention_mask = cross_attention_mask[:, :, cache_position]
+ full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
+
+ outputs = self.language_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cross_attention_states=cross_attention_states,
+ cross_attention_mask=cross_attention_mask,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ inputs_embeds=inputs_embeds,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=True,
+ cache_position=cache_position,
+ )
+
+ output = BaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+@add_start_docstrings(
+ """The Mllama model which consists of a vision encoder and a language model.""",
+ MLLAMA_START_DOCSTRING,
+)
+class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_model": "model.vision_model",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: MllamaConfig):
+ super().__init__(config)
+ self.model = MllamaModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
+ return self.lm_head
def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
+ self.lm_head = new_embeddings
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
- def get_decoder(self):
- return self.language_model.get_decoder()
+ @property
+ def vision_model(self):
+ return self.model.vision_model
+ @can_return_tuple
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaConfig")
def forward(
@@ -2084,78 +2204,36 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if pixel_values is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
- )
-
- if pixel_values is not None and cross_attention_states is not None:
- raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")
-
- if pixel_values is not None:
- if aspect_ratio_ids is None:
- raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
- # get vision tokens from vision model
- vision_outputs = self.vision_model(
- pixel_values=pixel_values,
- aspect_ratio_ids=aspect_ratio_ids,
- aspect_ratio_mask=aspect_ratio_mask,
- output_hidden_states=output_hidden_states,
- output_attentions=output_attentions,
- return_dict=return_dict,
- )
- cross_attention_states = vision_outputs[0]
- cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
- -1, cross_attention_states.shape[-2], self.hidden_size
- )
-
- if cross_attention_mask is not None:
- cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
- cross_attention_mask,
- num_vision_tokens=self.vision_model.num_patches,
- dtype=self.dtype,
- )
- else:
- full_text_row_masked_out_mask = None
-
- if cross_attention_mask is not None and cache_position is not None:
- cross_attention_mask = cross_attention_mask[:, :, cache_position]
- full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
-
- outputs = self.language_model(
+ outputs = self.model(
input_ids=input_ids,
+ pixel_values=pixel_values,
+ aspect_ratio_mask=aspect_ratio_mask,
+ aspect_ratio_ids=aspect_ratio_ids,
+ cross_attention_mask=cross_attention_mask,
+ cross_attention_states=cross_attention_states,
attention_mask=attention_mask,
position_ids=position_ids,
- cross_attention_states=cross_attention_states,
- cross_attention_mask=cross_attention_mask,
- full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
- use_cache=use_cache,
inputs_embeds=inputs_embeds,
- output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
output_attentions=output_attentions,
- return_dict=return_dict,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
- **loss_kwargs,
)
- # Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
loss = None
- logits = outputs[0]
-
if labels is not None:
- loss = self.loss_function(logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs)
-
- if not return_dict:
- return (loss,) + outputs if loss is not None else outputs
+ loss = self.loss_function(logits, labels, self.config.text_config.vocab_size, **loss_kwargs)
return CausalLMOutputWithPast(
loss=loss,
- logits=outputs.logits,
+ logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
@@ -2179,58 +2257,28 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
- # Exception 1: when passing input_embeds, input_ids may be missing entries
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
- # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
- # (we can't check exception 3 while compiling)
- if past_key_values is not None:
- if (
- inputs_embeds is not None # Exception 1
- or cache_position[-1] >= input_ids.shape[1] # Exception 3
- ):
- input_ids = input_ids[:, -cache_position.shape[0] :]
- elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
- input_ids = input_ids[:, cache_position]
-
- # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -input_ids.shape[1] :]
-
- # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
- position_ids = position_ids.clone(memory_format=torch.contiguous_format)
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and cache_position[0] == 0:
- model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
- else:
- # The clone here is for the same reason as for `position_ids`.
- model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
-
- if logits_to_keep is not None:
- model_inputs["logits_to_keep"] = logits_to_keep
-
- model_inputs.update(
- {
- "position_ids": position_ids,
- "cache_position": cache_position,
- "past_key_values": past_key_values,
- "use_cache": use_cache,
- "attention_mask": attention_mask,
- "cross_attention_mask": cross_attention_mask,
- }
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ pixel_values=pixel_values,
+ aspect_ratio_ids=aspect_ratio_ids,
+ aspect_ratio_mask=aspect_ratio_mask,
+ cross_attention_mask=cross_attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
)
# If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios
# to compute image hidden states, otherwise they are cached within each cross attn layer
- if cache_position[0] == 0:
- model_inputs["pixel_values"] = pixel_values
- model_inputs["aspect_ratio_ids"] = aspect_ratio_ids
- model_inputs["aspect_ratio_mask"] = aspect_ratio_mask
+ if cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["aspect_ratio_ids"] = None
+ model_inputs["aspect_ratio_mask"] = None
return model_inputs
@@ -2257,4 +2305,5 @@ __all__ = [
"MllamaTextModel",
"MllamaVisionModel",
"MllamaPreTrainedModel",
+ "MllamaModel",
]
diff --git a/src/transformers/models/paligemma/configuration_paligemma.py b/src/transformers/models/paligemma/configuration_paligemma.py
index 4551b85bcd5..f32ad303bf1 100644
--- a/src/transformers/models/paligemma/configuration_paligemma.py
+++ b/src/transformers/models/paligemma/configuration_paligemma.py
@@ -13,8 +13,6 @@
# limitations under the License.
"""PaliGemmamodel configuration"""
-import warnings
-
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING, AutoConfig
@@ -39,8 +37,6 @@ class PaliGemmaConfig(PretrainedConfig):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
- ignore_index (`int`, *optional*, defaults to -100):
- The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 256000):
The image token index to encode the image prompt.
vocab_size (`int`, *optional*, defaults to 257152):
@@ -83,16 +79,13 @@ class PaliGemmaConfig(PretrainedConfig):
self,
vision_config=None,
text_config=None,
- ignore_index=-100,
image_token_index=256000,
vocab_size=257152,
projection_dim=2048,
hidden_size=2048,
**kwargs,
):
- self._ignore_index = ignore_index
self.image_token_index = image_token_index
- self._vocab_size = vocab_size
self.projection_dim = projection_dim
self.hidden_size = hidden_size
self.vision_config = vision_config
@@ -133,22 +126,5 @@ class PaliGemmaConfig(PretrainedConfig):
self.vision_config.projection_dim = projection_dim
super().__init__(**kwargs)
- @property
- def ignore_index(self):
- warnings.warn(
- "The `ignore_index` attribute is deprecated and will be removed in v4.47.",
- FutureWarning,
- )
- return self._ignore_index
-
- @ignore_index.setter
- def ignore_index(self, value):
- self._ignore_index = value
-
- def to_dict(self):
- output = super().to_dict()
- output.pop("_ignore_index", None)
- return output
-
__all__ = ["PaliGemmaConfig"]
diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py
index 924e3b1dcad..9561f0f0628 100644
--- a/src/transformers/models/paligemma/modeling_paligemma.py
+++ b/src/transformers/models/paligemma/modeling_paligemma.py
@@ -23,17 +23,18 @@ from torch import nn
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
-from ...modeling_outputs import CausalLMOutputWithPast
+from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_paligemma import PaliGemmaConfig
@@ -42,78 +43,43 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "PaliGemmaConfig"
-# Adapted from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
-# But Paligemma has no causal mask on prefix
-def _prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask: torch.Tensor,
- sequence_length: int,
- target_length: int,
- dtype: torch.dtype,
- min_dtype: float,
- cache_position: torch.Tensor,
- batch_size: int,
- is_training: bool = False,
- token_type_ids: Optional[torch.Tensor] = None,
- **kwargs,
-):
+@dataclass
+class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
"""
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+ Base class for Paligemma outputs, with hidden states and attentions.
Args:
- attention_mask (`torch.Tensor`):
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
- sequence_length (`int`):
- The sequence length being processed.
- target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
- dtype (`torch.dtype`):
- The dtype to use for the 4D attention mask.
- min_dtype (`float`):
- The minimum value representable with the dtype `dtype`.
- cache_position (`torch.Tensor`):
- Indices depicting the position of the input sequence tokens in the sequence.
- batch_size (`torch.Tensor`):
- Batch size.
- is_training (`bool`):
- Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels`
- """
- if attention_mask is not None and attention_mask.dim() == 4:
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
- causal_mask = attention_mask
- else:
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
- )
- # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
- if sequence_length != 1:
- if is_training:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- else:
- causal_mask[:, :sequence_length] = 0.0
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
- causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
- # we are training thus we need to create a full mask on the image + prefix but causal on suffix
- if is_training:
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
- )
- return causal_mask
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class PaliGemmaCausalLMOutputWithPast(ModelOutput):
"""
- Base class for PaliGemmacausal language model (or autoregressive) outputs.
+ Base class for PaliGemma causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
@@ -184,7 +150,7 @@ PALIGEMMA_START_DOCSTRING = r"""
)
class PaliGemmaPreTrainedModel(PreTrainedModel):
config_class = PaliGemmaConfig
- base_model_prefix = "model"
+ base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["PaliGemmaMultiModalProjector"]
_skip_keys_device_placement = "past_key_values"
@@ -276,49 +242,32 @@ PALIGEMMA_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
- """The PALIGEMMA model which consists of a vision backbone and a language model.""",
+ """Base Paligemma model which consists of a vision backbone and a language model withou language modeling head.""",
PALIGEMMA_START_DOCSTRING,
)
-class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
+class PaliGemmaModel(PaliGemmaPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: PaliGemmaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
- language_model = AutoModelForCausalLM.from_config(config=config.text_config)
-
- if language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
+ language_model = AutoModel.from_config(config=config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings with Llava->PaliGemma
+ # Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings with Llava->PaliGemma
+ # Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings with Llava->PaliGemma
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings with Llava->PaliGemma
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder with Llava->PaliGemma
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder with Llava->PaliGemma
- def get_decoder(self):
- return self.language_model.get_decoder()
-
def _update_causal_mask(
self,
attention_mask,
@@ -326,7 +275,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
past_key_values=None,
cache_position=None,
input_tensor=None,
- is_training: Optional[bool] = None,
+ is_training: bool = None,
):
if self.config.text_config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
@@ -403,12 +352,154 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
image_features = image_features / (self.config.text_config.hidden_size**0.5)
return image_features
+ @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, PaligemmaModelOutputWithPast]:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ is_training = token_type_ids is not None and labels is not None
+
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
+ if input_ids is not None and self.config.image_token_id >= self.vocab_size:
+ special_image_mask = input_ids == self.config.image_token_id
+ llm_input_ids = input_ids.clone()
+ llm_input_ids[special_image_mask] = 0
+ else:
+ llm_input_ids = input_ids
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
+
+ # Merge text and images
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values)
+
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ else:
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
+ raise ValueError(
+ f"Number of images does not match number of special image tokens in the input text. "
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
+ "tokens from image embeddings."
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
+ )
+ outputs = self.language_model(
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ output = PaligemmaModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+@add_start_docstrings(
+ """Base Paligemma model which consists of a vision backbone and a language model withou language modeling head.""",
+ PALIGEMMA_START_DOCSTRING,
+)
+class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: PaliGemmaConfig):
+ super().__init__(config)
+ self.model = PaliGemmaModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
@@ -459,117 +550,46 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Where is the cat standing?\nsnow"
```"""
-
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- is_training = token_type_ids is not None and labels is not None
-
- # Replace image id with PAD if the image token if OOV, to avoid index-errors
- if input_ids is not None and self.config.image_token_id >= self.vocab_size:
- special_image_mask = input_ids == self.config.image_token_id
- llm_input_ids = input_ids.clone()
- llm_input_ids[special_image_mask] = 0
- else:
- llm_input_ids = input_ids
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(llm_input_ids)
-
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
-
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
-
- # Merge text and images
- if pixel_values is not None:
- image_features = self.get_image_features(pixel_values)
-
- if input_ids is None:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- else:
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
-
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
- image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
- raise ValueError(
- f"Number of images does not match number of special image tokens in the input text. "
- f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
- "tokens from image embeddings."
- )
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- # mask out pad-token-ids in labels for BC
- if labels is not None and self.pad_token_id in labels:
- logger.warning_once(
- "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
- "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
- )
- labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
-
- causal_mask = self._update_causal_mask(
- attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
- )
- outputs: CausalLMOutputWithPast = self.language_model(
- attention_mask=causal_mask,
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
+ labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
**lm_kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
loss = None
if labels is not None:
- # Upcast to float if we need to compute the loss to avoid potential precision issues
- logits = logits.float()
- shift_logits = logits[..., :-1, :]
- shift_labels = labels[..., 1:]
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
- shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
- else:
- shift_logits = shift_logits.contiguous()
- shift_labels = shift_labels.contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
- flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
- flat_labels = shift_labels.view(-1).to(shift_logits.device)
- loss = loss_fct(flat_logits, flat_labels)
-
- output = PaliGemmaCausalLMOutputWithPast(
+ return PaliGemmaCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
- return output if return_dict else output.to_tuple()
def prepare_inputs_for_generation(
self,
@@ -587,7 +607,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -610,12 +630,68 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
- causal_mask = self._update_causal_mask(
+ causal_mask = self.model._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs
+ @staticmethod
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel"]
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"]
diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
index 9737e1437e8..b1b875abcd1 100644
--- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
+++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
@@ -2056,7 +2056,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
- " this may lead to unexpected behaviour for Flash Attention version of Qwen25OmniThinkerText. Make sure to "
+ " this may lead to unexpected behaviour for Flash Attention version of Qwen25OmniThinker. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
@@ -2156,7 +2156,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
- config (`Qwen25OmniThinkerTextConfig`):
+ config (`Qwen25OmniThinkerConfig`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
@@ -2772,7 +2772,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
- " this may lead to unexpected behaviour for Flash Attention version of Qwen25OmniTalker. Make sure to "
+ " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5Omni. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
@@ -2872,7 +2872,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
- config (`Qwen25OmniTalkerConfig`):
+ config (`Qwen2_5OmniConfig`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
index 9b5c4167646..5a51e787577 100644
--- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
+++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
@@ -33,8 +33,8 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLAttention,
Qwen2_5_VLMLP,
- Qwen2_5_VLModel,
Qwen2_5_VLPreTrainedModel,
+ Qwen2_5_VLTextModel,
Qwen2_5_VLVisionBlock,
Qwen2RMSNorm,
)
@@ -2165,7 +2165,7 @@ QWEN2_5OMNI_START_DOCSTRING = r"""
"The bare Qwen2.5OmniThinker Model outputting raw hidden-states without any specific head on top.",
QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniTextConfig"),
)
-class Qwen2_5OmniThinkerTextModel(Qwen2_5_VLModel):
+class Qwen2_5OmniThinkerTextModel(Qwen2_5_VLTextModel):
config_class = Qwen2_5OmniTextConfig
_no_split_modules = ["Qwen2_5OmniDecoderLayer"]
@@ -2591,7 +2591,7 @@ class Qwen2_5OmniTalkerCausalLMOutputWithPast(ModelOutput):
"The bare Qwen2.5OmniTalker Model outputting raw hidden-states without any specific head on top.",
QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniTalkerConfig"),
)
-class Qwen2_5OmniTalkerModel(Qwen2_5_VLModel):
+class Qwen2_5OmniTalkerModel(Qwen2_5_VLTextModel):
config_class = Qwen2_5OmniTalkerConfig
_no_split_modules = ["Qwen2_5OmniTalkerDecoderLayer"]
diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
index 6c19b9fdcfb..f87260c69b7 100644
--- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
+++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
@@ -31,7 +31,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
-from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
@@ -44,6 +43,7 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ can_return_tuple,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
@@ -565,6 +565,42 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return hidden_states
+@dataclass
+class Qwen2_5_VLModelOutputWithPast(ModelOutput):
+ """
+ Base class for Llava outputs, with hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ past_key_values: Optional[List[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
class Qwen2_5_VLRotaryEmbedding(nn.Module):
def __init__(self, config: Qwen2_5_VLTextConfig, device=None):
super().__init__()
@@ -1076,7 +1112,7 @@ class Qwen2_5_VLDecoderLayer(nn.Module):
"The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.",
Qwen2_5_VL_START_DOCSTRING,
)
-class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
+class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
config_class = Qwen2_5_VLTextConfig
def __init__(self, config: Qwen2_5_VLTextConfig):
@@ -1374,45 +1410,6 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
return causal_mask
-@dataclass
-class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput):
- """
- Base class for Qwen2_5_VL causal language model (or autoregressive) outputs.
-
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
-
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
- The rope index difference between sequence length and multimodal rope.
- """
-
- loss: Optional[torch.FloatTensor] = None
- logits: Optional[torch.FloatTensor] = None
- past_key_values: Optional[List[torch.FloatTensor]] = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
- rope_deltas: Optional[torch.LongTensor] = None
-
-
QWEN2_5_VL_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1489,41 +1486,25 @@ QWEN2_5_VL_INPUTS_DOCSTRING = r"""
"""
-class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["lm_head.weight"]
+class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
+ _checkpoint_conversion_mapping = {"^model": "language_model"}
config_class = Qwen2_5_VLConfig
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
def __init__(self, config):
super().__init__(config)
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
-
- text_config = config.get_text_config()
- self.model = Qwen2_5_VLModel._from_config(text_config)
- self.vocab_size = text_config.vocab_size
- self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
+ self.language_model = Qwen2_5_VLTextModel._from_config(config.text_config)
self.rope_deltas = None # cache rope_deltas here
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
- return self.model.embed_tokens
+ return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- def get_output_embeddings(self):
- return self.lm_head
-
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
-
- def set_decoder(self, decoder):
- self.model = decoder
-
- def get_decoder(self):
- return self.model
+ self.language_model.set_input_embeddings(value)
def get_rope_index(
self,
@@ -1708,15 +1689,13 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
return position_ids, mrope_position_deltas
@add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
+ input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
@@ -1728,46 +1707,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
- ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- Returns:
-
- Example:
-
- ```python
- >>> from PIL import Image
- >>> import requests
- >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
-
- >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
- >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
-
- >>> messages = [
- {
- "role": "user",
- "content": [
- {"type": "image"},
- {"type": "text", "text": "What is shown in this image?"},
- ],
- },
- ]
- >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
- >>> image = Image.open(requests.get(url, stream=True).raw)
-
- >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
- >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
- ```"""
-
+ ) -> Union[Tuple, Qwen2_5_VLModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1775,7 +1715,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
- inputs_embeds = self.model.embed_tokens(input_ids)
+ inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
@@ -1846,7 +1786,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
- outputs = self.model(
+ outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
@@ -1855,6 +1795,182 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ )
+
+ output = Qwen2_5_VLModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+@dataclass
+class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for Qwen2_5_VL causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[List[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^visual": "model.visual",
+ r"^model(?!\.(language_model|visual))": "model.language_model",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Qwen2_5_VLModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def visual(self):
+ return self.model.visual
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ second_per_grid_ts: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
+
+ >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": "What is shown in this image?"},
+ ],
+ },
+ ]
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
@@ -1864,18 +1980,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
loss = None
if labels is not None:
- # Upcast to float if we need to compute the loss to avoid potential precision issues
- logits = logits.float()
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels)
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -1887,7 +1992,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- rope_deltas=self.rope_deltas,
+ rope_deltas=outputs.rope_deltas,
)
def prepare_inputs_for_generation(
@@ -2055,5 +2160,60 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
return input_ids, model_kwargs
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"]
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLTextModel"]
diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py
index e34724c7906..2d220aa2a90 100644
--- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py
+++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py
@@ -26,7 +26,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
-from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
@@ -36,6 +35,7 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLCausalLMOutputWithPast,
Qwen2VLForConditionalGeneration,
Qwen2VLModel,
+ Qwen2VLModelOutputWithPast,
Qwen2VLPreTrainedModel,
VisionAttention,
VisionRotaryEmbedding,
@@ -50,7 +50,12 @@ from ...image_utils import ImageInput, VideoInput
from ...modeling_flash_attention_utils import is_flash_attn_available
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput
-from ...utils import logging
+from ...utils import (
+ add_start_docstrings_to_model_forward,
+ can_return_tuple,
+ logging,
+ replace_return_docstrings,
+)
if is_flash_attn_available():
@@ -59,6 +64,8 @@ if is_flash_attn_available():
logger = logging.get_logger(__name__)
+_CONFIG_FOR_DOC = "Qwen2_5_VLConfig"
+
def apply_rotary_pos_emb_flashatt(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
@@ -406,16 +413,12 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return hidden_states
-class Qwen2_5_VLModel(Qwen2VLModel):
- pass
-
-
@dataclass
-class Qwen2_5_VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast):
+class Qwen2_5_VLModelOutputWithPast(Qwen2VLModelOutputWithPast):
pass
-class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
+class Qwen2_5_VLModel(Qwen2VLModel):
config_class = Qwen2_5_VLConfig
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
@@ -607,12 +610,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
+ input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
@@ -624,46 +626,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
- ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- Returns:
-
- Example:
-
- ```python
- >>> from PIL import Image
- >>> import requests
- >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
-
- >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
- >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
-
- >>> messages = [
- {
- "role": "user",
- "content": [
- {"type": "image"},
- {"type": "text", "text": "What is shown in this image?"},
- ],
- },
- ]
- >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
- >>> image = Image.open(requests.get(url, stream=True).raw)
-
- >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
- >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
- ```"""
-
+ ) -> Union[Tuple, Qwen2_5_VLModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -671,7 +634,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
- inputs_embeds = self.model.embed_tokens(input_ids)
+ inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
@@ -742,7 +705,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
- outputs = self.model(
+ outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
@@ -751,6 +714,111 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ )
+
+ output = Qwen2_5_VLModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+@dataclass
+class Qwen2_5_VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast):
+ pass
+
+
+QWEN2_5_VL_INPUTS_DOCSTRING = None
+
+
+class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ second_per_grid_ts: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
+
+ >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": "What is shown in this image?"},
+ ],
+ },
+ ]
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
@@ -760,18 +828,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
loss = None
if labels is not None:
- # Upcast to float if we need to compute the loss to avoid potential precision issues
- logits = logits.float()
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels)
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -783,7 +840,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- rope_deltas=self.rope_deltas,
+ rope_deltas=outputs.rope_deltas,
)
def prepare_inputs_for_generation(
@@ -985,4 +1042,5 @@ __all__ = [
"Qwen2_5_VLModel",
"Qwen2_5_VLPreTrainedModel",
"Qwen2_5_VLProcessor",
+ "Qwen2_5_VLTextModel", # noqa: F822
]
diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
index 4496ef73b8c..03700451ce8 100644
--- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
+++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
@@ -799,27 +799,21 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
raise ValueError(f"{padding_side} is not `left` or `right`.")
self._padding_side = padding_side
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
def get_decoder(self):
return self.language_model.get_decoder()
diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
index ce008fdaf9f..2e46c8927a7 100644
--- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
+++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
@@ -27,7 +27,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
-from torch.nn import CrossEntropyLoss, LayerNorm
+from torch.nn import LayerNorm
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
@@ -40,7 +40,9 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ can_return_tuple,
is_torch_flex_attn_available,
+ is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@@ -61,6 +63,42 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Qwen2VLConfig"
+@dataclass
+class Qwen2VLModelOutputWithPast(ModelOutput):
+ """
+ Base class for Llava outputs, with hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ past_key_values: Optional[List[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
@dataclass
class Qwen2VLCausalLMOutputWithPast(ModelOutput):
"""
@@ -1027,7 +1065,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
"The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.",
QWEN2VL_START_DOCSTRING,
)
-class Qwen2VLModel(Qwen2VLPreTrainedModel):
+class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
config_class = Qwen2VLTextConfig
def __init__(self, config: Qwen2VLTextConfig):
@@ -1403,39 +1441,23 @@ QWEN2_VL_INPUTS_DOCSTRING = r"""
"""
-class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["lm_head.weight"]
+class Qwen2VLModel(Qwen2VLPreTrainedModel):
+ _checkpoint_conversion_mapping = {"^model": "language_model"}
- def __init__(self, config):
+ def __init__(self, config: Qwen2VLConfig):
super().__init__(config)
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
-
- text_config = config.get_text_config()
- self.model = Qwen2VLModel._from_config(text_config)
- self.vocab_size = text_config.vocab_size
- self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
+ self.language_model = Qwen2VLTextModel._from_config(config.text_config)
self.rope_deltas = None # cache rope_deltas here
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
- return self.model.embed_tokens
+ return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- def get_output_embeddings(self):
- return self.lm_head
-
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
-
- def set_decoder(self, decoder):
- self.model = decoder
-
- def get_decoder(self):
- return self.model
+ self.language_model.set_input_embeddings(value)
def get_rope_index(
self,
@@ -1586,11 +1608,166 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
return position_ids, mrope_position_deltas
+ @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, Qwen2VLModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ if pixel_values is not None:
+ pixel_values = pixel_values.type(self.visual.get_dtype())
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+ n_image_features = image_embeds.shape[0]
+ if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ image_mask = (
+ (input_ids == self.config.image_token_id)
+ .unsqueeze(-1)
+ .expand_as(inputs_embeds)
+ .to(inputs_embeds.device)
+ )
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
+ n_video_tokens = (input_ids == self.config.video_token_id).sum()
+ n_video_features = video_embeds.shape[0]
+ if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
+ raise ValueError(
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
+ )
+ video_mask = (
+ (input_ids == self.config.video_token_id)
+ .unsqueeze(-1)
+ .expand_as(inputs_embeds)
+ .to(inputs_embeds.device)
+ )
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(inputs_embeds.device)
+
+ # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
+ if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
+ # calculate RoPE index once per generation in the pre-fill stage only
+ if (
+ (cache_position is not None and cache_position[0] == 0)
+ or self.rope_deltas is None
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
+ ):
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids, image_grid_thw, video_grid_thw, attention_mask
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ if cache_position is not None: # otherwise `deltas` is an int `0`
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ delta = delta.to(position_ids.device)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ )
+
+ output = Qwen2VLModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^visual": "model.visual",
+ r"^model(?!\.(language_model|visual))": "model.language_model",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Qwen2VLModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def visual(self):
+ return self.model.visual
+
+ @can_return_tuple
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
+ input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@@ -1652,70 +1829,12 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if inputs_embeds is None:
- inputs_embeds = self.model.embed_tokens(input_ids)
- if pixel_values is not None:
- pixel_values = pixel_values.type(self.visual.get_dtype())
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
- n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
- n_image_features = image_embeds.shape[0]
- if n_image_tokens != n_image_features:
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
- image_mask = (
- (input_ids == self.config.image_token_id)
- .unsqueeze(-1)
- .expand_as(inputs_embeds)
- .to(inputs_embeds.device)
- )
- image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
-
- if pixel_values_videos is not None:
- pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
- n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
- n_video_features = video_embeds.shape[0]
- if n_video_tokens != n_video_features:
- raise ValueError(
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
- )
- video_mask = (
- (input_ids == self.config.video_token_id)
- .unsqueeze(-1)
- .expand_as(inputs_embeds)
- .to(inputs_embeds.device)
- )
- video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
-
- # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
- if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
- # calculate RoPE index once per generation in the pre-fill stage only
- if (
- (cache_position is not None and cache_position[0] == 0)
- or self.rope_deltas is None
- or (past_key_values is None or past_key_values.get_seq_length() == 0)
- ):
- position_ids, rope_deltas = self.get_rope_index(
- input_ids, image_grid_thw, video_grid_thw, attention_mask
- )
- self.rope_deltas = rope_deltas
- # then use the prev pre-calculated rope-deltas to get the correct position ids
- else:
- batch_size, seq_length, _ = inputs_embeds.shape
- delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
- position_ids = torch.arange(seq_length, device=inputs_embeds.device)
- position_ids = position_ids.view(1, -1).expand(batch_size, -1)
- if cache_position is not None: # otherwise `deltas` is an int `0`
- delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
- delta = delta.to(position_ids.device)
- position_ids = position_ids.add(delta)
- position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
-
outputs = self.model(
- input_ids=None,
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
@@ -1732,22 +1851,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
- # Upcast to float if we need to compute the loss to avoid potential precision issues
- logits = logits.float()
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels)
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
return Qwen2VLCausalLMOutputWithPast(
loss=loss,
@@ -1755,7 +1859,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- rope_deltas=self.rope_deltas,
+ rope_deltas=outputs.rope_deltas,
)
def prepare_inputs_for_generation(
@@ -1921,5 +2025,61 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
return input_ids, model_kwargs
+ @staticmethod
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel"]
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel", "Qwen2VLTextModel"]
diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py
index 70ebef344c3..d40dd4db887 100644
--- a/src/transformers/models/video_llava/modeling_video_llava.py
+++ b/src/transformers/models/video_llava/modeling_video_llava.py
@@ -28,11 +28,12 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_video_llava import VideoLlavaConfig
@@ -41,6 +42,47 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VideoLlavaConfig"
+@dataclass
+class VideoLlavaModelOutputWithPast(ModelOutput):
+ """
+ Base class for VideoLlava base model outputs.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ video_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
+ video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ past_key_values: Optional[List[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+ video_hidden_states: Optional[torch.FloatTensor] = None
+
+
@dataclass
class VideoLlavaCausalLMOutputWithPast(ModelOutput):
"""
@@ -130,7 +172,7 @@ VIDEO_LLAVA_START_DOCSTRING = r"""
)
class VideoLlavaPreTrainedModel(PreTrainedModel):
config_class = VideoLlavaConfig
- base_model_prefix = "model"
+ base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["VideoLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
@@ -242,10 +284,12 @@ VIDEO_LLAVA_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
- """The VideoLlava model which consists of a vision backbone and a language model.""",
+ """The VideoLlava model which consists of a vision backbone and a language model without language modeling head.""",
VIDEO_LLAVA_START_DOCSTRING,
)
-class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMixin):
+class VideoLlavaModel(VideoLlavaPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: VideoLlavaConfig):
super().__init__(config)
self.video_tower = AutoModel.from_config(config.vision_config)
@@ -253,10 +297,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
self.multi_modal_projector = VideoLlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
+ self.language_model = AutoModel.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
@@ -266,18 +307,6 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
def get_image_features(
self,
pixel_values_images: torch.FloatTensor,
@@ -358,13 +387,165 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
return video_features, num_frames
+ @add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values_images: torch.FloatTensor = None,
+ pixel_values_videos: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, VideoLlavaModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same "
+ "time, and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values_images is not None:
+ image_features = self.get_image_features(
+ pixel_values_images,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ if pixel_values_videos is not None:
+ video_features, num_frames = self.get_video_features(
+ pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer
+ )
+
+ special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
+ n_video_tokens = (input_ids == self.config.video_token_id).sum()
+ n_video_features = video_features.shape[0] * video_features.shape[1]
+ raise ValueError(
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
+ )
+ video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ output = VideoLlavaModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values_images is not None else None,
+ video_hidden_states=video_features if pixel_values_videos is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+@add_start_docstrings(
+ """The VideoLlava model which consists of a vision backbone and a language model.""",
+ VIDEO_LLAVA_START_DOCSTRING,
+)
+class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^image_tower": "model.image_tower",
+ "^video_tower": "model.video_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: VideoLlavaConfig):
+ super().__init__(config)
+ self.model = VideoLlavaModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def video_tower(self):
+ return self.model.video_tower
+
+ @property
+ def image_tower(self):
+ return self.model.image_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
@add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values_images: Optional[torch.FloatTensor] = None,
- pixel_values_videos: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values_images: torch.FloatTensor = None,
+ pixel_values_videos: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@@ -475,88 +656,32 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
else self.config.vision_feature_select_strategy
)
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same "
- "time, and must specify either one"
- )
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if pixel_values_images is not None:
- image_features = self.get_image_features(
- pixel_values_images,
- vision_feature_layer=vision_feature_layer,
- vision_feature_select_strategy=vision_feature_select_strategy,
- )
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
- n_image_tokens = (input_ids == self.config.image_token_id).sum()
- n_image_features = image_features.shape[0] * image_features.shape[1]
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- if pixel_values_videos is not None:
- video_features, num_frames = self.get_video_features(
- pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer
- )
-
- special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
- n_video_tokens = (input_ids == self.config.video_token_id).sum()
- n_video_features = video_features.shape[0] * video_features.shape[1]
- raise ValueError(
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
- )
- video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
-
- outputs = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values_images=pixel_values_images,
+ pixel_values_videos=pixel_values_videos,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- return_dict=return_dict,
+ return_dict=True,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
**lm_kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
- )
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return VideoLlavaCausalLMOutputWithPast(
loss=loss,
@@ -564,8 +689,8 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values_images is not None else None,
- video_hidden_states=video_features if pixel_values_videos is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
+ video_hidden_states=outputs.video_hidden_states,
)
def prepare_inputs_for_generation(
@@ -582,7 +707,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -600,5 +725,61 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
return model_inputs
+ @staticmethod
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-__all__ = ["VideoLlavaPreTrainedModel", "VideoLlavaForConditionalGeneration"]
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = ["VideoLlavaPreTrainedModel", "VideoLlavaModel", "VideoLlavaForConditionalGeneration"]
diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py
index 9243bbe9e2d..49169593320 100644
--- a/src/transformers/models/vipllava/modeling_vipllava.py
+++ b/src/transformers/models/vipllava/modeling_vipllava.py
@@ -1,3 +1,9 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/vipllava/modular_vipllava.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_vipllava.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2023 the HuggingFace Inc. team. All rights reserved.
#
@@ -12,37 +18,65 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""PyTorch VipLlava model."""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
-import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
-from ...modeling_outputs import ModelOutput
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ can_return_tuple,
is_torchdynamo_compiling,
- logging,
replace_return_docstrings,
)
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_vipllava import VipLlavaConfig
-logger = logging.get_logger(__name__)
-
_CONFIG_FOR_DOC = "VipLlavaConfig"
@dataclass
-# Copied from transformers.models.llava.modeling_llava.LlavaCausalLMOutputWithPast with Llava->VipLlava
+class VipLlavaModelOutputWithPast(BaseModelOutputWithPast):
+ """
+ Base class for VipLlava outputs, with hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
class VipLlavaCausalLMOutputWithPast(ModelOutput):
"""
Base class for VipLlava causal language model (or autoregressive) outputs.
@@ -70,7 +104,7 @@ class VipLlavaCausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@@ -129,9 +163,8 @@ VIPLLAVA_START_DOCSTRING = r"""
)
class VipLlavaPreTrainedModel(PreTrainedModel):
config_class = VipLlavaConfig
- base_model_prefix = "model"
+ base_model_prefix = ""
supports_gradient_checkpointing = True
- _no_split_modules = ["VipLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@@ -140,6 +173,9 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
def _init_weights(self, module):
+ # important: this ported version of VipLlava isn't meant for training from scratch - only
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
+ # https://github.com/haotian-liu/VipLlava/tree/main/vipllava should serve for that purpose
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
@@ -147,8 +183,8 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
VIPLLAVA_INPUTS_DOCSTRING = r"""
@@ -163,7 +199,7 @@ VIPLLAVA_INPUTS_DOCSTRING = r"""
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
- [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`VipLlavaProcessor`] uses
[`CLIPImageProcessor`] for processing images).
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
@@ -203,6 +239,13 @@ VIPLLAVA_INPUTS_DOCSTRING = r"""
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
+ vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
@@ -222,24 +265,18 @@ VIPLLAVA_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
- """The VIPLLAVA model which consists of a vision backbone and a language model.""",
+ """The VipLlava model which consists of a vision backbone and a language model, without a language modeling head.""",
VIPLLAVA_START_DOCSTRING,
)
-# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration with LLAVA->VIPLLAVA,Llava->VipLlava
-class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin):
+class VipLlavaModel(VipLlavaPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: VipLlavaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = VipLlavaMultiModalProjector(config)
- self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
-
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
-
+ self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@@ -248,19 +285,6 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
- # Ignore copy
def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: Union[int, List[int]]):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
@@ -288,12 +312,132 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
return image_features
@add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- # Ignore copy
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layers: Optional[Union[int, List[int]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, VipLlavaModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layers = (
+ vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
+ )
+
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ output = VipLlavaModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+@add_start_docstrings(
+ """The VIPLLAVA model which consists of a vision backbone and a language model.""",
+ VIPLLAVA_START_DOCSTRING,
+)
+class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: VipLlavaConfig):
+ super().__init__(config)
+ self.model = VipLlavaModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@@ -358,68 +502,30 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
)
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if pixel_values is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
- )
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if pixel_values is not None:
- image_features = self.get_image_features(
- pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
- )
-
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
- n_image_tokens = (input_ids == self.config.image_token_id).sum()
- n_image_features = image_features.shape[0] * image_features.shape[1]
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- outputs = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
+ vision_feature_layers=vision_feature_layers,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- return_dict=return_dict,
+ return_dict=True,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
**lm_kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
- )
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return VipLlavaCausalLMOutputWithPast(
loss=loss,
@@ -427,7 +533,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@@ -443,7 +549,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -460,5 +566,60 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
return model_inputs
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-__all__ = ["VipLlavaForConditionalGeneration", "VipLlavaPreTrainedModel"]
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = ["VipLlavaModel", "VipLlavaForConditionalGeneration", "VipLlavaPreTrainedModel"]
diff --git a/src/transformers/models/vipllava/modular_vipllava.py b/src/transformers/models/vipllava/modular_vipllava.py
new file mode 100644
index 00000000000..93fdecffec9
--- /dev/null
+++ b/src/transformers/models/vipllava/modular_vipllava.py
@@ -0,0 +1,294 @@
+# coding=utf-8
+# Copyright 2023 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from transformers.models.llava.modeling_llava import (
+ LlavaCausalLMOutputWithPast,
+ LlavaForConditionalGeneration,
+ LlavaModel,
+ LlavaModelOutputWithPast,
+ LlavaPreTrainedModel,
+)
+
+from ...activations import ACT2FN
+from ...utils import (
+ add_start_docstrings_to_model_forward,
+ can_return_tuple,
+ is_torchdynamo_compiling,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_vipllava import VipLlavaConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+VIPLLAVA_INPUTS_DOCSTRING = None
+_CONFIG_FOR_DOC = "VipLlavaConfig"
+
+
+class VipLlavaModelOutputWithPast(LlavaModelOutputWithPast):
+ pass
+
+
+class VipLlavaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
+ pass
+
+
+class VipLlavaMultiModalProjector(nn.Module):
+ def __init__(self, config: VipLlavaConfig):
+ super().__init__()
+ num_feature_layers = 1 if isinstance(config.vision_feature_layers, int) else len(config.vision_feature_layers)
+ self.projector_layernorm = nn.LayerNorm(
+ num_feature_layers * config.vision_config.hidden_size, eps=config.projector_layernorm_eps
+ )
+
+ self.linear_1 = nn.Linear(
+ num_feature_layers * config.vision_config.hidden_size,
+ config.text_config.hidden_size,
+ bias=True,
+ )
+ self.act = ACT2FN[config.projector_hidden_act]
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
+
+ def forward(self, hidden_states):
+ hidden_states = self.projector_layernorm(hidden_states)
+ hidden_states = self.linear_1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class VipLlavaPreTrainedModel(LlavaPreTrainedModel):
+ pass
+
+
+class VipLlavaModel(LlavaModel):
+ def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: Union[int, List[int]]):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
+ The tensors corresponding to the input images.
+ vision_feature_layers (`Union[int, List[int]]`):
+ The vision feature layer, or the list of indexes of the layers to select
+ the vision feature.
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
+
+ # If multiple feature layers are provided (which is usually the case)
+ # then the image features are concatenated after the CLS is removed.
+ if isinstance(vision_feature_layers, int):
+ image_features = image_outputs.hidden_states[vision_feature_layers][:, 1:]
+ else:
+ # Usually, we select the features from index 1: the layers -2, -5, -8, -11 and 6
+ image_features = [image_outputs.hidden_states[index][:, 1:] for index in vision_feature_layers]
+ image_features = torch.cat(image_features, dim=-1)
+ image_features = self.multi_modal_projector(image_features)
+ return image_features
+
+ @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layers: Optional[Union[int, List[int]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, VipLlavaModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layers = (
+ vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
+ )
+
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_id).sum()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ output = VipLlavaModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+class VipLlavaForConditionalGeneration(LlavaForConditionalGeneration):
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ # Ignore copy
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layers: Optional[Union[int, List[int]]] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **lm_kwargs,
+ ) -> Union[Tuple, VipLlavaCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, VipLlavaForConditionalGeneration
+
+ >>> model = VipLlavaForConditionalGeneration.from_pretrained("llava-hf/vip-llava-7b-hf", device_map="auto", torch_dtype=torch.float16)
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf")
+
+ >>> prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: \n{}###Assistant:"
+ >>> question = "Can you please describe this image?"
+ >>> prompt = prompt.format(question)
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(text=text, images=image, return_tensors="pt").to(0, torch.float16)
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_new_tokens=20)
+ >>> processor.decode(generate_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
+ The image features a brown and white cat sitting on a green surface, with a red ball in its
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layers = (
+ vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ vision_feature_layers=vision_feature_layers,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return VipLlavaCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+
+__all__ = ["VipLlavaModel", "VipLlavaForConditionalGeneration", "VipLlavaPreTrainedModel"]
diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py
index 2f802d9b70e..3baa74a8bdc 100644
--- a/tests/models/aria/test_modeling_aria.py
+++ b/tests/models/aria/test_modeling_aria.py
@@ -21,6 +21,7 @@ import requests
from transformers import (
AriaConfig,
AriaForConditionalGeneration,
+ AriaModel,
AriaTextConfig,
AutoProcessor,
AutoTokenizer,
@@ -175,7 +176,7 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
Model tester for `AriaForConditionalGeneration`.
"""
- all_model_classes = (AriaForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (AriaModel, AriaForConditionalGeneration) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
_is_composite = True
@@ -281,6 +282,18 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
+ @unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading")
+ def test_cpu_offload(self):
+ pass
+
+ @unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading")
+ def test_disk_offload_bin(self):
+ pass
+
+ @unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading")
+ def test_disk_offload_safetensors(self):
+ pass
+
@require_torch
class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
diff --git a/tests/models/aya_vision/test_modeling_aya_vision.py b/tests/models/aya_vision/test_modeling_aya_vision.py
index c35058abd62..f4a464ed621 100644
--- a/tests/models/aya_vision/test_modeling_aya_vision.py
+++ b/tests/models/aya_vision/test_modeling_aya_vision.py
@@ -46,6 +46,7 @@ if is_torch_available():
from transformers import (
AyaVisionForConditionalGeneration,
+ AyaVisionModel,
)
@@ -158,7 +159,14 @@ class AyaVisionVisionText2TextModelTester:
@require_torch
class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
- all_model_classes = (AyaVisionForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ AyaVisionModel,
+ AyaVisionForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
all_generative_model_classes = (AyaVisionForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
diff --git a/tests/models/emu3/test_modeling_emu3.py b/tests/models/emu3/test_modeling_emu3.py
index b27b1d4c708..b6d93ef73f2 100644
--- a/tests/models/emu3/test_modeling_emu3.py
+++ b/tests/models/emu3/test_modeling_emu3.py
@@ -46,6 +46,7 @@ if is_torch_available():
from transformers import (
Emu3ForCausalLM,
Emu3ForConditionalGeneration,
+ Emu3Model,
Emu3Processor,
Emu3TextModel,
)
@@ -310,7 +311,14 @@ class Emu3Vision2TextModelTester:
@require_torch
class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
- all_model_classes = (Emu3ForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ Emu3Model,
+ Emu3ForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
pipeline_model_mapping = {}
test_headmasking = False
test_pruning = False
@@ -395,6 +403,10 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
def test_generate_with_static_cache(self):
pass
+ @unittest.skip("Emu3 doesn't support Flex attn yet!")
+ def test_flex_attention_with_grads(self):
+ pass
+
@require_torch
class Emu3IntegrationTest(unittest.TestCase):
diff --git a/tests/models/fuyu/test_modeling_fuyu.py b/tests/models/fuyu/test_modeling_fuyu.py
index 06f0171e46a..1be8d9fcca8 100644
--- a/tests/models/fuyu/test_modeling_fuyu.py
+++ b/tests/models/fuyu/test_modeling_fuyu.py
@@ -38,7 +38,7 @@ if is_torch_available() and is_vision_available():
if is_torch_available():
- from transformers import FuyuForCausalLM
+ from transformers import FuyuForCausalLM, FuyuModel
class FuyuModelTester:
@@ -145,7 +145,14 @@ class FuyuModelTester:
@require_torch
class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
- all_model_classes = (FuyuForCausalLM,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ FuyuModel,
+ FuyuForCausalLM,
+ )
+ if is_torch_available()
+ else ()
+ )
pipeline_model_mapping = (
{"text-generation": FuyuForCausalLM, "image-text-to-text": FuyuForCausalLM} if is_torch_available() else {}
)
diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py
index 9638ff4a87f..efd6f4095cd 100644
--- a/tests/models/gemma3/test_modeling_gemma3.py
+++ b/tests/models/gemma3/test_modeling_gemma3.py
@@ -50,6 +50,7 @@ if is_torch_available():
from transformers import (
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
+ Gemma3Model,
Gemma3Processor,
Gemma3TextModel,
)
@@ -148,9 +149,9 @@ class Gemma3Vision2TextModelTester:
self,
parent,
mm_tokens_per_image=2,
- image_token_index=1,
- boi_token_index=2,
- eoi_token_index=3,
+ image_token_index=4,
+ boi_token_index=5,
+ eoi_token_index=6,
seq_length=25,
is_training=True,
vision_config={
@@ -242,7 +243,14 @@ class Gemma3Vision2TextModelTester:
@require_torch
class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
- all_model_classes = (Gemma3ForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ Gemma3Model,
+ Gemma3ForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
all_generative_model_classes = (Gemma3ForConditionalGeneration,) if is_torch_available() else ()
test_headmasking = False
test_pruning = False
diff --git a/tests/models/got_ocr2/test_modeling_got_ocr2.py b/tests/models/got_ocr2/test_modeling_got_ocr2.py
index ed0a25f7b19..e4706b6db0c 100644
--- a/tests/models/got_ocr2/test_modeling_got_ocr2.py
+++ b/tests/models/got_ocr2/test_modeling_got_ocr2.py
@@ -34,6 +34,7 @@ if is_torch_available():
from transformers import (
GotOcr2ForConditionalGeneration,
+ GotOcr2Model,
)
@@ -140,7 +141,14 @@ class GotOcr2VisionText2TextModelTester:
@require_torch
class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
- all_model_classes = (GotOcr2ForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ GotOcr2Model,
+ GotOcr2ForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
pipeline_model_mapping = (
{
"image-to-text": GotOcr2ForConditionalGeneration,
@@ -228,6 +236,10 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_past_key_values_format(self):
pass
+ @unittest.skip(reason="Vision backbone doesn't support FLEX yet!")
+ def test_flex_attention_with_grads(self):
+ pass
+
@require_torch
class GotOcr2IntegrationTest(unittest.TestCase):
diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py
index f47330ec673..b154a09c2bc 100644
--- a/tests/models/instructblip/test_modeling_instructblip.py
+++ b/tests/models/instructblip/test_modeling_instructblip.py
@@ -54,7 +54,7 @@ if is_torch_available():
import torch
from torch import nn
- from transformers import InstructBlipForConditionalGeneration, InstructBlipVisionModel
+ from transformers import InstructBlipForConditionalGeneration, InstructBlipModel, InstructBlipVisionModel
if is_vision_available():
@@ -460,14 +460,20 @@ class InstructBlipForConditionalGenerationDecoderOnlyModelTester:
"attention_mask": attention_mask,
"qformer_input_ids": qformer_input_ids,
"qformer_attention_mask": qformer_attention_mask,
- "labels": input_ids,
}
return config, inputs_dict
@require_torch
class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
- all_model_classes = (InstructBlipForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ InstructBlipModel,
+ InstructBlipForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
pipeline_model_mapping = {"image-text-to-text": InstructBlipForConditionalGeneration}
fx_compatible = False
test_head_masking = False
diff --git a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py
index a4b7f25e347..9bd617b4666 100644
--- a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py
+++ b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py
@@ -54,7 +54,11 @@ if is_torch_available():
import torch
from torch import nn
- from transformers import InstructBlipVideoForConditionalGeneration, InstructBlipVideoVisionModel
+ from transformers import (
+ InstructBlipVideoForConditionalGeneration,
+ InstructBlipVideoModel,
+ InstructBlipVideoVisionModel,
+ )
class InstructBlipVideoVisionModelTester:
@@ -477,7 +481,6 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyModelTester:
"attention_mask": attention_mask,
"qformer_input_ids": qformer_input_ids,
"qformer_attention_mask": qformer_attention_mask,
- "labels": input_ids,
}
return config, inputs_dict
@@ -486,7 +489,9 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyModelTester:
class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
):
- all_model_classes = (InstructBlipVideoForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel) if is_torch_available() else ()
+ )
fx_compatible = False
test_head_masking = False
test_pruning = False
diff --git a/tests/models/internvl/test_modeling_internvl.py b/tests/models/internvl/test_modeling_internvl.py
index e51126f14ea..bb85f0cba41 100644
--- a/tests/models/internvl/test_modeling_internvl.py
+++ b/tests/models/internvl/test_modeling_internvl.py
@@ -47,9 +47,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
- from transformers import (
- InternVLForConditionalGeneration,
- )
+ from transformers import InternVLForConditionalGeneration, InternVLModel
if is_vision_available():
@@ -191,7 +189,7 @@ class InternVLVisionText2TextModelTester:
@require_torch
class InternVLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
- all_model_classes = (InternVLForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (InternVLForConditionalGeneration, InternVLModel) if is_torch_available() else ()
all_generative_model_classes = (InternVLForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py
index a692566340c..5d60f6248b1 100644
--- a/tests/models/llava/test_modeling_llava.py
+++ b/tests/models/llava/test_modeling_llava.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch Llava model."""
+import copy
import unittest
import requests
@@ -23,6 +24,7 @@ from transformers import (
AutoTokenizer,
LlavaConfig,
LlavaForConditionalGeneration,
+ LlavaModel,
is_torch_available,
is_vision_available,
)
@@ -166,7 +168,14 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
Model tester for `LlavaForConditionalGeneration`.
"""
- all_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ LlavaModel,
+ LlavaForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
pipeline_model_mapping = (
{"image-to-text": LlavaForConditionalGeneration, "image-text-to-text": LlavaForConditionalGeneration}
if is_torch_available()
@@ -238,16 +247,17 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
- _ = model(**input_dict) # successful forward with no modifications
+ curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further
+ _ = model(**curr_input_dict) # successful forward with no modifications
# remove one image but leave the image token in text
- input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
+ curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
- _ = model(**input_dict)
+ _ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
- input_ids = input_dict["input_ids"][:1]
- pixel_values = input_dict["pixel_values"][:1]
+ input_ids = curr_input_dict["input_ids"][:1]
+ pixel_values = curr_input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
@@ -281,7 +291,8 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
model = model_class(config).to(torch_device)
# We should have the right number of input features,
# and should be able to run a forward pass without exploding
- assert model.multi_modal_projector.linear_1.in_features == expected_features
+ base_model = getattr(model, "model", model)
+ assert base_model.multi_modal_projector.linear_1.in_features == expected_features
model(**input_dict)
@unittest.skip(
diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py
index 95f5eaa083c..c8789f0ba38 100644
--- a/tests/models/llava_next/test_modeling_llava_next.py
+++ b/tests/models/llava_next/test_modeling_llava_next.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch Llava-NeXT model."""
+import copy
import unittest
import requests
@@ -23,6 +24,7 @@ from transformers import (
AutoProcessor,
LlavaNextConfig,
LlavaNextForConditionalGeneration,
+ LlavaNextModel,
is_torch_available,
is_vision_available,
)
@@ -181,7 +183,14 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
Model tester for `LlavaNextForConditionalGeneration`.
"""
- all_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ LlavaNextModel,
+ LlavaNextForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
pipeline_model_mapping = {"image-text-to-text": LlavaNextForConditionalGeneration} if is_torch_available() else {}
test_pruning = False
test_head_masking = False
@@ -265,18 +274,19 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
- _ = model(**input_dict) # successful forward with no modifications
+ curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further
+ _ = model(**curr_input_dict) # successful forward with no modifications
# remove one image but leave the image token in text
- input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
- input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
+ curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
+ curr_input_dict["image_sizes"] = curr_input_dict["image_sizes"][-1:, ...]
with self.assertRaises(ValueError):
- _ = model(**input_dict)
+ _ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
- input_ids = input_dict["input_ids"][:1]
- pixel_values = input_dict["pixel_values"][:1]
- image_sizes = input_dict["image_sizes"][:1]
+ input_ids = curr_input_dict["input_ids"][:1]
+ pixel_values = curr_input_dict["pixel_values"][:1]
+ image_sizes = curr_input_dict["image_sizes"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
@@ -324,7 +334,8 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
model = model_class(config).to(torch_device)
# We should have the right number of input features,
# and should be able to run a forward pass without exploding
- assert model.multi_modal_projector.linear_1.in_features == expected_features
+ base_model = getattr(model, "model", model)
+ assert base_model.multi_modal_projector.linear_1.in_features == expected_features
model(**input_dict)
@unittest.skip(
diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py
index 628611d689a..e68a1e4362e 100644
--- a/tests/models/llava_next_video/test_modeling_llava_next_video.py
+++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch Llava-NeXT-Video model."""
+import copy
import unittest
import numpy as np
@@ -23,6 +24,7 @@ from transformers import (
AutoProcessor,
LlavaNextVideoConfig,
LlavaNextVideoForConditionalGeneration,
+ LlavaNextVideoModel,
is_torch_available,
is_vision_available,
)
@@ -196,7 +198,14 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
Model tester for `LlavaNextVideoForConditionalGeneration`.
"""
- all_model_classes = (LlavaNextVideoForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ LlavaNextVideoModel,
+ LlavaNextVideoForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
test_pruning = False
test_head_masking = False
_is_composite = True
@@ -281,18 +290,19 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
- _ = model(**input_dict) # successful forward with no modifications
+ curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further
+ _ = model(**curr_input_dict) # successful forward with no modifications
# remove one image but leave the image token in text
- input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
- input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
+ curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
+ curr_input_dict["image_sizes"] = curr_input_dict["image_sizes"][-1:, ...]
with self.assertRaises(ValueError):
- _ = model(**input_dict)
+ _ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
- input_ids = input_dict["input_ids"][:1]
- pixel_values = input_dict["pixel_values"][:1]
- image_sizes = input_dict["image_sizes"][:1]
+ input_ids = curr_input_dict["input_ids"][:1]
+ pixel_values = curr_input_dict["pixel_values"][:1]
+ image_sizes = curr_input_dict["image_sizes"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
@@ -340,7 +350,8 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
model = model_class(config).to(torch_device)
# We should have the right number of input features,
# and should be able to run a forward pass without exploding
- assert model.multi_modal_projector.linear_1.in_features == expected_features
+ base_model = getattr(model, "model", model)
+ assert base_model.multi_modal_projector.linear_1.in_features == expected_features
model(**input_dict)
@unittest.skip(
diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py
index 4ec719d3c54..ba95c330dbd 100644
--- a/tests/models/llava_onevision/test_modeling_llava_onevision.py
+++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py
@@ -24,6 +24,7 @@ from transformers import (
AutoProcessor,
LlavaOnevisionConfig,
LlavaOnevisionForConditionalGeneration,
+ LlavaOnevisionModel,
is_torch_available,
is_vision_available,
)
@@ -182,7 +183,14 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
Model tester for `LlavaOnevisionForConditionalGeneration`.
"""
- all_model_classes = (LlavaOnevisionForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ LlavaOnevisionModel,
+ LlavaOnevisionForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
pipeline_model_mapping = (
{"image-text-to-text": LlavaOnevisionForConditionalGeneration} if is_torch_available() else {}
)
@@ -296,7 +304,8 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
model = model_class(config).to(torch_device)
# We should have the right number of input features,
# and should be able to run a forward pass without exploding
- assert model.multi_modal_projector.linear_1.in_features == expected_features
+ base_model = getattr(model, "model", model)
+ assert base_model.multi_modal_projector.linear_1.in_features == expected_features
model(**input_dict)
@unittest.skip(
diff --git a/tests/models/mistral3/test_modeling_mistral3.py b/tests/models/mistral3/test_modeling_mistral3.py
index 2555126dbaa..7c8f60fdc0a 100644
--- a/tests/models/mistral3/test_modeling_mistral3.py
+++ b/tests/models/mistral3/test_modeling_mistral3.py
@@ -42,6 +42,7 @@ if is_torch_available():
from transformers import (
Mistral3ForConditionalGeneration,
+ Mistral3Model,
)
@@ -162,7 +163,14 @@ class Mistral3VisionText2TextModelTester:
@require_torch
class Mistral3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
- all_model_classes = (Mistral3ForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ Mistral3Model,
+ Mistral3ForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
all_generative_model_classes = (Mistral3ForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
@@ -278,6 +286,10 @@ class Mistral3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
def test_sdpa_can_dispatch_on_flash(self):
pass
+ @unittest.skip("Pixtral does not support attention interfaces.")
+ def test_flex_attention_with_grads(self):
+ pass
+
@slow
@require_torch_gpu
diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py
index 34e4d4e4896..e67e0455e1f 100644
--- a/tests/models/mllama/test_modeling_mllama.py
+++ b/tests/models/mllama/test_modeling_mllama.py
@@ -25,6 +25,7 @@ from transformers import (
MllamaConfig,
MllamaForCausalLM,
MllamaForConditionalGeneration,
+ MllamaModel,
is_torch_available,
is_vision_available,
)
@@ -262,7 +263,14 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
Model tester for `MllamaForConditionalGeneration`.
"""
- all_model_classes = (MllamaForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ MllamaModel,
+ MllamaForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
pipeline_model_mapping = {"image-text-to-text": MllamaForConditionalGeneration} if is_torch_available() else ()
test_pruning = False
test_head_masking = False
@@ -325,19 +333,18 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
# resizing embeddings should result in successful loss computation
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
- for model_class in self.all_model_classes:
- model = model_class(config)
- model_vocab_size = config.get_text_config().vocab_size
- inputs = self._prepare_for_class(inputs, model_class, return_labels=True)
- # Resize embeddings and call forward
- model.resize_token_embeddings(model_vocab_size + 10)
- output = model(
- input_ids=inputs["input_ids"],
- attention_mask=inputs["attention_mask"],
- labels=inputs["labels"],
- return_dict=True,
- )
- self.assertTrue("loss" in output)
+ model = MllamaForConditionalGeneration(config).to(torch_device)
+ model_vocab_size = config.get_text_config().vocab_size
+ inputs = self._prepare_for_class(inputs, MllamaForConditionalGeneration, return_labels=True)
+ # Resize embeddings and call forward
+ model.resize_token_embeddings(model_vocab_size + 10)
+ output = model(
+ input_ids=inputs["input_ids"],
+ attention_mask=inputs["attention_mask"],
+ labels=inputs["labels"],
+ return_dict=True,
+ )
+ self.assertTrue("loss" in output)
def _check_attentions_for_generate(
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
@@ -409,6 +416,18 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
def test_assisted_decoding_with_num_logits_to_keep(self):
pass
+ @unittest.skip(reason="Mllama uses self.weights dirrectly causing device mismatch when offloading`")
+ def test_cpu_offload(self):
+ pass
+
+ @unittest.skip(reason="Mllama uses self.weights dirrectly causing device mismatch when offloading`")
+ def test_disk_offload_bin(self):
+ pass
+
+ @unittest.skip(reason="Mllama uses self.weights dirrectly causing device mismatch when offloading`")
+ def test_disk_offload_safetensors(self):
+ pass
+
@pytest.mark.generate
# overridden because mllama is not an encoder-decoder model, but has encoder-decoder-like cache
def test_past_key_values_format(self):
@@ -501,7 +520,7 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- for model_class in self.all_model_classes:
+ for model_class in self.all_generative_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py
index 4b9db1d75d1..a4d323baa31 100644
--- a/tests/models/paligemma/test_modeling_paligemma.py
+++ b/tests/models/paligemma/test_modeling_paligemma.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch PaliGemma model."""
+import copy
import unittest
import requests
@@ -20,6 +21,7 @@ import requests
from transformers import (
PaliGemmaConfig,
PaliGemmaForConditionalGeneration,
+ PaliGemmaModel,
PaliGemmaProcessor,
is_torch_available,
is_vision_available,
@@ -177,7 +179,14 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
Model tester for `PaliGemmaForConditionalGeneration`.
"""
- all_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ PaliGemmaModel,
+ PaliGemmaForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
pipeline_model_mapping = {"image-text-to-text": PaliGemmaForConditionalGeneration}
fx_compatible = False
test_pruning = False
@@ -242,16 +251,17 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
- _ = model(**input_dict) # successful forward with no modifications
+ curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further
+ _ = model(**curr_input_dict) # successful forward with no modifications
# remove one image but leave the image token in text
- input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
+ curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
- _ = model(**input_dict)
+ _ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
- input_ids = input_dict["input_ids"][:1]
- pixel_values = input_dict["pixel_values"][:1]
+ input_ids = curr_input_dict["input_ids"][:1]
+ pixel_values = curr_input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
diff --git a/tests/models/paligemma2/test_modeling_paligemma2.py b/tests/models/paligemma2/test_modeling_paligemma2.py
index e13430f2b73..bc62c527e29 100644
--- a/tests/models/paligemma2/test_modeling_paligemma2.py
+++ b/tests/models/paligemma2/test_modeling_paligemma2.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch PaliGemma model."""
+import copy
import unittest
import pytest
@@ -239,16 +240,17 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
- _ = model(**input_dict) # successful forward with no modifications
+ curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further
+ _ = model(**curr_input_dict) # successful forward with no modifications
# remove one image but leave the image token in text
- input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
+ curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
- _ = model(**input_dict)
+ _ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
- input_ids = input_dict["input_ids"][:1]
- pixel_values = input_dict["pixel_values"][:1]
+ input_ids = curr_input_dict["input_ids"][:1]
+ pixel_values = curr_input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py
index 21947dca355..5f06e1c84be 100644
--- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py
+++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch Qwen2.5-VL model."""
+import copy
import gc
import tempfile
import unittest
@@ -23,6 +24,7 @@ from transformers import (
AutoProcessor,
Qwen2_5_VLConfig,
Qwen2_5_VLForConditionalGeneration,
+ Qwen2_5_VLModel,
is_torch_available,
is_vision_available,
)
@@ -180,17 +182,11 @@ class Qwen2_5_VLVisionText2TextModelTester:
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
input_ids[:, self.num_image_tokens] = self.image_token_id
input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
- labels = torch.zeros(
- (self.batch_size, self.seq_length),
- dtype=torch.long,
- device=torch_device,
- )
inputs_dict = {
"pixel_values": pixel_values,
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size),
"input_ids": input_ids,
"attention_mask": attention_mask,
- "labels": labels,
}
return config, inputs_dict
@@ -201,7 +197,14 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
Model tester for `Qwen2_5_VLForConditionalGeneration`.
"""
- all_model_classes = (Qwen2_5_VLForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ Qwen2_5_VLModel,
+ Qwen2_5_VLForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
test_pruning = False
test_head_masking = False
@@ -236,19 +239,20 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successful forward with no modifications
+ curr_input_dict = copy.deepcopy(input_dict)
# remove one image but leave the image token in text
patch_size = config.vision_config.patch_size
one_img_length = (self.model_tester.image_size**2) // (patch_size**2)
- input_dict["pixel_values"] = input_dict["pixel_values"][-one_img_length:, ...]
- input_dict["image_grid_thw"] = input_dict["image_grid_thw"][-1:, ...]
+ curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-one_img_length:, ...]
+ curr_input_dict["image_grid_thw"] = curr_input_dict["image_grid_thw"][-1:, ...]
with self.assertRaises(ValueError):
- _ = model(**input_dict)
+ _ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
- input_ids = input_dict["input_ids"][:1]
- pixel_values = input_dict["pixel_values"][:one_img_length]
- image_grid_thw = input_dict["image_grid_thw"][:1]
+ input_ids = curr_input_dict["input_ids"][:1]
+ pixel_values = curr_input_dict["pixel_values"][:one_img_length]
+ image_grid_thw = curr_input_dict["image_grid_thw"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
@@ -375,6 +379,29 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
def test_save_load_fast_init_from_base(self):
pass
+ # The multimodal base model embeds will not match ids, due to pixel values. We can't change base test
+ # because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels`
+ # TODO: @raushan
+ def test_inputs_embeds_matches_input_ids(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ inputs = self._prepare_for_class(inputs_dict, model_class)
+ input_ids = inputs["input_ids"]
+ del inputs["input_ids"]
+ del inputs["pixel_values"]
+
+ inputs_embeds = model.get_input_embeddings()(input_ids)
+
+ with torch.no_grad():
+ out_ids = model(input_ids=input_ids, **inputs)[0]
+ out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
+ torch.testing.assert_close(out_embeds, out_ids)
+
@require_torch
class Qwen2_5_VLIntegrationTest(unittest.TestCase):
diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
index e213ccd819b..57e112790cc 100644
--- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
+++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch Qwen2-VL model."""
+import copy
import gc
import unittest
@@ -22,6 +23,7 @@ from transformers import (
AutoProcessor,
Qwen2VLConfig,
Qwen2VLForConditionalGeneration,
+ Qwen2VLModel,
is_torch_available,
is_vision_available,
)
@@ -169,17 +171,12 @@ class Qwen2VLVisionText2TextModelTester:
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
input_ids[:, self.num_image_tokens] = self.image_token_id
input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
- labels = torch.zeros(
- (self.batch_size, self.seq_length),
- dtype=torch.long,
- device=torch_device,
- )
+
inputs_dict = {
"pixel_values": pixel_values,
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size),
"input_ids": input_ids,
"attention_mask": attention_mask,
- "labels": labels,
}
return config, inputs_dict
@@ -190,7 +187,14 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
Model tester for `Qwen2VLForConditionalGeneration`.
"""
- all_model_classes = (Qwen2VLForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ Qwen2VLModel,
+ Qwen2VLForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
pipeline_model_mapping = {"image-text-to-text": Qwen2VLForConditionalGeneration}
test_pruning = False
test_head_masking = False
@@ -226,20 +230,21 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
- _ = model(**input_dict) # successful forward with no modifications
+ curr_input_dict = copy.deepcopy(input_dict)
+ _ = model(**curr_input_dict) # successfull forward with no modifications
# remove one image but leave the image token in text
patch_size = config.vision_config.patch_size
one_img_length = (self.model_tester.image_size**2) // (patch_size**2)
- input_dict["pixel_values"] = input_dict["pixel_values"][-one_img_length:, ...]
- input_dict["image_grid_thw"] = input_dict["image_grid_thw"][-1:, ...]
+ curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-one_img_length:, ...]
+ curr_input_dict["image_grid_thw"] = curr_input_dict["image_grid_thw"][-1:, ...]
with self.assertRaises(ValueError):
- _ = model(**input_dict)
+ _ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
- input_ids = input_dict["input_ids"][:1]
- pixel_values = input_dict["pixel_values"][:one_img_length]
- image_grid_thw = input_dict["image_grid_thw"][:1]
+ input_ids = curr_input_dict["input_ids"][:1]
+ pixel_values = curr_input_dict["pixel_values"][:one_img_length]
+ image_grid_thw = curr_input_dict["image_grid_thw"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
@@ -262,11 +267,11 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
model = model_class(config).to(torch_device)
# Generate and make sure rope_deltas are not `None`
- self.assertTrue(model.rope_deltas is None)
+ self.assertTrue(model.model.rope_deltas is None)
generation_output = model.generate(
**input_dict, max_new_tokens=4, return_dict_in_generate=True, output_logits=True
)
- self.assertTrue(model.rope_deltas is not None)
+ self.assertTrue(model.model.rope_deltas is not None)
# Now if we try to do forward pass, we should get new rope logits, because cache is not passed
forward_output = model(**input_dict)
@@ -320,6 +325,29 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
def test_save_load_fast_init_from_base(self):
pass
+ # The multimodal base model embeds will not match ids, due to pixel values. We can't change base test
+ # because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels`
+ # TODO: @raushan
+ def test_inputs_embeds_matches_input_ids(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ inputs = self._prepare_for_class(inputs_dict, model_class)
+ input_ids = inputs["input_ids"]
+ del inputs["input_ids"]
+ del inputs["pixel_values"]
+
+ inputs_embeds = model.get_input_embeddings()(input_ids)
+
+ with torch.no_grad():
+ out_ids = model(input_ids=input_ids, **inputs)[0]
+ out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
+ torch.testing.assert_close(out_embeds, out_ids)
+
@require_torch
class Qwen2VLIntegrationTest(unittest.TestCase):
diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py
index bda728b3919..92ad5a193bf 100644
--- a/tests/models/video_llava/test_modeling_video_llava.py
+++ b/tests/models/video_llava/test_modeling_video_llava.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch VideoLlava model."""
+import copy
import unittest
import numpy as np
@@ -23,6 +24,7 @@ from parameterized import parameterized
from transformers import (
VideoLlavaConfig,
VideoLlavaForConditionalGeneration,
+ VideoLlavaModel,
VideoLlavaProcessor,
is_torch_available,
is_vision_available,
@@ -190,7 +192,14 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
Model tester for `VideoLlavaForConditionalGeneration`.
"""
- all_model_classes = (VideoLlavaForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ VideoLlavaModel,
+ VideoLlavaForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
fx_compatible = False
test_pruning = False
test_resize_embeddings = True
@@ -235,46 +244,49 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
def test_mixed_input(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
+ curr_inputs = copy.deepcopy(inputs)
model = model_class(config).to(torch_device).eval()
# test that the forward does not fail
with torch.no_grad():
- _ = model(**inputs)
+ _ = model(**curr_inputs)
# if we remove some images from inputs leaving only one
# image number mismatch error should raise
- inputs["pixel_values_images"] = inputs["pixel_values_images"][:1]
+ curr_inputs["pixel_values_images"] = curr_inputs["pixel_values_images"][:1]
with self.assertRaises(ValueError):
- _ = model(**inputs)
+ _ = model(**curr_inputs)
def test_video_only_input(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
+ curr_inputs = copy.deepcopy(inputs)
model = model_class(config).to(torch_device).eval()
# replace image token id with dummy id
# Error will be raised as num-image-tokens and num-of-image-embeds mismatch
- inputs["input_ids"][:, : self.model_tester.num_image_tokens] = 2
+ curr_inputs["input_ids"][:, : self.model_tester.num_image_tokens] = 2
with self.assertRaises(ValueError):
- _ = model(**inputs)
+ _ = model(**curr_inputs)
- inputs["pixel_values_images"] = None
- _ = model(**inputs)
+ curr_inputs["pixel_values_images"] = None
+ _ = model(**curr_inputs)
def test_image_only_input(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
+ curr_inputs = copy.deepcopy(inputs)
model = model_class(config).to(torch_device).eval()
# set dummy id, which is not video token id
# Error will be raised as num-video-tokens and num-of-video-embeds mismatch
- inputs["input_ids"][
+ curr_inputs["input_ids"][
:,
self.model_tester.num_image_tokens : self.model_tester.num_image_tokens
+ self.model_tester.num_video_tokens,
] = 2
with self.assertRaises(ValueError):
- _ = model(**inputs)
+ _ = model(**curr_inputs)
- inputs["pixel_values_videos"] = None
- _ = model(**inputs)
+ curr_inputs["pixel_values_videos"] = None
+ _ = model(**curr_inputs)
def test_batching_equivalence(self):
def recursive_check(batched_object, single_row_object, model_name, key):
@@ -386,16 +398,17 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
- _ = model(**input_dict) # successful forward with no modifications
+ curr_input_dict = copy.deepcopy(input_dict)
+ _ = model(**curr_input_dict) # successfull forward with no modifications
# remove one image but leave the image token in text
- input_dict["pixel_values_images"] = input_dict["pixel_values_images"][-1:, ...]
+ curr_input_dict["pixel_values_images"] = curr_input_dict["pixel_values_images"][-1:, ...]
with self.assertRaises(ValueError):
- _ = model(**input_dict)
+ _ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
- input_ids = input_dict["input_ids"][:1]
- pixel_values = input_dict["pixel_values_images"][:1]
+ input_ids = curr_input_dict["input_ids"][:1]
+ pixel_values = curr_input_dict["pixel_values_images"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
@@ -429,7 +442,8 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
model = model_class(config).to(torch_device)
# We should have the right number of input features,
# and should be able to run a forward pass without exploding
- assert model.multi_modal_projector.linear_1.in_features == expected_features
+ base_model = getattr(model, "model", model)
+ assert base_model.multi_modal_projector.linear_1.in_features == expected_features
model(**input_dict)
diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py
index 60460ddfb5b..79b57e12ffa 100644
--- a/tests/models/vipllava/test_modeling_vipllava.py
+++ b/tests/models/vipllava/test_modeling_vipllava.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch VipLlava model."""
+import copy
import unittest
import requests
@@ -22,6 +23,7 @@ from transformers import (
AutoProcessor,
VipLlavaConfig,
VipLlavaForConditionalGeneration,
+ VipLlavaModel,
is_torch_available,
is_vision_available,
)
@@ -165,7 +167,14 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
Model tester for `VipLlavaForConditionalGeneration`.
"""
- all_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (
+ (
+ VipLlavaModel,
+ VipLlavaForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
pipeline_model_mapping = {"image-text-to-text": VipLlavaForConditionalGeneration} if is_torch_available() else {}
fx_compatible = False
test_pruning = False
@@ -236,16 +245,17 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
- _ = model(**input_dict) # successful forward with no modifications
+ curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further
+ _ = model(**curr_input_dict) # successful forward with no modifications
# remove one image but leave the image token in text
- input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
+ curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
- _ = model(**input_dict)
+ _ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
- input_ids = input_dict["input_ids"][:1]
- pixel_values = input_dict["pixel_values"][:1]
+ input_ids = curr_input_dict["input_ids"][:1]
+ pixel_values = curr_input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
@@ -284,7 +294,8 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
model = model_class(config).to(torch_device)
# We should have the right number of input features,
# and should be able to run a forward pass without exploding
- assert model.multi_modal_projector.linear_1.in_features == expected_features
+ base_model = getattr(model, "model", model)
+ assert base_model.multi_modal_projector.linear_1.in_features == expected_features
model(**input_dict)
@unittest.skip(
@@ -311,6 +322,10 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
+ @unittest.skip("LLaVA vision backbones doesn't support flex attention yet")
+ def test_flex_attention_with_grads(self):
+ pass
+
@require_torch
class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index 974bfb7b5a7..b8a4fda96d3 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -3494,8 +3494,8 @@ class ModelTesterMixin:
vision_model_name = [name for name in vision_model_names if hasattr(model_sdpa, name)][0]
language_model_name = [name for name in language_model_names if hasattr(model_sdpa, name)][0]
- vision_model_sdpa = getattr(model, vision_model_name)
- language_model_sdpa = getattr(model, language_model_name)
+ vision_model_sdpa = getattr(model_sdpa, vision_model_name)
+ language_model_sdpa = getattr(model_sdpa, language_model_name)
text_attn = "sdpa" if language_model_sdpa._supports_sdpa else "eager"
vision_attn = "sdpa" if vision_model_sdpa._supports_sdpa else "eager"
@@ -4489,7 +4489,8 @@ class ModelTesterMixin:
@require_torch_gpu
def test_flex_attention_with_grads(self):
for model_class in self.all_model_classes:
- if not model_class._supports_flex_attn:
+ # TODO: raushan, fix for composite models after making VLMs support new attn API
+ if not model_class._supports_flex_attn or self._is_composite:
self.skipTest(reason="This model does not support flex attention")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config._attn_implementation = "flex_attention"
diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py
index dc0dac82da2..cea9ed693d5 100644
--- a/utils/check_config_attributes.py
+++ b/utils/check_config_attributes.py
@@ -378,8 +378,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
"rope_theta",
"partial_rotary_factor",
"pretraining_tp",
- "boi_token_index",
- "eoi_token_index",
+ "boi_token_id",
+ "eoi_token_id",
]
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 960a73e734b..9eb5ccdf06a 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -156,6 +156,8 @@ IGNORE_NON_TESTED = (
"Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests
"Emu3VQVAE", # Building part of bigger (tested) model
"Emu3TextModel", # Building part of bigger (tested) model
+ "Qwen2VLTextModel", # Building part of bigger (tested) model
+ "Qwen2_5_VLTextModel", # Building part of bigger (tested) model
"InternVLVisionModel", # Building part of bigger (tested) model
"JanusVisionModel", # Building part of bigger (tested) model
"TimesFmModel", # Building part of bigger (tested) model