mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
update florence2 model
This commit is contained in:
parent
aa88d1f3f2
commit
19978e827e
@ -26,7 +26,7 @@ The abstract from the paper is the following:
|
|||||||
|
|
||||||
*We introduce Florence-2, a novel vision foundation model with a unified, prompt-based representation for a variety of computer vision and vision-language tasks. While existing large vision models excel in transfer learning, they struggle to perform a diversity of tasks with simple instructions, a capability that implies handling the complexity of various spatial hierarchy and semantic granularity. Florence-2 was designed to take text-prompt as task instructions and generate desirable results in text forms, whether it be captioning, object detection, grounding or segmentation. This multi-task learning setup demands large-scale, high-quality annotated data. To this end, we co-developed FLD-5B that consists of 5.4 billion comprehensive visual annotations on 126 million images, using an iterative strategy of automated image annotation and model refinement. We adopted a sequence-to-sequence structure to train Florence-2 to perform versatile and comprehensive vision tasks. Extensive evaluations on numerous tasks demonstrated Florence-2 to be a strong vision foundation model contender with unprecedented zero-shot and fine-tuning capabilities.*
|
*We introduce Florence-2, a novel vision foundation model with a unified, prompt-based representation for a variety of computer vision and vision-language tasks. While existing large vision models excel in transfer learning, they struggle to perform a diversity of tasks with simple instructions, a capability that implies handling the complexity of various spatial hierarchy and semantic granularity. Florence-2 was designed to take text-prompt as task instructions and generate desirable results in text forms, whether it be captioning, object detection, grounding or segmentation. This multi-task learning setup demands large-scale, high-quality annotated data. To this end, we co-developed FLD-5B that consists of 5.4 billion comprehensive visual annotations on 126 million images, using an iterative strategy of automated image annotation and model refinement. We adopted a sequence-to-sequence structure to train Florence-2 to perform versatile and comprehensive vision tasks. Extensive evaluations on numerous tasks demonstrated Florence-2 to be a strong vision foundation model contender with unprecedented zero-shot and fine-tuning capabilities.*
|
||||||
|
|
||||||
This model was contributed by [hlky](https://huggingface.co/hlky).
|
This model was contributed by [ducviet00](https://huggingface.co/ducviet00).
|
||||||
The original code can be found [here](https://huggingface.co/microsoft/Florence-2-base/tree/main).
|
The original code can be found [here](https://huggingface.co/microsoft/Florence-2-base/tree/main).
|
||||||
|
|
||||||
## Florence2VisionConfig
|
## Florence2VisionConfig
|
||||||
@ -41,6 +41,11 @@ The original code can be found [here](https://huggingface.co/microsoft/Florence-
|
|||||||
|
|
||||||
[[autodoc]] Florence2Processor
|
[[autodoc]] Florence2Processor
|
||||||
|
|
||||||
|
## Florence2Model
|
||||||
|
|
||||||
|
[[autodoc]] Florence2Model
|
||||||
|
- forward
|
||||||
|
|
||||||
## Florence2ForConditionalGeneration
|
## Florence2ForConditionalGeneration
|
||||||
|
|
||||||
[[autodoc]] Florence2ForConditionalGeneration
|
[[autodoc]] Florence2ForConditionalGeneration
|
||||||
|
@ -124,7 +124,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
|
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
|
||||||
("flaubert", "FlaubertModel"),
|
("flaubert", "FlaubertModel"),
|
||||||
("flava", "FlavaModel"),
|
("flava", "FlavaModel"),
|
||||||
("florence2", "Florence2ForConditionalGeneration"),
|
("florence2", "Florence2Model"),
|
||||||
("fnet", "FNetModel"),
|
("fnet", "FNetModel"),
|
||||||
("focalnet", "FocalNetModel"),
|
("focalnet", "FocalNetModel"),
|
||||||
("fsmt", "FSMTModel"),
|
("fsmt", "FSMTModel"),
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
# modular_florence2.py file directly. One of our CI enforces this.
|
# modular_florence2.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
# Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -115,14 +115,14 @@ class Florence2VisionConfig(PretrainedConfig):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.depths = depths
|
self.depths = list(depths)
|
||||||
self.patch_size = patch_size
|
self.patch_size = list(patch_size)
|
||||||
self.patch_stride = patch_stride
|
self.patch_stride = list(patch_stride)
|
||||||
self.patch_padding = patch_padding
|
self.patch_padding = list(patch_padding)
|
||||||
self.patch_prenorm = patch_prenorm
|
self.patch_prenorm = list(patch_prenorm)
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = list(embed_dim)
|
||||||
self.num_heads = num_heads
|
self.num_heads = list(num_heads)
|
||||||
self.num_groups = num_groups
|
self.num_groups = list(num_groups)
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.drop_path_rate = drop_path_rate
|
self.drop_path_rate = drop_path_rate
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
# modular_florence2.py file directly. One of our CI enforces this.
|
# modular_florence2.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
# Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -31,11 +31,11 @@ from transformers.activations import ACT2FN
|
|||||||
from ...cache_utils import Cache
|
from ...cache_utils import Cache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import Seq2SeqLMOutput
|
from ...modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple
|
from ...utils import LossKwargs, auto_docstring, can_return_tuple
|
||||||
from ..auto import AutoModelForSeq2SeqLM
|
from ..auto import AutoModel, AutoModelForSeq2SeqLM
|
||||||
from .configuration_florence2 import Florence2Config, Florence2VisionConfig
|
from .configuration_florence2 import Florence2Config, Florence2VisionConfig
|
||||||
|
|
||||||
|
|
||||||
@ -622,6 +622,23 @@ class Florence2VisionProjector(nn.Module):
|
|||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@auto_docstring(
|
||||||
|
custom_intro="""
|
||||||
|
Base class for Florence-2 base model's outputs that also contains : pre-computed hidden states that can speed up sequential
|
||||||
|
decoding.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
class Florence2Seq2SeqModelOutput(Seq2SeqModelOutput):
|
||||||
|
r"""
|
||||||
|
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||||
|
A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`.
|
||||||
|
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@auto_docstring(
|
@auto_docstring(
|
||||||
custom_intro="""
|
custom_intro="""
|
||||||
@ -635,12 +652,6 @@ class Florence2Seq2SeqLMOutput(Seq2SeqLMOutput):
|
|||||||
Language modeling loss (for next-token prediction).
|
Language modeling loss (for next-token prediction).
|
||||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
|
||||||
`past_key_values` input) to speed up sequential decoding.
|
|
||||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||||
A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`.
|
A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`.
|
||||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||||
@ -679,6 +690,152 @@ class Florence2PreTrainedModel(PreTrainedModel):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring(
|
||||||
|
custom_intro="""
|
||||||
|
Florence-2 is an advanced vision foundation model that uses a prompt-based approach to handle a wide range of vision and vision-language tasks. Florence-2 can interpret simple text prompts to perform tasks like captioning, object detection, and segmentation.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
class Florence2Model(Florence2PreTrainedModel):
|
||||||
|
_tied_weights_keys = [
|
||||||
|
"language_model.shared.weight",
|
||||||
|
"language_model.encoder.embed_tokens.weight",
|
||||||
|
"language_model.decoder.embed_tokens.weight",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, config: Florence2Config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.vocab_size = config.text_config.vocab_size
|
||||||
|
|
||||||
|
self.vision_tower = Florence2VisionBackbone(config=config.vision_config)
|
||||||
|
self.vision_projector = Florence2VisionProjector(config=config.vision_config)
|
||||||
|
self.language_model = AutoModel.from_config(config=config.text_config)
|
||||||
|
|
||||||
|
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_encoder(self):
|
||||||
|
return self.language_model.get_encoder()
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.language_model.get_decoder()
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value: nn.Module) -> None:
|
||||||
|
self.language_model.set_input_embeddings(value)
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
|
return self.language_model.get_input_embeddings()
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
||||||
|
self.language_model.set_output_embeddings(new_embeddings)
|
||||||
|
|
||||||
|
def get_output_embeddings(self) -> nn.Linear:
|
||||||
|
return self.language_model.get_output_embeddings()
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
def get_image_features(self, pixel_values: torch.Tensor, **kwargs):
|
||||||
|
image_features = self.vision_tower(pixel_values, **kwargs)
|
||||||
|
image_embeds = self.vision_projector(image_features)
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
@auto_docstring
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Cache] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> Union[tuple, Florence2Seq2SeqModelOutput]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Florence2ForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large")
|
||||||
|
|
||||||
|
>>> prompt = "<CAPTION>"
|
||||||
|
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_length=100)
|
||||||
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"A green car parked in front of a yellow building."
|
||||||
|
```"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long, device=inputs_embeds.device)
|
||||||
|
|
||||||
|
image_embeds = None
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_embeds = self.get_image_features(pixel_values)
|
||||||
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||||
|
inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1)
|
||||||
|
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
|
||||||
|
|
||||||
|
if decoder_input_ids is None:
|
||||||
|
decoder_start_token_id = self.config.text_config.decoder_start_token_id
|
||||||
|
decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device)
|
||||||
|
decoder_input_ids *= decoder_start_token_id
|
||||||
|
|
||||||
|
outputs = self.language_model(
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Florence2Seq2SeqModelOutput(
|
||||||
|
last_hidden_state=outputs.last_hidden_state,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||||
|
decoder_attentions=outputs.decoder_attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||||
|
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||||
|
encoder_attentions=outputs.encoder_attentions,
|
||||||
|
image_hidden_states=image_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||||
|
|
||||||
|
|
||||||
@ -696,12 +853,13 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
|||||||
|
|
||||||
def __init__(self, config: Florence2Config):
|
def __init__(self, config: Florence2Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.text_config.vocab_size
|
||||||
|
|
||||||
self.vision_tower = Florence2VisionBackbone(config=config.vision_config)
|
self.vision_tower = Florence2VisionBackbone(config=config.vision_config)
|
||||||
self.vision_projector = Florence2VisionProjector(config=config.vision_config)
|
self.vision_projector = Florence2VisionProjector(config=config.vision_config)
|
||||||
self.language_model = AutoModelForSeq2SeqLM.from_config(config=config.text_config)
|
self.language_model = AutoModelForSeq2SeqLM.from_config(config=config.text_config)
|
||||||
|
|
||||||
|
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def get_encoder(self):
|
def get_encoder(self):
|
||||||
@ -787,6 +945,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
@ -800,6 +961,11 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
|||||||
inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1)
|
inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1)
|
||||||
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
|
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
|
||||||
|
|
||||||
|
if decoder_input_ids is None:
|
||||||
|
decoder_start_token_id = self.config.text_config.decoder_start_token_id
|
||||||
|
decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device)
|
||||||
|
decoder_input_ids *= decoder_start_token_id
|
||||||
|
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
@ -856,6 +1022,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
|||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"Florence2Model",
|
||||||
"Florence2ForConditionalGeneration",
|
"Florence2ForConditionalGeneration",
|
||||||
"Florence2PreTrainedModel",
|
"Florence2PreTrainedModel",
|
||||||
"Florence2VisionBackbone",
|
"Florence2VisionBackbone",
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
# Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -26,7 +26,7 @@ from ...cache_utils import Cache
|
|||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import Seq2SeqLMOutput
|
from ...modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@ -35,7 +35,7 @@ from ...utils import (
|
|||||||
can_return_tuple,
|
can_return_tuple,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from ..auto import CONFIG_MAPPING, AutoConfig, AutoModelForSeq2SeqLM
|
from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForSeq2SeqLM
|
||||||
from ..beit.modeling_beit import BeitDropPath
|
from ..beit.modeling_beit import BeitDropPath
|
||||||
from ..detr.modeling_detr import DetrLearnedPositionEmbedding
|
from ..detr.modeling_detr import DetrLearnedPositionEmbedding
|
||||||
|
|
||||||
@ -132,14 +132,14 @@ class Florence2VisionConfig(PretrainedConfig):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.depths = depths
|
self.depths = list(depths)
|
||||||
self.patch_size = patch_size
|
self.patch_size = list(patch_size)
|
||||||
self.patch_stride = patch_stride
|
self.patch_stride = list(patch_stride)
|
||||||
self.patch_padding = patch_padding
|
self.patch_padding = list(patch_padding)
|
||||||
self.patch_prenorm = patch_prenorm
|
self.patch_prenorm = list(patch_prenorm)
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = list(embed_dim)
|
||||||
self.num_heads = num_heads
|
self.num_heads = list(num_heads)
|
||||||
self.num_groups = num_groups
|
self.num_groups = list(num_groups)
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.drop_path_rate = drop_path_rate
|
self.drop_path_rate = drop_path_rate
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
@ -760,6 +760,23 @@ class Florence2VisionProjector(nn.Module):
|
|||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@auto_docstring(
|
||||||
|
custom_intro="""
|
||||||
|
Base class for Florence-2 base model's outputs that also contains : pre-computed hidden states that can speed up sequential
|
||||||
|
decoding.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
class Florence2Seq2SeqModelOutput(Seq2SeqModelOutput):
|
||||||
|
r"""
|
||||||
|
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||||
|
A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`.
|
||||||
|
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@auto_docstring(
|
@auto_docstring(
|
||||||
custom_intro="""
|
custom_intro="""
|
||||||
@ -773,12 +790,6 @@ class Florence2Seq2SeqLMOutput(Seq2SeqLMOutput):
|
|||||||
Language modeling loss (for next-token prediction).
|
Language modeling loss (for next-token prediction).
|
||||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
|
||||||
`past_key_values` input) to speed up sequential decoding.
|
|
||||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||||
A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`.
|
A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`.
|
||||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||||
@ -817,6 +828,152 @@ class Florence2PreTrainedModel(PreTrainedModel):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring(
|
||||||
|
custom_intro="""
|
||||||
|
Florence-2 is an advanced vision foundation model that uses a prompt-based approach to handle a wide range of vision and vision-language tasks. Florence-2 can interpret simple text prompts to perform tasks like captioning, object detection, and segmentation.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
class Florence2Model(Florence2PreTrainedModel):
|
||||||
|
_tied_weights_keys = [
|
||||||
|
"language_model.shared.weight",
|
||||||
|
"language_model.encoder.embed_tokens.weight",
|
||||||
|
"language_model.decoder.embed_tokens.weight",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, config: Florence2Config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.vocab_size = config.text_config.vocab_size
|
||||||
|
|
||||||
|
self.vision_tower = Florence2VisionBackbone(config=config.vision_config)
|
||||||
|
self.vision_projector = Florence2VisionProjector(config=config.vision_config)
|
||||||
|
self.language_model = AutoModel.from_config(config=config.text_config)
|
||||||
|
|
||||||
|
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_encoder(self):
|
||||||
|
return self.language_model.get_encoder()
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.language_model.get_decoder()
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value: nn.Module) -> None:
|
||||||
|
self.language_model.set_input_embeddings(value)
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
|
return self.language_model.get_input_embeddings()
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
||||||
|
self.language_model.set_output_embeddings(new_embeddings)
|
||||||
|
|
||||||
|
def get_output_embeddings(self) -> nn.Linear:
|
||||||
|
return self.language_model.get_output_embeddings()
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
def get_image_features(self, pixel_values: torch.Tensor, **kwargs):
|
||||||
|
image_features = self.vision_tower(pixel_values, **kwargs)
|
||||||
|
image_embeds = self.vision_projector(image_features)
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
@auto_docstring
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Cache] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> Union[tuple, Florence2Seq2SeqModelOutput]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Florence2ForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large")
|
||||||
|
|
||||||
|
>>> prompt = "<CAPTION>"
|
||||||
|
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_length=100)
|
||||||
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"A green car parked in front of a yellow building."
|
||||||
|
```"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long, device=inputs_embeds.device)
|
||||||
|
|
||||||
|
image_embeds = None
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_embeds = self.get_image_features(pixel_values)
|
||||||
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||||
|
inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1)
|
||||||
|
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
|
||||||
|
|
||||||
|
if decoder_input_ids is None:
|
||||||
|
decoder_start_token_id = self.config.text_config.decoder_start_token_id
|
||||||
|
decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device)
|
||||||
|
decoder_input_ids *= decoder_start_token_id
|
||||||
|
|
||||||
|
outputs = self.language_model(
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Florence2Seq2SeqModelOutput(
|
||||||
|
last_hidden_state=outputs.last_hidden_state,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||||
|
decoder_attentions=outputs.decoder_attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||||
|
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||||
|
encoder_attentions=outputs.encoder_attentions,
|
||||||
|
image_hidden_states=image_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||||
|
|
||||||
|
|
||||||
@ -834,12 +991,13 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
|||||||
|
|
||||||
def __init__(self, config: Florence2Config):
|
def __init__(self, config: Florence2Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.text_config.vocab_size
|
||||||
|
|
||||||
self.vision_tower = Florence2VisionBackbone(config=config.vision_config)
|
self.vision_tower = Florence2VisionBackbone(config=config.vision_config)
|
||||||
self.vision_projector = Florence2VisionProjector(config=config.vision_config)
|
self.vision_projector = Florence2VisionProjector(config=config.vision_config)
|
||||||
self.language_model = AutoModelForSeq2SeqLM.from_config(config=config.text_config)
|
self.language_model = AutoModelForSeq2SeqLM.from_config(config=config.text_config)
|
||||||
|
|
||||||
|
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def get_encoder(self):
|
def get_encoder(self):
|
||||||
@ -925,6 +1083,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
@ -938,6 +1099,11 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
|||||||
inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1)
|
inputs_embeds = torch.cat([image_embeds, inputs_embeds], dim=1)
|
||||||
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
|
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
|
||||||
|
|
||||||
|
if decoder_input_ids is None:
|
||||||
|
decoder_start_token_id = self.config.text_config.decoder_start_token_id
|
||||||
|
decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device)
|
||||||
|
decoder_input_ids *= decoder_start_token_id
|
||||||
|
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
@ -996,6 +1162,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"Florence2Config",
|
"Florence2Config",
|
||||||
"Florence2VisionConfig",
|
"Florence2VisionConfig",
|
||||||
|
"Florence2Model",
|
||||||
"Florence2ForConditionalGeneration",
|
"Florence2ForConditionalGeneration",
|
||||||
"Florence2PreTrainedModel",
|
"Florence2PreTrainedModel",
|
||||||
"Florence2VisionBackbone",
|
"Florence2VisionBackbone",
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2025 Microsoft and The HuggingFace Inc. team.
|
# Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -12,10 +12,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
|
||||||
Processor class for FLORENCE2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@ -89,13 +85,6 @@ class Florence2Processor(ProcessorMixin):
|
|||||||
image_processor=None,
|
image_processor=None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
):
|
):
|
||||||
if image_processor is None:
|
|
||||||
raise ValueError("You need to specify an `image_processor`.")
|
|
||||||
if tokenizer is None:
|
|
||||||
raise ValueError("You need to specify a `tokenizer`.")
|
|
||||||
if not hasattr(image_processor, "image_seq_length"):
|
|
||||||
raise ValueError("Image processor is missing an `image_seq_length` attribute.")
|
|
||||||
|
|
||||||
self.image_seq_length = image_processor.image_seq_length
|
self.image_seq_length = image_processor.image_seq_length
|
||||||
|
|
||||||
tokens_to_add = {
|
tokens_to_add = {
|
||||||
|
@ -22,6 +22,7 @@ from transformers import (
|
|||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
Florence2Config,
|
Florence2Config,
|
||||||
Florence2ForConditionalGeneration,
|
Florence2ForConditionalGeneration,
|
||||||
|
Florence2Model,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
@ -212,19 +213,27 @@ class Florence2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
|||||||
Model tester for `Florence2ForConditionalGeneration`.
|
Model tester for `Florence2ForConditionalGeneration`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
additional_model_inputs = ["pixel_values"]
|
all_model_classes = (Florence2ForConditionalGeneration, Florence2Model) if is_torch_available() else ()
|
||||||
all_model_classes = (Florence2ForConditionalGeneration,) if is_torch_available() else ()
|
pipeline_model_mapping = (
|
||||||
all_generative_model_classes = (Florence2ForConditionalGeneration,) if is_torch_available() else ()
|
{
|
||||||
pipeline_model_mapping = {"image-to-text": Florence2ForConditionalGeneration} if is_torch_available() else {}
|
"image-to-text": Florence2ForConditionalGeneration,
|
||||||
|
"image-text-to-text": Florence2ForConditionalGeneration,
|
||||||
|
}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
test_torchscript = False
|
_is_composite = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = Florence2VisionText2TextModelTester(self)
|
self.model_tester = Florence2VisionText2TextModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=Florence2Config, has_text_modality=False)
|
self.config_tester = ConfigTester(self, config_class=Florence2Config, has_text_modality=False)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@ -286,12 +295,6 @@ class Florence2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
|||||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# @unittest.skip(
|
|
||||||
# reason="This architecture has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
|
|
||||||
# )
|
|
||||||
# def test_contrastive_generate_low_memory(self):
|
|
||||||
# pass
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
reason="This architecture has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
|
reason="This architecture has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user