From 17742bd9c8852ab35986dcaa3e68415342ae7eef Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 7 May 2025 17:47:51 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=B4=20[VLM]=20Add=20base=20model=20wit?= =?UTF-8?q?hout=20head=20=20(#37033)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * i guessreverted all CdGen classes * style * llava onevision * fix copies * fix some tests * some more tests * dump * skip these * nevermind, i am dumb * revert fix not needed * fixup * fixup * another fixup * more fixup to make ci finally happy * fixup after rebasing * fix qwen tests * add internVL + typos here and there * image token index -> id * style * fix init weights * revert blip-2 not supported * address comments * fix copies * revert blip2 test file as well * as discussed internally, revert back CdGen models * fix some tests * fix more tests for compile * CI red * fix copies * enumerate explicitly allowed models * address comments * fix tests * fixup * style again * add tests for new model class * another fixup ( x _ x ) * [fixup] unused attributes can be removed post-deprecation --- docs/source/en/model_doc/aria.md | 4 + docs/source/en/model_doc/aya_vision.md | 4 + docs/source/en/model_doc/emu3.md | 4 + docs/source/en/model_doc/fuyu.md | 4 + docs/source/en/model_doc/gemma3.md | 4 + docs/source/en/model_doc/got_ocr2.md | 4 + docs/source/en/model_doc/instructblip.md | 4 + docs/source/en/model_doc/instructblipvideo.md | 4 + docs/source/en/model_doc/internvl.md | 5 + docs/source/en/model_doc/llava.md | 4 + docs/source/en/model_doc/llava_next.md | 4 + docs/source/en/model_doc/llava_next_video.md | 4 + docs/source/en/model_doc/llava_onevision.md | 4 + docs/source/en/model_doc/mistral3.md | 3 + docs/source/en/model_doc/mllama.md | 4 + docs/source/en/model_doc/paligemma.md | 4 + docs/source/en/model_doc/qwen2_5_vl.md | 4 + docs/source/en/model_doc/qwen2_vl.md | 5 + docs/source/en/model_doc/video_llava.md | 4 + docs/source/en/model_doc/vipllava.md | 4 + src/transformers/modeling_utils.py | 47 +- src/transformers/models/aria/modeling_aria.py | 360 ++++++-- src/transformers/models/aria/modular_aria.py | 224 +++-- src/transformers/models/auto/modeling_auto.py | 25 +- .../models/aya_vision/modeling_aya_vision.py | 345 ++++++-- .../models/aya_vision/modular_aya_vision.py | 107 +-- .../models/colpali/modeling_colpali.py | 20 +- src/transformers/models/emu3/modeling_emu3.py | 214 ++++- src/transformers/models/emu3/modular_emu3.py | 206 ++++- src/transformers/models/fuyu/modeling_fuyu.py | 206 +++-- .../models/gemma3/modeling_gemma3.py | 355 ++++++-- .../models/gemma3/modular_gemma3.py | 245 +++--- .../models/got_ocr2/modeling_got_ocr2.py | 305 +++++-- .../models/got_ocr2/modular_got_ocr2.py | 155 ++-- .../modeling_granitemoehybrid.py | 8 +- .../instructblip/modeling_instructblip.py | 160 +++- .../modeling_instructblipvideo.py | 168 +++- .../modular_instructblipvideo.py | 108 ++- .../models/internvl/modeling_internvl.py | 419 ++++++--- .../models/internvl/modular_internvl.py | 21 +- .../models/llava/modeling_llava.py | 345 ++++++-- .../models/llava_next/modeling_llava_next.py | 361 +++++--- .../modeling_llava_next_video.py | 539 ++++++++---- .../modular_llava_next_video.py | 352 +++++--- .../image_processing_llava_onevision_fast.py | 20 +- .../modeling_llava_onevision.py | 821 +++++++++++------- .../modular_llava_onevision.py | 478 +++++++++- .../models/mistral3/modeling_mistral3.py | 352 +++++--- .../models/mistral3/modular_mistral3.py | 201 +++-- .../models/mllama/modeling_mllama.py | 285 +++--- .../paligemma/configuration_paligemma.py | 24 - .../models/paligemma/modeling_paligemma.py | 438 ++++++---- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 8 +- .../qwen2_5_omni/modular_qwen2_5_omni.py | 6 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 402 ++++++--- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 188 ++-- .../qwen2_audio/modeling_qwen2_audio.py | 6 - .../models/qwen2_vl/modeling_qwen2_vl.py | 372 +++++--- .../video_llava/modeling_video_llava.py | 371 ++++++-- .../models/vipllava/modeling_vipllava.py | 345 ++++++-- .../models/vipllava/modular_vipllava.py | 294 +++++++ tests/models/aria/test_modeling_aria.py | 15 +- .../aya_vision/test_modeling_aya_vision.py | 10 +- tests/models/emu3/test_modeling_emu3.py | 14 +- tests/models/fuyu/test_modeling_fuyu.py | 11 +- tests/models/gemma3/test_modeling_gemma3.py | 16 +- .../models/got_ocr2/test_modeling_got_ocr2.py | 14 +- .../test_modeling_instructblip.py | 12 +- .../test_modeling_instructblipvideo.py | 11 +- .../models/internvl/test_modeling_internvl.py | 6 +- tests/models/llava/test_modeling_llava.py | 25 +- .../llava_next/test_modeling_llava_next.py | 29 +- .../test_modeling_llava_next_video.py | 29 +- .../test_modeling_llava_onevision.py | 13 +- .../models/mistral3/test_modeling_mistral3.py | 14 +- tests/models/mllama/test_modeling_mllama.py | 49 +- .../paligemma/test_modeling_paligemma.py | 22 +- .../paligemma2/test_modeling_paligemma2.py | 12 +- .../qwen2_5_vl/test_modeling_qwen2_5_vl.py | 53 +- .../models/qwen2_vl/test_modeling_qwen2_vl.py | 60 +- .../video_llava/test_modeling_video_llava.py | 50 +- .../models/vipllava/test_modeling_vipllava.py | 29 +- tests/test_modeling_common.py | 7 +- utils/check_config_attributes.py | 4 +- utils/check_repo.py | 2 + 85 files changed, 7590 insertions(+), 2904 deletions(-) create mode 100644 src/transformers/models/vipllava/modular_vipllava.py diff --git a/docs/source/en/model_doc/aria.md b/docs/source/en/model_doc/aria.md index 7b58f59cab7..89cd26db649 100644 --- a/docs/source/en/model_doc/aria.md +++ b/docs/source/en/model_doc/aria.md @@ -102,6 +102,10 @@ response = processor.decode(output_ids, skip_special_tokens=True) [[autodoc]] AriaTextModel +## AriaModel + +[[autodoc]] AriaModel + ## AriaTextForCausalLM [[autodoc]] AriaTextForCausalLM diff --git a/docs/source/en/model_doc/aya_vision.md b/docs/source/en/model_doc/aya_vision.md index 17daf494920..f2a82089506 100644 --- a/docs/source/en/model_doc/aya_vision.md +++ b/docs/source/en/model_doc/aya_vision.md @@ -237,6 +237,10 @@ for i, output in enumerate(batch_outputs): [[autodoc]] AyaVisionConfig +## AyaVisionModel + +[[autodoc]] AyaVisionModel + ## AyaVisionForConditionalGeneration [[autodoc]] AyaVisionForConditionalGeneration diff --git a/docs/source/en/model_doc/emu3.md b/docs/source/en/model_doc/emu3.md index 4ac7d0b0c4f..20b8a5e1cdb 100644 --- a/docs/source/en/model_doc/emu3.md +++ b/docs/source/en/model_doc/emu3.md @@ -174,6 +174,10 @@ for i, image in enumerate(images['pixel_values']): [[autodoc]] Emu3TextModel - forward +## Emu3Model + +[[autodoc]] Emu3Model + ## Emu3ForCausalLM [[autodoc]] Emu3ForCausalLM diff --git a/docs/source/en/model_doc/fuyu.md b/docs/source/en/model_doc/fuyu.md index c0ea89ad19f..60ae9efdf3f 100644 --- a/docs/source/en/model_doc/fuyu.md +++ b/docs/source/en/model_doc/fuyu.md @@ -103,6 +103,10 @@ The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece. [[autodoc]] FuyuConfig +## FuyuModel + +[[autodoc]] FuyuModel + ## FuyuForCausalLM [[autodoc]] FuyuForCausalLM diff --git a/docs/source/en/model_doc/gemma3.md b/docs/source/en/model_doc/gemma3.md index f3864012c84..8372fd9ed15 100644 --- a/docs/source/en/model_doc/gemma3.md +++ b/docs/source/en/model_doc/gemma3.md @@ -254,6 +254,10 @@ visualizer("What is shown in this image?") [[autodoc]] Gemma3TextModel - forward +## Gemma3Model + +[[autodoc]] Gemma3Model + ## Gemma3ForCausalLM [[autodoc]] Gemma3ForCausalLM diff --git a/docs/source/en/model_doc/got_ocr2.md b/docs/source/en/model_doc/got_ocr2.md index bb14cdbcf84..c7a73659e88 100644 --- a/docs/source/en/model_doc/got_ocr2.md +++ b/docs/source/en/model_doc/got_ocr2.md @@ -277,6 +277,10 @@ alt="drawing" width="600"/> [[autodoc]] GotOcr2Processor +## GotOcr2Model + +[[autodoc]] GotOcr2Model + ## GotOcr2ForConditionalGeneration [[autodoc]] GotOcr2ForConditionalGeneration diff --git a/docs/source/en/model_doc/instructblip.md b/docs/source/en/model_doc/instructblip.md index 4f2feb015f1..944e8888fcf 100644 --- a/docs/source/en/model_doc/instructblip.md +++ b/docs/source/en/model_doc/instructblip.md @@ -69,6 +69,10 @@ The attributes can be obtained from model config, as `model.config.num_query_tok [[autodoc]] InstructBlipQFormerModel - forward +## InstructBlipModel + +[[autodoc]] InstructBlipModel + ## InstructBlipForConditionalGeneration [[autodoc]] InstructBlipForConditionalGeneration diff --git a/docs/source/en/model_doc/instructblipvideo.md b/docs/source/en/model_doc/instructblipvideo.md index c26562a8530..c021a4c7afa 100644 --- a/docs/source/en/model_doc/instructblipvideo.md +++ b/docs/source/en/model_doc/instructblipvideo.md @@ -73,6 +73,10 @@ The attributes can be obtained from model config, as `model.config.num_query_tok [[autodoc]] InstructBlipVideoQFormerModel - forward +## InstructBlipVideoModel +[[autodoc]] InstructBlipVideoModel + - forward + ## InstructBlipVideoForConditionalGeneration [[autodoc]] InstructBlipVideoForConditionalGeneration diff --git a/docs/source/en/model_doc/internvl.md b/docs/source/en/model_doc/internvl.md index 4ac56c85377..8a19726e3e0 100644 --- a/docs/source/en/model_doc/internvl.md +++ b/docs/source/en/model_doc/internvl.md @@ -340,6 +340,11 @@ This example showcases how to handle a batch of chat conversations with interlea [[autodoc]] InternVLVisionModel - forward +## InternVLModel + +[[autodoc]] InternVLModel + - forward + ## InternVLForConditionalGeneration [[autodoc]] InternVLForConditionalGeneration diff --git a/docs/source/en/model_doc/llava.md b/docs/source/en/model_doc/llava.md index 79033ec5a18..dcf2cd2f3f6 100644 --- a/docs/source/en/model_doc/llava.md +++ b/docs/source/en/model_doc/llava.md @@ -256,6 +256,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] LlavaProcessor +## LlavaModel + +[[autodoc]] LlavaModel + ## LlavaForConditionalGeneration [[autodoc]] LlavaForConditionalGeneration diff --git a/docs/source/en/model_doc/llava_next.md b/docs/source/en/model_doc/llava_next.md index 7d85ab8b696..2af882b6118 100644 --- a/docs/source/en/model_doc/llava_next.md +++ b/docs/source/en/model_doc/llava_next.md @@ -315,6 +315,10 @@ model = AutoModelForImageTextToText.from_pretrained( [[autodoc]] LlavaNextProcessor +## LlavaNextModel + +[[autodoc]] LlavaNextModel + ## LlavaNextForConditionalGeneration [[autodoc]] LlavaNextForConditionalGeneration diff --git a/docs/source/en/model_doc/llava_next_video.md b/docs/source/en/model_doc/llava_next_video.md index e435861cbe2..d3306256846 100644 --- a/docs/source/en/model_doc/llava_next_video.md +++ b/docs/source/en/model_doc/llava_next_video.md @@ -262,6 +262,10 @@ model = LlavaNextVideoForConditionalGeneration.from_pretrained( [[autodoc]] LlavaNextVideoImageProcessor +## LlavaNextVideoModel + +[[autodoc]] LlavaNextVideoModel + ## LlavaNextVideoForConditionalGeneration [[autodoc]] LlavaNextVideoForConditionalGeneration diff --git a/docs/source/en/model_doc/llava_onevision.md b/docs/source/en/model_doc/llava_onevision.md index 77fe807d46d..a00dd5a0e12 100644 --- a/docs/source/en/model_doc/llava_onevision.md +++ b/docs/source/en/model_doc/llava_onevision.md @@ -313,6 +313,10 @@ model = LlavaOnevisionForConditionalGeneration.from_pretrained( [[autodoc]] LlavaOnevisionVideoProcessor +## LlavaOnevisionModel + +[[autodoc]] LlavaOnevisionModel + ## LlavaOnevisionForConditionalGeneration [[autodoc]] LlavaOnevisionForConditionalGeneration diff --git a/docs/source/en/model_doc/mistral3.md b/docs/source/en/model_doc/mistral3.md index 4efdb641559..8eedb5de6bd 100644 --- a/docs/source/en/model_doc/mistral3.md +++ b/docs/source/en/model_doc/mistral3.md @@ -227,6 +227,9 @@ This example also how to use `BitsAndBytes` to load the model in 4bit quantizati [[autodoc]] Mistral3Config +## Mistral3Model + +[[autodoc]] Mistral3Model ## Mistral3ForConditionalGeneration diff --git a/docs/source/en/model_doc/mllama.md b/docs/source/en/model_doc/mllama.md index 77f5e211f17..cdd4da240af 100644 --- a/docs/source/en/model_doc/mllama.md +++ b/docs/source/en/model_doc/mllama.md @@ -130,6 +130,10 @@ print(processor.decode(output[0], skip_special_tokens=True)) [[autodoc]] MllamaTextModel - forward +## MllamaModel + +[[autodoc]] MllamaModel + ## MllamaForCausalLM [[autodoc]] MllamaForCausalLM diff --git a/docs/source/en/model_doc/paligemma.md b/docs/source/en/model_doc/paligemma.md index fa119a5f836..a0a0c1b714f 100644 --- a/docs/source/en/model_doc/paligemma.md +++ b/docs/source/en/model_doc/paligemma.md @@ -174,6 +174,10 @@ visualizer(" What is in this image?") [[autodoc]] PaliGemmaProcessor +## PaliGemmaModel + +[[autodoc]] PaliGemmaModel + ## PaliGemmaForConditionalGeneration [[autodoc]] PaliGemmaForConditionalGeneration diff --git a/docs/source/en/model_doc/qwen2_5_vl.md b/docs/source/en/model_doc/qwen2_5_vl.md index c414b41faac..57b88d1b8da 100644 --- a/docs/source/en/model_doc/qwen2_5_vl.md +++ b/docs/source/en/model_doc/qwen2_5_vl.md @@ -240,6 +240,10 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained( [[autodoc]] Qwen2_5_VLProcessor +## Qwen2_5_VLTextModel + +[[autodoc]] Qwen2_5_VLTextModel + - forward ## Qwen2_5_VLModel diff --git a/docs/source/en/model_doc/qwen2_vl.md b/docs/source/en/model_doc/qwen2_vl.md index 3d1845b6015..7fef4e2fdbd 100644 --- a/docs/source/en/model_doc/qwen2_vl.md +++ b/docs/source/en/model_doc/qwen2_vl.md @@ -296,6 +296,11 @@ model = Qwen2VLForConditionalGeneration.from_pretrained( [[autodoc]] Qwen2VLProcessor +## Qwen2VLTextModel + +[[autodoc]] Qwen2VLTextModel + - forward + ## Qwen2VLModel [[autodoc]] Qwen2VLModel diff --git a/docs/source/en/model_doc/video_llava.md b/docs/source/en/model_doc/video_llava.md index f407b4dc5eb..ca1a06d4cdc 100644 --- a/docs/source/en/model_doc/video_llava.md +++ b/docs/source/en/model_doc/video_llava.md @@ -215,6 +215,10 @@ model = VideoLlavaForConditionalGeneration.from_pretrained( [[autodoc]] VideoLlavaProcessor +## VideoLlavaModel + +[[autodoc]] VideoLlavaModel + ## VideoLlavaForConditionalGeneration [[autodoc]] VideoLlavaForConditionalGeneration diff --git a/docs/source/en/model_doc/vipllava.md b/docs/source/en/model_doc/vipllava.md index 9438893dfb1..8edf1540268 100644 --- a/docs/source/en/model_doc/vipllava.md +++ b/docs/source/en/model_doc/vipllava.md @@ -101,6 +101,10 @@ A chat between a curious human and an artificial intelligence assistant. The ass [[autodoc]] VipLlavaConfig +## VipLlavaModel + +[[autodoc]] VipLlavaModel + ## VipLlavaForConditionalGeneration [[autodoc]] VipLlavaForConditionalGeneration diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index be84fdee54b..344990c2c9f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -216,6 +216,28 @@ TORCH_INIT_FUNCTIONS = { "kaiming_normal": nn.init.kaiming_normal, } +# DO NOT MODIFY, KEPT FOR BC ONLY +VLMS = [ + "aria", + "aya_vision", + "emu3", + "fuyu", + "got_ocr2", + "gemma3", + "internvl", + "llava", + "llava_next", + "llava_next_video", + "llava_onevision", + "mistral3", + "mllama", + "paligemma", + "qwen2_vl", + "qwem2_5_vl", + "video_llava", + "vipllava", +] + @contextmanager def no_init_weights(): @@ -1778,6 +1800,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi main_input_name = "input_ids" model_tags = None + _checkpoint_conversion_mapping = {} # used for BC support in VLMs, not meant to be used by new models + _auto_class = None _no_split_modules = None _skip_keys_device_placement = None @@ -3484,6 +3508,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi module_map[name + f".{key}"] = module state_dict = model_to_save.state_dict() + if any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS): + reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()} + + original_state_dict = {} + for key, value in state_dict.items(): + for pattern, replacement in reverse_key_mapping.items(): + replacement = replacement.lstrip("^") # strip off un-needed chars and patterns + replacement = re.sub(r"\(.*?\)", "", pattern) + key, n_replace = re.subn(pattern, replacement, key) + # Early exit of the loop + if n_replace > 0: + break + original_state_dict[key] = value + state_dict = original_state_dict + # Translate state_dict from smp to hf if saving with smp >= 1.10 if IS_SAGEMAKER_MP_POST_1_10: for smp_to_hf, _ in smp.state.module_manager.translate_functions: @@ -4071,7 +4110,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi gguf_file = kwargs.pop("gguf_file", None) tp_plan = kwargs.pop("tp_plan", None) tp_size = kwargs.pop("tp_size", None) - key_mapping = kwargs.pop("key_mapping", None) + + # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model + if any(allowed_name in cls.__name__.lower() for allowed_name in VLMS): + key_mapping = kwargs.pop("key_mapping", cls._checkpoint_conversion_mapping) + else: + key_mapping = kwargs.pop("key_mapping", None) + # Not used anymore -- remove them from the kwargs _ = kwargs.pop("resume_download", None) _ = kwargs.pop("trust_remote_code", None) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index f6a1861d308..31e0980fdeb 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -42,7 +42,7 @@ from ...utils import ( replace_return_docstrings, ) from ...utils.import_utils import is_torch_available -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_aria import AriaConfig, AriaTextConfig @@ -58,7 +58,9 @@ if is_torch_flex_attn_available(): logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "AriaTextConfig" + + +_CONFIG_FOR_DOC = "AriaConfig" @use_kernel_forward_from_hub("RMSNorm") @@ -659,7 +661,7 @@ class AriaTextPreTrainedModel(PreTrainedModel): An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ - config_class = AriaConfig + config_class = AriaTextConfig base_model_prefix = "model" _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"] supports_gradient_checkpointing = True @@ -706,8 +708,8 @@ ARIA_TEXT_START_DOCSTRING = r""" ARIA_TEXT_START_DOCSTRING, ) class AriaPreTrainedModel(PreTrainedModel): - config_class = AriaTextConfig - base_model_prefix = "model" + config_class = AriaConfig + base_model_prefix = "" supports_gradient_checkpointing = True _no_split_modules = ["AriaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] @@ -1097,6 +1099,9 @@ class AriaTextModel(AriaTextPreTrainedModel): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... +_CONFIG_FOR_TEXT_DOC = "AriaTextConfig" + + class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): """ Aria model for causal language modeling tasks. @@ -1112,7 +1117,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - config_class = AriaTextConfig def __init__(self, config: AriaTextConfig): super().__init__(config) @@ -1141,9 +1145,8 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model - @can_return_tuple @add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_TEXT_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1255,7 +1258,7 @@ class AriaCausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ @@ -1267,6 +1270,39 @@ class AriaCausalLMOutputWithPast(ModelOutput): image_hidden_states: Optional[torch.FloatTensor] = None +@dataclass +class AriaModelOutputWithPast(BaseModelOutputWithPast): + """ + Base class for Aria outputs, with hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + ARIA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor`, *optional*): @@ -1320,30 +1356,131 @@ ARIA_START_DOCSTRING = r""" @add_start_docstrings( - """Aria model for conditional generation tasks. - - This model combines a vision tower, a multi-modal projector, and a language model - to perform tasks that involve both image and text inputs.""", + """The Aria model which consists of a vision backbone and a language model, without a language modeling head.""", ARIA_START_DOCSTRING, ) -class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): - config_class = AriaConfig - _supports_flash_attn_2 = False - _supports_flex_attn = False - _supports_sdpa = False - _tied_weights_keys = ["language_model.lm_head.weight"] +class AriaModel(AriaPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} def __init__(self, config: AriaConfig): super().__init__(config) - self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = AriaProjector(config) - self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" + self.language_model = AutoModel.from_config(config.text_config) self.post_init() + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.FloatTensor = None, + vision_feature_layer: int = -1, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + pixel_mask (`torch.FloatTensor]`, *optional*): + The tensors corresponding to the input image mask. + vision_feature_layer (`Union[int, List[int]]`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + patch_attention_mask = self._create_patch_attention_mask(pixel_mask) + image_outputs = self.vision_tower( + pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True + ) + image_attn_mask = None + if patch_attention_mask is not None: + flattened_mask = patch_attention_mask.flatten(1) + image_attn_mask = torch.logical_not(flattened_mask) + + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) + return image_features + + @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + pixel_mask: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, AriaModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text and images + if pixel_values is not None and inputs_embeds.shape[1] != 1: + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + else: + image_embeds = input_ids == self.config.image_token_id + special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) + image_features = self.get_image_features( + pixel_values=pixel_values, + pixel_mask=pixel_mask, + vision_feature_layer=self.config.vision_feature_layer, + ) + n_images, n_features_per_image = image_features.shape[0], image_features.shape[1] + n_image_features = n_images * n_features_per_image + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + ) + + output = AriaModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values if use_cache else None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + return output if return_dict else output.to_tuple() + def _create_patch_attention_mask(self, pixel_mask): if pixel_mask is None: return None @@ -1360,51 +1497,61 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): ) return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + +@add_start_docstrings( + """Aria model for conditional generation tasks. + This model combines a vision tower, a multi-modal projector, and a language model + to perform tasks that involve both image and text inputs.""", + ARIA_START_DOCSTRING, +) +class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: AriaConfig): + super().__init__(config) + self.model = AriaModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + def get_input_embeddings(self): - return self.language_model.get_input_embeddings() + return self.model.get_input_embeddings() def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) + self.model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() + def get_output_embeddings(self) -> nn.Module: + return self.lm_head def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) + self.lm_head = new_embeddings - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model - def get_decoder(self): - return self.language_model.get_decoder() + @property + def vision_tower(self): + return self.model.vision_tower - def get_image_features( - self, - pixel_values: torch.FloatTensor, - pixel_mask: Optional[torch.FloatTensor] = None, - vision_feature_layer: int = -1, - ): - patch_attention_mask = self._create_patch_attention_mask(pixel_mask) - image_outputs = self.vision_tower( - pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True - ) - image_attn_mask = None - if patch_attention_mask is not None: - flattened_mask = patch_attention_mask.flatten(1) - image_attn_mask = torch.logical_not(flattened_mask) - - selected_image_feature = image_outputs.hidden_states[vision_feature_layer] - image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) - return image_features + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector @can_return_tuple @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig) + @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_mask: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + pixel_mask: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -1413,10 +1560,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, cache_position: Optional[torch.LongTensor] = None, **loss_kwargs, - ) -> AriaCausalLMOutputWithPast: + ) -> Union[Tuple, AriaCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1482,37 +1630,12 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - # 2. Merge text and images - if pixel_values is not None and inputs_embeds.shape[1] != 1: - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] - else: - image_embeds = input_ids == self.config.image_token_id - special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) - image_features = self.get_image_features( - pixel_values=pixel_values, - pixel_mask=pixel_mask, - vision_feature_layer=self.config.vision_feature_layer, - ) - n_images, n_features_per_image = image_features.shape[0], image_features.shape[1] - n_image_features = n_images * n_features_per_image - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_mask=pixel_mask, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -1520,11 +1643,14 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - logits_to_keep=logits_to_keep, + return_dict=return_dict, cache_position=cache_position, ) - logits = outputs.logits + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1552,7 +1678,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): logits_to_keep=None, **kwargs, ): - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1570,11 +1696,67 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): return model_inputs + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + __all__ = [ "AriaForConditionalGeneration", "AriaPreTrainedModel", "AriaTextPreTrainedModel", "AriaTextModel", + "AriaModel", "AriaTextForCausalLM", ] diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index e4d063f7827..a42c7227772 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -18,7 +18,6 @@ import numpy as np from ...activations import ACT2FN from ...configuration_utils import PretrainedConfig -from ...generation import GenerationMixin from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format from ...image_utils import ( @@ -49,7 +48,7 @@ from ...utils import ( replace_return_docstrings, ) from ...utils.import_utils import is_torch_available -from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer +from ..auto import CONFIG_MAPPING, AutoConfig, AutoTokenizer from ..llama.configuration_llama import LlamaConfig from ..llama.modeling_llama import ( LlamaDecoderLayer, @@ -59,7 +58,12 @@ from ..llama.modeling_llama import ( LlamaPreTrainedModel, LlamaRMSNorm, ) -from ..llava.modeling_llava import LlavaCausalLMOutputWithPast +from ..llava.modeling_llava import ( + LlavaCausalLMOutputWithPast, + LlavaForConditionalGeneration, + LlavaModel, + LlavaModelOutputWithPast, +) from ..llava_next.image_processing_llava_next import divide_to_patches @@ -70,6 +74,11 @@ if is_torch_available(): from torch import nn +_CONFIG_FOR_DOC = "AriaConfig" +_CONFIG_FOR_TEXT_DOC = "AriaTextConfig" +ARIA_TEXT_INPUTS_DOCSTRING = None + + def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert): """ Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts. @@ -1223,7 +1232,7 @@ class AriaTextPreTrainedModel(PreTrainedModel): An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ - config_class = AriaConfig + config_class = AriaTextConfig base_model_prefix = "model" _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"] supports_gradient_checkpointing = True @@ -1249,6 +1258,8 @@ class AriaTextPreTrainedModel(PreTrainedModel): class AriaPreTrainedModel(LlamaPreTrainedModel): + config_class = AriaConfig + base_model_prefix = "" _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) _supports_attention_backend = False @@ -1292,7 +1303,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM): """ _tied_weights_keys = ["lm_head.weight"] - config_class = AriaTextConfig def __init__(self, config: AriaTextConfig): super().__init__(config) @@ -1303,11 +1313,20 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM): # Initialize weights and apply final processing self.post_init() + @add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_TEXT_DOC) + def forward(self, **super_kwargs): + super().forward(self, **super_kwargs) + class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): pass +class AriaModelOutputWithPast(LlavaModelOutputWithPast): + pass + + ARIA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor`, *optional*): @@ -1360,30 +1379,10 @@ ARIA_START_DOCSTRING = r""" """ -@add_start_docstrings( - """Aria model for conditional generation tasks. - - This model combines a vision tower, a multi-modal projector, and a language model - to perform tasks that involve both image and text inputs.""", - ARIA_START_DOCSTRING, -) -class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): - config_class = AriaConfig - _supports_flash_attn_2 = False - _supports_flex_attn = False - _supports_sdpa = False - _tied_weights_keys = ["language_model.lm_head.weight"] - +class AriaModel(LlavaModel): def __init__(self, config: AriaConfig): super().__init__(config) - - self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = AriaProjector(config) - self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" - self.post_init() def _create_patch_attention_mask(self, pixel_mask): if pixel_mask is None: @@ -1401,30 +1400,27 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): ) return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - def get_input_embeddings(self): - return self.language_model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) - - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, - pixel_mask: Optional[torch.FloatTensor] = None, + pixel_mask: torch.FloatTensor = None, vision_feature_layer: int = -1, ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + pixel_mask (`torch.FloatTensor]`, *optional*): + The tensors corresponding to the input image mask. + vision_feature_layer (`Union[int, List[int]]`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ patch_attention_mask = self._create_patch_attention_mask(pixel_mask) image_outputs = self.vision_tower( pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True @@ -1438,14 +1434,94 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) return image_features - @can_return_tuple @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_mask: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + pixel_mask: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, AriaModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text and images + if pixel_values is not None and inputs_embeds.shape[1] != 1: + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + else: + image_embeds = input_ids == self.config.image_token_id + special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) + image_features = self.get_image_features( + pixel_values=pixel_values, + pixel_mask=pixel_mask, + vision_feature_layer=self.config.vision_feature_layer, + ) + n_images, n_features_per_image = image_features.shape[0], image_features.shape[1] + n_image_features = n_images * n_features_per_image + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + ) + + output = AriaModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values if use_cache else None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + return output if return_dict else output.to_tuple() + + +@add_start_docstrings( + """Aria model for conditional generation tasks. + This model combines a vision tower, a multi-modal projector, and a language model + to perform tasks that involve both image and text inputs.""", + ARIA_START_DOCSTRING, +) +class AriaForConditionalGeneration(LlavaForConditionalGeneration): + @can_return_tuple + @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + pixel_mask: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -1454,10 +1530,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, cache_position: Optional[torch.LongTensor] = None, **loss_kwargs, - ) -> AriaCausalLMOutputWithPast: + ) -> Union[Tuple, AriaCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1523,37 +1600,12 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - # 2. Merge text and images - if pixel_values is not None and inputs_embeds.shape[1] != 1: - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] - else: - image_embeds = input_ids == self.config.image_token_id - special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) - image_features = self.get_image_features( - pixel_values=pixel_values, - pixel_mask=pixel_mask, - vision_feature_layer=self.config.vision_feature_layer, - ) - n_images, n_features_per_image = image_features.shape[0], image_features.shape[1] - n_image_features = n_images * n_features_per_image - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_mask=pixel_mask, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -1561,11 +1613,14 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - logits_to_keep=logits_to_keep, + return_dict=return_dict, cache_position=cache_position, ) - logits = outputs.logits + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1593,7 +1648,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): logits_to_keep=None, **kwargs, ): - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1621,5 +1676,6 @@ __all__ = [ "AriaPreTrainedModel", "AriaTextPreTrainedModel", "AriaTextModel", + "AriaModel", "AriaTextForCausalLM", ] diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index b196a7718f7..bd7dff185b2 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -35,10 +35,11 @@ MODEL_MAPPING_NAMES = OrderedDict( ("albert", "AlbertModel"), ("align", "AlignModel"), ("altclip", "AltCLIPModel"), - ("aria", "AriaForConditionalGeneration"), + ("aria", "AriaModel"), ("aria_text", "AriaTextModel"), ("audio-spectrogram-transformer", "ASTModel"), ("autoformer", "AutoformerModel"), + ("aya_vision", "AyaVisionModel"), ("bamba", "BambaModel"), ("bark", "BarkModel"), ("bart", "BartModel"), @@ -108,6 +109,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("efficientformer", "EfficientFormerModel"), ("efficientnet", "EfficientNetModel"), ("electra", "ElectraModel"), + ("emu3", "Emu3Model"), ("encodec", "EncodecModel"), ("ernie", "ErnieModel"), ("ernie_m", "ErnieMModel"), @@ -121,14 +123,16 @@ MODEL_MAPPING_NAMES = OrderedDict( ("focalnet", "FocalNetModel"), ("fsmt", "FSMTModel"), ("funnel", ("FunnelModel", "FunnelBaseModel")), + ("fuyu", "FuyuModel"), ("gemma", "GemmaModel"), ("gemma2", "Gemma2Model"), + ("gemma3", "Gemma3Model"), ("gemma3_text", "Gemma3TextModel"), ("git", "GitModel"), ("glm", "GlmModel"), ("glm4", "Glm4Model"), ("glpn", "GLPNModel"), - ("got_ocr2", "GotOcr2ForConditionalGeneration"), + ("got_ocr2", "GotOcr2Model"), ("gpt-sw3", "GPT2Model"), ("gpt2", "GPT2Model"), ("gpt_bigcode", "GPTBigCodeModel"), @@ -156,6 +160,9 @@ MODEL_MAPPING_NAMES = OrderedDict( ("ijepa", "IJepaModel"), ("imagegpt", "ImageGPTModel"), ("informer", "InformerModel"), + ("instructblip", "InstructBlipModel"), + ("instructblipvideo", "InstructBlipVideoModel"), + ("internvl", "InternVLModel"), ("internvl_vision", "InternVLVisionModel"), ("jamba", "JambaModel"), ("janus", "JanusModel"), @@ -170,6 +177,10 @@ MODEL_MAPPING_NAMES = OrderedDict( ("lilt", "LiltModel"), ("llama", "LlamaModel"), ("llama4", "Llama4ForConditionalGeneration"), + ("llava", "LlavaModel"), + ("llava_next", "LlavaNextModel"), + ("llava_next_video", "LlavaNextVideoModel"), + ("llava_onevision", "LlavaOnevisionModel"), ("longformer", "LongformerModel"), ("longt5", "LongT5Model"), ("luke", "LukeModel"), @@ -189,8 +200,10 @@ MODEL_MAPPING_NAMES = OrderedDict( ("mgp-str", "MgpstrForSceneTextRecognition"), ("mimi", "MimiModel"), ("mistral", "MistralModel"), + ("mistral3", "Mistral3Model"), ("mixtral", "MixtralModel"), ("mlcd", "MLCDVisionModel"), + ("mllama", "MllamaModel"), ("mobilebert", "MobileBertModel"), ("mobilenet_v1", "MobileNetV1Model"), ("mobilenet_v2", "MobileNetV2Model"), @@ -221,6 +234,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("opt", "OPTModel"), ("owlv2", "Owlv2Model"), ("owlvit", "OwlViTModel"), + ("paligemma", "PaliGemmaModel"), ("patchtsmixer", "PatchTSMixerModel"), ("patchtst", "PatchTSTModel"), ("pegasus", "PegasusModel"), @@ -240,11 +254,11 @@ MODEL_MAPPING_NAMES = OrderedDict( ("qdqbert", "QDQBertModel"), ("qwen2", "Qwen2Model"), ("qwen2_5_vl", "Qwen2_5_VLModel"), - ("qwen2_5_vl_text", "Qwen2_5_VLModel"), + ("qwen2_5_vl_text", "Qwen2_5_VLTextModel"), ("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_moe", "Qwen2MoeModel"), ("qwen2_vl", "Qwen2VLModel"), - ("qwen2_vl_text", "Qwen2VLModel"), + ("qwen2_vl_text", "Qwen2VLTextModel"), ("qwen3", "Qwen3Model"), ("qwen3_moe", "Qwen3MoeModel"), ("recurrent_gemma", "RecurrentGemmaModel"), @@ -306,8 +320,10 @@ MODEL_MAPPING_NAMES = OrderedDict( ("unispeech-sat", "UniSpeechSatModel"), ("univnet", "UnivNetModel"), ("van", "VanModel"), + ("video_llava", "VideoLlavaModel"), ("videomae", "VideoMAEModel"), ("vilt", "ViltModel"), + ("vipllava", "VipLlavaModel"), ("vision-text-dual-encoder", "VisionTextDualEncoderModel"), ("visual_bert", "VisualBertModel"), ("vit", "ViTModel"), @@ -879,6 +895,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict( ("llama4", "Llama4ForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"), + ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), ("mistral3", "Mistral3ForConditionalGeneration"), ("mllama", "MllamaForConditionalGeneration"), diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 45c2ab66e3e..042ff0c05a7 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -27,15 +27,16 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_outputs import ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torchdynamo_compiling, replace_return_docstrings, ) -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_aya_vision import AyaVisionConfig @@ -115,9 +116,8 @@ AYA_VISION_START_DOCSTRING = r""" ) class AyaVisionPreTrainedModel(PreTrainedModel): config_class = AyaVisionConfig - base_model_prefix = "model" + base_model_prefix = "" supports_gradient_checkpointing = True - _no_split_modules = ["AyaVisionVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True @@ -169,7 +169,7 @@ class AyaVisionCausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ @@ -181,7 +181,40 @@ class AyaVisionCausalLMOutputWithPast(ModelOutput): image_hidden_states: Optional[torch.FloatTensor] = None -AYA_VISION_INPUTS_DOCSTRING = """ +@dataclass +class AyaVisionModelOutputWithPast(BaseModelOutputWithPast): + """ + Base class for AyaVision outputs, with hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + +AYA_VISION_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide @@ -193,8 +226,8 @@ AYA_VISION_INPUTS_DOCSTRING = """ [What are input IDs?](../glossary#input-ids) pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`GotOcr2ImageProcessor.__call__`] for details. [`AyaVisionProcessor`] uses - [`GotOcr2ImageProcessor`] for processing images. + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`AyaVisionProcessor`] uses + [`CLIPImageProcessor`] for processing images). attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: @@ -259,23 +292,18 @@ AYA_VISION_INPUTS_DOCSTRING = """ @add_start_docstrings( - """The AyaVision model which consists of a vision backbone and a language model.""", + """The AyaVision model which consists of a vision backbone and a language model, without a language modeling head.""", AYA_VISION_START_DOCSTRING, ) -class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin): +class AyaVisionModel(AyaVisionPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + def __init__(self, config: AyaVisionConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = AyaVisionMultiModalProjector(config) - self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] - - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - + self.language_model = AutoModel.from_config(config.text_config) self.post_init() def get_input_embeddings(self): @@ -284,18 +312,6 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -307,7 +323,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi Obtains image last hidden states from the vision tower and apply multimodal projection. Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): The tensors corresponding to the input images. vision_feature_layer (`Union[int, List[int]]`): The index of the layer to select the vision feature. If multiple indices are provided, @@ -342,6 +358,140 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi image_features = self.multi_modal_projector(selected_image_feature) return image_features + @add_start_docstrings_to_model_forward(AYA_VISION_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + image_sizes: torch.Tensor = None, + **lm_kwargs, + ) -> Union[Tuple, AyaVisionModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + image_sizes=image_sizes, + ) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_tokens = (input_ids == self.config.image_token_id).sum() + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + output = AyaVisionModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + return output if return_dict else output.to_tuple() + + +@add_start_docstrings( + """The AyaVision model which consists of a vision backbone and a language model.""", + AYA_VISION_START_DOCSTRING, +) +class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: AyaVisionConfig): + super().__init__(config) + self.model = AyaVisionModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple @add_start_docstrings_to_model_forward(AYA_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=AyaVisionCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -410,7 +560,6 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi >>> gen_tokens = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.3) >>> processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -425,73 +574,32 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi else self.config.vision_feature_select_strategy ) - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - image_sizes=image_sizes, - ) - - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - outputs = self.language_model( + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, - logits_to_keep=logits_to_keep, + image_sizes=image_sizes, **lm_kwargs, ) - logits = outputs[0] + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) return AyaVisionCausalLMOutputWithPast( loss=loss, @@ -499,7 +607,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, + image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( @@ -515,7 +623,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -532,15 +640,60 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi return model_inputs - def tie_weights(self): - return self.language_model.tie_weights() + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: - model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) - # update vocab size - self.config.text_config.vocab_size = model_embeds.num_embeddings - self.vocab_size = model_embeds.num_embeddings - return model_embeds + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask -__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel"] +__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"] diff --git a/src/transformers/models/aya_vision/modular_aya_vision.py b/src/transformers/models/aya_vision/modular_aya_vision.py index 96f0de888dd..5d7acecff88 100644 --- a/src/transformers/models/aya_vision/modular_aya_vision.py +++ b/src/transformers/models/aya_vision/modular_aya_vision.py @@ -22,6 +22,7 @@ from torch import nn from transformers.models.llava.modeling_llava import ( LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, + LlavaModel, LlavaPreTrainedModel, ) @@ -88,21 +89,8 @@ class AyaVisionMultiModalProjector(nn.Module): return image_features -AYA_VISION_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`AyaVisionConfig`] or [`AyaVisionVisionConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" +AYA_VISION_START_DOCSTRING = None +AYA_VISION_INPUTS_DOCSTRING = None @add_start_docstrings( @@ -133,81 +121,8 @@ class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): pass -AYA_VISION_INPUTS_DOCSTRING = """ - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): - The tensors corresponding to the input images. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`GotOcr2ImageProcessor.__call__`] for details. [`AyaVisionProcessor`] uses - [`GotOcr2ImageProcessor`] for processing images. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`): - The index of the layer to select the vision feature. If multiple indices are provided, - the vision feature of the corresponding indices will be concatenated to form the - vision features. - vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): - The feature selection strategy used to select the vision feature from the vision backbone. - Can be one of `"default"` or `"full"`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" +class AyaVisionModel(LlavaModel): + pass @add_start_docstrings( @@ -215,16 +130,6 @@ AYA_VISION_INPUTS_DOCSTRING = """ AYA_VISION_START_DOCSTRING, ) class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration): - def tie_weights(self): - return self.language_model.tie_weights() - - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: - model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) - # update vocab size - self.config.text_config.vocab_size = model_embeds.num_embeddings - self.vocab_size = model_embeds.num_embeddings - return model_embeds - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -312,4 +217,4 @@ class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration): ) -__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel"] +__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"] diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py index f1d6b3a0476..c3a669d4c98 100644 --- a/src/transformers/models/colpali/modeling_colpali.py +++ b/src/transformers/models/colpali/modeling_colpali.py @@ -175,8 +175,8 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel): self.vocab_size = config.vlm_config.text_config.vocab_size vlm = AutoModelForImageTextToText.from_config(config.vlm_config) - if vlm.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"vlm.language_model.{k}" for k in vlm.language_model._tied_weights_keys] + if vlm._tied_weights_keys is not None: + self._tied_weights_keys = [f"vlm.{k}" for k in vlm._tied_weights_keys] self.vlm = vlm self.embedding_dim = self.config.embedding_dim @@ -246,25 +246,25 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel): ) def get_input_embeddings(self): - return self.vlm.language_model.get_input_embeddings() + return self.vlm.get_input_embeddings() def set_input_embeddings(self, value): - self.vlm.language_model.set_input_embeddings(value) + self.vlm.set_input_embeddings(value) def get_output_embeddings(self): - return self.vlm.language_model.get_output_embeddings() + return self.vlm.get_output_embeddings() def set_output_embeddings(self, new_embeddings): - self.vlm.language_model.set_output_embeddings(new_embeddings) + self.vlm.set_output_embeddings(new_embeddings) def set_decoder(self, decoder): - self.vlm.language_model.set_decoder(decoder) + self.vlm.set_decoder(decoder) def get_decoder(self): - return self.vlm.language_model.get_decoder() + return self.vlm.get_decoder() def tie_weights(self): - return self.vlm.language_model.tie_weights() + return self.vlm.tie_weights() def resize_token_embeddings( self, @@ -272,7 +272,7 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel): pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True, ) -> nn.Embedding: - model_embeds = self.vlm.language_model.resize_token_embeddings( + model_embeds = self.vlm.resize_token_embeddings( new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of, mean_resizing=mean_resizing, diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 67a643b273e..45031a1a647 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1778,13 +1778,16 @@ EMU3_INPUTS_DOCSTRING = r""" """ -class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["text_model.lm_head.weight"] - _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compilable +class Emu3Model(Emu3PreTrainedModel): + _checkpoint_conversion_mapping = {"text_model.model": "text_model"} + _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable def __init__(self, config): super().__init__(config) - self.text_model = Emu3ForCausalLM._from_config(config.text_config) + self.text_model = Emu3TextModel._from_config(config.text_config) + if self.text_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys] + self.vqmodel = Emu3VQVAE(config.vq_config) self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map) @@ -1833,14 +1836,12 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): image = self.vqmodel.decode(image_tokens) return image - @can_return_tuple @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - image_sizes: Optional[torch.Tensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, @@ -1848,10 +1849,99 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None: + image_tokens = self.get_image_tokens(pixel_values, image_sizes) + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) + input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + return outputs + + +class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): + base_model_prefix = "" + _checkpoint_conversion_mapping = { + "^text_model.model": "model.text_model", + "^vqmodel": "model.vqmodel", + "^text_model.lm_head": "lm_head", + } + _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable + + def __init__(self, config): + super().__init__(config) + self.model = Emu3Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + # Make modules available throught conditional class for BC + @property + def text_model(self): + return self.model.text_model + + @property + def vqmodel(self): + return self.model.vqmodel + + @can_return_tuple + @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - ) -> CausalLMOutputWithPast: + ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1906,25 +1996,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if pixel_values is not None: - image_tokens = self.get_image_tokens(pixel_values, image_sizes) - special_image_mask = input_ids == self.vocabulary_mapping.image_token_id - image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) - input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - return self.text_model( + outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1933,8 +2007,25 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=return_dict, cache_position=cache_position, - logits_to_keep=logits_to_keep, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) def prepare_inputs_for_generation( @@ -1968,5 +2059,68 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): return model_inputs + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. -__all__ = ["Emu3ForConditionalGeneration", "Emu3ForCausalLM", "Emu3TextModel", "Emu3PreTrainedModel", "Emu3VQVAE"] + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = [ + "Emu3ForConditionalGeneration", + "Emu3ForCausalLM", + "Emu3TextModel", + "Emu3PreTrainedModel", + "Emu3VQVAE", + "Emu3Model", +] diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 52d32dbdeea..6075b8d7370 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -1121,13 +1121,16 @@ class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin): super().forward() -class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["text_model.lm_head.weight"] - _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compilable +class Emu3Model(Emu3PreTrainedModel): + _checkpoint_conversion_mapping = {"text_model.model": "text_model"} + _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable def __init__(self, config): super().__init__(config) - self.text_model = Emu3ForCausalLM._from_config(config.text_config) + self.text_model = Emu3TextModel._from_config(config.text_config) + if self.text_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys] + self.vqmodel = Emu3VQVAE(config.vq_config) self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map) @@ -1176,14 +1179,12 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): image = self.vqmodel.decode(image_tokens) return image - @can_return_tuple @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - image_sizes: Optional[torch.Tensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, @@ -1191,10 +1192,99 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None: + image_tokens = self.get_image_tokens(pixel_values, image_sizes) + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) + input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + return outputs + + +class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): + base_model_prefix = "" + _checkpoint_conversion_mapping = { + "^text_model.model": "model.text_model", + "^vqmodel": "model.vqmodel", + "^text_model.lm_head": "lm_head", + } + _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable + + def __init__(self, config): + super().__init__(config) + self.model = Emu3Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + # Make modules available throught conditional class for BC + @property + def text_model(self): + return self.model.text_model + + @property + def vqmodel(self): + return self.model.vqmodel + + @can_return_tuple + @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - ) -> CausalLMOutputWithPast: + ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1249,25 +1339,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if pixel_values is not None: - image_tokens = self.get_image_tokens(pixel_values, image_sizes) - special_image_mask = input_ids == self.vocabulary_mapping.image_token_id - image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) - input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - return self.text_model( + outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1276,8 +1350,25 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=return_dict, cache_position=cache_position, - logits_to_keep=logits_to_keep, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) def prepare_inputs_for_generation( @@ -1311,6 +1402,62 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): return model_inputs + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + __all__ = [ "Emu3ForConditionalGeneration", @@ -1318,4 +1465,5 @@ __all__ = [ "Emu3TextModel", "Emu3PreTrainedModel", "Emu3VQVAE", + "Emu3Model", ] diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index fd19ff7b8d4..ec74b9be3d5 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -21,9 +21,9 @@ import torch.utils.checkpoint from torch import nn from ...generation import GenerationMixin -from ...modeling_outputs import CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel -from ...models.auto.modeling_auto import AutoModelForCausalLM +from ...models.auto.modeling_auto import AutoModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_fuyu import FuyuConfig @@ -143,18 +143,17 @@ FUYU_INPUTS_DOCSTRING = r""" @add_start_docstrings( - "Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.", + """The Fuyu model which consists of a vision backbone and a language model, without a language modeling head.""", FUYU_START_DOCSTRING, ) -class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin): +class FuyuModel(FuyuPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + def __init__(self, config: FuyuConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] - + self.language_model = AutoModel.from_config(config.text_config) self.vision_embed_tokens = nn.Linear( config.patch_size * config.patch_size * config.num_channels, config.hidden_size ) @@ -169,18 +168,6 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def gather_continuous_embeddings( self, word_embeddings: torch.Tensor, @@ -224,56 +211,21 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin): return output_embeddings @add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - image_patches: Optional[ - torch.Tensor - ] = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ] - image_patches_indices: Optional[torch.Tensor] = None, + input_ids: torch.LongTensor = None, + image_patches: torch.Tensor = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ] + image_patches_indices: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - Returns: - - Examples: - - ```python - >>> from transformers import FuyuProcessor, FuyuForCausalLM - >>> from PIL import Image - >>> import requests - - >>> processor = FuyuProcessor.from_pretrained("adept/fuyu-8b") - >>> model = FuyuForCausalLM.from_pretrained("adept/fuyu-8b") - - >>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png" - >>> image = Image.open(requests.get(url, stream=True).raw) - >>> prompt = "Generate a coco-style caption.\n" - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - >>> outputs = model(**inputs) - - >>> generated_ids = model.generate(**inputs, max_new_tokens=7) - >>> generation_text = processor.batch_decode(generated_ids[:, -7:], skip_special_tokens=True) - >>> print(generation_text[0]) - A blue bus parked on the side of a road. - ```""" - + ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -327,7 +279,6 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin): past_key_values=past_key_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - labels=labels, use_cache=use_cache, return_dict=return_dict, **kwargs, @@ -335,6 +286,139 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin): return outputs + +@add_start_docstrings( + "Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.", + FUYU_START_DOCSTRING, +) +class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_embed_tokens": "model.vision_embed_tokens", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: FuyuConfig): + super().__init__(config) + self.model = FuyuModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + @add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + image_patches: torch.Tensor = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ] + image_patches_indices: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + Returns: + + Examples: + + ```python + >>> from transformers import FuyuProcessor, FuyuForCausalLM + >>> from PIL import Image + >>> import requests + + >>> processor = FuyuProcessor.from_pretrained("adept/fuyu-8b") + >>> model = FuyuForCausalLM.from_pretrained("adept/fuyu-8b") + + >>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> prompt = "Generate a coco-style caption.\n" + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + >>> outputs = model(**inputs) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=7) + >>> generation_text = processor.batch_decode(generated_ids[:, -7:], skip_special_tokens=True) + >>> print(generation_text[0]) + A blue bus parked on the side of a road. + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + image_patches=image_patches, + image_patches_indices=image_patches_indices, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + return_dict=return_dict, + # don't pass kwargs because Persimmon-backbone doesn't accept FA2 kwargs yet, TODO: raushan + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + def prepare_inputs_for_generation( self, input_ids, @@ -373,4 +457,4 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin): return reordered_past -__all__ = ["FuyuForCausalLM", "FuyuPreTrainedModel"] +__all__ = ["FuyuForCausalLM", "FuyuPreTrainedModel", "FuyuModel"] diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 892b8898b62..74642ea6baa 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -32,11 +32,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( + ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, @@ -46,7 +47,7 @@ from ...utils import ( replace_return_docstrings, ) from ...utils.deprecation import deprecate_kwarg -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig @@ -60,6 +61,39 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Gemma3Config" +@dataclass +class Gemma3ModelOutputWithPast(BaseModelOutputWithPast): + """ + Base class for Gemma3 outputs, with hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + @dataclass class Gemma3CausalLMOutputWithPast(ModelOutput): """ @@ -88,7 +122,7 @@ class Gemma3CausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`. + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder after projecting last hidden state. """ @@ -480,7 +514,7 @@ GEMMA3_START_DOCSTRING = r""" ) class Gemma3PreTrainedModel(PreTrainedModel): config_class = Gemma3Config - base_model_prefix = "language_model" + base_model_prefix = "" supports_gradient_checkpointing = True _no_split_modules = [ "Gemma3DecoderLayer", @@ -1066,20 +1100,19 @@ class Gemma3MultiModalProjector(nn.Module): @add_start_docstrings( - """The GEMMA3 model which consists of a vision backbone and a language model.""", + """Base Gemma3 model which consists of a vision backbone and a language model withou language modeling head.""", GEMMA3_START_DOCSTRING, ) -class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): +class Gemma3Model(Gemma3PreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + def __init__(self, config: Gemma3Config): super().__init__(config) self.vision_tower = AutoModel.from_config(config=config.vision_config) self.multi_modal_projector = Gemma3MultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - language_model = AutoModelForCausalLM.from_config(config=config.text_config) - - if language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] + language_model = AutoModel.from_config(config=config.text_config) self.language_model = language_model self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 @@ -1091,18 +1124,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def _update_causal_mask( self, attention_mask, @@ -1188,11 +1209,10 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): @can_return_tuple @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, @@ -1203,6 +1223,145 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **lm_kwargs, + ) -> Union[Tuple, Gemma3ModelOutputWithPast]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + is_training = token_type_ids is not None and labels is not None + + # Replace image id woth PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_id >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training + ) + outputs = self.language_model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + return Gemma3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values if use_cache else None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +@add_start_docstrings( + """The Gemma3 model which consists of a vision backbone and a language model.""", + GEMMA3_START_DOCSTRING, +) +class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Gemma3Config): + super().__init__(config) + self.model = Gemma3Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **lm_kwargs, ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: @@ -1260,80 +1419,34 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): ``` """ - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - is_training = token_type_ids is not None and labels is not None - - # Replace image id with PAD if the image token if OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_id >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_id - llm_input_ids = input_ids.clone() - llm_input_ids[special_image_mask] = 0 - else: - llm_input_ids = input_ids - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - # Merge text and images - if pixel_values is not None: - image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - # mask out pad-token-ids in labels for BC - if labels is not None and self.pad_token_id in labels: - logger.warning_once( - "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " - "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", - ) - labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) - - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training - ) - outputs: CausalLMOutputWithPast = self.language_model( - attention_mask=causal_mask, + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + token_type_ids=token_type_ids, + attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, + labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=return_dict, cache_position=cache_position, - logits_to_keep=logits_to_keep, **lm_kwargs, ) - logits = outputs.logits + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues @@ -1356,13 +1469,17 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): flat_labels = shift_labels.view(-1).to(shift_logits.device) loss = loss_fct(flat_logits, flat_labels) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + return Gemma3CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, + image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( @@ -1381,7 +1498,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): **kwargs, ): # Overwritten -- custom `position_ids` and `pixel_values` handling - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1401,15 +1518,73 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): is_training = token_type_ids is not None and labels is not None if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): input_tensor = inputs_embeds if inputs_embeds is not None else input_ids - causal_mask = self._update_causal_mask( + causal_mask = self.model._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training ) model_inputs["attention_mask"] = causal_mask return model_inputs - def tie_weights(self): - return self.language_model.tie_weights() + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask -__all__ = ["Gemma3PreTrainedModel", "Gemma3TextModel", "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration"] +__all__ = [ + "Gemma3PreTrainedModel", + "Gemma3TextModel", + "Gemma3ForCausalLM", + "Gemma3ForConditionalGeneration", + "Gemma3Model", +] diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index ecac4921d2e..24206e6f838 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -26,11 +26,12 @@ import torch.utils.checkpoint from ...cache_utils import Cache, HybridCache, StaticCache from ...configuration_utils import PretrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import ( + add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, is_torchdynamo_compiling, @@ -50,7 +51,12 @@ from ..gemma2.modeling_gemma2 import ( apply_rotary_pos_emb, eager_attention_forward, ) -from ..paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration +from ..paligemma.modeling_paligemma import ( + PaligemmaCausalLMOutputWithPast, + PaliGemmaForConditionalGeneration, + PaliGemmaModel, + PaligemmaModelOutputWithPast, +) from ..siglip import SiglipVisionConfig @@ -302,43 +308,13 @@ class Gemma3Config(PretrainedConfig): @dataclass -class Gemma3CausalLMOutputWithPast(ModelOutput): - """ - Base class for Gemma3 causal language model (or autoregressive) outputs. +class Gemma3ModelOutputWithPast(PaligemmaModelOutputWithPast): + pass - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder after projecting last hidden state. - """ - - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[torch.FloatTensor] = None +@dataclass +class Gemma3CausalLMOutputWithPast(PaligemmaCausalLMOutputWithPast): + pass class Gemma3TextScaledWordEmbedding(nn.Embedding): @@ -545,7 +521,7 @@ GEMMA3_START_DOCSTRING = None class Gemma3PreTrainedModel(Gemma2PreTrainedModel): - base_model_prefix = "language_model" + base_model_prefix = "" _no_split_modules = [ "Gemma3DecoderLayer", "SiglipVisionEmbeddings", @@ -755,10 +731,7 @@ class Gemma3MultiModalProjector(nn.Module): return projected_vision_outputs.type_as(vision_outputs) -class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): - def tie_weights(self): - return self.language_model.tie_weights() - +class Gemma3Model(PaliGemmaModel): def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: """ Projects the last hidden state from the vision model into language model space. @@ -844,11 +817,10 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): @can_return_tuple @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, @@ -859,6 +831,106 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **lm_kwargs, + ) -> Union[Tuple, Gemma3ModelOutputWithPast]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + is_training = token_type_ids is not None and labels is not None + + # Replace image id woth PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_id >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training + ) + outputs = self.language_model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + return Gemma3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values if use_cache else None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +@add_start_docstrings( + """The Gemma3 model which consists of a vision backbone and a language model.""", + GEMMA3_START_DOCSTRING, +) +class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): + @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **lm_kwargs, ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: @@ -916,80 +988,34 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): ``` """ - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - is_training = token_type_ids is not None and labels is not None - - # Replace image id with PAD if the image token if OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_id >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_id - llm_input_ids = input_ids.clone() - llm_input_ids[special_image_mask] = 0 - else: - llm_input_ids = input_ids - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - # Merge text and images - if pixel_values is not None: - image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - # mask out pad-token-ids in labels for BC - if labels is not None and self.pad_token_id in labels: - logger.warning_once( - "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " - "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", - ) - labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) - - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training - ) - outputs: CausalLMOutputWithPast = self.language_model( - attention_mask=causal_mask, + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + token_type_ids=token_type_ids, + attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, + labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=return_dict, cache_position=cache_position, - logits_to_keep=logits_to_keep, **lm_kwargs, ) - logits = outputs.logits + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues @@ -1012,13 +1038,17 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): flat_labels = shift_labels.view(-1).to(shift_logits.device) loss = loss_fct(flat_logits, flat_labels) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + return Gemma3CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, + image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( @@ -1037,7 +1067,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): **kwargs, ): # Overwritten -- custom `position_ids` and `pixel_values` handling - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1057,7 +1087,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): is_training = token_type_ids is not None and labels is not None if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): input_tensor = inputs_embeds if inputs_embeds is not None else input_ids - causal_mask = self._update_causal_mask( + causal_mask = self.model._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training ) model_inputs["attention_mask"] = causal_mask @@ -1072,4 +1102,5 @@ __all__ = [ "Gemma3TextModel", "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration", + "Gemma3Model", ] diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index d3a8a637ede..15013677c23 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -28,11 +28,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers.modeling_outputs import CausalLMOutputWithPast - from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_outputs import ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -40,7 +38,7 @@ from ...utils import ( can_return_tuple, replace_return_docstrings, ) -from ..auto import AutoModelForCausalLM +from ..auto import AutoModel from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig @@ -545,7 +543,7 @@ class GotOcr2CausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ @@ -557,6 +555,39 @@ class GotOcr2CausalLMOutputWithPast(ModelOutput): image_hidden_states: Optional[torch.FloatTensor] = None +@dataclass +class GotOcr2ModelOutputWithPast(BaseModelOutputWithPast): + """ + Base class for GotOcr2 outputs, with hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + GOT_OCR2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -575,14 +606,13 @@ GOT_OCR2_START_DOCSTRING = r""" @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + "The bare GotOcr2 Model outputting raw hidden-states without any specific head on top.", GOT_OCR2_START_DOCSTRING, ) class GotOcr2PreTrainedModel(PreTrainedModel): config_class = GotOcr2Config - base_model_prefix = "model" + base_model_prefix = "" supports_gradient_checkpointing = True - _no_split_modules = ["GotOcr2VisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True @@ -680,23 +710,18 @@ GOT_OCR2_INPUTS_DOCSTRING = r""" @add_start_docstrings( - """The GOT_OCR2 model which consists of a vision backbone and a language model.""", + """The GotOcr2 model which consists of a vision backbone and a language model, without a language modeling head.""", GOT_OCR2_START_DOCSTRING, ) -class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): +class GotOcr2Model(GotOcr2PreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + def __init__(self, config: GotOcr2Config): super().__init__(config) self.vision_tower = GotOcr2VisionEncoder(config.vision_config) self.multi_modal_projector = GotOcr2MultiModalProjector(config) - self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] - - self.pad_token_id = config.pad_token_id - + self.language_model = AutoModel.from_config(config.text_config) self.post_init() def get_input_embeddings(self): @@ -705,18 +730,6 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -732,13 +745,124 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): image_outputs = self.vision_tower(pixel_values).last_hidden_state return self.multi_modal_projector(image_outputs) + @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, GotOcr2ModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + ) + + output = GotOcr2ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + return output if return_dict else output.to_tuple() + + +@add_start_docstrings( + """The GOT_OCR2 model which consists of a vision backbone and a language model.""", + GOT_OCR2_START_DOCSTRING, +) +class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: GotOcr2Config): + super().__init__(config) + self.model = GotOcr2Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + @can_return_tuple @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -747,9 +871,10 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - ) -> GotOcr2CausalLMOutputWithPast: + ) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -794,37 +919,15 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): "You should keep in mind what features from the module should be used, especially when you're planning to sell a template." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if pixel_values is not None: - image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) - n_image_tokens = (input_ids == self.config.image_token_id).sum() - n_image_features = image_features.shape[0] * image_features.shape[1] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -832,29 +935,18 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=True, cache_position=cache_position, - logits_to_keep=logits_to_keep, ) - logits = outputs.logits + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) - ) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) return GotOcr2CausalLMOutputWithPast( loss=loss, @@ -862,7 +954,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, + image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( @@ -878,7 +970,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -895,5 +987,60 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): return model_inputs + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. -__all__ = ["GotOcr2PreTrainedModel", "GotOcr2ForConditionalGeneration"] + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = ["GotOcr2PreTrainedModel", "GotOcr2Model", "GotOcr2ForConditionalGeneration"] diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index aec8c5e1749..f485146f2e0 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -14,16 +14,17 @@ # limitations under the License. -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint -from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llava.modeling_llava import ( LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, + LlavaModel, + LlavaModelOutputWithPast, LlavaPreTrainedModel, ) from transformers.models.sam.modeling_sam import SamMLPBlock, SamVisionAttention, SamVisionEncoder, SamVisionLayer @@ -36,7 +37,7 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ..auto import CONFIG_MAPPING, AutoConfig, AutoModelForCausalLM +from ..auto import CONFIG_MAPPING, AutoConfig if is_vision_available(): @@ -278,6 +279,10 @@ class GotOcr2CausalLMOutputWithPast(LlavaCausalLMOutputWithPast): pass +class GotOcr2ModelOutputWithPast(LlavaModelOutputWithPast): + pass + + class GotOcr2PreTrainedModel(LlavaPreTrainedModel): def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) @@ -368,22 +373,11 @@ GOT_OCR2_INPUTS_DOCSTRING = r""" """ -class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): +class GotOcr2Model(LlavaModel): def __init__(self, config: GotOcr2Config): super().__init__(config) self.vision_tower = GotOcr2VisionEncoder(config.vision_config) - self.multi_modal_projector = GotOcr2MultiModalProjector(config) - self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] - - self.pad_token_id = config.pad_token_id - - self.post_init() - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -399,13 +393,81 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): image_outputs = self.vision_tower(pixel_values).last_hidden_state return self.multi_modal_projector(image_outputs) + @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, GotOcr2ModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + ) + + output = GotOcr2ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + return output if return_dict else output.to_tuple() + + +class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): @can_return_tuple @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -414,9 +476,10 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - ) -> GotOcr2CausalLMOutputWithPast: + ) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -461,37 +524,15 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): "You should keep in mind what features from the module should be used, especially when you're planning to sell a template." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if pixel_values is not None: - image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) - n_image_tokens = (input_ids == self.config.image_token_id).sum() - n_image_features = image_features.shape[0] * image_features.shape[1] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -499,29 +540,18 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=True, cache_position=cache_position, - logits_to_keep=logits_to_keep, ) - logits = outputs.logits + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) - ) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) return GotOcr2CausalLMOutputWithPast( loss=loss, @@ -529,7 +559,7 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, + image_hidden_states=outputs.image_hidden_states, ) @@ -537,5 +567,6 @@ __all__ = [ "GotOcr2VisionConfig", "GotOcr2Config", "GotOcr2PreTrainedModel", + "GotOcr2Model", "GotOcr2ForConditionalGeneration", ] diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 016b2af2fa0..932c57ba0a6 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -28,7 +28,7 @@ from torch import nn import transformers.models.jamba.modeling_jamba as modeling_jamba from transformers.activations import ACT2FN -from ...cache_utils import Cache, StaticCache +from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_layers import GradientCheckpointingLayer @@ -1511,10 +1511,10 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1525,7 +1525,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index cdfb59b5804..6abed48c86a 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -41,7 +41,7 @@ from ...utils import ( replace_return_docstrings, torch_int, ) -from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM +from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM from .configuration_instructblip import InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig @@ -315,6 +315,9 @@ class InstructBlipPreTrainedModel(PreTrainedModel): config_class = InstructBlipConfig base_model_prefix = "blip" supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = True + _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _no_split_modules = [ "InstructBlipQFormerEmbeddings", @@ -339,7 +342,7 @@ class InstructBlipPreTrainedModel(PreTrainedModel): elif isinstance(module, InstructBlipVisionEmbeddings): nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) - elif isinstance(module, InstructBlipForConditionalGeneration): + elif isinstance(module, (InstructBlipForConditionalGeneration, InstructBlipModel)): module.query_tokens.data.zero_() @@ -1274,6 +1277,156 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel): ) +@add_start_docstrings( + """ + InstructBLIP base Model consisting of language model, qformer and vision encoder. + """, + INSTRUCTBLIP_START_DOCSTRING, +) +class InstructBlipModel(InstructBlipPreTrainedModel): + main_input_name = "pixel_values" + _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 + + def __init__(self, config: InstructBlipConfig): + super().__init__(config) + + self.vision_model = InstructBlipVisionModel(config.vision_config) + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = InstructBlipQFormerModel(config.qformer_config) + + self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) + self.language_model = AutoModel.from_config(config.text_config) + + if self.language_model._no_split_modules is not None: + self._no_split_modules.extend(self.language_model._no_split_modules) + + if self.language_model._keep_in_fp32_modules is not None: + self._keep_in_fp32_modules.extend(self.language_model._keep_in_fp32_modules) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def _tie_weights(self): + if not self.config.use_decoder_only_language_model: + self.language_model.encoder.embed_tokens = self.language_model.shared + self.language_model.decoder.embed_tokens = self.language_model.shared + + def _preprocess_accelerate(self): + r""" + Some pre-processing hacks to make the model `accelerate` compatible. Check + https://github.com/huggingface/transformers/pull/21707 for more details. + """ + hf_device_map = self.hf_device_map + + if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: + # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`. + logger.warning( + "The `language_model` is not in the `hf_device_map` dictionary and you are running your script" + " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." + " Please pass a `device_map` that contains `language_model` to remove this warning." + " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for" + " more details on creating a `device_map` for large models.", + ) + + if hasattr(self.language_model, "_hf_hook"): + self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + + @add_start_docstrings_to_model_forward(INSTRUCTBLIP_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.FloatTensor, + qformer_attention_mask: Optional[torch.LongTensor] = None, + input_ids: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + use_cache: Optional[bool] = None, + ) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # step 1: forward the images through the vision encoder, + # to get image embeddings of shape (batch_size, seq_len, hidden_size) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) + if qformer_attention_mask is None: + qformer_attention_mask = torch.ones_like(qformer_input_ids) + qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) + query_outputs = self.qformer( + input_ids=qformer_input_ids, + attention_mask=qformer_attention_mask, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + query_output = query_outputs[0][:, : query_tokens.size(1), :] + + # step 3: use the language model, conditioned on the query outputs and the prompt + language_model_inputs = self.language_projection(query_output) + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() + + if self.config.use_decoder_only_language_model: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + ) + else: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + ) + + if not return_dict: + return (vision_outputs, query_outputs, outputs) + + return InstructBlipForConditionalGenerationModelOutput( + vision_outputs=vision_outputs, + qformer_outputs=query_outputs, + language_model_outputs=outputs, + ) + + @add_start_docstrings( """ InstructBLIP Model for generating text given an image and an optional text prompt. The model consists of a vision @@ -1336,11 +1489,13 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati def get_decoder(self): return self.language_model.get_decoder() + # Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._tie_weights def _tie_weights(self): if not self.config.use_decoder_only_language_model: self.language_model.encoder.embed_tokens = self.language_model.shared self.language_model.decoder.embed_tokens = self.language_model.shared + # Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._preprocess_accelerate def _preprocess_accelerate(self): r""" Some pre-processing hacks to make the model `accelerate` compatible. Check @@ -1645,6 +1800,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati __all__ = [ "InstructBlipQFormerModel", "InstructBlipPreTrainedModel", + "InstructBlipModel", "InstructBlipForConditionalGeneration", "InstructBlipVisionModel", ] diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index e9d9e4938a9..0ce752bed8b 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -45,7 +45,7 @@ from ...utils import ( replace_return_docstrings, torch_int, ) -from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM +from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM from .configuration_instructblipvideo import ( InstructBlipVideoConfig, InstructBlipVideoQFormerConfig, @@ -945,6 +945,9 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel): config_class = InstructBlipVideoConfig base_model_prefix = "blip" supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = True + _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _no_split_modules = [ "InstructBlipVideoQFormerEmbeddings", @@ -969,7 +972,7 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel): elif isinstance(module, InstructBlipVideoVisionEmbeddings): nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) - elif isinstance(module, InstructBlipVideoForConditionalGeneration): + elif isinstance(module, (InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel)): module.query_tokens.data.zero_() @@ -1269,6 +1272,166 @@ class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput): ) +@add_start_docstrings( + """ + InstructBlipVideo base Model consisting of language model, qformer and vision encoder. + """, + INSTRUCTBLIPVIDEO_START_DOCSTRING, +) +class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel): + main_input_name = "pixel_values" + _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 + + def __init__(self, config: InstructBlipVideoConfig): + super().__init__(config) + + self.vision_model = InstructBlipVideoVisionModel(config.vision_config) + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = InstructBlipVideoQFormerModel(config.qformer_config) + + self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) + self.language_model = AutoModel.from_config(config.text_config) + + if self.language_model._no_split_modules is not None: + self._no_split_modules.extend(self.language_model._no_split_modules) + + if self.language_model._keep_in_fp32_modules is not None: + self._keep_in_fp32_modules.extend(self.language_model._keep_in_fp32_modules) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def _tie_weights(self): + if not self.config.use_decoder_only_language_model: + self.language_model.encoder.embed_tokens = self.language_model.shared + self.language_model.decoder.embed_tokens = self.language_model.shared + + def _preprocess_accelerate(self): + r""" + Some pre-processing hacks to make the model `accelerate` compatible. Check + https://github.com/huggingface/transformers/pull/21707 for more details. + """ + hf_device_map = self.hf_device_map + + if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: + # warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`. + logger.warning( + "The `language_model` is not in the `hf_device_map` dictionary and you are running your script" + " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." + " Please pass a `device_map` that contains `language_model` to remove this warning." + " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for" + " more details on creating a `device_map` for large models.", + ) + + if hasattr(self.language_model, "_hf_hook"): + self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + + @add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.FloatTensor, + qformer_attention_mask: Optional[torch.LongTensor] = None, + input_ids: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + use_cache: Optional[bool] = None, + ) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # step 1: forward the images through the vision encoder, + # we process in a batched way, later unbatch it back (video has frames=4 always) + batch_size, frames, channel, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) + + if qformer_attention_mask is None: + qformer_attention_mask = torch.ones_like(qformer_input_ids) + + qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0) + qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0) + qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) + query_outputs = self.qformer( + input_ids=qformer_input_ids, + attention_mask=qformer_attention_mask, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + query_output = query_outputs[0][:, : query_tokens.size(1), :] + + # step 3: use the language model, conditioned on the query outputs and the prompt + language_model_inputs = self.language_projection(query_output) + + # unbatch inputs back, each video-frame gets `num_query_tokens` seq length + language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() + + if self.config.use_decoder_only_language_model: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + ) + else: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + ) + + if not return_dict: + return (vision_outputs, query_outputs, outputs) + + return InstructBlipVideoForConditionalGenerationModelOutput( + vision_outputs=vision_outputs, + qformer_outputs=query_outputs, + language_model_outputs=outputs, + ) + + @add_start_docstrings( """ InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision @@ -1682,5 +1845,6 @@ __all__ = [ "InstructBlipVideoVisionModel", "InstructBlipVideoPreTrainedModel", "InstructBlipVideoQFormerModel", + "InstructBlipVideoModel", "InstructBlipVideoForConditionalGeneration", ] diff --git a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py index 212050877a4..d28485545f7 100644 --- a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py @@ -27,6 +27,7 @@ from transformers.models.instructblip.configuration_instructblip import ( from transformers.models.instructblip.modeling_instructblip import ( InstructBlipForConditionalGeneration, InstructBlipForConditionalGenerationModelOutput, + InstructBlipModel, InstructBlipPreTrainedModel, InstructBlipQFormerModel, InstructBlipVisionModel, @@ -34,7 +35,7 @@ from transformers.models.instructblip.modeling_instructblip import ( from ...configuration_utils import PretrainedConfig from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES -from ...utils import logging +from ...utils import add_start_docstrings_to_model_forward, logging from ..auto import CONFIG_MAPPING, AutoConfig @@ -191,6 +192,110 @@ class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForCondit pass +INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = None + + +class InstructBlipVideoModel(InstructBlipModel): + @add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.FloatTensor, + qformer_attention_mask: Optional[torch.LongTensor] = None, + input_ids: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + use_cache: Optional[bool] = None, + ) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # step 1: forward the images through the vision encoder, + # we process in a batched way, later unbatch it back (video has frames=4 always) + batch_size, frames, channel, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) + + if qformer_attention_mask is None: + qformer_attention_mask = torch.ones_like(qformer_input_ids) + + qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0) + qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0) + qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) + query_outputs = self.qformer( + input_ids=qformer_input_ids, + attention_mask=qformer_attention_mask, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + query_output = query_outputs[0][:, : query_tokens.size(1), :] + + # step 3: use the language model, conditioned on the query outputs and the prompt + language_model_inputs = self.language_projection(query_output) + + # unbatch inputs back, each video-frame gets `num_query_tokens` seq length + language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = language_model_inputs.flatten() + + if self.config.use_decoder_only_language_model: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + ) + else: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + ) + + if not return_dict: + return (vision_outputs, query_outputs, outputs) + + return InstructBlipVideoForConditionalGenerationModelOutput( + vision_outputs=vision_outputs, + qformer_outputs=query_outputs, + language_model_outputs=outputs, + ) + + class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration): def forward( self, @@ -508,5 +613,6 @@ __all__ = [ "InstructBlipVideoVisionModel", "InstructBlipVideoPreTrainedModel", "InstructBlipVideoQFormerModel", + "InstructBlipVideoModel", "InstructBlipVideoForConditionalGeneration", ] diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index c181976e0ec..6c59d06d2ee 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -45,7 +45,7 @@ from ...utils import ( replace_return_docstrings, torch_int, ) -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_internvl import InternVLConfig, InternVLVisionConfig @@ -608,14 +608,13 @@ INTERNVL_START_DOCSTRING = r""" @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + "The bare InternVL Model outputting raw hidden-states without any specific head on top.", INTERNVL_START_DOCSTRING, ) class InternVLPreTrainedModel(PreTrainedModel): config_class = InternVLConfig - base_model_prefix = "model" + base_model_prefix = "" supports_gradient_checkpointing = True - _no_split_modules = ["InternVLVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True @@ -654,15 +653,13 @@ class InternVLMultiModalProjector(nn.Module): @dataclass -class InternVLCausalLMOutputWithPast(ModelOutput): +class InternVLModelOutputWithPast(BaseModelOutputWithPast): """ - Base class for InternVL causal language model (or autoregressive) outputs. + Base class for InternVL outputs, with hidden states and attentions. Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -681,15 +678,10 @@ class InternVLCausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - past_key_values: Optional[List[torch.FloatTensor]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None @@ -771,23 +763,18 @@ INTERNVL_INPUTS_DOCSTRING = r""" @add_start_docstrings( - """The INTERNVL model which consists of a vision backbone and a language model.""", + """The InternVL model which consists of a vision backbone and a language model, without a language modeling head.""", INTERNVL_START_DOCSTRING, ) -class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin): +class InternVLModel(InternVLPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + def __init__(self, config: InternVLConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = InternVLMultiModalProjector(config) - self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] - - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - + self.language_model = AutoModel.from_config(config.text_config) self.post_init() def get_input_embeddings(self): @@ -796,18 +783,6 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin) def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -853,12 +828,221 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin) return vision_features + @add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + image_sizes: torch.Tensor = None, + **lm_kwargs, + ) -> Union[Tuple, InternVLModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + image_sizes=image_sizes, + ) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_tokens = (input_ids == self.config.image_token_id).sum() + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + output = InternVLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + return output if return_dict else output.to_tuple() + + def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5): + """Perform pixel shuffle downsampling on vision features. + + Args: + vision_features (`torch.Tensor`): + Input tensor of shape (batch_size, width, height, channels). + scale_factor (`float`, *optional*, defaults to `0.5`): + Factor by which to downsample. Default is 0.5, which halves the dimensions. + + Returns: + vision_features (`torch.Tensor`): + Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)). + """ + batch_size, width, height, channels = vision_features.size() + + if height % scale_factor != 0 or width % scale_factor != 0: + raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.") + + # Reshape to allow downsampling + vision_features = vision_features.view( + batch_size, width, int(height * scale_factor), int(channels / scale_factor) + ) + # Permute dimensions to align downsampled axis correctly + vision_features = vision_features.permute(0, 2, 1, 3).contiguous() + + # Reshape to achieve final downsampled dimensions + vision_features = vision_features.view( + batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2)) + ) + + # Swap height and width back for proper orientation + vision_features = vision_features.permute(0, 2, 1, 3).contiguous() + + return vision_features + + +@dataclass +class InternVLCausalLMOutputWithPast(ModelOutput): + """ + Base class for InternVL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +@add_start_docstrings( + """The INTERNVL model which consists of a vision backbone and a language model.""", + INTERNVL_START_DOCSTRING, +) +class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: InternVLConfig): + super().__init__(config) + self.model = InternVLModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple @add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=InternVLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -872,7 +1056,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin) return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - image_sizes: Optional[torch.Tensor] = None, + image_sizes: torch.Tensor = None, **lm_kwargs, ) -> Union[Tuple, InternVLCausalLMOutputWithPast]: r""" @@ -925,7 +1109,6 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin) >>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)) The images depict the Statue of Liberty and the Golden Gate Bridge. ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -940,73 +1123,32 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin) else self.config.vision_feature_select_strategy ) - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - image_sizes=image_sizes, - ) - - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - outputs = self.language_model( + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, - logits_to_keep=logits_to_keep, + image_sizes=image_sizes, **lm_kwargs, ) - logits = outputs[0] + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) return InternVLCausalLMOutputWithPast( loss=loss, @@ -1014,7 +1156,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin) past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, + image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( @@ -1030,7 +1172,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin) ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1047,45 +1189,66 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin) return model_inputs - def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5): - """Perform pixel shuffle downsampling on vision features. + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: - vision_features (`torch.Tensor`): - Input tensor of shape (batch_size, width, height, channels). - scale_factor (`float`, *optional*, defaults to `0.5`): - Factor by which to downsample. Default is 0.5, which halves the dimensions. - - Returns: - vision_features (`torch.Tensor`): - Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)). + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. """ - batch_size, width, height, channels = vision_features.size() + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) - if height % scale_factor != 0 or width % scale_factor != 0: - raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.") - - # Reshape to allow downsampling - vision_features = vision_features.view( - batch_size, width, int(height * scale_factor), int(channels / scale_factor) - ) - # Permute dimensions to align downsampled axis correctly - vision_features = vision_features.permute(0, 2, 1, 3).contiguous() - - # Reshape to achieve final downsampled dimensions - vision_features = vision_features.view( - batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2)) - ) - - # Swap height and width back for proper orientation - vision_features = vision_features.permute(0, 2, 1, 3).contiguous() - - return vision_features + return causal_mask __all__ = [ "InternVLVisionPreTrainedModel", "InternVLVisionModel", "InternVLPreTrainedModel", + "InternVLModel", "InternVLForConditionalGeneration", ] diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index 00b516cf041..bc3c3bfc6da 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -39,7 +39,12 @@ from ...utils import ( from ..clip.modeling_clip import CLIPMLP from ..janus.modeling_janus import JanusVisionAttention from ..llama.modeling_llama import LlamaRMSNorm -from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaPreTrainedModel +from ..llava.modeling_llava import ( + LlavaCausalLMOutputWithPast, + LlavaForConditionalGeneration, + LlavaModel, + LlavaPreTrainedModel, +) from .configuration_internvl import InternVLConfig, InternVLVisionConfig @@ -573,11 +578,7 @@ class InternVLMultiModalProjector(nn.Module): return hidden_states -class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): - pass - - -class InternVLForConditionalGeneration(LlavaForConditionalGeneration): +class InternVLModel(LlavaModel): def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5): """Perform pixel shuffle downsampling on vision features. @@ -658,6 +659,13 @@ class InternVLForConditionalGeneration(LlavaForConditionalGeneration): return vision_features + +class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): + pass + + +class InternVLForConditionalGeneration(LlavaForConditionalGeneration): + @can_return_tuple @add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=InternVLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward(**super_kwargs): @@ -718,5 +726,6 @@ __all__ = [ "InternVLVisionPreTrainedModel", "InternVLVisionModel", "InternVLPreTrainedModel", + "InternVLModel", "InternVLForConditionalGeneration", ] diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index bc78b571d95..3273d595a5a 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -23,16 +23,17 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_outputs import ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torchdynamo_compiling, logging, replace_return_docstrings, ) -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_llava import LlavaConfig @@ -44,6 +45,39 @@ _CONFIG_FOR_DOC = "LlavaConfig" _CHECKPOINT_FOR_DOC = "llava-hf/llava-1.5-7b-hf" +@dataclass +class LlavaModelOutputWithPast(BaseModelOutputWithPast): + """ + Base class for Llava outputs, with hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + @dataclass class LlavaCausalLMOutputWithPast(ModelOutput): """ @@ -72,7 +106,7 @@ class LlavaCausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ @@ -124,14 +158,13 @@ LLAVA_START_DOCSTRING = r""" @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + "The bare Llava Model outputting raw hidden-states without any specific head on top.", LLAVA_START_DOCSTRING, ) class LlavaPreTrainedModel(PreTrainedModel): config_class = LlavaConfig - base_model_prefix = "model" + base_model_prefix = "" supports_gradient_checkpointing = True - _no_split_modules = ["LlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True @@ -149,6 +182,9 @@ class LlavaPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() LLAVA_INPUTS_DOCSTRING = r""" @@ -229,23 +265,18 @@ LLAVA_INPUTS_DOCSTRING = r""" @add_start_docstrings( - """The LLAVA model which consists of a vision backbone and a language model.""", + """The Llava model which consists of a vision backbone and a language model, without a language modeling head.""", LLAVA_START_DOCSTRING, ) -class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): +class LlavaModel(LlavaPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + def __init__(self, config: LlavaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = LlavaMultiModalProjector(config) - self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] - - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - + self.language_model = AutoModel.from_config(config.text_config) self.post_init() def get_input_embeddings(self): @@ -254,18 +285,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -277,7 +296,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): Obtains image last hidden states from the vision tower and apply multimodal projection. Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): The tensors corresponding to the input images. vision_feature_layer (`Union[int, List[int]]`): The index of the layer to select the vision feature. If multiple indices are provided, @@ -312,12 +331,146 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): image_features = self.multi_modal_projector(selected_image_feature) return image_features + @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + image_sizes: torch.Tensor = None, + **lm_kwargs, + ) -> Union[Tuple, LlavaModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + image_sizes=image_sizes, + ) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_tokens = (input_ids == self.config.image_token_id).sum() + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + output = LlavaModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + return output if return_dict else output.to_tuple() + + +@add_start_docstrings( + """The LLAVA model which consists of a vision backbone and a language model.""", + LLAVA_START_DOCSTRING, +) +class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: LlavaConfig): + super().__init__(config) + self.model = LlavaModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -331,7 +484,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - image_sizes: Optional[torch.Tensor] = None, + image_sizes: torch.Tensor = None, **lm_kwargs, ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: r""" @@ -371,7 +524,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -386,73 +538,32 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): else self.config.vision_feature_select_strategy ) - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - image_sizes=image_sizes, - ) - - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - outputs = self.language_model( + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, - logits_to_keep=logits_to_keep, + image_sizes=image_sizes, **lm_kwargs, ) - logits = outputs[0] + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) return LlavaCausalLMOutputWithPast( loss=loss, @@ -460,7 +571,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, + image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( @@ -476,7 +587,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -493,5 +604,61 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): return model_inputs + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. -__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel"] + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel", "LlavaModel"] diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index fa7c04bdf40..c17eb9622c8 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -26,16 +26,17 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution -from ...modeling_outputs import ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torchdynamo_compiling, logging, replace_return_docstrings, ) -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_llava_next import LlavaNextConfig @@ -151,6 +152,39 @@ def unpad_image(tensor, original_size): return unpadded_tensor +@dataclass +class LlavaNextModelOutputWithPast(BaseModelOutputWithPast): + """ + Base class for Llava outputs, with hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + @dataclass class LlavaNextCausalLMOutputWithPast(ModelOutput): """ @@ -237,9 +271,9 @@ LLAVA_NEXT_START_DOCSTRING = r""" ) class LlavaNextPreTrainedModel(PreTrainedModel): config_class = LlavaNextConfig - base_model_prefix = "model" + base_model_prefix = "" supports_gradient_checkpointing = True - _no_split_modules = ["LlavaNextVisionAttention"] + _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True @@ -254,7 +288,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, LlavaNextForConditionalGeneration): + elif isinstance(module, LlavaNextModel): embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) module.image_newline.data.normal_(mean=0.0, std=embed_std) @@ -340,10 +374,12 @@ LLAVA_NEXT_INPUTS_DOCSTRING = r""" @add_start_docstrings( - """The LLAVA-NeXT model which consists of a vision backbone and a language model.""", + """The Llava-Next model which consists of a vision backbone and a language model without language modeling head.""", LLAVA_NEXT_START_DOCSTRING, ) -class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin): +class LlavaNextModel(LlavaNextPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + def __init__(self, config: LlavaNextConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) @@ -353,48 +389,16 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std) self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] - + self.language_model = AutoModel.from_config(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides self.post_init() - @property - def padding_side(self): - return self._padding_side - - @padding_side.setter - def padding_side(self, padding_side: str): - if padding_side not in ["left", "right"]: - raise ValueError(f"{padding_side} is not `left` or `right`.") - self._padding_side = padding_side - - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings def get_input_embeddings(self): return self.language_model.get_input_embeddings() - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder - def get_decoder(self): - return self.language_model.get_decoder() - def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. @@ -524,12 +528,152 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi image_features = torch.split(image_features, image_num_patches, dim=0) return image_features + @add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **lm_kwargs, + ) -> Union[Tuple, LlavaNextModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None and pixel_values.size(0) > 0: + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + image_features, feature_lens = self.pack_image_features( + image_features, + image_sizes, + vision_feature_select_strategy=vision_feature_select_strategy, + image_newline=self.image_newline, + ) + + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + output = LlavaNextModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + return output if return_dict else output.to_tuple() + + +@add_start_docstrings( + """The LLAVA-NeXT model which consists of a vision backbone and a language model.""", + LLAVA_NEXT_START_DOCSTRING, +) +class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^image_newline": "model.image_newline", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: LlavaNextConfig): + super().__init__(config) + self.model = LlavaNextModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple @add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -597,45 +741,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi else self.config.vision_feature_select_strategy ) - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if pixel_values is not None and pixel_values.size(0) > 0: - image_features = self.get_image_features( - pixel_values, - image_sizes, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - ) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - image_features, feature_lens = self.pack_image_features( - image_features, - image_sizes, - vision_feature_select_strategy=vision_feature_select_strategy, - image_newline=self.image_newline, - ) - - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() - n_image_features = image_features.shape[0] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - outputs = self.language_model( + outputs = self.model( + input_ids, + pixel_values=pixel_values, + image_sizes=image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -645,33 +756,17 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - logits_to_keep=logits_to_keep, **lm_kwargs, ) - logits = outputs[0] + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) return LlavaNextCausalLMOutputWithPast( loss=loss, @@ -679,7 +774,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, + image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( @@ -696,7 +791,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -714,5 +809,61 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi return model_inputs + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. -__all__ = ["LlavaNextForConditionalGeneration", "LlavaNextPreTrainedModel"] + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = ["LlavaNextForConditionalGeneration", "LlavaNextPreTrainedModel", "LlavaNextModel"] diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 113441fd2aa..6b6f14bc342 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -30,24 +30,65 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution -from ...modeling_outputs import ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torchdynamo_compiling, logging, replace_return_docstrings, ) -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_llava_next_video import LlavaNextVideoConfig logger = logging.get_logger(__name__) + _CONFIG_FOR_DOC = "LlavaNextVideoConfig" +@dataclass +class LlavaNextVideoModelOutputWithPast(BaseModelOutputWithPast): + """ + Base class for Llava outputs, with hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + video_hidden_states: Optional[torch.FloatTensor] = None + + @dataclass class LlavaNextVideoCausalLMOutputWithPast(ModelOutput): """ @@ -167,6 +208,34 @@ LLAVA_NEXT_VIDEO_START_DOCSTRING = r""" """ +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAVA_NEXT_VIDEO_START_DOCSTRING, +) +class LlavaNextVideoPreTrainedModel(PreTrainedModel): + config_class = LlavaNextVideoConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, LlavaNextVideoModel): + embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) + module.image_newline.data.normal_(mean=0.0, std=embed_std) + + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. @@ -355,38 +424,12 @@ LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r""" @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + """The Llava-Next model which consists of a vision backbone and a language model without language modeling head.""", LLAVA_NEXT_VIDEO_START_DOCSTRING, ) -class LlavaNextVideoPreTrainedModel(PreTrainedModel): - config_class = LlavaNextVideoConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlavaNextVideoVisionAttention"] - _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_quantized_cache = True - _supports_static_cache = True +class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} - def _init_weights(self, module): - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, LlavaNextVideoForConditionalGeneration): - embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) - module.image_newline.data.normal_(mean=0.0, std=embed_std) - - -@add_start_docstrings( - """The LLAVA-NeXT model which consists of a vision backbone and a language model.""", - LLAVA_NEXT_VIDEO_START_DOCSTRING, -) -class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, GenerationMixin): def __init__( self, config: LlavaNextVideoConfig, @@ -399,43 +442,17 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std) self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] - + self.language_model = AutoModel.from_config(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides self.vision_resampler = LlavaNextVideoPooler(config) self.post_init() - @property - def padding_side(self): - return self._padding_side - - @padding_side.setter - def padding_side(self, padding_side: str): - if padding_side not in ["left", "right"]: - raise ValueError(f"{padding_side} is not `left` or `right`.") - self._padding_side = padding_side - def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. @@ -564,13 +581,222 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene image_features = torch.split(image_features, image_num_patches, dim=0) return image_features + @add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + pixel_values_videos: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **lm_kwargs, + ) -> Union[Tuple, LlavaNextVideoModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self.vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + self.vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None: + raise ValueError( + "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, " + "and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None and pixel_values.size(0) > 0: + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=self.vision_feature_layer, + vision_feature_select_strategy=self.vision_feature_select_strategy, + ) + image_features, feature_lens = self.pack_image_features( + image_features, + image_sizes, + self.vision_feature_select_strategy, + image_newline=self.image_newline, + ) + + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: + video_features = self.get_video_features( + pixel_values_videos, + vision_feature_layer=self.vision_feature_layer, + vision_feature_select_strategy=self.vision_feature_select_strategy, + ) + video_features = [feature.flatten(0, 1) for feature in video_features] + video_feature_lens = [feature.size(0) for feature in video_features] + video_features = torch.cat(video_features, dim=0) + video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) + + special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_features.shape[0] + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + output = LlavaNextVideoModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + video_hidden_states=video_features if pixel_values_videos is not None else None, + ) + return output if return_dict else output.to_tuple() + + def get_video_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Union[int, List[int]], + vision_feature_select_strategy: str, + ): + """ + Obtains video last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) + The tensors corresponding to the input video. + vision_feature_layer (`Union[int, List[int]]`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches + and are of shape `(num_videos, video_length, embed_dim)`). + """ + batch_size, frames, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width) + video_features = self.vision_tower(pixel_values, output_hidden_states=True) + + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_video_features = video_features.hidden_states[vision_feature_layer] + else: + hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + selected_video_features = torch.cat(hs_pool, dim=-1) + + if vision_feature_select_strategy == "default": + selected_video_features = selected_video_features[:, 1:] + elif vision_feature_select_strategy == "full": + selected_video_features = selected_video_features + + # Same as image features except that video has pooling layer + video_features = self.vision_resampler(selected_video_features) + video_features = self.multi_modal_projector(video_features) + video_features = torch.split(video_features, frames, dim=0) + return video_features + + +@add_start_docstrings( + """The LLAVA-NeXT model which consists of a vision backbone and a language model.""", + LLAVA_NEXT_VIDEO_START_DOCSTRING, +) +class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^image_newline": "model.image_newline", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: LlavaNextVideoConfig): + super().__init__(config) + self.model = LlavaNextVideoModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple @add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + pixel_values_videos: torch.FloatTensor = None, image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -662,117 +888,47 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "USER: \nWhat's the content of the image? ASSISTANT: The image shows a red stop sign on a pole, with a traditional Chinese archway (...)" ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self.vision_feature_layer = ( + vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) - self.vision_feature_select_strategy = ( + vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None: - raise ValueError( - "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, " - "and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if pixel_values is not None and pixel_values.size(0) > 0: - image_features = self.get_image_features( - pixel_values, - image_sizes, - vision_feature_layer=self.vision_feature_layer, - vision_feature_select_strategy=self.vision_feature_select_strategy, - ) - image_features, feature_lens = self.pack_image_features( - image_features, - image_sizes, - self.vision_feature_select_strategy, - image_newline=self.image_newline, - ) - - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() - n_image_features = image_features.shape[0] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: - video_features = self.get_video_features( - pixel_values_videos, - vision_feature_layer=self.vision_feature_layer, - vision_feature_select_strategy=self.vision_feature_select_strategy, - ) - video_features = [feature.flatten(0, 1) for feature in video_features] - video_feature_lens = [feature.size(0) for feature in video_features] - video_features = torch.cat(video_features, dim=0) - video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - - special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() - n_video_features = video_features.shape[0] - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) - - outputs = self.language_model( + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - logits_to_keep=logits_to_keep, + image_sizes=image_sizes, **lm_kwargs, ) - logits = outputs[0] + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) return LlavaNextVideoCausalLMOutputWithPast( loss=loss, @@ -780,8 +936,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - video_hidden_states=video_features if pixel_values_videos is not None else None, + image_hidden_states=outputs.image_hidden_states, + video_hidden_states=outputs.video_hidden_states, ) def prepare_inputs_for_generation( @@ -799,7 +955,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene ): # Overwritten -- extra custom processing - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -818,51 +974,60 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene return model_inputs - def get_video_features( - self, - pixel_values: torch.FloatTensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, ): """ - Obtains video last hidden states from the vision tower and apply multimodal projection. + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) - The tensors corresponding to the input video. - vision_feature_layer (`Union[int, List[int]]`): - The index of the layer to select the vision feature. If multiple indices are provided, - the vision feature of the corresponding indices will be concatenated to form the - vision features. - vision_feature_select_strategy (`str`): - The feature selection strategy used to select the vision feature from the vision backbone. - Can be one of `"default"` or `"full"` - Returns: - video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches - and are of shape `(num_videos, video_length, embed_dim)`). + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. """ - batch_size, frames, channels, height, width = pixel_values.shape - pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width) - video_features = self.vision_tower(pixel_values, output_hidden_states=True) - - # If we have one vision feature layer, return the corresponding hidden states, - # otherwise, select the hidden states of each feature layer and concatenate them - if isinstance(vision_feature_layer, int): - selected_video_features = video_features.hidden_states[vision_feature_layer] + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask else: - hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] - selected_video_features = torch.cat(hs_pool, dim=-1) + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) - if vision_feature_select_strategy == "default": - selected_video_features = selected_video_features[:, 1:] - elif vision_feature_select_strategy == "full": - selected_video_features = selected_video_features - - # Same as image features except that video has pooling layer - video_features = self.vision_resampler(selected_video_features) - video_features = self.multi_modal_projector(video_features) - video_features = torch.split(video_features, frames, dim=0) - return video_features + return causal_mask -__all__ = ["LlavaNextVideoForConditionalGeneration", "LlavaNextVideoPreTrainedModel"] +__all__ = ["LlavaNextVideoForConditionalGeneration", "LlavaNextVideoModel", "LlavaNextVideoPreTrainedModel"] diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index b0b744c5b32..985a69a68ec 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -24,15 +24,19 @@ from torch import nn from transformers.models.llava_next.modeling_llava_next import ( LlavaNextCausalLMOutputWithPast, LlavaNextForConditionalGeneration, + LlavaNextModel, + LlavaNextModelOutputWithPast, LlavaNextMultiModalProjector, - LlavaNextPreTrainedModel, image_size_to_num_patches, ) from ...configuration_utils import PretrainedConfig from ...utils import ( + add_start_docstrings_to_model_forward, + can_return_tuple, is_torchdynamo_compiling, logging, + replace_return_docstrings, ) from ..auto import CONFIG_MAPPING, AutoConfig @@ -40,6 +44,9 @@ from ..auto import CONFIG_MAPPING, AutoConfig logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "LlavaNextVideoConfig" + + class LlavaNextVideoConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`LlavaNextVideoForConditionalGeneration`]. It is used to instantiate an @@ -182,6 +189,17 @@ class LlavaNextVideoConfig(PretrainedConfig): super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) +@dataclass +class LlavaNextVideoModelOutputWithPast(LlavaNextModelOutputWithPast): + """ + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + video_hidden_states: Optional[torch.FloatTensor] = None + + @dataclass class LlavaNextVideoCausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast): """ @@ -231,20 +249,7 @@ class LlavaNextVideoMultiModalProjector(LlavaNextMultiModalProjector): pass -class LlavaNextVideoPreTrainedModel(LlavaNextPreTrainedModel): - def _init_weights(self, module): - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, LlavaNextVideoForConditionalGeneration): - embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) - module.image_newline.data.normal_(mean=0.0, std=embed_std) - - -class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): +class LlavaNextVideoModel(LlavaNextModel): def __init__(self, config: LlavaNextVideoConfig, **super_kwargs): super().__init__(config, **super_kwargs) self.vision_resampler = LlavaNextVideoPooler(config) @@ -358,9 +363,209 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + pixel_values_videos: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **lm_kwargs, + ) -> Union[Tuple, LlavaNextVideoModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self.vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + self.vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None: + raise ValueError( + "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, " + "and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None and pixel_values.size(0) > 0: + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=self.vision_feature_layer, + vision_feature_select_strategy=self.vision_feature_select_strategy, + ) + image_features, feature_lens = self.pack_image_features( + image_features, + image_sizes, + self.vision_feature_select_strategy, + image_newline=self.image_newline, + ) + + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: + video_features = self.get_video_features( + pixel_values_videos, + vision_feature_layer=self.vision_feature_layer, + vision_feature_select_strategy=self.vision_feature_select_strategy, + ) + video_features = [feature.flatten(0, 1) for feature in video_features] + video_feature_lens = [feature.size(0) for feature in video_features] + video_features = torch.cat(video_features, dim=0) + video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) + + special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_features.shape[0] + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + output = LlavaNextVideoModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + video_hidden_states=video_features if pixel_values_videos is not None else None, + ) + return output if return_dict else output.to_tuple() + + +LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`LlavaNextVideoImageProcessor.__call__`] for details. [`LlavaProcessor`] uses + [`LlavaNextVideoImageProcessor`] for processing images. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*): + The sizes of the images in the batch, being (height, width) for each image. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. + If `"full"`, the full vision features are used. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): + @can_return_tuple + @add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + pixel_values_videos: torch.FloatTensor = None, image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -452,117 +657,47 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "USER: \nWhat's the content of the image? ASSISTANT: The image shows a red stop sign on a pole, with a traditional Chinese archway (...)" ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self.vision_feature_layer = ( + vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) - self.vision_feature_select_strategy = ( + vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None: - raise ValueError( - "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, " - "and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if pixel_values is not None and pixel_values.size(0) > 0: - image_features = self.get_image_features( - pixel_values, - image_sizes, - vision_feature_layer=self.vision_feature_layer, - vision_feature_select_strategy=self.vision_feature_select_strategy, - ) - image_features, feature_lens = self.pack_image_features( - image_features, - image_sizes, - self.vision_feature_select_strategy, - image_newline=self.image_newline, - ) - - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() - n_image_features = image_features.shape[0] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: - video_features = self.get_video_features( - pixel_values_videos, - vision_feature_layer=self.vision_feature_layer, - vision_feature_select_strategy=self.vision_feature_select_strategy, - ) - video_features = [feature.flatten(0, 1) for feature in video_features] - video_feature_lens = [feature.size(0) for feature in video_features] - video_features = torch.cat(video_features, dim=0) - video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - - special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() - n_video_features = video_features.shape[0] - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) - - outputs = self.language_model( + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - logits_to_keep=logits_to_keep, + image_sizes=image_sizes, **lm_kwargs, ) - logits = outputs[0] + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) return LlavaNextVideoCausalLMOutputWithPast( loss=loss, @@ -570,8 +705,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - video_hidden_states=video_features if pixel_values_videos is not None else None, + image_hidden_states=outputs.image_hidden_states, + video_hidden_states=outputs.video_hidden_states, ) def prepare_inputs_for_generation( @@ -589,7 +724,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): ): # Overwritten -- extra custom processing - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -609,4 +744,9 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): return model_inputs -__all__ = ["LlavaNextVideoConfig", "LlavaNextVideoForConditionalGeneration", "LlavaNextVideoPreTrainedModel"] +__all__ = [ + "LlavaNextVideoConfig", + "LlavaNextVideoForConditionalGeneration", + "LlavaNextVideoModel", + "LlavaNextVideoPreTrainedModel", # noqa: F822 +] diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py index f5e2da2cd9e..df49004eb24 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py @@ -4,9 +4,25 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_llava_onevision.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import List, Optional, Union +import torch + from ...image_processing_utils import BatchFeature, get_patch_output_size, select_best_resolution from ...image_processing_utils_fast import ( BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, @@ -28,11 +44,9 @@ from ...image_utils import ( make_flat_list_of_images, ) from ...processing_utils import Unpack -from ...utils import TensorType, add_start_docstrings, is_torch_available, is_torchvision_v2_available +from ...utils import TensorType, add_start_docstrings, is_torchvision_v2_available -if is_torch_available(): - import torch if is_torchvision_v2_available(): from torchvision.transforms.v2 import functional as F else: diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index be67df9b3af..5ef23387e9f 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/llava_onevision/modular_llava_onevision.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_llava_onevision.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # @@ -12,7 +18,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch Llava-Onevision model.""" import math from dataclasses import dataclass @@ -20,140 +25,66 @@ from typing import List, Optional, Tuple, Union import numpy as np import torch -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution -from ...modeling_outputs import ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, + can_return_tuple, is_torchdynamo_compiling, logging, ) -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_llava_onevision import LlavaOnevisionConfig logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "LlavaNextConfig" - -# Copied from transformers.models.llava_next.modeling_llava_next.get_anyres_image_grid_shape -def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): +@dataclass +class LlavaOnevisionModelOutputWithPast(BaseModelOutputWithPast): """ - Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + Base class for Llava outputs, with hidden states and attentions. Args: - image_size (`tuple`): - The size of the input image in the format (width, height). - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - Returns: - tuple: The shape of the image patch grid in the format (width, height). + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ - if not isinstance(grid_pinpoints, list): - raise TypeError("grid_pinpoints should be a list of tuples or lists") - # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate - if not isinstance(image_size, (list, tuple)): - if not isinstance(image_size, (torch.Tensor, np.ndarray)): - raise TypeError( - f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor" - ) - image_size = image_size.tolist() + image_hidden_states: Optional[torch.FloatTensor] = None - height, width = select_best_resolution(image_size, grid_pinpoints) - return height // patch_size, width // patch_size - - -# Copied from transformers.models.llava_next.modeling_llava_next.image_size_to_num_patches -def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): - """ - Calculate the number of patches after the preprocessing for images of any resolution. - - Args: - image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`): - The size of the input image in the format (height, width). ? - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. - - Returns: - int: the number of patches - """ - if not isinstance(grid_pinpoints, list): - raise TypeError("grid_pinpoints should be a list of tuples or lists") - - # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate - if not isinstance(image_size, (list, tuple)): - if not isinstance(image_size, (torch.Tensor, np.ndarray)): - raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") - image_size = image_size.tolist() - - best_resolution = select_best_resolution(image_size, grid_pinpoints) - height, width = best_resolution - num_patches = 0 - # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 - for i in range(0, height, patch_size): - for j in range(0, width, patch_size): - num_patches += 1 - # add the base patch - num_patches += 1 - return num_patches - - -# Copied from transformers.models.llava_next.modeling_llava_next.unpad_image -def unpad_image(tensor, original_size): - """ - Unpads a PyTorch tensor of a padded and resized image. - - Args: - tensor (`torch.Tensor`): - The image tensor, assumed to be of shape (num_channels, height, width). - original_size (`tuple`): - The original size of the image (height, width). - - Returns: - `torch.Tensor`: The unpadded image tensor. - """ - if not isinstance(original_size, (list, tuple)): - if not isinstance(original_size, (torch.Tensor, np.ndarray)): - raise TypeError( - f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor" - ) - original_size = original_size.tolist() - original_height, original_width = original_size - current_height, current_width = tensor.shape[1:] - - original_aspect_ratio = original_width / original_height - current_aspect_ratio = current_width / current_height - - if original_aspect_ratio > current_aspect_ratio: - scale_factor = current_width / original_width - new_height = min(math.ceil(original_height * scale_factor), current_height) - padding, r = divmod(current_height - new_height, 2) - unpadded_tensor = tensor[:, padding : current_height - (padding + r), :] - else: - scale_factor = current_height / original_height - new_width = min(math.ceil(original_width * scale_factor), current_width) - padding, r = divmod(current_width - new_width, 2) - unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)] - - return unpadded_tensor + video_hidden_states: Optional[torch.FloatTensor] = None @dataclass -# Copied from transformers.models.llava_next_video.modeling_llava_next_video.LlavaNextVideoCausalLMOutputWithPast with LlavaNextVideo->LlavaOnevision class LlavaOnevisionCausalLMOutputWithPast(ModelOutput): """ Base class for LlavaOnevision causal language model (or autoregressive) outputs. @@ -195,10 +126,44 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput): hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None -# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaOnevision +class LlavaOnevisionPooler(nn.Module): + def __init__(self, config): + super().__init__() + + mode = config.spatial_pool_mode + stride = config.spatial_pool_stride + out_channels = getattr(config, "spatial_pool_out_channels", config.vision_config.hidden_size) + self.image_size = (config.vision_config.image_size // config.vision_config.patch_size) ** 2 + + if mode == "average": + self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride) + elif mode == "max": + self.pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + elif mode == "conv": + self.pool = nn.Conv2d( + in_channels=config.vision_config.hidden_size, + out_channels=out_channels, + kernel_size=stride, + stride=stride, + ) + else: + raise ValueError(f"Unknown pooling mode: {mode}. Has to be one of [`average`, `max`, `conv`]") + + def forward(self, image_features): + ori_width = int(math.sqrt(image_features.shape[1] * self.image_size // self.image_size)) + ori_height = int(ori_width * self.image_size // self.image_size) + + batch_size, _, dim = image_features.shape + image_features_spatial = image_features.view(batch_size, ori_height, ori_height, dim).permute(0, 3, 1, 2) + image_features_spatial_pool = self.pool(image_features_spatial) + + return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() + + class LlavaOnevisionMultiModalProjector(nn.Module): def __init__(self, config: LlavaOnevisionConfig): super().__init__() @@ -231,40 +196,118 @@ LLAVA_ONEVISION_START_DOCSTRING = r""" and behavior. Parameters: - config ([`LlavaNextConfig`] or [`LlavaNextVisionConfig`]): + config ([`LlavaOnevisionConfig`] or [`LlavaOnevisionVisionConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ -@add_start_docstrings( - "The bare LLaVA-Onevision Model outputting raw hidden-states without any specific head on top.", - LLAVA_ONEVISION_START_DOCSTRING, -) -class LlavaOnevisionPreTrainedModel(PreTrainedModel): - config_class = LlavaOnevisionConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlavaOnevisionVisionAttention"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_cache_class = True - _supports_static_cache = True - _supports_quantized_cache = True - _supports_sdpa = True +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. - # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextPreTrainedModel._init_weights with LlavaNext->LlavaOnevision - def _init_weights(self, module): - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + Args: + image_size (`tuple`): + The size of the input image in the format (width, height). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, LlavaOnevisionForConditionalGeneration): - embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) - module.image_newline.data.normal_(mean=0.0, std=embed_std) + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError( + f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor" + ) + image_size = image_size.tolist() + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): + """ + Calculate the number of patches after the preprocessing for images of any resolution. + + Args: + image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`): + The size of the input image in the format (height, width). ? + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + int: the number of patches + """ + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") + image_size = image_size.tolist() + + best_resolution = select_best_resolution(image_size, grid_pinpoints) + height, width = best_resolution + num_patches = 0 + # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + num_patches += 1 + # add the base patch + num_patches += 1 + return num_patches + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (`torch.Tensor`): + The image tensor, assumed to be of shape (num_channels, height, width). + original_size (`tuple`): + The original size of the image (height, width). + + Returns: + `torch.Tensor`: The unpadded image tensor. + """ + if not isinstance(original_size, (list, tuple)): + if not isinstance(original_size, (torch.Tensor, np.ndarray)): + raise TypeError( + f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor" + ) + original_size = original_size.tolist() + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = min(math.ceil(original_height * scale_factor), current_height) + padding, r = divmod(current_height - new_height, 2) + unpadded_tensor = tensor[:, padding : current_height - (padding + r), :] + else: + scale_factor = current_height / original_height + new_width = min(math.ceil(original_width * scale_factor), current_width) + padding, r = divmod(current_width - new_width, 2) + unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)] + + return unpadded_tensor LLAVA_ONEVISION_INPUTS_DOCSTRING = r""" @@ -279,16 +322,10 @@ LLAVA_ONEVISION_INPUTS_DOCSTRING = r""" [What are input IDs?](../glossary#input-ids) pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`LlavaNextImageProcessor.__call__`] for details. [`LlavaProcessor`] uses - [`LlavaNextImageProcessor`] for processing images. + [`AutoImageProcessor`]. See [`LlavaOnevisionImageProcessor.__call__`] for details. [`LlavaProcessor`] uses + [`LlavaOnevisionImageProcessor`] for processing images. image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*): The sizes of the images in the batch, being (height, width) for each image. - pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, frames, num_channels, image_size, image_size)): - The tensors corresponding to the input videos. Pixel values can be obtained using - [`LlavaNextVideoProcessor`]. See [`LlavaNextVideoProcessor.__call__`] for details. [`LlavaProcessor`] uses - [`LlavaNextVideoProcessor`] for processing videos. - image_sizes_videos (`torch.LongTensor` of shape `(batch_size, frames, 2)`, *optional*): - The sizes of the videos in the batch, being (height, width) for each frame in the video. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: @@ -335,8 +372,6 @@ LLAVA_ONEVISION_INPUTS_DOCSTRING = r""" The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. If `"full"`, the full vision features are used. - vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`): - Aspect ratio used when processong image features. The default value is "anyres_max_9". use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -356,11 +391,41 @@ LLAVA_ONEVISION_INPUTS_DOCSTRING = r""" @add_start_docstrings( - """The LLaVA-Onevision model which consists of a vision backbone and a language model.""", + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAVA_ONEVISION_START_DOCSTRING, ) -class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, GenerationMixin): - def __init__(self, config: LlavaOnevisionConfig): +class LlavaOnevisionPreTrainedModel(PreTrainedModel): + config_class = LlavaOnevisionConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, LlavaOnevisionModel): + embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) + module.image_newline.data.normal_(mean=0.0, std=embed_std) + + +@add_start_docstrings( + """The Llava-Next model which consists of a vision backbone and a language model without language modeling head.""", + LLAVA_ONEVISION_START_DOCSTRING, +) +class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + + def __init__(self, config): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) @@ -369,36 +434,16 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std) self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] - + self.language_model = AutoModel.from_config(config.text_config) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() - # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings def get_input_embeddings(self): return self.language_model.get_input_embeddings() - # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder - def get_decoder(self): - return self.language_model.get_decoder() - def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. @@ -465,20 +510,6 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) return image_features, feature_lens - def apply_pooling(self, image_features): - height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size - batch_frames, seq_len, dim = image_features.shape - image_features = image_features.view(batch_frames, height, width, -1) - image_features = image_features.permute(0, 3, 1, 2).contiguous() - - height, width = image_features.shape[2:] - scaled_shape = [math.ceil(height / 2), math.ceil(width / 2)] - image_features = nn.functional.interpolate(image_features, size=scaled_shape, mode="bilinear") - - image_features = image_features.permute(0, 2, 3, 1) - image_features = image_features.view(batch_frames, -1, dim) - return image_features - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -494,7 +525,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene The tensors corresponding to the input images. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). - vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`): + vision_feature_layer (`Union[int, List[int]]`): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. @@ -539,59 +570,13 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene image_features = torch.split(image_features, image_num_patches, dim=0) return image_features - def get_video_features( - self, - pixel_values: torch.FloatTensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, - ): - """ - Obtains video last hidden states from the vision tower, apply multimodal projection and pooling. - - Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) - The tensors corresponding to the input video. - vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`): - The index of the layer to select the vision feature. If multiple indices are provided, - the vision feature of the corresponding indices will be concatenated to form the - vision features. - vision_feature_select_strategy (`str`): - The feature selection strategy used to select the vision feature from the vision backbone. - Can be one of `"default"` or `"full"` - Returns: - video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches - and are of shape `(num_videos, video_length, embed_dim)`). - """ - batch_size, frames, channels, height, width = pixel_values.shape - pixel_values = pixel_values.view(batch_size * frames, channels, height, width) - video_features = self.vision_tower(pixel_values, output_hidden_states=True) - - # If we have one vision feature layer, return the corresponding hidden states, - # otherwise, select the hidden states of each feature layer and concatenate them - if isinstance(vision_feature_layer, int): - selected_video_feature = video_features.hidden_states[vision_feature_layer] - else: - hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] - selected_video_feature = torch.cat(hs_pool, dim=-1) - - if vision_feature_select_strategy == "default": - selected_video_feature = selected_video_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_video_feature = selected_video_feature - video_features = self.multi_modal_projector(selected_video_feature) - - video_features = self.apply_pooling(video_features) - video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1) - - return video_features - @add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, image_sizes: Optional[torch.LongTensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, + pixel_values_videos: torch.FloatTensor = None, image_sizes_videos: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -600,62 +585,13 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene vision_feature_layer: Optional[Union[int, List[int]]] = None, vision_feature_select_strategy: Optional[str] = None, vision_aspect_ratio: Optional[str] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, **lm_kwargs, - ) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - - Returns: - [`~LlavaOnevisionCausalLMOutputWithPast`] (if `return_dict=True`) or a `tuple`. - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> import torch - >>> from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration - - >>> model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype="float16", device_map="cuda:0") - >>> processor = LlavaOnevisionProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf") - - >>> conversation = [ - ... { - ... "role": "user", - ... "content": [ - ... {"type": "text", "text": "What is shown in this image?"}, - ... {"type": "image"}, - ... ], - ... }, - ... ] - >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) - - >>> image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> raw_image = Image.open(requests.get(image_file, stream=True).raw) - >>> inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.float16) - - >>> output = model.generate(**inputs, max_new_tokens=20, do_sample=False) - >>> processor.batch_decode(output, skip_special_tokens=True)[0] - "user\n\nWhat is shown in this image?\nassistant\ncat" - ```""" + ) -> Union[Tuple, LlavaOnevisionModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -743,35 +679,247 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + output = LlavaOnevisionModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + video_hidden_states=video_features if pixel_values_videos is not None else None, + ) + + return output if return_dict else output.to_tuple() + + def get_video_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Union[int, List[int]], + vision_feature_select_strategy: str, + ): + """ + Obtains video last hidden states from the vision tower, apply multimodal projection and pooling. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) + The tensors corresponding to the input video. + vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches + and are of shape `(num_videos, video_length, embed_dim)`). + """ + batch_size, frames, channels, height, width = pixel_values.shape + pixel_values = pixel_values.view(batch_size * frames, channels, height, width) + video_features = self.vision_tower(pixel_values, output_hidden_states=True) + + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_video_feature = video_features.hidden_states[vision_feature_layer] + else: + hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + selected_video_feature = torch.cat(hs_pool, dim=-1) + + if vision_feature_select_strategy == "default": + selected_video_feature = selected_video_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_video_feature = selected_video_feature + video_features = self.multi_modal_projector(selected_video_feature) + + video_features = self.apply_pooling(video_features) + video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1) + + return video_features + + def apply_pooling(self, image_features): + height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + batch_frames, seq_len, dim = image_features.shape + image_features = image_features.view(batch_frames, height, width, -1) + image_features = image_features.permute(0, 3, 1, 2).contiguous() + + height, width = image_features.shape[2:] + scaled_shape = [math.ceil(height / 2), math.ceil(width / 2)] + image_features = nn.functional.interpolate(image_features, size=scaled_shape, mode="bilinear") + + image_features = image_features.permute(0, 2, 3, 1) + image_features = image_features.view(batch_frames, -1, dim) + return image_features + + +@add_start_docstrings( + """The LLAVA-NeXT model which consists of a vision backbone and a language model.""", + LLAVA_ONEVISION_START_DOCSTRING, +) +class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^image_newline": "model.image_newline", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: LlavaOnevisionConfig): + super().__init__(config) + self.model = LlavaOnevisionModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, + pixel_values_videos: torch.FloatTensor = None, + image_sizes_videos: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + vision_aspect_ratio: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + + Returns: + [`~LlavaOnevisionCausalLMOutputWithPast`] (if `return_dict=True`) or a `tuple`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> import torch + >>> from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration + + >>> model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype="float16", device_map="cuda:0") + >>> processor = LlavaOnevisionProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf") + + >>> conversation = [ + ... { + ... "role": "user", + ... "content": [ + ... {"type": "text", "text": "What is shown in this image?"}, + ... {"type": "image"}, + ... ], + ... }, + ... ] + >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + + >>> image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> raw_image = Image.open(requests.get(image_file, stream=True).raw) + >>> inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.float16) + + >>> output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + >>> processor.batch_decode(output, skip_special_tokens=True)[0] + "user\n\nWhat is shown in this image?\nassistant\ncat" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + vision_aspect_ratio = ( + vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_sizes=image_sizes, + image_sizes_videos=image_sizes_videos, + vision_aspect_ratio=vision_aspect_ratio, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, cache_position=cache_position, logits_to_keep=logits_to_keep, **lm_kwargs, ) - logits = outputs[0] + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) return LlavaOnevisionCausalLMOutputWithPast( loss=loss, @@ -779,8 +927,8 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - video_hidden_states=video_features if pixel_values_videos is not None else None, + image_hidden_states=outputs.image_hidden_states, + video_hidden_states=outputs.video_hidden_states, ) def prepare_inputs_for_generation( @@ -799,7 +947,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -819,5 +967,60 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene return model_inputs + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. -__all__ = ["LlavaOnevisionForConditionalGeneration", "LlavaOnevisionPreTrainedModel"] + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = ["LlavaOnevisionModel", "LlavaOnevisionForConditionalGeneration", "LlavaOnevisionPreTrainedModel"] diff --git a/src/transformers/models/llava_onevision/modular_llava_onevision.py b/src/transformers/models/llava_onevision/modular_llava_onevision.py index 5a25124e58c..bc692c10a64 100644 --- a/src/transformers/models/llava_onevision/modular_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modular_llava_onevision.py @@ -1,4 +1,34 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + from transformers.models.llava_next.image_processing_llava_next_fast import LlavaNextImageProcessorFast +from transformers.models.llava_next_video.modeling_llava_next_video import ( + LlavaNextVideoCausalLMOutputWithPast, + LlavaNextVideoForConditionalGeneration, + LlavaNextVideoModel, + LlavaNextVideoModelOutputWithPast, + LlavaNextVideoPreTrainedModel, + get_anyres_image_grid_shape, + unpad_image, +) from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING from ...image_utils import ( @@ -6,11 +36,13 @@ from ...image_utils import ( OPENAI_CLIP_STD, PILImageResampling, ) -from ...utils import add_start_docstrings, logging +from ...utils import add_start_docstrings, can_return_tuple, is_torchdynamo_compiling, logging logger = logging.get_logger(__name__) +LLAVA_ONEVISION_INPUTS_DOCSTRING = None + @add_start_docstrings( "Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.", @@ -42,4 +74,446 @@ class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast): model_input_names = ["pixel_values_videos"] -__all__ = ["LlavaOnevisionImageProcessorFast"] +class LlavaOnevisionModelOutputWithPast(LlavaNextVideoModelOutputWithPast): + pass + + +class LlavaOnevisionCausalLMOutputWithPast(LlavaNextVideoCausalLMOutputWithPast): + pass + + +class LlavaOnevisionPreTrainedModel(LlavaNextVideoPreTrainedModel): + pass + + +class LlavaOnevisionModel(LlavaNextVideoModel): + def __init__(self, config): + super().__init__(config) + del self.vision_resampler + + def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"): + """ + Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. + + Args: + image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) + List of image feature tensor, each contains all the visual feature of all patches. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + image_newline (`torch.Tensor` of shape `(embed_dim)`) + New line embedding vector. + vision_aspect_ratio (`str`, *optional*, "anyres_max_9"): + Aspect ratio used when processong image features. The default value is "anyres_max_9". + Returns: + image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) + feature_lens (`List[int]`) + token length of each image in image_features + """ + new_image_features = [] + feature_lens = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + if height * width != base_image_feature.shape[0]: + raise ValueError("The number of patches is not consistent with the image size.") + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + max_num_patches = int(vision_aspect_ratio.strip("anyres_max_")) + channels, curr_height, curr_width = image_feature.shape + ratio = math.sqrt(curr_height * curr_width / (max_num_patches * height**2)) + if ratio > 1.1: + image_feature = image_feature[None] + image_feature = nn.functional.interpolate( + image_feature, [int(curr_height // ratio), int(curr_width // ratio)], mode="bilinear" + )[0] + if image_newline is not None: + image_feature = torch.cat( + ( + image_feature, + image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device, image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + if image_newline is not None: + image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) + new_image_features.append(image_feature) + feature_lens.append(image_feature.size(0)) + image_features = torch.cat(new_image_features, dim=0) + feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) + return image_features, feature_lens + + def apply_pooling(self, image_features): + height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + batch_frames, seq_len, dim = image_features.shape + image_features = image_features.view(batch_frames, height, width, -1) + image_features = image_features.permute(0, 3, 1, 2).contiguous() + + height, width = image_features.shape[2:] + scaled_shape = [math.ceil(height / 2), math.ceil(width / 2)] + image_features = nn.functional.interpolate(image_features, size=scaled_shape, mode="bilinear") + + image_features = image_features.permute(0, 2, 3, 1) + image_features = image_features.view(batch_frames, -1, dim) + return image_features + + def get_video_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Union[int, List[int]], + vision_feature_select_strategy: str, + ): + """ + Obtains video last hidden states from the vision tower, apply multimodal projection and pooling. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) + The tensors corresponding to the input video. + vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches + and are of shape `(num_videos, video_length, embed_dim)`). + """ + batch_size, frames, channels, height, width = pixel_values.shape + pixel_values = pixel_values.view(batch_size * frames, channels, height, width) + video_features = self.vision_tower(pixel_values, output_hidden_states=True) + + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_video_feature = video_features.hidden_states[vision_feature_layer] + else: + hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + selected_video_feature = torch.cat(hs_pool, dim=-1) + + if vision_feature_select_strategy == "default": + selected_video_feature = selected_video_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_video_feature = selected_video_feature + video_features = self.multi_modal_projector(selected_video_feature) + + video_features = self.apply_pooling(video_features) + video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1) + + return video_features + + @add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, + pixel_values_videos: torch.FloatTensor = None, + image_sizes_videos: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + vision_aspect_ratio: Optional[str] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **lm_kwargs, + ) -> Union[Tuple, LlavaOnevisionModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + vision_aspect_ratio = ( + vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None: + raise ValueError( + "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, " + "and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # Images are processed with Anyres + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + image_features, feature_lens = self.pack_image_features( + image_features, + image_sizes, + image_newline=self.image_newline, + vision_aspect_ratio=vision_aspect_ratio, + ) + + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # Video are simply embedded and further pooled to decrease seq len + if pixel_values_videos is not None: + video_features = self.get_video_features( + pixel_values_videos, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + image_newline = ( + self.image_newline[None, None, :].repeat(video_features.shape[0], 1, 1).to(video_features.device) + ) + video_features = torch.cat((video_features, image_newline), dim=1) + video_features = video_features.flatten(0, 1) + + special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) + special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel(): + n_video_tokens = (input_ids == self.config.video_token_id).sum() + n_video_features = video_features.shape[0] + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + output = LlavaOnevisionModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + video_hidden_states=video_features if pixel_values_videos is not None else None, + ) + + return output if return_dict else output.to_tuple() + + +class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGeneration): + @can_return_tuple + @add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, + pixel_values_videos: torch.FloatTensor = None, + image_sizes_videos: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + vision_aspect_ratio: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + + Returns: + [`~LlavaOnevisionCausalLMOutputWithPast`] (if `return_dict=True`) or a `tuple`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> import torch + >>> from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration + + >>> model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype="float16", device_map="cuda:0") + >>> processor = LlavaOnevisionProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf") + + >>> conversation = [ + ... { + ... "role": "user", + ... "content": [ + ... {"type": "text", "text": "What is shown in this image?"}, + ... {"type": "image"}, + ... ], + ... }, + ... ] + >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + + >>> image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> raw_image = Image.open(requests.get(image_file, stream=True).raw) + >>> inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.float16) + + >>> output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + >>> processor.batch_decode(output, skip_special_tokens=True)[0] + "user\n\nWhat is shown in this image?\nassistant\ncat" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + vision_aspect_ratio = ( + vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_sizes=image_sizes, + image_sizes_videos=image_sizes_videos, + vision_aspect_ratio=vision_aspect_ratio, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return LlavaOnevisionCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + video_hidden_states=outputs.video_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + image_sizes=None, + pixel_values_videos=None, + image_sizes_videos=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + model_inputs["image_sizes"] = image_sizes + model_inputs["pixel_values_videos"] = pixel_values_videos + model_inputs["image_sizes_videos"] = image_sizes_videos + + return model_inputs + + +__all__ = [ + "LlavaOnevisionImageProcessorFast", + "LlavaOnevisionModel", + "LlavaOnevisionForConditionalGeneration", + "LlavaOnevisionPreTrainedModel", +] diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 5ce7763dd7c..7078631552f 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -28,19 +28,20 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_outputs import ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torchdynamo_compiling, replace_return_docstrings, ) -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_mistral3 import Mistral3Config -_CONFIG_FOR_DOC = "Mistral3Config" +_CONFIG_FOR_DOC = "Mistra3Config" @use_kernel_forward_from_hub("RMSNorm") @@ -156,7 +157,7 @@ class Mistral3CausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ @@ -168,6 +169,39 @@ class Mistral3CausalLMOutputWithPast(ModelOutput): image_hidden_states: Optional[torch.FloatTensor] = None +@dataclass +class Mistral3ModelOutputWithPast(BaseModelOutputWithPast): + """ + Base class for Mistral3 outputs, with hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + MISTRAL3_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -186,14 +220,13 @@ MISTRAL3_START_DOCSTRING = r""" @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + "The bare Mistral3 Model outputting raw hidden-states without any specific head on top.", MISTRAL3_START_DOCSTRING, ) class Mistral3PreTrainedModel(PreTrainedModel): config_class = Mistral3Config - base_model_prefix = "model" + base_model_prefix = "" supports_gradient_checkpointing = True - _no_split_modules = ["Mistral3VisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True @@ -202,12 +235,18 @@ class Mistral3PreTrainedModel(PreTrainedModel): _supports_static_cache = True def _init_weights(self, module): + # important: this ported version of Mistral3 isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the original codebase + # https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() elif isinstance(module, Mistral3RMSNorm): module.weight.data.fill_(1.0) @@ -290,23 +329,18 @@ MISTRAL3_INPUTS_DOCSTRING = r""" @add_start_docstrings( - """The MISTRAL3 model which consists of a vision backbone and a language model.""", + """The Mistral3 model which consists of a vision backbone and a language model, without a language modeling head.""", MISTRAL3_START_DOCSTRING, ) -class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin): +class Mistral3Model(Mistral3PreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + def __init__(self, config: Mistral3Config): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = Mistral3MultiModalProjector(config) - self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] - - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - + self.language_model = AutoModel.from_config(config.text_config) self.post_init() def get_input_embeddings(self): @@ -315,18 +349,6 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin) def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -364,64 +386,23 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin) return image_features @add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Mistral3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, List[int]]] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - image_sizes: Optional[torch.Tensor] = None, + image_sizes: torch.Tensor = None, **lm_kwargs, - ) -> Union[Tuple, Mistral3CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration - - >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") - >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") - - >>> prompt = "[INST][IMG]What is the image?[/INST]" - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_new_tokens=15) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is the image?The image depicts two cats lying on a pink blanket." - ```""" - + ) -> Union[Tuple, Mistral3ModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -468,35 +449,153 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin) use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, - logits_to_keep=logits_to_keep, **lm_kwargs, ) - logits = outputs[0] + output = Mistral3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + return output if return_dict else output.to_tuple() + + +@add_start_docstrings( + """The MISTRAL3 model which consists of a vision backbone and a language model.""", + MISTRAL3_START_DOCSTRING, +) +class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Mistral3Config): + super().__init__(config) + self.model = Mistral3Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Mistral3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: torch.Tensor = None, + **lm_kwargs, + ) -> Union[Tuple, Mistral3CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration + + >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + + >>> prompt = "[INST][IMG]What is the image?[/INST]" + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is the image?The image depicts two cats lying on a pink blanket." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + image_sizes=image_sizes, + **lm_kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) return Mistral3CausalLMOutputWithPast( loss=loss, @@ -504,7 +603,7 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin) past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, + image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( @@ -520,7 +619,7 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin) ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -537,5 +636,60 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin) return model_inputs + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. -__all__ = ["Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"] + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = ["Mistral3Model", "Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"] diff --git a/src/transformers/models/mistral3/modular_mistral3.py b/src/transformers/models/mistral3/modular_mistral3.py index 36fd4526838..5ef6663bde0 100644 --- a/src/transformers/models/mistral3/modular_mistral3.py +++ b/src/transformers/models/mistral3/modular_mistral3.py @@ -19,14 +19,29 @@ import torch from torch import nn from ...activations import ACT2FN -from ...utils import is_torchdynamo_compiling, logging -from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaPreTrainedModel +from ...utils import ( + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from ..llava.modeling_llava import ( + LlavaCausalLMOutputWithPast, + LlavaForConditionalGeneration, + LlavaModel, + LlavaModelOutputWithPast, + LlavaPreTrainedModel, +) from ..mistral.modeling_mistral import MistralRMSNorm from .configuration_mistral3 import Mistral3Config logger = logging.get_logger(__name__) +MISTRAL3_INPUTS_DOCSTRING = None +_CONFIG_FOR_DOC = "Mistra3Config" + class Mistral3RMSNorm(MistralRMSNorm): pass @@ -100,19 +115,29 @@ class Mistral3CausalLMOutputWithPast(LlavaCausalLMOutputWithPast): pass +class Mistral3ModelOutputWithPast(LlavaModelOutputWithPast): + pass + + class Mistral3PreTrainedModel(LlavaPreTrainedModel): def _init_weights(self, module): + # important: this ported version of Mistral3 isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the original codebase + # https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() elif isinstance(module, Mistral3RMSNorm): module.weight.data.fill_(1.0) -class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration): +class Mistral3Model(LlavaModel): def get_image_features( self, pixel_values: torch.FloatTensor, @@ -151,61 +176,21 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration): def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, List[int]]] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - image_sizes: Optional[torch.Tensor] = None, + image_sizes: torch.Tensor = None, **lm_kwargs, - ) -> Union[Tuple, Mistral3CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration - - >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") - >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") - - >>> prompt = "[INST][IMG]What is the image?[/INST]" - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_new_tokens=15) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is the image?The image depicts two cats lying on a pink blanket." - ```""" - + ) -> Union[Tuple, Mistral3ModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -252,35 +237,110 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, - logits_to_keep=logits_to_keep, **lm_kwargs, ) - logits = outputs[0] + output = Mistral3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + return output if return_dict else output.to_tuple() + + +class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration): + @can_return_tuple + @add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Mistral3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: torch.Tensor = None, + **lm_kwargs, + ) -> Union[Tuple, Mistral3CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration + + >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + + >>> prompt = "[INST][IMG]What is the image?[/INST]" + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is the image?The image depicts two cats lying on a pink blanket." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + image_sizes=image_sizes, + **lm_kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) return Mistral3CausalLMOutputWithPast( loss=loss, @@ -288,11 +348,12 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration): past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, + image_hidden_states=outputs.image_hidden_states, ) __all__ = [ + "Mistral3Model", "Mistral3PreTrainedModel", # noqa "Mistral3ForConditionalGeneration", ] diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index d842bd7c131..79278c9892e 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -32,6 +32,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -1968,10 +1969,11 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): @add_start_docstrings( - """The Mllama model which consists of a vision encoder and a language model.""", + """The Mllama model which consists of a vision encoder and a language model without language modeling head.""", MLLAMA_START_DOCSTRING, ) -class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): +class MllamaModel(MllamaPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting def __init__(self, config: MllamaConfig): @@ -1983,10 +1985,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.vision_model = MllamaVisionModel._from_config(config.vision_config) - self.language_model = MllamaForCausalLM._from_config(config.text_config) - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] - + self.language_model = MllamaTextModel._from_config(config.text_config) self.multi_modal_projector = nn.Linear( config.vision_config.vision_output_dim, config.text_config.hidden_size, @@ -2000,18 +1999,139 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) + @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and cross_attention_states is not None: + raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") + + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") + # get vision tokens from vision model + vision_outputs = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.hidden_size + ) + + if cross_attention_mask is not None: + cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=self.vision_model.num_patches, + dtype=self.dtype, + ) + else: + full_text_row_masked_out_mask = None + + if cross_attention_mask is not None and cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] + + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=True, + cache_position=cache_position, + ) + + output = BaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + return output if return_dict else output.to_tuple() + + +@add_start_docstrings( + """The Mllama model which consists of a vision encoder and a language model.""", + MLLAMA_START_DOCSTRING, +) +class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_model": "model.vision_model", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: MllamaConfig): + super().__init__(config) + self.model = MllamaModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + def get_output_embeddings(self): - return self.language_model.get_output_embeddings() + return self.lm_head def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) + self.lm_head = new_embeddings - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model - def get_decoder(self): - return self.language_model.get_decoder() + @property + def vision_model(self): + return self.model.vision_model + @can_return_tuple @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaConfig") def forward( @@ -2084,78 +2204,36 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if pixel_values is not None and cross_attention_states is not None: - raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") - - if pixel_values is not None: - if aspect_ratio_ids is None: - raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") - # get vision tokens from vision model - vision_outputs = self.vision_model( - pixel_values=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - ) - cross_attention_states = vision_outputs[0] - cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( - -1, cross_attention_states.shape[-2], self.hidden_size - ) - - if cross_attention_mask is not None: - cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( - cross_attention_mask, - num_vision_tokens=self.vision_model.num_patches, - dtype=self.dtype, - ) - else: - full_text_row_masked_out_mask = None - - if cross_attention_mask is not None and cache_position is not None: - cross_attention_mask = cross_attention_mask[:, :, cache_position] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] - - outputs = self.language_model( + outputs = self.model( input_ids=input_ids, + pixel_values=pixel_values, + aspect_ratio_mask=aspect_ratio_mask, + aspect_ratio_ids=aspect_ratio_ids, + cross_attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, attention_mask=attention_mask, position_ids=position_ids, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, - use_cache=use_cache, inputs_embeds=inputs_embeds, - output_hidden_states=output_hidden_states, + use_cache=use_cache, output_attentions=output_attentions, - return_dict=return_dict, + output_hidden_states=output_hidden_states, + return_dict=True, cache_position=cache_position, - logits_to_keep=logits_to_keep, - **loss_kwargs, ) - # Temporary fix to calculate the loss in main class, as the model's vocab size may be resized + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + loss = None - logits = outputs[0] - if labels is not None: - loss = self.loss_function(logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs) - - if not return_dict: - return (loss,) + outputs if loss is not None else outputs + loss = self.loss_function(logits, labels, self.config.text_config.vocab_size, **loss_kwargs) return CausalLMOutputWithPast( loss=loss, - logits=outputs.logits, + logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, @@ -2179,58 +2257,28 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - if past_key_values is not None: - if ( - inputs_embeds is not None # Exception 1 - or cache_position[-1] >= input_ids.shape[1] # Exception 3 - ): - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - - # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - - if logits_to_keep is not None: - model_inputs["logits_to_keep"] = logits_to_keep - - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "cross_attention_mask": cross_attention_mask, - } + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + cross_attention_mask=cross_attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, ) # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios # to compute image hidden states, otherwise they are cached within each cross attn layer - if cache_position[0] == 0: - model_inputs["pixel_values"] = pixel_values - model_inputs["aspect_ratio_ids"] = aspect_ratio_ids - model_inputs["aspect_ratio_mask"] = aspect_ratio_mask + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["aspect_ratio_ids"] = None + model_inputs["aspect_ratio_mask"] = None return model_inputs @@ -2257,4 +2305,5 @@ __all__ = [ "MllamaTextModel", "MllamaVisionModel", "MllamaPreTrainedModel", + "MllamaModel", ] diff --git a/src/transformers/models/paligemma/configuration_paligemma.py b/src/transformers/models/paligemma/configuration_paligemma.py index 4551b85bcd5..f32ad303bf1 100644 --- a/src/transformers/models/paligemma/configuration_paligemma.py +++ b/src/transformers/models/paligemma/configuration_paligemma.py @@ -13,8 +13,6 @@ # limitations under the License. """PaliGemmamodel configuration""" -import warnings - from ...configuration_utils import PretrainedConfig from ...utils import logging from ..auto import CONFIG_MAPPING, AutoConfig @@ -39,8 +37,6 @@ class PaliGemmaConfig(PretrainedConfig): Custom vision config or dict text_config (`Union[AutoConfig, dict]`, *optional*): The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. - ignore_index (`int`, *optional*, defaults to -100): - The ignore index for the loss function. image_token_index (`int`, *optional*, defaults to 256000): The image token index to encode the image prompt. vocab_size (`int`, *optional*, defaults to 257152): @@ -83,16 +79,13 @@ class PaliGemmaConfig(PretrainedConfig): self, vision_config=None, text_config=None, - ignore_index=-100, image_token_index=256000, vocab_size=257152, projection_dim=2048, hidden_size=2048, **kwargs, ): - self._ignore_index = ignore_index self.image_token_index = image_token_index - self._vocab_size = vocab_size self.projection_dim = projection_dim self.hidden_size = hidden_size self.vision_config = vision_config @@ -133,22 +126,5 @@ class PaliGemmaConfig(PretrainedConfig): self.vision_config.projection_dim = projection_dim super().__init__(**kwargs) - @property - def ignore_index(self): - warnings.warn( - "The `ignore_index` attribute is deprecated and will be removed in v4.47.", - FutureWarning, - ) - return self._ignore_index - - @ignore_index.setter - def ignore_index(self, value): - self._ignore_index = value - - def to_dict(self): - output = super().to_dict() - output.pop("_ignore_index", None) - return output - __all__ = ["PaliGemmaConfig"] diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 924e3b1dcad..9561f0f0628 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -23,17 +23,18 @@ from torch import nn from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin -from ...modeling_outputs import CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torchdynamo_compiling, logging, replace_return_docstrings, ) -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_paligemma import PaliGemmaConfig @@ -42,78 +43,43 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "PaliGemmaConfig" -# Adapted from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position -# But Paligemma has no causal mask on prefix -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, - is_training: bool = False, - token_type_ids: Optional[torch.Tensor] = None, - **kwargs, -): +@dataclass +class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + Base class for Paligemma outputs, with hidden states and attentions. Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - is_training (`bool`): - Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels` - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below - if sequence_length != 1: - if is_training: - causal_mask = torch.triu(causal_mask, diagonal=1) - else: - causal_mask[:, :sequence_length] = 0.0 + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - # we are training thus we need to create a full mask on the image + prefix but causal on suffix - if is_training: - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 - ) - return causal_mask + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None @dataclass class PaliGemmaCausalLMOutputWithPast(ModelOutput): """ - Base class for PaliGemmacausal language model (or autoregressive) outputs. + Base class for PaliGemma causal language model (or autoregressive) outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): @@ -184,7 +150,7 @@ PALIGEMMA_START_DOCSTRING = r""" ) class PaliGemmaPreTrainedModel(PreTrainedModel): config_class = PaliGemmaConfig - base_model_prefix = "model" + base_model_prefix = "" supports_gradient_checkpointing = True _no_split_modules = ["PaliGemmaMultiModalProjector"] _skip_keys_device_placement = "past_key_values" @@ -276,49 +242,32 @@ PALIGEMMA_INPUTS_DOCSTRING = r""" @add_start_docstrings( - """The PALIGEMMA model which consists of a vision backbone and a language model.""", + """Base Paligemma model which consists of a vision backbone and a language model withou language modeling head.""", PALIGEMMA_START_DOCSTRING, ) -class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): +class PaliGemmaModel(PaliGemmaPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + def __init__(self, config: PaliGemmaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config=config.vision_config) self.multi_modal_projector = PaliGemmaMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - language_model = AutoModelForCausalLM.from_config(config=config.text_config) - - if language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] + language_model = AutoModel.from_config(config=config.text_config) self.language_model = language_model self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings with Llava->PaliGemma + # Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma def get_input_embeddings(self): return self.language_model.get_input_embeddings() - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings with Llava->PaliGemma + # Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings with Llava->PaliGemma - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings with Llava->PaliGemma - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder with Llava->PaliGemma - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder with Llava->PaliGemma - def get_decoder(self): - return self.language_model.get_decoder() - def _update_causal_mask( self, attention_mask, @@ -326,7 +275,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi past_key_values=None, cache_position=None, input_tensor=None, - is_training: Optional[bool] = None, + is_training: bool = None, ): if self.config.text_config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: @@ -403,12 +352,154 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features + @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **lm_kwargs, + ) -> Union[Tuple, PaligemmaModelOutputWithPast]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + is_training = token_type_ids is not None and labels is not None + + # Replace image id woth PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_id >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training + ) + outputs = self.language_model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + output = PaligemmaModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + return output if return_dict else output.to_tuple() + + +@add_start_docstrings( + """Base Paligemma model which consists of a vision backbone and a language model withou language modeling head.""", + PALIGEMMA_START_DOCSTRING, +) +class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: PaliGemmaConfig): + super().__init__(config) + self.model = PaliGemmaModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, @@ -459,117 +550,46 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Where is the cat standing?\nsnow" ```""" - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - is_training = token_type_ids is not None and labels is not None - - # Replace image id with PAD if the image token if OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_id >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_id - llm_input_ids = input_ids.clone() - llm_input_ids[special_image_mask] = 0 - else: - llm_input_ids = input_ids - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed - - # Merge text and images - if pixel_values is not None: - image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - # mask out pad-token-ids in labels for BC - if labels is not None and self.pad_token_id in labels: - logger.warning_once( - "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " - "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", - ) - labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) - - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training - ) - outputs: CausalLMOutputWithPast = self.language_model( - attention_mask=causal_mask, + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + token_type_ids=token_type_ids, + attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, + labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, - logits_to_keep=logits_to_keep, **lm_kwargs, ) - logits = outputs[0] + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - shift_logits = logits[..., :-1, :] - shift_labels = labels[..., 1:] - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) - shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() - else: - shift_logits = shift_logits.contiguous() - shift_labels = shift_labels.contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) - flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) - flat_labels = shift_labels.view(-1).to(shift_logits.device) - loss = loss_fct(flat_logits, flat_labels) - - output = PaliGemmaCausalLMOutputWithPast( + return PaliGemmaCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, + image_hidden_states=outputs.image_hidden_states, ) - return output if return_dict else output.to_tuple() def prepare_inputs_for_generation( self, @@ -587,7 +607,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi **kwargs, ): # Overwritten -- custom `position_ids` and `pixel_values` handling - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -610,12 +630,68 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi is_training = token_type_ids is not None and labels is not None if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): input_tensor = inputs_embeds if inputs_embeds is not None else input_ids - causal_mask = self._update_causal_mask( + causal_mask = self.model._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training ) model_inputs["attention_mask"] = causal_mask return model_inputs + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. -__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel"] + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"] diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 9737e1437e8..b1b875abcd1 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -2056,7 +2056,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen25OmniThinkerText. Make sure to " + " this may lead to unexpected behaviour for Flash Attention version of Qwen25OmniThinker. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) if attention_mask is not None and 0.0 in attention_mask: @@ -2156,7 +2156,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. - config (`Qwen25OmniThinkerTextConfig`): + config (`Qwen25OmniThinkerConfig`): The model's configuration class past_key_values (`Cache`): The cache class that is being used currently to generate @@ -2772,7 +2772,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen25OmniTalker. Make sure to " + " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5Omni. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) if attention_mask is not None and 0.0 in attention_mask: @@ -2872,7 +2872,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. - config (`Qwen25OmniTalkerConfig`): + config (`Qwen2_5OmniConfig`): The model's configuration class past_key_values (`Cache`): The cache class that is being used currently to generate diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 9b5c4167646..5a51e787577 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -33,8 +33,8 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLAttention, Qwen2_5_VLMLP, - Qwen2_5_VLModel, Qwen2_5_VLPreTrainedModel, + Qwen2_5_VLTextModel, Qwen2_5_VLVisionBlock, Qwen2RMSNorm, ) @@ -2165,7 +2165,7 @@ QWEN2_5OMNI_START_DOCSTRING = r""" "The bare Qwen2.5OmniThinker Model outputting raw hidden-states without any specific head on top.", QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniTextConfig"), ) -class Qwen2_5OmniThinkerTextModel(Qwen2_5_VLModel): +class Qwen2_5OmniThinkerTextModel(Qwen2_5_VLTextModel): config_class = Qwen2_5OmniTextConfig _no_split_modules = ["Qwen2_5OmniDecoderLayer"] @@ -2591,7 +2591,7 @@ class Qwen2_5OmniTalkerCausalLMOutputWithPast(ModelOutput): "The bare Qwen2.5OmniTalker Model outputting raw hidden-states without any specific head on top.", QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniTalkerConfig"), ) -class Qwen2_5OmniTalkerModel(Qwen2_5_VLModel): +class Qwen2_5OmniTalkerModel(Qwen2_5_VLTextModel): config_class = Qwen2_5OmniTalkerConfig _no_split_modules = ["Qwen2_5OmniTalkerDecoderLayer"] diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 6c19b9fdcfb..f87260c69b7 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -31,7 +31,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache @@ -44,6 +43,7 @@ from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -565,6 +565,42 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): return hidden_states +@dataclass +class Qwen2_5_VLModelOutputWithPast(ModelOutput): + """ + Base class for Llava outputs, with hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + class Qwen2_5_VLRotaryEmbedding(nn.Module): def __init__(self, config: Qwen2_5_VLTextConfig, device=None): super().__init__() @@ -1076,7 +1112,7 @@ class Qwen2_5_VLDecoderLayer(nn.Module): "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", Qwen2_5_VL_START_DOCSTRING, ) -class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): +class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel): config_class = Qwen2_5_VLTextConfig def __init__(self, config: Qwen2_5_VLTextConfig): @@ -1374,45 +1410,6 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): return causal_mask -@dataclass -class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): - """ - Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): - The rope index difference between sequence length and multimodal rope. - """ - - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - past_key_values: Optional[List[torch.FloatTensor]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - rope_deltas: Optional[torch.LongTensor] = None - - QWEN2_5_VL_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -1489,41 +1486,25 @@ QWEN2_5_VL_INPUTS_DOCSTRING = r""" """ -class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] +class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): + _checkpoint_conversion_mapping = {"^model": "language_model"} config_class = Qwen2_5_VLConfig _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] def __init__(self, config): super().__init__(config) self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) - - text_config = config.get_text_config() - self.model = Qwen2_5_VLModel._from_config(text_config) - self.vocab_size = text_config.vocab_size - self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False) + self.language_model = Qwen2_5_VLTextModel._from_config(config.text_config) self.rope_deltas = None # cache rope_deltas here # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): - return self.model.embed_tokens + return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model + self.language_model.set_input_embeddings(value) def get_rope_index( self, @@ -1708,15 +1689,13 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi return position_ids, mrope_position_deltas @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -1728,46 +1707,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, - ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration - - >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - - >>> messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What is shown in this image?"}, - ], - }, - ] - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." - ```""" - + ) -> Union[Tuple, Qwen2_5_VLModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1775,7 +1715,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) + inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) @@ -1846,7 +1786,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - outputs = self.model( + outputs = self.language_model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, @@ -1855,6 +1795,182 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + ) + + output = Qwen2_5_VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + return output if return_dict else output.to_tuple() + + +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^visual": "model.visual", + r"^model(?!\.(language_model|visual))": "model.language_model", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2_5_VLModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def visual(self): + return self.model.visual + + @can_return_tuple + @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) @@ -1864,18 +1980,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) if not return_dict: output = (logits,) + outputs[1:] @@ -1887,7 +1992,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - rope_deltas=self.rope_deltas, + rope_deltas=outputs.rope_deltas, ) def prepare_inputs_for_generation( @@ -2055,5 +2160,60 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi return input_ids, model_kwargs + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. -__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"] + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLTextModel"] diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index e34724c7906..2d220aa2a90 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -26,7 +26,6 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig from transformers.models.qwen2_vl.modeling_qwen2_vl import ( @@ -36,6 +35,7 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import ( Qwen2VLCausalLMOutputWithPast, Qwen2VLForConditionalGeneration, Qwen2VLModel, + Qwen2VLModelOutputWithPast, Qwen2VLPreTrainedModel, VisionAttention, VisionRotaryEmbedding, @@ -50,7 +50,12 @@ from ...image_utils import ImageInput, VideoInput from ...modeling_flash_attention_utils import is_flash_attn_available from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs from ...tokenization_utils_base import PreTokenizedInput, TextInput -from ...utils import logging +from ...utils import ( + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, + replace_return_docstrings, +) if is_flash_attn_available(): @@ -59,6 +64,8 @@ if is_flash_attn_available(): logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + def apply_rotary_pos_emb_flashatt( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor @@ -406,16 +413,12 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): return hidden_states -class Qwen2_5_VLModel(Qwen2VLModel): - pass - - @dataclass -class Qwen2_5_VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast): +class Qwen2_5_VLModelOutputWithPast(Qwen2VLModelOutputWithPast): pass -class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): +class Qwen2_5_VLModel(Qwen2VLModel): config_class = Qwen2_5_VLConfig _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] @@ -607,12 +610,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): def forward( self, - input_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -624,46 +626,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, - ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration - - >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - - >>> messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What is shown in this image?"}, - ], - }, - ] - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." - ```""" - + ) -> Union[Tuple, Qwen2_5_VLModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -671,7 +634,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) + inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) @@ -742,7 +705,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - outputs = self.model( + outputs = self.language_model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, @@ -751,6 +714,111 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + ) + + output = Qwen2_5_VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + return output if return_dict else output.to_tuple() + + +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast): + pass + + +QWEN2_5_VL_INPUTS_DOCSTRING = None + + +class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): + @can_return_tuple + @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) @@ -760,18 +828,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) if not return_dict: output = (logits,) + outputs[1:] @@ -783,7 +840,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - rope_deltas=self.rope_deltas, + rope_deltas=outputs.rope_deltas, ) def prepare_inputs_for_generation( @@ -985,4 +1042,5 @@ __all__ = [ "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLProcessor", + "Qwen2_5_VLTextModel", # noqa: F822 ] diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 4496ef73b8c..03700451ce8 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -799,27 +799,21 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi raise ValueError(f"{padding_side} is not `left` or `right`.") self._padding_side = padding_side - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings def get_input_embeddings(self): return self.language_model.get_input_embeddings() - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings def get_output_embeddings(self): return self.language_model.get_output_embeddings() - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder def set_decoder(self, decoder): self.language_model.set_decoder(decoder) - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder def get_decoder(self): return self.language_model.get_decoder() diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index ce008fdaf9f..2e46c8927a7 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn import LayerNorm from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache @@ -40,7 +40,9 @@ from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -61,6 +63,42 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Qwen2VLConfig" +@dataclass +class Qwen2VLModelOutputWithPast(ModelOutput): + """ + Base class for Llava outputs, with hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + @dataclass class Qwen2VLCausalLMOutputWithPast(ModelOutput): """ @@ -1027,7 +1065,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.", QWEN2VL_START_DOCSTRING, ) -class Qwen2VLModel(Qwen2VLPreTrainedModel): +class Qwen2VLTextModel(Qwen2VLPreTrainedModel): config_class = Qwen2VLTextConfig def __init__(self, config: Qwen2VLTextConfig): @@ -1403,39 +1441,23 @@ QWEN2_VL_INPUTS_DOCSTRING = r""" """ -class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] +class Qwen2VLModel(Qwen2VLPreTrainedModel): + _checkpoint_conversion_mapping = {"^model": "language_model"} - def __init__(self, config): + def __init__(self, config: Qwen2VLConfig): super().__init__(config) self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config) - - text_config = config.get_text_config() - self.model = Qwen2VLModel._from_config(text_config) - self.vocab_size = text_config.vocab_size - self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False) + self.language_model = Qwen2VLTextModel._from_config(config.text_config) self.rope_deltas = None # cache rope_deltas here # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): - return self.model.embed_tokens + return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model + self.language_model.set_input_embeddings(value) def get_rope_index( self, @@ -1586,11 +1608,166 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): return position_ids, mrope_position_deltas + @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, Qwen2VLModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_embeds.shape[0] + if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum() + n_video_features = video_embeds.shape[0] + if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + delta = delta.to(position_ids.device) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + ) + + output = Qwen2VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + return output if return_dict else output.to_tuple() + + +class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^visual": "model.visual", + r"^model(?!\.(language_model|visual))": "model.language_model", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2VLModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def visual(self): + return self.model.visual + + @can_return_tuple @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -1652,70 +1829,12 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.get_dtype()) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - n_image_features = image_embeds.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() - n_video_features = video_embeds.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): - # calculate RoPE index once per generation in the pre-fill stage only - if ( - (cache_position is not None and cache_position[0] == 0) - or self.rope_deltas is None - or (past_key_values is None or past_key_values.get_seq_length() == 0) - ): - position_ids, rope_deltas = self.get_rope_index( - input_ids, image_grid_thw, video_grid_thw, attention_mask - ) - self.rope_deltas = rope_deltas - # then use the prev pre-calculated rope-deltas to get the correct position ids - else: - batch_size, seq_length, _ = inputs_embeds.shape - delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 - position_ids = torch.arange(seq_length, device=inputs_embeds.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - if cache_position is not None: # otherwise `deltas` is an int `0` - delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) - delta = delta.to(position_ids.device) - position_ids = position_ids.add(delta) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - outputs = self.model( - input_ids=None, + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, @@ -1732,22 +1851,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) return Qwen2VLCausalLMOutputWithPast( loss=loss, @@ -1755,7 +1859,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - rope_deltas=self.rope_deltas, + rope_deltas=outputs.rope_deltas, ) def prepare_inputs_for_generation( @@ -1921,5 +2025,61 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): return input_ids, model_kwargs + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. -__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel"] + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel", "Qwen2VLTextModel"] diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 70ebef344c3..d40dd4db887 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -28,11 +28,12 @@ from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torchdynamo_compiling, logging, replace_return_docstrings, ) -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_video_llava import VideoLlavaConfig @@ -41,6 +42,47 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "VideoLlavaConfig" +@dataclass +class VideoLlavaModelOutputWithPast(ModelOutput): + """ + Base class for VideoLlava base model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None + + @dataclass class VideoLlavaCausalLMOutputWithPast(ModelOutput): """ @@ -130,7 +172,7 @@ VIDEO_LLAVA_START_DOCSTRING = r""" ) class VideoLlavaPreTrainedModel(PreTrainedModel): config_class = VideoLlavaConfig - base_model_prefix = "model" + base_model_prefix = "" supports_gradient_checkpointing = True _no_split_modules = ["VideoLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" @@ -242,10 +284,12 @@ VIDEO_LLAVA_INPUTS_DOCSTRING = r""" @add_start_docstrings( - """The VideoLlava model which consists of a vision backbone and a language model.""", + """The VideoLlava model which consists of a vision backbone and a language model without language modeling head.""", VIDEO_LLAVA_START_DOCSTRING, ) -class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMixin): +class VideoLlavaModel(VideoLlavaPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + def __init__(self, config: VideoLlavaConfig): super().__init__(config) self.video_tower = AutoModel.from_config(config.vision_config) @@ -253,10 +297,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi self.multi_modal_projector = VideoLlavaMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] - + self.language_model = AutoModel.from_config(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() @@ -266,18 +307,6 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def get_image_features( self, pixel_values_images: torch.FloatTensor, @@ -358,13 +387,165 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi return video_features, num_frames + @add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values_images: torch.FloatTensor = None, + pixel_values_videos: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **lm_kwargs, + ) -> Union[Tuple, VideoLlavaModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None: + raise ValueError( + "You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same " + "time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values_images is not None: + image_features = self.get_image_features( + pixel_values_images, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + if pixel_values_videos is not None: + video_features, num_frames = self.get_video_features( + pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer + ) + + special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): + n_video_tokens = (input_ids == self.config.video_token_id).sum() + n_video_features = video_features.shape[0] * video_features.shape[1] + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + output = VideoLlavaModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values_images is not None else None, + video_hidden_states=video_features if pixel_values_videos is not None else None, + ) + return output if return_dict else output.to_tuple() + + +@add_start_docstrings( + """The VideoLlava model which consists of a vision backbone and a language model.""", + VIDEO_LLAVA_START_DOCSTRING, +) +class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^image_tower": "model.image_tower", + "^video_tower": "model.video_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: VideoLlavaConfig): + super().__init__(config) + self.model = VideoLlavaModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def video_tower(self): + return self.model.video_tower + + @property + def image_tower(self): + return self.model.image_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple @add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values_images: Optional[torch.FloatTensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values_images: torch.FloatTensor = None, + pixel_values_videos: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -475,88 +656,32 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi else self.config.vision_feature_select_strategy ) - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None: - raise ValueError( - "You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same " - "time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if pixel_values_images is not None: - image_features = self.get_image_features( - pixel_values_images, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - ) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - if pixel_values_videos is not None: - video_features, num_frames = self.get_video_features( - pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer - ) - - special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): - n_video_tokens = (input_ids == self.config.video_token_id).sum() - n_video_features = video_features.shape[0] * video_features.shape[1] - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) - - outputs = self.language_model( + outputs = self.model( + input_ids=input_ids, + pixel_values_images=pixel_values_images, + pixel_values_videos=pixel_values_videos, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, - logits_to_keep=logits_to_keep, **lm_kwargs, ) - logits = outputs[0] + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) return VideoLlavaCausalLMOutputWithPast( loss=loss, @@ -564,8 +689,8 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values_images is not None else None, - video_hidden_states=video_features if pixel_values_videos is not None else None, + image_hidden_states=outputs.image_hidden_states, + video_hidden_states=outputs.video_hidden_states, ) def prepare_inputs_for_generation( @@ -582,7 +707,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -600,5 +725,61 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi return model_inputs + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. -__all__ = ["VideoLlavaPreTrainedModel", "VideoLlavaForConditionalGeneration"] + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = ["VideoLlavaPreTrainedModel", "VideoLlavaModel", "VideoLlavaForConditionalGeneration"] diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 9243bbe9e2d..49169593320 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/vipllava/modular_vipllava.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_vipllava.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2023 the HuggingFace Inc. team. All rights reserved. # @@ -12,37 +18,65 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch VipLlava model.""" from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_outputs import ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torchdynamo_compiling, - logging, replace_return_docstrings, ) -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_vipllava import VipLlavaConfig -logger = logging.get_logger(__name__) - _CONFIG_FOR_DOC = "VipLlavaConfig" @dataclass -# Copied from transformers.models.llava.modeling_llava.LlavaCausalLMOutputWithPast with Llava->VipLlava +class VipLlavaModelOutputWithPast(BaseModelOutputWithPast): + """ + Base class for VipLlava outputs, with hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass class VipLlavaCausalLMOutputWithPast(ModelOutput): """ Base class for VipLlava causal language model (or autoregressive) outputs. @@ -70,7 +104,7 @@ class VipLlavaCausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ @@ -129,9 +163,8 @@ VIPLLAVA_START_DOCSTRING = r""" ) class VipLlavaPreTrainedModel(PreTrainedModel): config_class = VipLlavaConfig - base_model_prefix = "model" + base_model_prefix = "" supports_gradient_checkpointing = True - _no_split_modules = ["VipLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True @@ -140,6 +173,9 @@ class VipLlavaPreTrainedModel(PreTrainedModel): _supports_static_cache = True def _init_weights(self, module): + # important: this ported version of VipLlava isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the original codebase + # https://github.com/haotian-liu/VipLlava/tree/main/vipllava should serve for that purpose std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, nn.Linear): @@ -147,8 +183,8 @@ class VipLlavaPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() module.weight.data.fill_(1.0) + module.bias.data.zero_() VIPLLAVA_INPUTS_DOCSTRING = r""" @@ -163,7 +199,7 @@ VIPLLAVA_INPUTS_DOCSTRING = r""" [What are input IDs?](../glossary#input-ids) pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`VipLlavaProcessor`] uses [`CLIPImageProcessor`] for processing images). attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: @@ -203,6 +239,13 @@ VIPLLAVA_INPUTS_DOCSTRING = r""" Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -222,24 +265,18 @@ VIPLLAVA_INPUTS_DOCSTRING = r""" @add_start_docstrings( - """The VIPLLAVA model which consists of a vision backbone and a language model.""", + """The VipLlava model which consists of a vision backbone and a language model, without a language modeling head.""", VIPLLAVA_START_DOCSTRING, ) -# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration with LLAVA->VIPLLAVA,Llava->VipLlava -class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin): +class VipLlavaModel(VipLlavaPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + def __init__(self, config: VipLlavaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = VipLlavaMultiModalProjector(config) - self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] - - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - + self.language_model = AutoModel.from_config(config.text_config) self.post_init() def get_input_embeddings(self): @@ -248,19 +285,6 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - - # Ignore copy def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: Union[int, List[int]]): """ Obtains image last hidden states from the vision tower and apply multimodal projection. @@ -288,12 +312,132 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) return image_features @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - # Ignore copy def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layers: Optional[Union[int, List[int]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **lm_kwargs, + ) -> Union[Tuple, VipLlavaModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layers = ( + vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, vision_feature_layers=vision_feature_layers + ) + + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + output = VipLlavaModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + return output if return_dict else output.to_tuple() + + +@add_start_docstrings( + """The VIPLLAVA model which consists of a vision backbone and a language model.""", + VIPLLAVA_START_DOCSTRING, +) +class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: VipLlavaConfig): + super().__init__(config) + self.model = VipLlavaModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -358,68 +502,30 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers ) - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, vision_feature_layers=vision_feature_layers - ) - - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - outputs = self.language_model( + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, + vision_feature_layers=vision_feature_layers, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, - logits_to_keep=logits_to_keep, **lm_kwargs, ) - logits = outputs[0] + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) return VipLlavaCausalLMOutputWithPast( loss=loss, @@ -427,7 +533,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, + image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( @@ -443,7 +549,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -460,5 +566,60 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) return model_inputs + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. -__all__ = ["VipLlavaForConditionalGeneration", "VipLlavaPreTrainedModel"] + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = ["VipLlavaModel", "VipLlavaForConditionalGeneration", "VipLlavaPreTrainedModel"] diff --git a/src/transformers/models/vipllava/modular_vipllava.py b/src/transformers/models/vipllava/modular_vipllava.py new file mode 100644 index 00000000000..93fdecffec9 --- /dev/null +++ b/src/transformers/models/vipllava/modular_vipllava.py @@ -0,0 +1,294 @@ +# coding=utf-8 +# Copyright 2023 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from transformers.models.llava.modeling_llava import ( + LlavaCausalLMOutputWithPast, + LlavaForConditionalGeneration, + LlavaModel, + LlavaModelOutputWithPast, + LlavaPreTrainedModel, +) + +from ...activations import ACT2FN +from ...utils import ( + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from .configuration_vipllava import VipLlavaConfig + + +logger = logging.get_logger(__name__) + + +VIPLLAVA_INPUTS_DOCSTRING = None +_CONFIG_FOR_DOC = "VipLlavaConfig" + + +class VipLlavaModelOutputWithPast(LlavaModelOutputWithPast): + pass + + +class VipLlavaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): + pass + + +class VipLlavaMultiModalProjector(nn.Module): + def __init__(self, config: VipLlavaConfig): + super().__init__() + num_feature_layers = 1 if isinstance(config.vision_feature_layers, int) else len(config.vision_feature_layers) + self.projector_layernorm = nn.LayerNorm( + num_feature_layers * config.vision_config.hidden_size, eps=config.projector_layernorm_eps + ) + + self.linear_1 = nn.Linear( + num_feature_layers * config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + + def forward(self, hidden_states): + hidden_states = self.projector_layernorm(hidden_states) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class VipLlavaPreTrainedModel(LlavaPreTrainedModel): + pass + + +class VipLlavaModel(LlavaModel): + def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: Union[int, List[int]]): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_layers (`Union[int, List[int]]`): + The vision feature layer, or the list of indexes of the layers to select + the vision feature. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + + # If multiple feature layers are provided (which is usually the case) + # then the image features are concatenated after the CLS is removed. + if isinstance(vision_feature_layers, int): + image_features = image_outputs.hidden_states[vision_feature_layers][:, 1:] + else: + # Usually, we select the features from index 1: the layers -2, -5, -8, -11 and 6 + image_features = [image_outputs.hidden_states[index][:, 1:] for index in vision_feature_layers] + image_features = torch.cat(image_features, dim=-1) + image_features = self.multi_modal_projector(image_features) + return image_features + + @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layers: Optional[Union[int, List[int]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **lm_kwargs, + ) -> Union[Tuple, VipLlavaModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layers = ( + vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, vision_feature_layers=vision_feature_layers + ) + + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + output = VipLlavaModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + return output if return_dict else output.to_tuple() + + +class VipLlavaForConditionalGeneration(LlavaForConditionalGeneration): + @can_return_tuple + @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layers: Optional[Union[int, List[int]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[Tuple, VipLlavaCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + + Returns: + + Example: + + ```python + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, VipLlavaForConditionalGeneration + + >>> model = VipLlavaForConditionalGeneration.from_pretrained("llava-hf/vip-llava-7b-hf", device_map="auto", torch_dtype=torch.float16) + >>> processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf") + + >>> prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: \n{}###Assistant:" + >>> question = "Can you please describe this image?" + >>> prompt = prompt.format(question) + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=text, images=image, return_tensors="pt").to(0, torch.float16) + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=20) + >>> processor.decode(generate_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) + The image features a brown and white cat sitting on a green surface, with a red ball in its + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layers = ( + vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + vision_feature_layers=vision_feature_layers, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return VipLlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + +__all__ = ["VipLlavaModel", "VipLlavaForConditionalGeneration", "VipLlavaPreTrainedModel"] diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py index 2f802d9b70e..3baa74a8bdc 100644 --- a/tests/models/aria/test_modeling_aria.py +++ b/tests/models/aria/test_modeling_aria.py @@ -21,6 +21,7 @@ import requests from transformers import ( AriaConfig, AriaForConditionalGeneration, + AriaModel, AriaTextConfig, AutoProcessor, AutoTokenizer, @@ -175,7 +176,7 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi Model tester for `AriaForConditionalGeneration`. """ - all_model_classes = (AriaForConditionalGeneration,) if is_torch_available() else () + all_model_classes = (AriaModel, AriaForConditionalGeneration) if is_torch_available() else () test_pruning = False test_head_masking = False _is_composite = True @@ -281,6 +282,18 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi def test_generate_from_inputs_embeds_with_static_cache(self): pass + @unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading") + def test_cpu_offload(self): + pass + + @unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading") + def test_disk_offload_bin(self): + pass + + @unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading") + def test_disk_offload_safetensors(self): + pass + @require_torch class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/aya_vision/test_modeling_aya_vision.py b/tests/models/aya_vision/test_modeling_aya_vision.py index c35058abd62..f4a464ed621 100644 --- a/tests/models/aya_vision/test_modeling_aya_vision.py +++ b/tests/models/aya_vision/test_modeling_aya_vision.py @@ -46,6 +46,7 @@ if is_torch_available(): from transformers import ( AyaVisionForConditionalGeneration, + AyaVisionModel, ) @@ -158,7 +159,14 @@ class AyaVisionVisionText2TextModelTester: @require_torch class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (AyaVisionForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + AyaVisionModel, + AyaVisionForConditionalGeneration, + ) + if is_torch_available() + else () + ) all_generative_model_classes = (AyaVisionForConditionalGeneration,) if is_torch_available() else () pipeline_model_mapping = ( { diff --git a/tests/models/emu3/test_modeling_emu3.py b/tests/models/emu3/test_modeling_emu3.py index b27b1d4c708..b6d93ef73f2 100644 --- a/tests/models/emu3/test_modeling_emu3.py +++ b/tests/models/emu3/test_modeling_emu3.py @@ -46,6 +46,7 @@ if is_torch_available(): from transformers import ( Emu3ForCausalLM, Emu3ForConditionalGeneration, + Emu3Model, Emu3Processor, Emu3TextModel, ) @@ -310,7 +311,14 @@ class Emu3Vision2TextModelTester: @require_torch class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (Emu3ForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + Emu3Model, + Emu3ForConditionalGeneration, + ) + if is_torch_available() + else () + ) pipeline_model_mapping = {} test_headmasking = False test_pruning = False @@ -395,6 +403,10 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline def test_generate_with_static_cache(self): pass + @unittest.skip("Emu3 doesn't support Flex attn yet!") + def test_flex_attention_with_grads(self): + pass + @require_torch class Emu3IntegrationTest(unittest.TestCase): diff --git a/tests/models/fuyu/test_modeling_fuyu.py b/tests/models/fuyu/test_modeling_fuyu.py index 06f0171e46a..1be8d9fcca8 100644 --- a/tests/models/fuyu/test_modeling_fuyu.py +++ b/tests/models/fuyu/test_modeling_fuyu.py @@ -38,7 +38,7 @@ if is_torch_available() and is_vision_available(): if is_torch_available(): - from transformers import FuyuForCausalLM + from transformers import FuyuForCausalLM, FuyuModel class FuyuModelTester: @@ -145,7 +145,14 @@ class FuyuModelTester: @require_torch class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (FuyuForCausalLM,) if is_torch_available() else () + all_model_classes = ( + ( + FuyuModel, + FuyuForCausalLM, + ) + if is_torch_available() + else () + ) pipeline_model_mapping = ( {"text-generation": FuyuForCausalLM, "image-text-to-text": FuyuForCausalLM} if is_torch_available() else {} ) diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index 9638ff4a87f..efd6f4095cd 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -50,6 +50,7 @@ if is_torch_available(): from transformers import ( Gemma3ForCausalLM, Gemma3ForConditionalGeneration, + Gemma3Model, Gemma3Processor, Gemma3TextModel, ) @@ -148,9 +149,9 @@ class Gemma3Vision2TextModelTester: self, parent, mm_tokens_per_image=2, - image_token_index=1, - boi_token_index=2, - eoi_token_index=3, + image_token_index=4, + boi_token_index=5, + eoi_token_index=6, seq_length=25, is_training=True, vision_config={ @@ -242,7 +243,14 @@ class Gemma3Vision2TextModelTester: @require_torch class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): - all_model_classes = (Gemma3ForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + Gemma3Model, + Gemma3ForConditionalGeneration, + ) + if is_torch_available() + else () + ) all_generative_model_classes = (Gemma3ForConditionalGeneration,) if is_torch_available() else () test_headmasking = False test_pruning = False diff --git a/tests/models/got_ocr2/test_modeling_got_ocr2.py b/tests/models/got_ocr2/test_modeling_got_ocr2.py index ed0a25f7b19..e4706b6db0c 100644 --- a/tests/models/got_ocr2/test_modeling_got_ocr2.py +++ b/tests/models/got_ocr2/test_modeling_got_ocr2.py @@ -34,6 +34,7 @@ if is_torch_available(): from transformers import ( GotOcr2ForConditionalGeneration, + GotOcr2Model, ) @@ -140,7 +141,14 @@ class GotOcr2VisionText2TextModelTester: @require_torch class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (GotOcr2ForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + GotOcr2Model, + GotOcr2ForConditionalGeneration, + ) + if is_torch_available() + else () + ) pipeline_model_mapping = ( { "image-to-text": GotOcr2ForConditionalGeneration, @@ -228,6 +236,10 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi def test_past_key_values_format(self): pass + @unittest.skip(reason="Vision backbone doesn't support FLEX yet!") + def test_flex_attention_with_grads(self): + pass + @require_torch class GotOcr2IntegrationTest(unittest.TestCase): diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index f47330ec673..b154a09c2bc 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -54,7 +54,7 @@ if is_torch_available(): import torch from torch import nn - from transformers import InstructBlipForConditionalGeneration, InstructBlipVisionModel + from transformers import InstructBlipForConditionalGeneration, InstructBlipModel, InstructBlipVisionModel if is_vision_available(): @@ -460,14 +460,20 @@ class InstructBlipForConditionalGenerationDecoderOnlyModelTester: "attention_mask": attention_mask, "qformer_input_ids": qformer_input_ids, "qformer_attention_mask": qformer_attention_mask, - "labels": input_ids, } return config, inputs_dict @require_torch class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): - all_model_classes = (InstructBlipForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + InstructBlipModel, + InstructBlipForConditionalGeneration, + ) + if is_torch_available() + else () + ) pipeline_model_mapping = {"image-text-to-text": InstructBlipForConditionalGeneration} fx_compatible = False test_head_masking = False diff --git a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py index a4b7f25e347..9bd617b4666 100644 --- a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py @@ -54,7 +54,11 @@ if is_torch_available(): import torch from torch import nn - from transformers import InstructBlipVideoForConditionalGeneration, InstructBlipVideoVisionModel + from transformers import ( + InstructBlipVideoForConditionalGeneration, + InstructBlipVideoModel, + InstructBlipVideoVisionModel, + ) class InstructBlipVideoVisionModelTester: @@ -477,7 +481,6 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyModelTester: "attention_mask": attention_mask, "qformer_input_ids": qformer_input_ids, "qformer_attention_mask": qformer_attention_mask, - "labels": input_ids, } return config, inputs_dict @@ -486,7 +489,9 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyModelTester: class InstructBlipVideoForConditionalGenerationDecoderOnlyTest( ModelTesterMixin, GenerationTesterMixin, unittest.TestCase ): - all_model_classes = (InstructBlipVideoForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + (InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel) if is_torch_available() else () + ) fx_compatible = False test_head_masking = False test_pruning = False diff --git a/tests/models/internvl/test_modeling_internvl.py b/tests/models/internvl/test_modeling_internvl.py index e51126f14ea..bb85f0cba41 100644 --- a/tests/models/internvl/test_modeling_internvl.py +++ b/tests/models/internvl/test_modeling_internvl.py @@ -47,9 +47,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): import torch - from transformers import ( - InternVLForConditionalGeneration, - ) + from transformers import InternVLForConditionalGeneration, InternVLModel if is_vision_available(): @@ -191,7 +189,7 @@ class InternVLVisionText2TextModelTester: @require_torch class InternVLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (InternVLForConditionalGeneration,) if is_torch_available() else () + all_model_classes = (InternVLForConditionalGeneration, InternVLModel) if is_torch_available() else () all_generative_model_classes = (InternVLForConditionalGeneration,) if is_torch_available() else () pipeline_model_mapping = ( { diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index a692566340c..5d60f6248b1 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -13,6 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch Llava model.""" +import copy import unittest import requests @@ -23,6 +24,7 @@ from transformers import ( AutoTokenizer, LlavaConfig, LlavaForConditionalGeneration, + LlavaModel, is_torch_available, is_vision_available, ) @@ -166,7 +168,14 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM Model tester for `LlavaForConditionalGeneration`. """ - all_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + LlavaModel, + LlavaForConditionalGeneration, + ) + if is_torch_available() + else () + ) pipeline_model_mapping = ( {"image-to-text": LlavaForConditionalGeneration, "image-text-to-text": LlavaForConditionalGeneration} if is_torch_available() @@ -238,16 +247,17 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config).to(torch_device) - _ = model(**input_dict) # successful forward with no modifications + curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further + _ = model(**curr_input_dict) # successful forward with no modifications # remove one image but leave the image token in text - input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...] + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...] with self.assertRaises(ValueError): - _ = model(**input_dict) + _ = model(**curr_input_dict) # simulate multi-image case by concatenating inputs where each has exactly one image/image-token - input_ids = input_dict["input_ids"][:1] - pixel_values = input_dict["pixel_values"][:1] + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:1] input_ids = torch.cat([input_ids, input_ids], dim=0) # one image and two image tokens raise an error @@ -281,7 +291,8 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM model = model_class(config).to(torch_device) # We should have the right number of input features, # and should be able to run a forward pass without exploding - assert model.multi_modal_projector.linear_1.in_features == expected_features + base_model = getattr(model, "model", model) + assert base_model.multi_modal_projector.linear_1.in_features == expected_features model(**input_dict) @unittest.skip( diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index 95f5eaa083c..c8789f0ba38 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -13,6 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch Llava-NeXT model.""" +import copy import unittest import requests @@ -23,6 +24,7 @@ from transformers import ( AutoProcessor, LlavaNextConfig, LlavaNextForConditionalGeneration, + LlavaNextModel, is_torch_available, is_vision_available, ) @@ -181,7 +183,14 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes Model tester for `LlavaNextForConditionalGeneration`. """ - all_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + LlavaNextModel, + LlavaNextForConditionalGeneration, + ) + if is_torch_available() + else () + ) pipeline_model_mapping = {"image-text-to-text": LlavaNextForConditionalGeneration} if is_torch_available() else {} test_pruning = False test_head_masking = False @@ -265,18 +274,19 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config).to(torch_device) - _ = model(**input_dict) # successful forward with no modifications + curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further + _ = model(**curr_input_dict) # successful forward with no modifications # remove one image but leave the image token in text - input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...] - input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...] + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...] + curr_input_dict["image_sizes"] = curr_input_dict["image_sizes"][-1:, ...] with self.assertRaises(ValueError): - _ = model(**input_dict) + _ = model(**curr_input_dict) # simulate multi-image case by concatenating inputs where each has exactly one image/image-token - input_ids = input_dict["input_ids"][:1] - pixel_values = input_dict["pixel_values"][:1] - image_sizes = input_dict["image_sizes"][:1] + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:1] + image_sizes = curr_input_dict["image_sizes"][:1] input_ids = torch.cat([input_ids, input_ids], dim=0) # one image and two image tokens raise an error @@ -324,7 +334,8 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes model = model_class(config).to(torch_device) # We should have the right number of input features, # and should be able to run a forward pass without exploding - assert model.multi_modal_projector.linear_1.in_features == expected_features + base_model = getattr(model, "model", model) + assert base_model.multi_modal_projector.linear_1.in_features == expected_features model(**input_dict) @unittest.skip( diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index 628611d689a..e68a1e4362e 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -13,6 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch Llava-NeXT-Video model.""" +import copy import unittest import numpy as np @@ -23,6 +24,7 @@ from transformers import ( AutoProcessor, LlavaNextVideoConfig, LlavaNextVideoForConditionalGeneration, + LlavaNextVideoModel, is_torch_available, is_vision_available, ) @@ -196,7 +198,14 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati Model tester for `LlavaNextVideoForConditionalGeneration`. """ - all_model_classes = (LlavaNextVideoForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + LlavaNextVideoModel, + LlavaNextVideoForConditionalGeneration, + ) + if is_torch_available() + else () + ) test_pruning = False test_head_masking = False _is_composite = True @@ -281,18 +290,19 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config).to(torch_device) - _ = model(**input_dict) # successful forward with no modifications + curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further + _ = model(**curr_input_dict) # successful forward with no modifications # remove one image but leave the image token in text - input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...] - input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...] + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...] + curr_input_dict["image_sizes"] = curr_input_dict["image_sizes"][-1:, ...] with self.assertRaises(ValueError): - _ = model(**input_dict) + _ = model(**curr_input_dict) # simulate multi-image case by concatenating inputs where each has exactly one image/image-token - input_ids = input_dict["input_ids"][:1] - pixel_values = input_dict["pixel_values"][:1] - image_sizes = input_dict["image_sizes"][:1] + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:1] + image_sizes = curr_input_dict["image_sizes"][:1] input_ids = torch.cat([input_ids, input_ids], dim=0) # one image and two image tokens raise an error @@ -340,7 +350,8 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati model = model_class(config).to(torch_device) # We should have the right number of input features, # and should be able to run a forward pass without exploding - assert model.multi_modal_projector.linear_1.in_features == expected_features + base_model = getattr(model, "model", model) + assert base_model.multi_modal_projector.linear_1.in_features == expected_features model(**input_dict) @unittest.skip( diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py index 4ec719d3c54..ba95c330dbd 100644 --- a/tests/models/llava_onevision/test_modeling_llava_onevision.py +++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py @@ -24,6 +24,7 @@ from transformers import ( AutoProcessor, LlavaOnevisionConfig, LlavaOnevisionForConditionalGeneration, + LlavaOnevisionModel, is_torch_available, is_vision_available, ) @@ -182,7 +183,14 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati Model tester for `LlavaOnevisionForConditionalGeneration`. """ - all_model_classes = (LlavaOnevisionForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + LlavaOnevisionModel, + LlavaOnevisionForConditionalGeneration, + ) + if is_torch_available() + else () + ) pipeline_model_mapping = ( {"image-text-to-text": LlavaOnevisionForConditionalGeneration} if is_torch_available() else {} ) @@ -296,7 +304,8 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati model = model_class(config).to(torch_device) # We should have the right number of input features, # and should be able to run a forward pass without exploding - assert model.multi_modal_projector.linear_1.in_features == expected_features + base_model = getattr(model, "model", model) + assert base_model.multi_modal_projector.linear_1.in_features == expected_features model(**input_dict) @unittest.skip( diff --git a/tests/models/mistral3/test_modeling_mistral3.py b/tests/models/mistral3/test_modeling_mistral3.py index 2555126dbaa..7c8f60fdc0a 100644 --- a/tests/models/mistral3/test_modeling_mistral3.py +++ b/tests/models/mistral3/test_modeling_mistral3.py @@ -42,6 +42,7 @@ if is_torch_available(): from transformers import ( Mistral3ForConditionalGeneration, + Mistral3Model, ) @@ -162,7 +163,14 @@ class Mistral3VisionText2TextModelTester: @require_torch class Mistral3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (Mistral3ForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + Mistral3Model, + Mistral3ForConditionalGeneration, + ) + if is_torch_available() + else () + ) all_generative_model_classes = (Mistral3ForConditionalGeneration,) if is_torch_available() else () pipeline_model_mapping = ( { @@ -278,6 +286,10 @@ class Mistral3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM def test_sdpa_can_dispatch_on_flash(self): pass + @unittest.skip("Pixtral does not support attention interfaces.") + def test_flex_attention_with_grads(self): + pass + @slow @require_torch_gpu diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py index 34e4d4e4896..e67e0455e1f 100644 --- a/tests/models/mllama/test_modeling_mllama.py +++ b/tests/models/mllama/test_modeling_mllama.py @@ -25,6 +25,7 @@ from transformers import ( MllamaConfig, MllamaForCausalLM, MllamaForConditionalGeneration, + MllamaModel, is_torch_available, is_vision_available, ) @@ -262,7 +263,14 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester Model tester for `MllamaForConditionalGeneration`. """ - all_model_classes = (MllamaForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + MllamaModel, + MllamaForConditionalGeneration, + ) + if is_torch_available() + else () + ) pipeline_model_mapping = {"image-text-to-text": MllamaForConditionalGeneration} if is_torch_available() else () test_pruning = False test_head_masking = False @@ -325,19 +333,18 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester # resizing embeddings should result in successful loss computation config, inputs = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: - model = model_class(config) - model_vocab_size = config.get_text_config().vocab_size - inputs = self._prepare_for_class(inputs, model_class, return_labels=True) - # Resize embeddings and call forward - model.resize_token_embeddings(model_vocab_size + 10) - output = model( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - labels=inputs["labels"], - return_dict=True, - ) - self.assertTrue("loss" in output) + model = MllamaForConditionalGeneration(config).to(torch_device) + model_vocab_size = config.get_text_config().vocab_size + inputs = self._prepare_for_class(inputs, MllamaForConditionalGeneration, return_labels=True) + # Resize embeddings and call forward + model.resize_token_embeddings(model_vocab_size + 10) + output = model( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + labels=inputs["labels"], + return_dict=True, + ) + self.assertTrue("loss" in output) def _check_attentions_for_generate( self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values @@ -409,6 +416,18 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester def test_assisted_decoding_with_num_logits_to_keep(self): pass + @unittest.skip(reason="Mllama uses self.weights dirrectly causing device mismatch when offloading`") + def test_cpu_offload(self): + pass + + @unittest.skip(reason="Mllama uses self.weights dirrectly causing device mismatch when offloading`") + def test_disk_offload_bin(self): + pass + + @unittest.skip(reason="Mllama uses self.weights dirrectly causing device mismatch when offloading`") + def test_disk_offload_safetensors(self): + pass + @pytest.mark.generate # overridden because mllama is not an encoder-decoder model, but has encoder-decoder-like cache def test_past_key_values_format(self): @@ -501,7 +520,7 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester """ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: + for model_class in self.all_generative_model_classes: model = model_class(config) model.to(torch_device) model.eval() diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index 4b9db1d75d1..a4d323baa31 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -13,6 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch PaliGemma model.""" +import copy import unittest import requests @@ -20,6 +21,7 @@ import requests from transformers import ( PaliGemmaConfig, PaliGemmaForConditionalGeneration, + PaliGemmaModel, PaliGemmaProcessor, is_torch_available, is_vision_available, @@ -177,7 +179,14 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes Model tester for `PaliGemmaForConditionalGeneration`. """ - all_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + PaliGemmaModel, + PaliGemmaForConditionalGeneration, + ) + if is_torch_available() + else () + ) pipeline_model_mapping = {"image-text-to-text": PaliGemmaForConditionalGeneration} fx_compatible = False test_pruning = False @@ -242,16 +251,17 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config).to(torch_device) - _ = model(**input_dict) # successful forward with no modifications + curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further + _ = model(**curr_input_dict) # successful forward with no modifications # remove one image but leave the image token in text - input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...] + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...] with self.assertRaises(ValueError): - _ = model(**input_dict) + _ = model(**curr_input_dict) # simulate multi-image case by concatenating inputs where each has exactly one image/image-token - input_ids = input_dict["input_ids"][:1] - pixel_values = input_dict["pixel_values"][:1] + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:1] input_ids = torch.cat([input_ids, input_ids], dim=0) # one image and two image tokens raise an error diff --git a/tests/models/paligemma2/test_modeling_paligemma2.py b/tests/models/paligemma2/test_modeling_paligemma2.py index e13430f2b73..bc62c527e29 100644 --- a/tests/models/paligemma2/test_modeling_paligemma2.py +++ b/tests/models/paligemma2/test_modeling_paligemma2.py @@ -13,6 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch PaliGemma model.""" +import copy import unittest import pytest @@ -239,16 +240,17 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config).to(torch_device) - _ = model(**input_dict) # successful forward with no modifications + curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further + _ = model(**curr_input_dict) # successful forward with no modifications # remove one image but leave the image token in text - input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...] + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...] with self.assertRaises(ValueError): - _ = model(**input_dict) + _ = model(**curr_input_dict) # simulate multi-image case by concatenating inputs where each has exactly one image/image-token - input_ids = input_dict["input_ids"][:1] - pixel_values = input_dict["pixel_values"][:1] + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:1] input_ids = torch.cat([input_ids, input_ids], dim=0) # one image and two image tokens raise an error diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index 21947dca355..5f06e1c84be 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -13,6 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch Qwen2.5-VL model.""" +import copy import gc import tempfile import unittest @@ -23,6 +24,7 @@ from transformers import ( AutoProcessor, Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLModel, is_torch_available, is_vision_available, ) @@ -180,17 +182,11 @@ class Qwen2_5_VLVisionText2TextModelTester: input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id input_ids[:, self.num_image_tokens] = self.image_token_id input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id - labels = torch.zeros( - (self.batch_size, self.seq_length), - dtype=torch.long, - device=torch_device, - ) inputs_dict = { "pixel_values": pixel_values, "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size), "input_ids": input_ids, "attention_mask": attention_mask, - "labels": labels, } return config, inputs_dict @@ -201,7 +197,14 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test Model tester for `Qwen2_5_VLForConditionalGeneration`. """ - all_model_classes = (Qwen2_5_VLForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + Qwen2_5_VLModel, + Qwen2_5_VLForConditionalGeneration, + ) + if is_torch_available() + else () + ) test_pruning = False test_head_masking = False @@ -236,19 +239,20 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test for model_class in self.all_model_classes: model = model_class(config).to(torch_device) _ = model(**input_dict) # successful forward with no modifications + curr_input_dict = copy.deepcopy(input_dict) # remove one image but leave the image token in text patch_size = config.vision_config.patch_size one_img_length = (self.model_tester.image_size**2) // (patch_size**2) - input_dict["pixel_values"] = input_dict["pixel_values"][-one_img_length:, ...] - input_dict["image_grid_thw"] = input_dict["image_grid_thw"][-1:, ...] + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-one_img_length:, ...] + curr_input_dict["image_grid_thw"] = curr_input_dict["image_grid_thw"][-1:, ...] with self.assertRaises(ValueError): - _ = model(**input_dict) + _ = model(**curr_input_dict) # simulate multi-image case by concatenating inputs where each has exactly one image/image-token - input_ids = input_dict["input_ids"][:1] - pixel_values = input_dict["pixel_values"][:one_img_length] - image_grid_thw = input_dict["image_grid_thw"][:1] + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:one_img_length] + image_grid_thw = curr_input_dict["image_grid_thw"][:1] input_ids = torch.cat([input_ids, input_ids], dim=0) # one image and two image tokens raise an error @@ -375,6 +379,29 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test def test_save_load_fast_init_from_base(self): pass + # The multimodal base model embeds will not match ids, due to pixel values. We can't change base test + # because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels` + # TODO: @raushan + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + inputs_embeds = model.get_input_embeddings()(input_ids) + + with torch.no_grad(): + out_ids = model(input_ids=input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] + torch.testing.assert_close(out_embeds, out_ids) + @require_torch class Qwen2_5_VLIntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index e213ccd819b..57e112790cc 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -13,6 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch Qwen2-VL model.""" +import copy import gc import unittest @@ -22,6 +23,7 @@ from transformers import ( AutoProcessor, Qwen2VLConfig, Qwen2VLForConditionalGeneration, + Qwen2VLModel, is_torch_available, is_vision_available, ) @@ -169,17 +171,12 @@ class Qwen2VLVisionText2TextModelTester: input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id input_ids[:, self.num_image_tokens] = self.image_token_id input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id - labels = torch.zeros( - (self.batch_size, self.seq_length), - dtype=torch.long, - device=torch_device, - ) + inputs_dict = { "pixel_values": pixel_values, "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size), "input_ids": input_ids, "attention_mask": attention_mask, - "labels": labels, } return config, inputs_dict @@ -190,7 +187,14 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas Model tester for `Qwen2VLForConditionalGeneration`. """ - all_model_classes = (Qwen2VLForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + Qwen2VLModel, + Qwen2VLForConditionalGeneration, + ) + if is_torch_available() + else () + ) pipeline_model_mapping = {"image-text-to-text": Qwen2VLForConditionalGeneration} test_pruning = False test_head_masking = False @@ -226,20 +230,21 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config).to(torch_device) - _ = model(**input_dict) # successful forward with no modifications + curr_input_dict = copy.deepcopy(input_dict) + _ = model(**curr_input_dict) # successfull forward with no modifications # remove one image but leave the image token in text patch_size = config.vision_config.patch_size one_img_length = (self.model_tester.image_size**2) // (patch_size**2) - input_dict["pixel_values"] = input_dict["pixel_values"][-one_img_length:, ...] - input_dict["image_grid_thw"] = input_dict["image_grid_thw"][-1:, ...] + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-one_img_length:, ...] + curr_input_dict["image_grid_thw"] = curr_input_dict["image_grid_thw"][-1:, ...] with self.assertRaises(ValueError): - _ = model(**input_dict) + _ = model(**curr_input_dict) # simulate multi-image case by concatenating inputs where each has exactly one image/image-token - input_ids = input_dict["input_ids"][:1] - pixel_values = input_dict["pixel_values"][:one_img_length] - image_grid_thw = input_dict["image_grid_thw"][:1] + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:one_img_length] + image_grid_thw = curr_input_dict["image_grid_thw"][:1] input_ids = torch.cat([input_ids, input_ids], dim=0) # one image and two image tokens raise an error @@ -262,11 +267,11 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas model = model_class(config).to(torch_device) # Generate and make sure rope_deltas are not `None` - self.assertTrue(model.rope_deltas is None) + self.assertTrue(model.model.rope_deltas is None) generation_output = model.generate( **input_dict, max_new_tokens=4, return_dict_in_generate=True, output_logits=True ) - self.assertTrue(model.rope_deltas is not None) + self.assertTrue(model.model.rope_deltas is not None) # Now if we try to do forward pass, we should get new rope logits, because cache is not passed forward_output = model(**input_dict) @@ -320,6 +325,29 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas def test_save_load_fast_init_from_base(self): pass + # The multimodal base model embeds will not match ids, due to pixel values. We can't change base test + # because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels` + # TODO: @raushan + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + inputs_embeds = model.get_input_embeddings()(input_ids) + + with torch.no_grad(): + out_ids = model(input_ids=input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] + torch.testing.assert_close(out_embeds, out_ids) + @require_torch class Qwen2VLIntegrationTest(unittest.TestCase): diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index bda728b3919..92ad5a193bf 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -13,6 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch VideoLlava model.""" +import copy import unittest import numpy as np @@ -23,6 +24,7 @@ from parameterized import parameterized from transformers import ( VideoLlavaConfig, VideoLlavaForConditionalGeneration, + VideoLlavaModel, VideoLlavaProcessor, is_torch_available, is_vision_available, @@ -190,7 +192,14 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe Model tester for `VideoLlavaForConditionalGeneration`. """ - all_model_classes = (VideoLlavaForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + VideoLlavaModel, + VideoLlavaForConditionalGeneration, + ) + if is_torch_available() + else () + ) fx_compatible = False test_pruning = False test_resize_embeddings = True @@ -235,46 +244,49 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe def test_mixed_input(self): config, inputs = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: + curr_inputs = copy.deepcopy(inputs) model = model_class(config).to(torch_device).eval() # test that the forward does not fail with torch.no_grad(): - _ = model(**inputs) + _ = model(**curr_inputs) # if we remove some images from inputs leaving only one # image number mismatch error should raise - inputs["pixel_values_images"] = inputs["pixel_values_images"][:1] + curr_inputs["pixel_values_images"] = curr_inputs["pixel_values_images"][:1] with self.assertRaises(ValueError): - _ = model(**inputs) + _ = model(**curr_inputs) def test_video_only_input(self): config, inputs = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: + curr_inputs = copy.deepcopy(inputs) model = model_class(config).to(torch_device).eval() # replace image token id with dummy id # Error will be raised as num-image-tokens and num-of-image-embeds mismatch - inputs["input_ids"][:, : self.model_tester.num_image_tokens] = 2 + curr_inputs["input_ids"][:, : self.model_tester.num_image_tokens] = 2 with self.assertRaises(ValueError): - _ = model(**inputs) + _ = model(**curr_inputs) - inputs["pixel_values_images"] = None - _ = model(**inputs) + curr_inputs["pixel_values_images"] = None + _ = model(**curr_inputs) def test_image_only_input(self): config, inputs = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: + curr_inputs = copy.deepcopy(inputs) model = model_class(config).to(torch_device).eval() # set dummy id, which is not video token id # Error will be raised as num-video-tokens and num-of-video-embeds mismatch - inputs["input_ids"][ + curr_inputs["input_ids"][ :, self.model_tester.num_image_tokens : self.model_tester.num_image_tokens + self.model_tester.num_video_tokens, ] = 2 with self.assertRaises(ValueError): - _ = model(**inputs) + _ = model(**curr_inputs) - inputs["pixel_values_videos"] = None - _ = model(**inputs) + curr_inputs["pixel_values_videos"] = None + _ = model(**curr_inputs) def test_batching_equivalence(self): def recursive_check(batched_object, single_row_object, model_name, key): @@ -386,16 +398,17 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config).to(torch_device) - _ = model(**input_dict) # successful forward with no modifications + curr_input_dict = copy.deepcopy(input_dict) + _ = model(**curr_input_dict) # successfull forward with no modifications # remove one image but leave the image token in text - input_dict["pixel_values_images"] = input_dict["pixel_values_images"][-1:, ...] + curr_input_dict["pixel_values_images"] = curr_input_dict["pixel_values_images"][-1:, ...] with self.assertRaises(ValueError): - _ = model(**input_dict) + _ = model(**curr_input_dict) # simulate multi-image case by concatenating inputs where each has exactly one image/image-token - input_ids = input_dict["input_ids"][:1] - pixel_values = input_dict["pixel_values_images"][:1] + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values_images"][:1] input_ids = torch.cat([input_ids, input_ids], dim=0) # one image and two image tokens raise an error @@ -429,7 +442,8 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe model = model_class(config).to(torch_device) # We should have the right number of input features, # and should be able to run a forward pass without exploding - assert model.multi_modal_projector.linear_1.in_features == expected_features + base_model = getattr(model, "model", model) + assert base_model.multi_modal_projector.linear_1.in_features == expected_features model(**input_dict) diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py index 60460ddfb5b..79b57e12ffa 100644 --- a/tests/models/vipllava/test_modeling_vipllava.py +++ b/tests/models/vipllava/test_modeling_vipllava.py @@ -13,6 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch VipLlava model.""" +import copy import unittest import requests @@ -22,6 +23,7 @@ from transformers import ( AutoProcessor, VipLlavaConfig, VipLlavaForConditionalGeneration, + VipLlavaModel, is_torch_available, is_vision_available, ) @@ -165,7 +167,14 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest Model tester for `VipLlavaForConditionalGeneration`. """ - all_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else () + all_model_classes = ( + ( + VipLlavaModel, + VipLlavaForConditionalGeneration, + ) + if is_torch_available() + else () + ) pipeline_model_mapping = {"image-text-to-text": VipLlavaForConditionalGeneration} if is_torch_available() else {} fx_compatible = False test_pruning = False @@ -236,16 +245,17 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config).to(torch_device) - _ = model(**input_dict) # successful forward with no modifications + curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further + _ = model(**curr_input_dict) # successful forward with no modifications # remove one image but leave the image token in text - input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...] + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...] with self.assertRaises(ValueError): - _ = model(**input_dict) + _ = model(**curr_input_dict) # simulate multi-image case by concatenating inputs where each has exactly one image/image-token - input_ids = input_dict["input_ids"][:1] - pixel_values = input_dict["pixel_values"][:1] + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:1] input_ids = torch.cat([input_ids, input_ids], dim=0) # one image and two image tokens raise an error @@ -284,7 +294,8 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest model = model_class(config).to(torch_device) # We should have the right number of input features, # and should be able to run a forward pass without exploding - assert model.multi_modal_projector.linear_1.in_features == expected_features + base_model = getattr(model, "model", model) + assert base_model.multi_modal_projector.linear_1.in_features == expected_features model(**input_dict) @unittest.skip( @@ -311,6 +322,10 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass + @unittest.skip("LLaVA vision backbones doesn't support flex attention yet") + def test_flex_attention_with_grads(self): + pass + @require_torch class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 974bfb7b5a7..b8a4fda96d3 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3494,8 +3494,8 @@ class ModelTesterMixin: vision_model_name = [name for name in vision_model_names if hasattr(model_sdpa, name)][0] language_model_name = [name for name in language_model_names if hasattr(model_sdpa, name)][0] - vision_model_sdpa = getattr(model, vision_model_name) - language_model_sdpa = getattr(model, language_model_name) + vision_model_sdpa = getattr(model_sdpa, vision_model_name) + language_model_sdpa = getattr(model_sdpa, language_model_name) text_attn = "sdpa" if language_model_sdpa._supports_sdpa else "eager" vision_attn = "sdpa" if vision_model_sdpa._supports_sdpa else "eager" @@ -4489,7 +4489,8 @@ class ModelTesterMixin: @require_torch_gpu def test_flex_attention_with_grads(self): for model_class in self.all_model_classes: - if not model_class._supports_flex_attn: + # TODO: raushan, fix for composite models after making VLMs support new attn API + if not model_class._supports_flex_attn or self._is_composite: self.skipTest(reason="This model does not support flex attention") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config._attn_implementation = "flex_attention" diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index dc0dac82da2..cea9ed693d5 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -378,8 +378,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s "rope_theta", "partial_rotary_factor", "pretraining_tp", - "boi_token_index", - "eoi_token_index", + "boi_token_id", + "eoi_token_id", ] attributes_used_in_generation = ["encoder_no_repeat_ngram_size"] diff --git a/utils/check_repo.py b/utils/check_repo.py index 960a73e734b..9eb5ccdf06a 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -156,6 +156,8 @@ IGNORE_NON_TESTED = ( "Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests "Emu3VQVAE", # Building part of bigger (tested) model "Emu3TextModel", # Building part of bigger (tested) model + "Qwen2VLTextModel", # Building part of bigger (tested) model + "Qwen2_5_VLTextModel", # Building part of bigger (tested) model "InternVLVisionModel", # Building part of bigger (tested) model "JanusVisionModel", # Building part of bigger (tested) model "TimesFmModel", # Building part of bigger (tested) model