Add SiglipForImageClassification and CLIPForImageClassification (#28952)

* First draft

* Add CLIPForImageClassification

* Remove scripts

* Fix doctests
This commit is contained in:
NielsRogge 2024-02-14 08:41:31 +01:00 committed by GitHub
parent de6029a059
commit 63ffd56d02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 380 additions and 5 deletions

View File

@ -172,6 +172,11 @@ The resource should ideally demonstrate something new instead of duplicating an
[[autodoc]] CLIPVisionModel [[autodoc]] CLIPVisionModel
- forward - forward
## CLIPForImageClassification
[[autodoc]] CLIPForImageClassification
- forward
</pt> </pt>
<tf> <tf>

View File

@ -140,3 +140,9 @@ If you want to do the pre- and postprocessing yourself, here's how to do that:
[[autodoc]] SiglipVisionModel [[autodoc]] SiglipVisionModel
- forward - forward
## SiglipForImageClassification
[[autodoc]] SiglipForImageClassification
- forward

View File

@ -34,7 +34,7 @@ The task illustrated in this tutorial is supported by the following model archit
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!--> <!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[BEiT](../model_doc/beit), [BiT](../model_doc/bit), [ConvNeXT](../model_doc/convnext), [ConvNeXTV2](../model_doc/convnextv2), [CvT](../model_doc/cvt), [Data2VecVision](../model_doc/data2vec-vision), [DeiT](../model_doc/deit), [DiNAT](../model_doc/dinat), [DINOv2](../model_doc/dinov2), [EfficientFormer](../model_doc/efficientformer), [EfficientNet](../model_doc/efficientnet), [FocalNet](../model_doc/focalnet), [ImageGPT](../model_doc/imagegpt), [LeViT](../model_doc/levit), [MobileNetV1](../model_doc/mobilenet_v1), [MobileNetV2](../model_doc/mobilenet_v2), [MobileViT](../model_doc/mobilevit), [MobileViTV2](../model_doc/mobilevitv2), [NAT](../model_doc/nat), [Perceiver](../model_doc/perceiver), [PoolFormer](../model_doc/poolformer), [PVT](../model_doc/pvt), [RegNet](../model_doc/regnet), [ResNet](../model_doc/resnet), [SegFormer](../model_doc/segformer), [SwiftFormer](../model_doc/swiftformer), [Swin Transformer](../model_doc/swin), [Swin Transformer V2](../model_doc/swinv2), [VAN](../model_doc/van), [ViT](../model_doc/vit), [ViT Hybrid](../model_doc/vit_hybrid), [ViTMSN](../model_doc/vit_msn) [BEiT](../model_doc/beit), [BiT](../model_doc/bit), [CLIP](../model_doc/clip), [ConvNeXT](../model_doc/convnext), [ConvNeXTV2](../model_doc/convnextv2), [CvT](../model_doc/cvt), [Data2VecVision](../model_doc/data2vec-vision), [DeiT](../model_doc/deit), [DiNAT](../model_doc/dinat), [DINOv2](../model_doc/dinov2), [EfficientFormer](../model_doc/efficientformer), [EfficientNet](../model_doc/efficientnet), [FocalNet](../model_doc/focalnet), [ImageGPT](../model_doc/imagegpt), [LeViT](../model_doc/levit), [MobileNetV1](../model_doc/mobilenet_v1), [MobileNetV2](../model_doc/mobilenet_v2), [MobileViT](../model_doc/mobilevit), [MobileViTV2](../model_doc/mobilevitv2), [NAT](../model_doc/nat), [Perceiver](../model_doc/perceiver), [PoolFormer](../model_doc/poolformer), [PVT](../model_doc/pvt), [RegNet](../model_doc/regnet), [ResNet](../model_doc/resnet), [SegFormer](../model_doc/segformer), [SigLIP](../model_doc/siglip), [SwiftFormer](../model_doc/swiftformer), [Swin Transformer](../model_doc/swin), [Swin Transformer V2](../model_doc/swinv2), [VAN](../model_doc/van), [ViT](../model_doc/vit), [ViT Hybrid](../model_doc/vit_hybrid), [ViTMSN](../model_doc/vit_msn)
<!--End of the generated tip--> <!--End of the generated tip-->

View File

@ -1762,6 +1762,7 @@ else:
_import_structure["models.clip"].extend( _import_structure["models.clip"].extend(
[ [
"CLIP_PRETRAINED_MODEL_ARCHIVE_LIST", "CLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"CLIPForImageClassification",
"CLIPModel", "CLIPModel",
"CLIPPreTrainedModel", "CLIPPreTrainedModel",
"CLIPTextModel", "CLIPTextModel",
@ -3200,6 +3201,7 @@ else:
_import_structure["models.siglip"].extend( _import_structure["models.siglip"].extend(
[ [
"SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST", "SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"SiglipForImageClassification",
"SiglipModel", "SiglipModel",
"SiglipPreTrainedModel", "SiglipPreTrainedModel",
"SiglipTextModel", "SiglipTextModel",
@ -6447,6 +6449,7 @@ if TYPE_CHECKING:
) )
from .models.clip import ( from .models.clip import (
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST, CLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
CLIPForImageClassification,
CLIPModel, CLIPModel,
CLIPPreTrainedModel, CLIPPreTrainedModel,
CLIPTextModel, CLIPTextModel,
@ -7625,6 +7628,7 @@ if TYPE_CHECKING:
) )
from .models.siglip import ( from .models.siglip import (
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST, SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
SiglipForImageClassification,
SiglipModel, SiglipModel,
SiglipPreTrainedModel, SiglipPreTrainedModel,
SiglipTextModel, SiglipTextModel,

View File

@ -498,6 +498,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
# Model for Image Classification mapping # Model for Image Classification mapping
("beit", "BeitForImageClassification"), ("beit", "BeitForImageClassification"),
("bit", "BitForImageClassification"), ("bit", "BitForImageClassification"),
("clip", "CLIPForImageClassification"),
("convnext", "ConvNextForImageClassification"), ("convnext", "ConvNextForImageClassification"),
("convnextv2", "ConvNextV2ForImageClassification"), ("convnextv2", "ConvNextV2ForImageClassification"),
("cvt", "CvtForImageClassification"), ("cvt", "CvtForImageClassification"),
@ -540,6 +541,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("regnet", "RegNetForImageClassification"), ("regnet", "RegNetForImageClassification"),
("resnet", "ResNetForImageClassification"), ("resnet", "ResNetForImageClassification"),
("segformer", "SegformerForImageClassification"), ("segformer", "SegformerForImageClassification"),
("siglip", "SiglipForImageClassification"),
("swiftformer", "SwiftFormerForImageClassification"), ("swiftformer", "SwiftFormerForImageClassification"),
("swin", "SwinForImageClassification"), ("swin", "SwinForImageClassification"),
("swinv2", "Swinv2ForImageClassification"), ("swinv2", "Swinv2ForImageClassification"),

View File

@ -67,6 +67,7 @@ else:
"CLIPTextModelWithProjection", "CLIPTextModelWithProjection",
"CLIPVisionModel", "CLIPVisionModel",
"CLIPVisionModelWithProjection", "CLIPVisionModelWithProjection",
"CLIPForImageClassification",
] ]
try: try:
@ -136,6 +137,7 @@ if TYPE_CHECKING:
else: else:
from .modeling_clip import ( from .modeling_clip import (
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST, CLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
CLIPForImageClassification,
CLIPModel, CLIPModel,
CLIPPreTrainedModel, CLIPPreTrainedModel,
CLIPTextModel, CLIPTextModel,

View File

@ -21,13 +21,15 @@ from typing import Any, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
logging, logging,
@ -38,8 +40,14 @@ from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "CLIPConfig"
_CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32" _CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "openai/clip-vit-base-patch32"
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
"openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32",
# See all CLIP models at https://huggingface.co/models?filter=clip # See all CLIP models at https://huggingface.co/models?filter=clip
@ -1306,3 +1314,105 @@ class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
hidden_states=vision_outputs.hidden_states, hidden_states=vision_outputs.hidden_states,
attentions=vision_outputs.attentions, attentions=vision_outputs.attentions,
) )
@add_start_docstrings(
"""
CLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
the patch tokens) e.g. for ImageNet.
""",
CLIP_START_DOCSTRING,
)
class CLIPForImageClassification(CLIPPreTrainedModel):
main_input_name = "pixel_values"
def __init__(self, config: CLIPConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.vision_model = CLIPVisionTransformer(config.vision_config)
# Classifier head
self.classifier = (
nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vision_model(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
# average pool the patch tokens
sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
# apply classifier
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -61,6 +61,7 @@ else:
"SiglipPreTrainedModel", "SiglipPreTrainedModel",
"SiglipTextModel", "SiglipTextModel",
"SiglipVisionModel", "SiglipVisionModel",
"SiglipForImageClassification",
] ]
@ -97,6 +98,7 @@ if TYPE_CHECKING:
else: else:
from .modeling_siglip import ( from .modeling_siglip import (
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST, SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
SiglipForImageClassification,
SiglipModel, SiglipModel,
SiglipPreTrainedModel, SiglipPreTrainedModel,
SiglipTextModel, SiglipTextModel,

View File

@ -24,14 +24,16 @@ import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn.init import _calculate_fan_in_and_fan_out from torch.nn.init import _calculate_fan_in_and_fan_out
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
logging, logging,
@ -42,8 +44,15 @@ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionCo
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "SiglipConfig"
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "google/siglip-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_1"
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/siglip-base-patch16-224", "google/siglip-base-patch16-224",
# See all SigLIP models at https://huggingface.co/models?filter=siglip # See all SigLIP models at https://huggingface.co/models?filter=siglip
@ -1185,3 +1194,105 @@ class SiglipModel(SiglipPreTrainedModel):
text_model_output=text_outputs, text_model_output=text_outputs,
vision_model_output=vision_outputs, vision_model_output=vision_outputs,
) )
@add_start_docstrings(
"""
SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
the patch tokens) e.g. for ImageNet.
""",
SIGLIP_START_DOCSTRING,
)
class SiglipForImageClassification(SiglipPreTrainedModel):
main_input_name = "pixel_values"
def __init__(self, config: SiglipConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.vision_model = SiglipVisionTransformer(config.vision_config)
# Classifier head
self.classifier = (
nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vision_model(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
# average pool the patch tokens
sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
# apply classifier
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -1901,6 +1901,13 @@ class ClapTextModelWithProjection(metaclass=DummyObject):
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None
class CLIPForImageClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class CLIPModel(metaclass=DummyObject): class CLIPModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
@ -7583,6 +7590,13 @@ class SEWDPreTrainedModel(metaclass=DummyObject):
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None
class SiglipForImageClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SiglipModel(metaclass=DummyObject): class SiglipModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]

View File

@ -51,6 +51,7 @@ if is_torch_available():
from torch import nn from torch import nn
from transformers import ( from transformers import (
CLIPForImageClassification,
CLIPModel, CLIPModel,
CLIPTextModel, CLIPTextModel,
CLIPTextModelWithProjection, CLIPTextModelWithProjection,
@ -744,6 +745,65 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
class CLIPForImageClassificationModelTester(CLIPModelTester):
def __init__(self, parent):
super().__init__(parent)
self.batch_size = self.vision_model_tester.batch_size
self.num_hidden_layers = self.vision_model_tester.num_hidden_layers
self.hidden_size = self.vision_model_tester.hidden_size
self.seq_length = self.vision_model_tester.seq_length
def prepare_config_and_inputs(self):
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class CLIPForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CLIPForImageClassification,) if is_torch_available() else ()
pipeline_model_mapping = {"image-classification": CLIPForImageClassification} if is_torch_available() else {}
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
def setUp(self):
self.model_tester = CLIPForImageClassificationModelTester(self)
@unittest.skip(reason="CLIPForImageClassification does not support inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="CLIPForImageClassification does not support inputs_embeds")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="CLIPForImageClassification does not support gradient checkpointing yet")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="CLIPForImageClassification does not support gradient checkpointing yet")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="CLIPForImageClassification does not support gradient checkpointing yet")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="CLIP uses the same initialization scheme as the Flax original implementation")
def test_initialization(self):
pass
# We will verify our results on an image of cute cats # We will verify our results on an image of cute cats
def prepare_img(): def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg" url = "http://images.cocodataset.org/val2017/000000039769.jpg"

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Siglip model. """ """ Testing suite for the PyTorch SigLIP model. """
import inspect import inspect
@ -47,7 +47,7 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import SiglipModel, SiglipTextModel, SiglipVisionModel from transformers import SiglipForImageClassification, SiglipModel, SiglipTextModel, SiglipVisionModel
from transformers.models.siglip.modeling_siglip import SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.siglip.modeling_siglip import SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST
@ -584,6 +584,65 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
class SiglipForImageClassificationModelTester(SiglipModelTester):
def __init__(self, parent):
super().__init__(parent)
self.batch_size = self.vision_model_tester.batch_size
self.num_hidden_layers = self.vision_model_tester.num_hidden_layers
self.hidden_size = self.vision_model_tester.hidden_size
self.seq_length = self.vision_model_tester.seq_length
def prepare_config_and_inputs(self):
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class SiglipForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (SiglipForImageClassification,) if is_torch_available() else ()
pipeline_model_mapping = {"image-classification": SiglipForImageClassification} if is_torch_available() else {}
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
def setUp(self):
self.model_tester = SiglipForImageClassificationModelTester(self)
@unittest.skip(reason="SiglipForImageClassification does not support inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="SiglipForImageClassification does not support inputs_embeds")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation")
def test_initialization(self):
pass
# We will verify our results on an image of cute cats # We will verify our results on an image of cute cats
def prepare_img(): def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg" url = "http://images.cocodataset.org/val2017/000000039769.jpg"