mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
update
This commit is contained in:
parent
dc245e76db
commit
8e7aa374cf
@ -1,4 +1,4 @@
|
|||||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
the License. You may obtain a copy of the License at
|
the License. You may obtain a copy of the License at
|
||||||
@ -18,50 +18,51 @@ rendered properly in your Markdown viewer.
|
|||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
The Florence2 model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
|
The Florence2 model was proposed in [Florence-2: Advancing a Unified Representation for a Variety of Vision Tasks](https://arxiv.org/abs/2311.06242) by Microsoft.
|
||||||
<INSERT SHORT SUMMARY HERE>
|
|
||||||
|
Florence-2 is an advanced vision foundation model that uses a prompt-based approach to handle a wide range of vision and vision-language tasks. Florence-2 can interpret simple text prompts to perform tasks like captioning, object detection, and segmentation. It leverages our FLD-5B dataset, containing 5.4 billion annotations across 126 million images, to master multi-task learning. The model's sequence-to-sequence architecture enables it to excel in both zero-shot and fine-tuned settings, proving to be a competitive vision foundation model.
|
||||||
|
|
||||||
The abstract from the paper is the following:
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
*<INSERT PAPER ABSTRACT HERE>*
|
*We introduce Florence-2, a novel vision foundation model with a unified, prompt-based representation for a variety of computer vision and vision-language tasks. While existing large vision models excel in transfer learning, they struggle to perform a diversity of tasks with simple instructions, a capability that implies handling the complexity of various spatial hierarchy and semantic granularity. Florence-2 was designed to take text-prompt as task instructions and generate desirable results in text forms, whether it be captioning, object detection, grounding or segmentation. This multi-task learning setup demands large-scale, high-quality annotated data. To this end, we co-developed FLD-5B that consists of 5.4 billion comprehensive visual annotations on 126 million images, using an iterative strategy of automated image annotation and model refinement. We adopted a sequence-to-sequence structure to train Florence-2 to perform versatile and comprehensive vision tasks. Extensive evaluations on numerous tasks demonstrated Florence-2 to be a strong vision foundation model contender with unprecedented zero-shot and fine-tuning capabilities.*
|
||||||
|
|
||||||
Tips:
|
|
||||||
|
|
||||||
<INSERT TIPS ABOUT MODEL HERE>
|
|
||||||
|
|
||||||
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
|
|
||||||
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
|
|
||||||
|
|
||||||
|
This model was contributed by [hlky](https://huggingface.co/hlky).
|
||||||
|
The original code can be found [here](https://huggingface.co/microsoft/Florence-2-base/tree/main).
|
||||||
|
|
||||||
## Florence2Config
|
## Florence2Config
|
||||||
|
|
||||||
[[autodoc]] Florence2Config
|
[[autodoc]] Florence2Config
|
||||||
- all
|
|
||||||
|
|
||||||
## Florence2Model
|
## Florence2Processor
|
||||||
|
|
||||||
[[autodoc]] Florence2Model
|
[[autodoc]] Florence2Processor
|
||||||
- forward
|
|
||||||
|
|
||||||
## Florence2ForConditionalGeneration
|
## Florence2ForConditionalGeneration
|
||||||
|
|
||||||
[[autodoc]] Florence2ForConditionalGeneration
|
[[autodoc]] Florence2ForConditionalGeneration
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
## Florence2ForSequenceClassification
|
## Florence2LanguageForConditionalGeneration
|
||||||
|
|
||||||
[[autodoc]] Florence2ForSequenceClassification
|
[[autodoc]] Florence2LanguageForConditionalGeneration
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
## Florence2ForQuestionAnswering
|
## Florence2LanguageModel
|
||||||
|
|
||||||
[[autodoc]] Florence2ForQuestionAnswering
|
[[autodoc]] Florence2LanguageModel
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
## Florence2ForCausalLM
|
## Florence2Vision
|
||||||
|
|
||||||
[[autodoc]] Florence2ForCausalLM
|
[[autodoc]] Florence2Vision
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
</pt>
|
## Florence2VisionModel
|
||||||
<tf>
|
|
||||||
|
[[autodoc]] Florence2VisionModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## Florence2VisionModelWithProjection
|
||||||
|
|
||||||
|
[[autodoc]] Florence2VisionModelWithProjection
|
||||||
|
- forward
|
||||||
|
@ -119,6 +119,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
|
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
|
||||||
("flaubert", "FlaubertModel"),
|
("flaubert", "FlaubertModel"),
|
||||||
("flava", "FlavaModel"),
|
("flava", "FlavaModel"),
|
||||||
|
("florence2", "Florence2ForConditionalGeneration"),
|
||||||
("fnet", "FNetModel"),
|
("fnet", "FNetModel"),
|
||||||
("focalnet", "FocalNetModel"),
|
("focalnet", "FocalNetModel"),
|
||||||
("fsmt", "FSMTModel"),
|
("fsmt", "FSMTModel"),
|
||||||
@ -884,8 +885,8 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
|||||||
("blip-2", "Blip2ForConditionalGeneration"),
|
("blip-2", "Blip2ForConditionalGeneration"),
|
||||||
("chameleon", "ChameleonForConditionalGeneration"),
|
("chameleon", "ChameleonForConditionalGeneration"),
|
||||||
("emu3", "Emu3ForConditionalGeneration"),
|
("emu3", "Emu3ForConditionalGeneration"),
|
||||||
("fuyu", "FuyuForCausalLM"),
|
|
||||||
("florence2", "Florence2ForConditionalGeneration"),
|
("florence2", "Florence2ForConditionalGeneration"),
|
||||||
|
("fuyu", "FuyuForCausalLM"),
|
||||||
("gemma3", "Gemma3ForConditionalGeneration"),
|
("gemma3", "Gemma3ForConditionalGeneration"),
|
||||||
("git", "GitForCausalLM"),
|
("git", "GitForCausalLM"),
|
||||||
("got_ocr2", "GotOcr2ForConditionalGeneration"),
|
("got_ocr2", "GotOcr2ForConditionalGeneration"),
|
||||||
|
@ -322,3 +322,6 @@ class Florence2Config(PretrainedConfig):
|
|||||||
self.text_config = Florence2LanguageConfig(**text_config)
|
self.text_config = Florence2LanguageConfig(**text_config)
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Florence2Config"]
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -28,7 +28,7 @@ from ...utils import ( # noqa: F401
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ..bart.modeling_bart import BartForConditionalGeneration
|
from ..bart.modeling_bart import BartForConditionalGeneration, BartPreTrainedModel
|
||||||
from .configuration_florence2 import Florence2Config, Florence2LanguageConfig, Florence2VisionConfig # noqa: F401
|
from .configuration_florence2 import Florence2Config, Florence2LanguageConfig, Florence2VisionConfig # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
@ -984,6 +984,14 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Florence2LanguagePreTrainedModel(BartPreTrainedModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Florence2LanguageForConditionalGeneration(BartForConditionalGeneration):
|
class Florence2LanguageForConditionalGeneration(BartForConditionalGeneration):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -1296,3 +1304,15 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|||||||
|
|
||||||
def _reorder_cache(self, *args, **kwargs):
|
def _reorder_cache(self, *args, **kwargs):
|
||||||
return self.language_model._reorder_cache(*args, **kwargs)
|
return self.language_model._reorder_cache(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Florence2ForConditionalGeneration",
|
||||||
|
"Florence2LanguageForConditionalGeneration",
|
||||||
|
"Florence2LanguageModel",
|
||||||
|
"Florence2LanguagePreTrainedModel",
|
||||||
|
"Florence2PreTrainedModel",
|
||||||
|
"Florence2Vision",
|
||||||
|
"Florence2VisionModel",
|
||||||
|
"Florence2VisionModelWithProjection",
|
||||||
|
]
|
||||||
|
@ -159,7 +159,7 @@ class Florence2Processor(ProcessorMixin):
|
|||||||
"<REGION_TO_OCR>": "What text is in the region {input}?",
|
"<REGION_TO_OCR>": "What text is in the region {input}?",
|
||||||
}
|
}
|
||||||
|
|
||||||
self.post_processor = Florence2PostProcesser(tokenizer=tokenizer)
|
self.post_processor = Florence2PostProcessor(tokenizer=tokenizer)
|
||||||
|
|
||||||
super().__init__(image_processor, tokenizer)
|
super().__init__(image_processor, tokenizer)
|
||||||
|
|
||||||
@ -513,7 +513,7 @@ class CoordinatesQuantizer(object):
|
|||||||
return dequantized_coordinates
|
return dequantized_coordinates
|
||||||
|
|
||||||
|
|
||||||
class Florence2PostProcesser(object):
|
class Florence2PostProcessor(object):
|
||||||
r"""
|
r"""
|
||||||
Florence-2 post process for converting text prediction to various tasks results.
|
Florence-2 post process for converting text prediction to various tasks results.
|
||||||
|
|
||||||
@ -1220,3 +1220,6 @@ class Florence2PostProcesser(object):
|
|||||||
raise ValueError("task {} is not supported".format(task))
|
raise ValueError("task {} is not supported".format(task))
|
||||||
|
|
||||||
return parsed_dict
|
return parsed_dict
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Florence2Processor"]
|
||||||
|
File diff suppressed because it is too large
Load Diff
63
tests/models/florence2/test_processor_florence2.py
Normal file
63
tests/models/florence2/test_processor_florence2.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import AutoProcessor, AutoTokenizer, BartTokenizerFast, Florence2Processor
|
||||||
|
from transformers.testing_utils import require_vision
|
||||||
|
from transformers.utils import is_vision_available
|
||||||
|
|
||||||
|
from ...test_processing_common import ProcessorTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from transformers import CLIPImageProcessor
|
||||||
|
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
class Florence2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||||
|
processor_class = Florence2Processor
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
image_processor = CLIPImageProcessor.from_pretrained("microsoft/Florence-2-base")
|
||||||
|
tokenizer = BartTokenizerFast.from_pretrained("microsoft/Florence-2-base")
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
processor = Florence2Processor(image_processor, tokenizer, **processor_kwargs)
|
||||||
|
processor.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def get_tokenizer(self, **kwargs):
|
||||||
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||||
|
|
||||||
|
def get_image_processor(self, **kwargs):
|
||||||
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"Skip because the model has no processor kwargs except for chat template and"
|
||||||
|
"chat template is saved as a separate file. Stop skipping this test when the processor"
|
||||||
|
"has new kwargs saved in config file."
|
||||||
|
)
|
||||||
|
def test_processor_to_json_string(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_can_load_various_tokenizers(self):
|
||||||
|
for checkpoint in ["microsoft/Florence-2-base"]:
|
||||||
|
processor = Florence2Processor.from_pretrained(checkpoint)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||||
|
self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__)
|
@ -164,6 +164,11 @@ IGNORE_NON_TESTED = (
|
|||||||
"CsmDepthDecoderForCausalLM", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
|
"CsmDepthDecoderForCausalLM", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
|
||||||
"CsmDepthDecoderModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
|
"CsmDepthDecoderModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
|
||||||
"CsmBackboneModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
|
"CsmBackboneModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
|
||||||
|
"Florence2LanguageModel", # Building part of bigger (tested) model. Tested implicitly through Florence2ForConditionalGeneration.
|
||||||
|
"Florence2Vision", # Building part of bigger (tested) model. Tested implicitly through Florence2ForConditionalGeneration.
|
||||||
|
"Florence2VisionModel", # Building part of bigger (tested) model. Tested implicitly through Florence2ForConditionalGeneration.
|
||||||
|
"Florence2VisionModelWithProjection", # Building part of bigger (tested) model. Tested implicitly through Florence2ForConditionalGeneration.
|
||||||
|
"Florence2LanguageForConditionalGeneration", # Building part of bigger (tested) model. Tested implicitly through Florence2ForConditionalGeneration.
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -377,6 +382,11 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
|||||||
"CsmDepthDecoderModel", # Building part of a bigger model
|
"CsmDepthDecoderModel", # Building part of a bigger model
|
||||||
"CsmDepthDecoderForCausalLM", # Building part of a bigger model
|
"CsmDepthDecoderForCausalLM", # Building part of a bigger model
|
||||||
"CsmForConditionalGeneration", # Building part of a bigger model
|
"CsmForConditionalGeneration", # Building part of a bigger model
|
||||||
|
"Florence2LanguageForConditionalGeneration", # Building part of a bigger model
|
||||||
|
"Florence2LanguageModel", # Building part of a bigger model
|
||||||
|
"Florence2Vision", # Building part of a bigger model
|
||||||
|
"Florence2VisionModel", # Building part of a bigger model
|
||||||
|
"Florence2VisionModelWithProjection", # Building part of a bigger model
|
||||||
]
|
]
|
||||||
|
|
||||||
# DO NOT edit this list!
|
# DO NOT edit this list!
|
||||||
|
Loading…
Reference in New Issue
Block a user