update florence2 model

This commit is contained in:
ducviet00 2025-07-01 15:42:28 +07:00
parent aa88d1f3f2
commit 19978e827e
7 changed files with 393 additions and 62 deletions

View File

@ -26,7 +26,7 @@ The abstract from the paper is the following:
*We introduce Florence-2, a novel vision foundation model with a unified, prompt-based representation for a variety of computer vision and vision-language tasks. While existing large vision models excel in transfer learning, they struggle to perform a diversity of tasks with simple instructions, a capability that implies handling the complexity of various spatial hierarchy and semantic granularity. Florence-2 was designed to take text-prompt as task instructions and generate desirable results in text forms, whether it be captioning, object detection, grounding or segmentation. This multi-task learning setup demands large-scale, high-quality annotated data. To this end, we co-developed FLD-5B that consists of 5.4 billion comprehensive visual annotations on 126 million images, using an iterative strategy of automated image annotation and model refinement. We adopted a sequence-to-sequence structure to train Florence-2 to perform versatile and comprehensive vision tasks. Extensive evaluations on numerous tasks demonstrated Florence-2 to be a strong vision foundation model contender with unprecedented zero-shot and fine-tuning capabilities.* *We introduce Florence-2, a novel vision foundation model with a unified, prompt-based representation for a variety of computer vision and vision-language tasks. While existing large vision models excel in transfer learning, they struggle to perform a diversity of tasks with simple instructions, a capability that implies handling the complexity of various spatial hierarchy and semantic granularity. Florence-2 was designed to take text-prompt as task instructions and generate desirable results in text forms, whether it be captioning, object detection, grounding or segmentation. This multi-task learning setup demands large-scale, high-quality annotated data. To this end, we co-developed FLD-5B that consists of 5.4 billion comprehensive visual annotations on 126 million images, using an iterative strategy of automated image annotation and model refinement. We adopted a sequence-to-sequence structure to train Florence-2 to perform versatile and comprehensive vision tasks. Extensive evaluations on numerous tasks demonstrated Florence-2 to be a strong vision foundation model contender with unprecedented zero-shot and fine-tuning capabilities.*
This model was contributed by [hlky](https://huggingface.co/hlky). This model was contributed by [ducviet00](https://huggingface.co/ducviet00).
The original code can be found [here](https://huggingface.co/microsoft/Florence-2-base/tree/main). The original code can be found [here](https://huggingface.co/microsoft/Florence-2-base/tree/main).
## Florence2VisionConfig ## Florence2VisionConfig
@ -41,6 +41,11 @@ The original code can be found [here](https://huggingface.co/microsoft/Florence-
[[autodoc]] Florence2Processor [[autodoc]] Florence2Processor
## Florence2Model
[[autodoc]] Florence2Model
- forward
## Florence2ForConditionalGeneration ## Florence2ForConditionalGeneration
[[autodoc]] Florence2ForConditionalGeneration [[autodoc]] Florence2ForConditionalGeneration

View File

@ -124,7 +124,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"), ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
("flaubert", "FlaubertModel"), ("flaubert", "FlaubertModel"),
("flava", "FlavaModel"), ("flava", "FlavaModel"),
("florence2", "Florence2ForConditionalGeneration"), ("florence2", "Florence2Model"),
("fnet", "FNetModel"), ("fnet", "FNetModel"),
("focalnet", "FocalNetModel"), ("focalnet", "FocalNetModel"),
("fsmt", "FSMTModel"), ("fsmt", "FSMTModel"),

View File

@ -5,7 +5,7 @@
# modular_florence2.py file directly. One of our CI enforces this. # modular_florence2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. # Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -115,14 +115,14 @@ class Florence2VisionConfig(PretrainedConfig):
super().__init__(**kwargs) super().__init__(**kwargs)
self.in_channels = in_channels self.in_channels = in_channels
self.depths = depths self.depths = list(depths)
self.patch_size = patch_size self.patch_size = list(patch_size)
self.patch_stride = patch_stride self.patch_stride = list(patch_stride)
self.patch_padding = patch_padding self.patch_padding = list(patch_padding)
self.patch_prenorm = patch_prenorm self.patch_prenorm = list(patch_prenorm)
self.embed_dim = embed_dim self.embed_dim = list(embed_dim)
self.num_heads = num_heads self.num_heads = list(num_heads)
self.num_groups = num_groups self.num_groups = list(num_groups)
self.window_size = window_size self.window_size = window_size
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio

View File

@ -5,7 +5,7 @@
# modular_florence2.py file directly. One of our CI enforces this. # modular_florence2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. # Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -31,11 +31,11 @@ from transformers.activations import ACT2FN
from ...cache_utils import Cache from ...cache_utils import Cache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import Seq2SeqLMOutput from ...modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple from ...utils import LossKwargs, auto_docstring, can_return_tuple
from ..auto import AutoModelForSeq2SeqLM from ..auto import AutoModel, AutoModelForSeq2SeqLM
from .configuration_florence2 import Florence2Config, Florence2VisionConfig from .configuration_florence2 import Florence2Config, Florence2VisionConfig
@ -622,6 +622,23 @@ class Florence2VisionProjector(nn.Module):
return image_features return image_features
@dataclass
@auto_docstring(
custom_intro="""
Base class for Florence-2 base model's outputs that also contains : pre-computed hidden states that can speed up sequential
decoding.
"""
)
class Florence2Seq2SeqModelOutput(Seq2SeqModelOutput):
r"""
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
@dataclass @dataclass
@auto_docstring( @auto_docstring(
custom_intro=""" custom_intro="""
@ -635,12 +652,6 @@ class Florence2Seq2SeqLMOutput(Seq2SeqLMOutput):
Language modeling loss (for next-token prediction). Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*): image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`. A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
@ -679,6 +690,152 @@ class Florence2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
@auto_docstring(
custom_intro="""
Florence-2 is an advanced vision foundation model that uses a prompt-based approach to handle a wide range of vision and vision-language tasks. Florence-2 can interpret simple text prompts to perform tasks like captioning, object detection, and segmentation.
"""
)
class Florence2Model(Florence2PreTrainedModel):
_tied_weights_keys = [
"language_model.shared.weight",
"language_model.encoder.embed_tokens.weight",
"language_model.decoder.embed_tokens.weight",
]
def __init__(self, config: Florence2Config):
super().__init__(config)
self.config = config
self.vocab_size = config.text_config.vocab_size
self.vision_tower = Florence2VisionBackbone(config=config.vision_config)
self.vision_projector = Florence2VisionProjector(config=config.vision_config)
self.language_model = AutoModel.from_config(config=config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
def get_encoder(self):
return self.language_model.get_encoder()
def get_decoder(self):
return self.language_model.get_decoder()
def set_input_embeddings(self, value: nn.Module) -> None:
self.language_model.set_input_embeddings(value)
def get_input_embeddings(self) -> nn.Module:
return self.language_model.get_input_embeddings()
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.language_model.set_output_embeddings(new_embeddings)
def get_output_embeddings(self) -> nn.Linear:
return self.language_model.get_output_embeddings()
@auto_docstring
def get_image_features(self, pixel_values: torch.Tensor, **kwargs):
image_features = self.vision_tower(pixel_values, **kwargs)
image_embeds = self.vision_projector(image_features)
return image_embeds
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, Florence2Seq2SeqModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Florence2ForConditionalGeneration
>>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large")
>>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large")
>>> prompt = "<CAPTION>"
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=100)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"A green car parked in front of a yellow building."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long, device=inputs_embeds.device)
image_embeds = None
if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values)
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
if decoder_input_ids is None:
decoder_start_token_id = self.config.text_config.decoder_start_token_id
decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device)
decoder_input_ids *= decoder_start_token_id
outputs = self.language_model(
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
return Florence2Seq2SeqModelOutput(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
image_hidden_states=image_embeds,
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@ -696,12 +853,13 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
def __init__(self, config: Florence2Config): def __init__(self, config: Florence2Config):
super().__init__(config) super().__init__(config)
self.vocab_size = config.vocab_size self.vocab_size = config.text_config.vocab_size
self.vision_tower = Florence2VisionBackbone(config=config.vision_config) self.vision_tower = Florence2VisionBackbone(config=config.vision_config)
self.vision_projector = Florence2VisionProjector(config=config.vision_config) self.vision_projector = Florence2VisionProjector(config=config.vision_config)
self.language_model = AutoModelForSeq2SeqLM.from_config(config=config.text_config) self.language_model = AutoModelForSeq2SeqLM.from_config(config=config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init() self.post_init()
def get_encoder(self): def get_encoder(self):
@ -787,6 +945,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -800,6 +961,11 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1) inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1) attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
if decoder_input_ids is None:
decoder_start_token_id = self.config.text_config.decoder_start_token_id
decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device)
decoder_input_ids *= decoder_start_token_id
outputs = self.language_model( outputs = self.language_model(
attention_mask=attention_mask, attention_mask=attention_mask,
labels=labels, labels=labels,
@ -856,6 +1022,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
__all__ = [ __all__ = [
"Florence2Model",
"Florence2ForConditionalGeneration", "Florence2ForConditionalGeneration",
"Florence2PreTrainedModel", "Florence2PreTrainedModel",
"Florence2VisionBackbone", "Florence2VisionBackbone",

View File

@ -1,5 +1,5 @@
# coding=utf-8 # coding=utf-8
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. # Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -26,7 +26,7 @@ from ...cache_utils import Cache
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import Seq2SeqLMOutput from ...modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import ( from ...utils import (
@ -35,7 +35,7 @@ from ...utils import (
can_return_tuple, can_return_tuple,
logging, logging,
) )
from ..auto import CONFIG_MAPPING, AutoConfig, AutoModelForSeq2SeqLM from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForSeq2SeqLM
from ..beit.modeling_beit import BeitDropPath from ..beit.modeling_beit import BeitDropPath
from ..detr.modeling_detr import DetrLearnedPositionEmbedding from ..detr.modeling_detr import DetrLearnedPositionEmbedding
@ -132,14 +132,14 @@ class Florence2VisionConfig(PretrainedConfig):
super().__init__(**kwargs) super().__init__(**kwargs)
self.in_channels = in_channels self.in_channels = in_channels
self.depths = depths self.depths = list(depths)
self.patch_size = patch_size self.patch_size = list(patch_size)
self.patch_stride = patch_stride self.patch_stride = list(patch_stride)
self.patch_padding = patch_padding self.patch_padding = list(patch_padding)
self.patch_prenorm = patch_prenorm self.patch_prenorm = list(patch_prenorm)
self.embed_dim = embed_dim self.embed_dim = list(embed_dim)
self.num_heads = num_heads self.num_heads = list(num_heads)
self.num_groups = num_groups self.num_groups = list(num_groups)
self.window_size = window_size self.window_size = window_size
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
@ -760,6 +760,23 @@ class Florence2VisionProjector(nn.Module):
return image_features return image_features
@dataclass
@auto_docstring(
custom_intro="""
Base class for Florence-2 base model's outputs that also contains : pre-computed hidden states that can speed up sequential
decoding.
"""
)
class Florence2Seq2SeqModelOutput(Seq2SeqModelOutput):
r"""
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
@dataclass @dataclass
@auto_docstring( @auto_docstring(
custom_intro=""" custom_intro="""
@ -773,12 +790,6 @@ class Florence2Seq2SeqLMOutput(Seq2SeqLMOutput):
Language modeling loss (for next-token prediction). Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*): image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`. A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
@ -817,6 +828,152 @@ class Florence2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
@auto_docstring(
custom_intro="""
Florence-2 is an advanced vision foundation model that uses a prompt-based approach to handle a wide range of vision and vision-language tasks. Florence-2 can interpret simple text prompts to perform tasks like captioning, object detection, and segmentation.
"""
)
class Florence2Model(Florence2PreTrainedModel):
_tied_weights_keys = [
"language_model.shared.weight",
"language_model.encoder.embed_tokens.weight",
"language_model.decoder.embed_tokens.weight",
]
def __init__(self, config: Florence2Config):
super().__init__(config)
self.config = config
self.vocab_size = config.text_config.vocab_size
self.vision_tower = Florence2VisionBackbone(config=config.vision_config)
self.vision_projector = Florence2VisionProjector(config=config.vision_config)
self.language_model = AutoModel.from_config(config=config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
def get_encoder(self):
return self.language_model.get_encoder()
def get_decoder(self):
return self.language_model.get_decoder()
def set_input_embeddings(self, value: nn.Module) -> None:
self.language_model.set_input_embeddings(value)
def get_input_embeddings(self) -> nn.Module:
return self.language_model.get_input_embeddings()
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.language_model.set_output_embeddings(new_embeddings)
def get_output_embeddings(self) -> nn.Linear:
return self.language_model.get_output_embeddings()
@auto_docstring
def get_image_features(self, pixel_values: torch.Tensor, **kwargs):
image_features = self.vision_tower(pixel_values, **kwargs)
image_embeds = self.vision_projector(image_features)
return image_embeds
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, Florence2Seq2SeqModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Florence2ForConditionalGeneration
>>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large")
>>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large")
>>> prompt = "<CAPTION>"
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=100)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"A green car parked in front of a yellow building."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long, device=inputs_embeds.device)
image_embeds = None
if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values)
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
if decoder_input_ids is None:
decoder_start_token_id = self.config.text_config.decoder_start_token_id
decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device)
decoder_input_ids *= decoder_start_token_id
outputs = self.language_model(
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
return Florence2Seq2SeqModelOutput(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
image_hidden_states=image_embeds,
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@ -834,12 +991,13 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
def __init__(self, config: Florence2Config): def __init__(self, config: Florence2Config):
super().__init__(config) super().__init__(config)
self.vocab_size = config.vocab_size self.vocab_size = config.text_config.vocab_size
self.vision_tower = Florence2VisionBackbone(config=config.vision_config) self.vision_tower = Florence2VisionBackbone(config=config.vision_config)
self.vision_projector = Florence2VisionProjector(config=config.vision_config) self.vision_projector = Florence2VisionProjector(config=config.vision_config)
self.language_model = AutoModelForSeq2SeqLM.from_config(config=config.text_config) self.language_model = AutoModelForSeq2SeqLM.from_config(config=config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init() self.post_init()
def get_encoder(self): def get_encoder(self):
@ -925,6 +1083,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -938,6 +1099,11 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1) inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1) attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
if decoder_input_ids is None:
decoder_start_token_id = self.config.text_config.decoder_start_token_id
decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device)
decoder_input_ids *= decoder_start_token_id
outputs = self.language_model( outputs = self.language_model(
attention_mask=attention_mask, attention_mask=attention_mask,
labels=labels, labels=labels,
@ -996,6 +1162,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
__all__ = [ __all__ = [
"Florence2Config", "Florence2Config",
"Florence2VisionConfig", "Florence2VisionConfig",
"Florence2Model",
"Florence2ForConditionalGeneration", "Florence2ForConditionalGeneration",
"Florence2PreTrainedModel", "Florence2PreTrainedModel",
"Florence2VisionBackbone", "Florence2VisionBackbone",

View File

@ -1,5 +1,5 @@
# coding=utf-8 # coding=utf-8
# Copyright 2025 Microsoft and The HuggingFace Inc. team. # Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,10 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Processor class for FLORENCE2.
"""
import math import math
import re import re
from typing import Union from typing import Union
@ -89,13 +85,6 @@ class Florence2Processor(ProcessorMixin):
image_processor=None, image_processor=None,
tokenizer=None, tokenizer=None,
): ):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
if not hasattr(image_processor, "image_seq_length"):
raise ValueError("Image processor is missing an `image_seq_length` attribute.")
self.image_seq_length = image_processor.image_seq_length self.image_seq_length = image_processor.image_seq_length
tokens_to_add = { tokens_to_add = {

View File

@ -22,6 +22,7 @@ from transformers import (
AutoProcessor, AutoProcessor,
Florence2Config, Florence2Config,
Florence2ForConditionalGeneration, Florence2ForConditionalGeneration,
Florence2Model,
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
@ -212,19 +213,27 @@ class Florence2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
Model tester for `Florence2ForConditionalGeneration`. Model tester for `Florence2ForConditionalGeneration`.
""" """
additional_model_inputs = ["pixel_values"] all_model_classes = (Florence2ForConditionalGeneration, Florence2Model) if is_torch_available() else ()
all_model_classes = (Florence2ForConditionalGeneration,) if is_torch_available() else () pipeline_model_mapping = (
all_generative_model_classes = (Florence2ForConditionalGeneration,) if is_torch_available() else () {
pipeline_model_mapping = {"image-to-text": Florence2ForConditionalGeneration} if is_torch_available() else {} "image-to-text": Florence2ForConditionalGeneration,
"image-text-to-text": Florence2ForConditionalGeneration,
}
if is_torch_available()
else {}
)
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
test_attention_outputs = False test_attention_outputs = False
test_torchscript = False _is_composite = True
def setUp(self): def setUp(self):
self.model_tester = Florence2VisionText2TextModelTester(self) self.model_tester = Florence2VisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=Florence2Config, has_text_modality=False) self.config_tester = ConfigTester(self, config_class=Florence2Config, has_text_modality=False)
def test_config(self):
self.config_tester.run_common_tests()
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self): def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -286,12 +295,6 @@ class Florence2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
# @unittest.skip(
# reason="This architecture has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
# )
# def test_contrastive_generate_low_memory(self):
# pass
@unittest.skip( @unittest.skip(
reason="This architecture has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" reason="This architecture has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
) )