diff --git a/docs/source/en/model_doc/florence2.md b/docs/source/en/model_doc/florence2.md index a7946f3f15c..f5843d1afb2 100644 --- a/docs/source/en/model_doc/florence2.md +++ b/docs/source/en/model_doc/florence2.md @@ -26,7 +26,7 @@ The abstract from the paper is the following: *We introduce Florence-2, a novel vision foundation model with a unified, prompt-based representation for a variety of computer vision and vision-language tasks. While existing large vision models excel in transfer learning, they struggle to perform a diversity of tasks with simple instructions, a capability that implies handling the complexity of various spatial hierarchy and semantic granularity. Florence-2 was designed to take text-prompt as task instructions and generate desirable results in text forms, whether it be captioning, object detection, grounding or segmentation. This multi-task learning setup demands large-scale, high-quality annotated data. To this end, we co-developed FLD-5B that consists of 5.4 billion comprehensive visual annotations on 126 million images, using an iterative strategy of automated image annotation and model refinement. We adopted a sequence-to-sequence structure to train Florence-2 to perform versatile and comprehensive vision tasks. Extensive evaluations on numerous tasks demonstrated Florence-2 to be a strong vision foundation model contender with unprecedented zero-shot and fine-tuning capabilities.* -This model was contributed by [hlky](https://huggingface.co/hlky). +This model was contributed by [ducviet00](https://huggingface.co/ducviet00). The original code can be found [here](https://huggingface.co/microsoft/Florence-2-base/tree/main). ## Florence2VisionConfig @@ -41,6 +41,11 @@ The original code can be found [here](https://huggingface.co/microsoft/Florence- [[autodoc]] Florence2Processor +## Florence2Model + +[[autodoc]] Florence2Model + - forward + ## Florence2ForConditionalGeneration [[autodoc]] Florence2ForConditionalGeneration diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 2cb886262df..ada4c6ecaa3 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -124,7 +124,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"), ("flaubert", "FlaubertModel"), ("flava", "FlavaModel"), - ("florence2", "Florence2ForConditionalGeneration"), + ("florence2", "Florence2Model"), ("fnet", "FNetModel"), ("focalnet", "FocalNetModel"), ("fsmt", "FSMTModel"), diff --git a/src/transformers/models/florence2/configuration_florence2.py b/src/transformers/models/florence2/configuration_florence2.py index 3dbb1452ec6..bcaae1a9059 100644 --- a/src/transformers/models/florence2/configuration_florence2.py +++ b/src/transformers/models/florence2/configuration_florence2.py @@ -5,7 +5,7 @@ # modular_florence2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 -# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -115,14 +115,14 @@ class Florence2VisionConfig(PretrainedConfig): super().__init__(**kwargs) self.in_channels = in_channels - self.depths = depths - self.patch_size = patch_size - self.patch_stride = patch_stride - self.patch_padding = patch_padding - self.patch_prenorm = patch_prenorm - self.embed_dim = embed_dim - self.num_heads = num_heads - self.num_groups = num_groups + self.depths = list(depths) + self.patch_size = list(patch_size) + self.patch_stride = list(patch_stride) + self.patch_padding = list(patch_padding) + self.patch_prenorm = list(patch_prenorm) + self.embed_dim = list(embed_dim) + self.num_heads = list(num_heads) + self.num_groups = list(num_groups) self.window_size = window_size self.drop_path_rate = drop_path_rate self.mlp_ratio = mlp_ratio diff --git a/src/transformers/models/florence2/modeling_florence2.py b/src/transformers/models/florence2/modeling_florence2.py index a90001ee8e2..17a1d2e617e 100644 --- a/src/transformers/models/florence2/modeling_florence2.py +++ b/src/transformers/models/florence2/modeling_florence2.py @@ -5,7 +5,7 @@ # modular_florence2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 -# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,11 +31,11 @@ from transformers.activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import Seq2SeqLMOutput +from ...modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple -from ..auto import AutoModelForSeq2SeqLM +from ..auto import AutoModel, AutoModelForSeq2SeqLM from .configuration_florence2 import Florence2Config, Florence2VisionConfig @@ -622,6 +622,23 @@ class Florence2VisionProjector(nn.Module): return image_features +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Florence-2 base model's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + """ +) +class Florence2Seq2SeqModelOutput(Seq2SeqModelOutput): + r""" + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + @dataclass @auto_docstring( custom_intro=""" @@ -635,12 +652,6 @@ class Florence2Seq2SeqLMOutput(Seq2SeqLMOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. @@ -679,6 +690,152 @@ class Florence2PreTrainedModel(PreTrainedModel): module.bias.data.zero_() +@auto_docstring( + custom_intro=""" + Florence-2 is an advanced vision foundation model that uses a prompt-based approach to handle a wide range of vision and vision-language tasks. Florence-2 can interpret simple text prompts to perform tasks like captioning, object detection, and segmentation. + """ +) +class Florence2Model(Florence2PreTrainedModel): + _tied_weights_keys = [ + "language_model.shared.weight", + "language_model.encoder.embed_tokens.weight", + "language_model.decoder.embed_tokens.weight", + ] + + def __init__(self, config: Florence2Config): + super().__init__(config) + self.config = config + self.vocab_size = config.text_config.vocab_size + + self.vision_tower = Florence2VisionBackbone(config=config.vision_config) + self.vision_projector = Florence2VisionProjector(config=config.vision_config) + self.language_model = AutoModel.from_config(config=config.text_config) + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def get_encoder(self): + return self.language_model.get_encoder() + + def get_decoder(self): + return self.language_model.get_decoder() + + def set_input_embeddings(self, value: nn.Module) -> None: + self.language_model.set_input_embeddings(value) + + def get_input_embeddings(self) -> nn.Module: + return self.language_model.get_input_embeddings() + + def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: + self.language_model.set_output_embeddings(new_embeddings) + + def get_output_embeddings(self) -> nn.Linear: + return self.language_model.get_output_embeddings() + + @auto_docstring + def get_image_features(self, pixel_values: torch.Tensor, **kwargs): + image_features = self.vision_tower(pixel_values, **kwargs) + image_embeds = self.vision_projector(image_features) + return image_embeds + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, Florence2Seq2SeqModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Florence2ForConditionalGeneration + + >>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large") + >>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large") + + >>> prompt = "" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=100) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "A green car parked in front of a yellow building." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if attention_mask is None: + attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long, device=inputs_embeds.device) + + image_embeds = None + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values) + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1) + + if decoder_input_ids is None: + decoder_start_token_id = self.config.text_config.decoder_start_token_id + decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device) + decoder_input_ids *= decoder_start_token_id + + outputs = self.language_model( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + return Florence2Seq2SeqModelOutput( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + image_hidden_states=image_embeds, + ) + + class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -696,12 +853,13 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi def __init__(self, config: Florence2Config): super().__init__(config) - self.vocab_size = config.vocab_size + self.vocab_size = config.text_config.vocab_size self.vision_tower = Florence2VisionBackbone(config=config.vision_config) self.vision_projector = Florence2VisionProjector(config=config.vision_config) self.language_model = AutoModelForSeq2SeqLM.from_config(config=config.text_config) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() def get_encoder(self): @@ -787,6 +945,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -800,6 +961,11 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1) + if decoder_input_ids is None: + decoder_start_token_id = self.config.text_config.decoder_start_token_id + decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device) + decoder_input_ids *= decoder_start_token_id + outputs = self.language_model( attention_mask=attention_mask, labels=labels, @@ -856,6 +1022,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi __all__ = [ + "Florence2Model", "Florence2ForConditionalGeneration", "Florence2PreTrainedModel", "Florence2VisionBackbone", diff --git a/src/transformers/models/florence2/modular_florence2.py b/src/transformers/models/florence2/modular_florence2.py index 359fd697ad6..0516bf63e57 100644 --- a/src/transformers/models/florence2/modular_florence2.py +++ b/src/transformers/models/florence2/modular_florence2.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ from ...cache_utils import Cache from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import Seq2SeqLMOutput +from ...modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -35,7 +35,7 @@ from ...utils import ( can_return_tuple, logging, ) -from ..auto import CONFIG_MAPPING, AutoConfig, AutoModelForSeq2SeqLM +from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForSeq2SeqLM from ..beit.modeling_beit import BeitDropPath from ..detr.modeling_detr import DetrLearnedPositionEmbedding @@ -132,14 +132,14 @@ class Florence2VisionConfig(PretrainedConfig): super().__init__(**kwargs) self.in_channels = in_channels - self.depths = depths - self.patch_size = patch_size - self.patch_stride = patch_stride - self.patch_padding = patch_padding - self.patch_prenorm = patch_prenorm - self.embed_dim = embed_dim - self.num_heads = num_heads - self.num_groups = num_groups + self.depths = list(depths) + self.patch_size = list(patch_size) + self.patch_stride = list(patch_stride) + self.patch_padding = list(patch_padding) + self.patch_prenorm = list(patch_prenorm) + self.embed_dim = list(embed_dim) + self.num_heads = list(num_heads) + self.num_groups = list(num_groups) self.window_size = window_size self.drop_path_rate = drop_path_rate self.mlp_ratio = mlp_ratio @@ -760,6 +760,23 @@ class Florence2VisionProjector(nn.Module): return image_features +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Florence-2 base model's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + """ +) +class Florence2Seq2SeqModelOutput(Seq2SeqModelOutput): + r""" + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + @dataclass @auto_docstring( custom_intro=""" @@ -773,12 +790,6 @@ class Florence2Seq2SeqLMOutput(Seq2SeqLMOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. @@ -817,6 +828,152 @@ class Florence2PreTrainedModel(PreTrainedModel): module.bias.data.zero_() +@auto_docstring( + custom_intro=""" + Florence-2 is an advanced vision foundation model that uses a prompt-based approach to handle a wide range of vision and vision-language tasks. Florence-2 can interpret simple text prompts to perform tasks like captioning, object detection, and segmentation. + """ +) +class Florence2Model(Florence2PreTrainedModel): + _tied_weights_keys = [ + "language_model.shared.weight", + "language_model.encoder.embed_tokens.weight", + "language_model.decoder.embed_tokens.weight", + ] + + def __init__(self, config: Florence2Config): + super().__init__(config) + self.config = config + self.vocab_size = config.text_config.vocab_size + + self.vision_tower = Florence2VisionBackbone(config=config.vision_config) + self.vision_projector = Florence2VisionProjector(config=config.vision_config) + self.language_model = AutoModel.from_config(config=config.text_config) + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def get_encoder(self): + return self.language_model.get_encoder() + + def get_decoder(self): + return self.language_model.get_decoder() + + def set_input_embeddings(self, value: nn.Module) -> None: + self.language_model.set_input_embeddings(value) + + def get_input_embeddings(self) -> nn.Module: + return self.language_model.get_input_embeddings() + + def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: + self.language_model.set_output_embeddings(new_embeddings) + + def get_output_embeddings(self) -> nn.Linear: + return self.language_model.get_output_embeddings() + + @auto_docstring + def get_image_features(self, pixel_values: torch.Tensor, **kwargs): + image_features = self.vision_tower(pixel_values, **kwargs) + image_embeds = self.vision_projector(image_features) + return image_embeds + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, Florence2Seq2SeqModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Florence2ForConditionalGeneration + + >>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large") + >>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large") + + >>> prompt = "" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=100) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "A green car parked in front of a yellow building." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if attention_mask is None: + attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long, device=inputs_embeds.device) + + image_embeds = None + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values) + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1) + + if decoder_input_ids is None: + decoder_start_token_id = self.config.text_config.decoder_start_token_id + decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device) + decoder_input_ids *= decoder_start_token_id + + outputs = self.language_model( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + return Florence2Seq2SeqModelOutput( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + image_hidden_states=image_embeds, + ) + + class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -834,12 +991,13 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi def __init__(self, config: Florence2Config): super().__init__(config) - self.vocab_size = config.vocab_size + self.vocab_size = config.text_config.vocab_size self.vision_tower = Florence2VisionBackbone(config=config.vision_config) self.vision_projector = Florence2VisionProjector(config=config.vision_config) self.language_model = AutoModelForSeq2SeqLM.from_config(config=config.text_config) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() def get_encoder(self): @@ -925,6 +1083,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -938,6 +1099,11 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1) + if decoder_input_ids is None: + decoder_start_token_id = self.config.text_config.decoder_start_token_id + decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device) + decoder_input_ids *= decoder_start_token_id + outputs = self.language_model( attention_mask=attention_mask, labels=labels, @@ -996,6 +1162,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi __all__ = [ "Florence2Config", "Florence2VisionConfig", + "Florence2Model", "Florence2ForConditionalGeneration", "Florence2PreTrainedModel", "Florence2VisionBackbone", diff --git a/src/transformers/models/florence2/processing_florence2.py b/src/transformers/models/florence2/processing_florence2.py index 722f2654a5f..faf33e0ed21 100644 --- a/src/transformers/models/florence2/processing_florence2.py +++ b/src/transformers/models/florence2/processing_florence2.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2025 Microsoft and The HuggingFace Inc. team. +# Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Processor class for FLORENCE2. -""" - import math import re from typing import Union @@ -89,13 +85,6 @@ class Florence2Processor(ProcessorMixin): image_processor=None, tokenizer=None, ): - if image_processor is None: - raise ValueError("You need to specify an `image_processor`.") - if tokenizer is None: - raise ValueError("You need to specify a `tokenizer`.") - if not hasattr(image_processor, "image_seq_length"): - raise ValueError("Image processor is missing an `image_seq_length` attribute.") - self.image_seq_length = image_processor.image_seq_length tokens_to_add = { diff --git a/tests/models/florence2/test_modeling_florence2.py b/tests/models/florence2/test_modeling_florence2.py index 236f4598130..40c9aa5eef8 100644 --- a/tests/models/florence2/test_modeling_florence2.py +++ b/tests/models/florence2/test_modeling_florence2.py @@ -22,6 +22,7 @@ from transformers import ( AutoProcessor, Florence2Config, Florence2ForConditionalGeneration, + Florence2Model, is_torch_available, is_vision_available, ) @@ -212,19 +213,27 @@ class Florence2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes Model tester for `Florence2ForConditionalGeneration`. """ - additional_model_inputs = ["pixel_values"] - all_model_classes = (Florence2ForConditionalGeneration,) if is_torch_available() else () - all_generative_model_classes = (Florence2ForConditionalGeneration,) if is_torch_available() else () - pipeline_model_mapping = {"image-to-text": Florence2ForConditionalGeneration} if is_torch_available() else {} + all_model_classes = (Florence2ForConditionalGeneration, Florence2Model) if is_torch_available() else () + pipeline_model_mapping = ( + { + "image-to-text": Florence2ForConditionalGeneration, + "image-text-to-text": Florence2ForConditionalGeneration, + } + if is_torch_available() + else {} + ) test_pruning = False test_head_masking = False test_attention_outputs = False - test_torchscript = False + _is_composite = True def setUp(self): self.model_tester = Florence2VisionText2TextModelTester(self) self.config_tester = ConfigTester(self, config_class=Florence2Config, has_text_modality=False) + def test_config(self): + self.config_tester.run_common_tests() + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -286,12 +295,6 @@ class Florence2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes def test_training_gradient_checkpointing_use_reentrant_false(self): pass - # @unittest.skip( - # reason="This architecture has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" - # ) - # def test_contrastive_generate_low_memory(self): - # pass - @unittest.skip( reason="This architecture has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" )