mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
Add SiglipForImageClassification and CLIPForImageClassification (#28952)
* First draft * Add CLIPForImageClassification * Remove scripts * Fix doctests
This commit is contained in:
parent
de6029a059
commit
63ffd56d02
@ -172,6 +172,11 @@ The resource should ideally demonstrate something new instead of duplicating an
|
||||
[[autodoc]] CLIPVisionModel
|
||||
- forward
|
||||
|
||||
## CLIPForImageClassification
|
||||
|
||||
[[autodoc]] CLIPForImageClassification
|
||||
- forward
|
||||
|
||||
</pt>
|
||||
<tf>
|
||||
|
||||
|
@ -140,3 +140,9 @@ If you want to do the pre- and postprocessing yourself, here's how to do that:
|
||||
|
||||
[[autodoc]] SiglipVisionModel
|
||||
- forward
|
||||
|
||||
|
||||
## SiglipForImageClassification
|
||||
|
||||
[[autodoc]] SiglipForImageClassification
|
||||
- forward
|
@ -34,7 +34,7 @@ The task illustrated in this tutorial is supported by the following model archit
|
||||
|
||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||
|
||||
[BEiT](../model_doc/beit), [BiT](../model_doc/bit), [ConvNeXT](../model_doc/convnext), [ConvNeXTV2](../model_doc/convnextv2), [CvT](../model_doc/cvt), [Data2VecVision](../model_doc/data2vec-vision), [DeiT](../model_doc/deit), [DiNAT](../model_doc/dinat), [DINOv2](../model_doc/dinov2), [EfficientFormer](../model_doc/efficientformer), [EfficientNet](../model_doc/efficientnet), [FocalNet](../model_doc/focalnet), [ImageGPT](../model_doc/imagegpt), [LeViT](../model_doc/levit), [MobileNetV1](../model_doc/mobilenet_v1), [MobileNetV2](../model_doc/mobilenet_v2), [MobileViT](../model_doc/mobilevit), [MobileViTV2](../model_doc/mobilevitv2), [NAT](../model_doc/nat), [Perceiver](../model_doc/perceiver), [PoolFormer](../model_doc/poolformer), [PVT](../model_doc/pvt), [RegNet](../model_doc/regnet), [ResNet](../model_doc/resnet), [SegFormer](../model_doc/segformer), [SwiftFormer](../model_doc/swiftformer), [Swin Transformer](../model_doc/swin), [Swin Transformer V2](../model_doc/swinv2), [VAN](../model_doc/van), [ViT](../model_doc/vit), [ViT Hybrid](../model_doc/vit_hybrid), [ViTMSN](../model_doc/vit_msn)
|
||||
[BEiT](../model_doc/beit), [BiT](../model_doc/bit), [CLIP](../model_doc/clip), [ConvNeXT](../model_doc/convnext), [ConvNeXTV2](../model_doc/convnextv2), [CvT](../model_doc/cvt), [Data2VecVision](../model_doc/data2vec-vision), [DeiT](../model_doc/deit), [DiNAT](../model_doc/dinat), [DINOv2](../model_doc/dinov2), [EfficientFormer](../model_doc/efficientformer), [EfficientNet](../model_doc/efficientnet), [FocalNet](../model_doc/focalnet), [ImageGPT](../model_doc/imagegpt), [LeViT](../model_doc/levit), [MobileNetV1](../model_doc/mobilenet_v1), [MobileNetV2](../model_doc/mobilenet_v2), [MobileViT](../model_doc/mobilevit), [MobileViTV2](../model_doc/mobilevitv2), [NAT](../model_doc/nat), [Perceiver](../model_doc/perceiver), [PoolFormer](../model_doc/poolformer), [PVT](../model_doc/pvt), [RegNet](../model_doc/regnet), [ResNet](../model_doc/resnet), [SegFormer](../model_doc/segformer), [SigLIP](../model_doc/siglip), [SwiftFormer](../model_doc/swiftformer), [Swin Transformer](../model_doc/swin), [Swin Transformer V2](../model_doc/swinv2), [VAN](../model_doc/van), [ViT](../model_doc/vit), [ViT Hybrid](../model_doc/vit_hybrid), [ViTMSN](../model_doc/vit_msn)
|
||||
|
||||
<!--End of the generated tip-->
|
||||
|
||||
|
@ -1762,6 +1762,7 @@ else:
|
||||
_import_structure["models.clip"].extend(
|
||||
[
|
||||
"CLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"CLIPForImageClassification",
|
||||
"CLIPModel",
|
||||
"CLIPPreTrainedModel",
|
||||
"CLIPTextModel",
|
||||
@ -3200,6 +3201,7 @@ else:
|
||||
_import_structure["models.siglip"].extend(
|
||||
[
|
||||
"SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"SiglipForImageClassification",
|
||||
"SiglipModel",
|
||||
"SiglipPreTrainedModel",
|
||||
"SiglipTextModel",
|
||||
@ -6447,6 +6449,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.clip import (
|
||||
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
CLIPForImageClassification,
|
||||
CLIPModel,
|
||||
CLIPPreTrainedModel,
|
||||
CLIPTextModel,
|
||||
@ -7625,6 +7628,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.siglip import (
|
||||
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
SiglipForImageClassification,
|
||||
SiglipModel,
|
||||
SiglipPreTrainedModel,
|
||||
SiglipTextModel,
|
||||
|
@ -498,6 +498,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
# Model for Image Classification mapping
|
||||
("beit", "BeitForImageClassification"),
|
||||
("bit", "BitForImageClassification"),
|
||||
("clip", "CLIPForImageClassification"),
|
||||
("convnext", "ConvNextForImageClassification"),
|
||||
("convnextv2", "ConvNextV2ForImageClassification"),
|
||||
("cvt", "CvtForImageClassification"),
|
||||
@ -540,6 +541,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("regnet", "RegNetForImageClassification"),
|
||||
("resnet", "ResNetForImageClassification"),
|
||||
("segformer", "SegformerForImageClassification"),
|
||||
("siglip", "SiglipForImageClassification"),
|
||||
("swiftformer", "SwiftFormerForImageClassification"),
|
||||
("swin", "SwinForImageClassification"),
|
||||
("swinv2", "Swinv2ForImageClassification"),
|
||||
|
@ -67,6 +67,7 @@ else:
|
||||
"CLIPTextModelWithProjection",
|
||||
"CLIPVisionModel",
|
||||
"CLIPVisionModelWithProjection",
|
||||
"CLIPForImageClassification",
|
||||
]
|
||||
|
||||
try:
|
||||
@ -136,6 +137,7 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
from .modeling_clip import (
|
||||
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
CLIPForImageClassification,
|
||||
CLIPModel,
|
||||
CLIPPreTrainedModel,
|
||||
CLIPTextModel,
|
||||
|
@ -21,13 +21,15 @@ from typing import Any, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
@ -38,8 +40,14 @@ from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "CLIPConfig"
|
||||
_CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
|
||||
|
||||
# Image classification docstring
|
||||
_IMAGE_CLASS_CHECKPOINT = "openai/clip-vit-base-patch32"
|
||||
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
|
||||
|
||||
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"openai/clip-vit-base-patch32",
|
||||
# See all CLIP models at https://huggingface.co/models?filter=clip
|
||||
@ -1306,3 +1314,105 @@ class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
|
||||
hidden_states=vision_outputs.hidden_states,
|
||||
attentions=vision_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
CLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
|
||||
the patch tokens) e.g. for ImageNet.
|
||||
""",
|
||||
CLIP_START_DOCSTRING,
|
||||
)
|
||||
class CLIPForImageClassification(CLIPPreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self, config: CLIPConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.vision_model = CLIPVisionTransformer(config.vision_config)
|
||||
|
||||
# Classifier head
|
||||
self.classifier = (
|
||||
nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
||||
output_type=ImageClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, ImageClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.vision_model(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
# average pool the patch tokens
|
||||
sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
|
||||
# apply classifier
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return ImageClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -61,6 +61,7 @@ else:
|
||||
"SiglipPreTrainedModel",
|
||||
"SiglipTextModel",
|
||||
"SiglipVisionModel",
|
||||
"SiglipForImageClassification",
|
||||
]
|
||||
|
||||
|
||||
@ -97,6 +98,7 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
from .modeling_siglip import (
|
||||
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
SiglipForImageClassification,
|
||||
SiglipModel,
|
||||
SiglipPreTrainedModel,
|
||||
SiglipTextModel,
|
||||
|
@ -24,14 +24,16 @@ import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
@ -42,8 +44,15 @@ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionCo
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "SiglipConfig"
|
||||
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
|
||||
|
||||
# Image classification docstring
|
||||
_IMAGE_CLASS_CHECKPOINT = "google/siglip-base-patch16-224"
|
||||
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_1"
|
||||
|
||||
|
||||
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"google/siglip-base-patch16-224",
|
||||
# See all SigLIP models at https://huggingface.co/models?filter=siglip
|
||||
@ -1185,3 +1194,105 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
text_model_output=text_outputs,
|
||||
vision_model_output=vision_outputs,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
|
||||
the patch tokens) e.g. for ImageNet.
|
||||
""",
|
||||
SIGLIP_START_DOCSTRING,
|
||||
)
|
||||
class SiglipForImageClassification(SiglipPreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self, config: SiglipConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.vision_model = SiglipVisionTransformer(config.vision_config)
|
||||
|
||||
# Classifier head
|
||||
self.classifier = (
|
||||
nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
||||
output_type=ImageClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, ImageClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.vision_model(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
# average pool the patch tokens
|
||||
sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
|
||||
# apply classifier
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return ImageClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -1901,6 +1901,13 @@ class ClapTextModelWithProjection(metaclass=DummyObject):
|
||||
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class CLIPForImageClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class CLIPModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@ -7583,6 +7590,13 @@ class SEWDPreTrainedModel(metaclass=DummyObject):
|
||||
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class SiglipForImageClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SiglipModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -51,6 +51,7 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
CLIPForImageClassification,
|
||||
CLIPModel,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
@ -744,6 +745,65 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class CLIPForImageClassificationModelTester(CLIPModelTester):
|
||||
def __init__(self, parent):
|
||||
super().__init__(parent)
|
||||
self.batch_size = self.vision_model_tester.batch_size
|
||||
self.num_hidden_layers = self.vision_model_tester.num_hidden_layers
|
||||
self.hidden_size = self.vision_model_tester.hidden_size
|
||||
self.seq_length = self.vision_model_tester.seq_length
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class CLIPForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (CLIPForImageClassification,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-classification": CLIPForImageClassification} if is_torch_available() else {}
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = CLIPForImageClassificationModelTester(self)
|
||||
|
||||
@unittest.skip(reason="CLIPForImageClassification does not support inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CLIPForImageClassification does not support inputs_embeds")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CLIPForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CLIPForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CLIPForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CLIP uses the same initialization scheme as the Flax original implementation")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
|
@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch Siglip model. """
|
||||
""" Testing suite for the PyTorch SigLIP model. """
|
||||
|
||||
|
||||
import inspect
|
||||
@ -47,7 +47,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import SiglipModel, SiglipTextModel, SiglipVisionModel
|
||||
from transformers import SiglipForImageClassification, SiglipModel, SiglipTextModel, SiglipVisionModel
|
||||
from transformers.models.siglip.modeling_siglip import SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@ -584,6 +584,65 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class SiglipForImageClassificationModelTester(SiglipModelTester):
|
||||
def __init__(self, parent):
|
||||
super().__init__(parent)
|
||||
self.batch_size = self.vision_model_tester.batch_size
|
||||
self.num_hidden_layers = self.vision_model_tester.num_hidden_layers
|
||||
self.hidden_size = self.vision_model_tester.hidden_size
|
||||
self.seq_length = self.vision_model_tester.seq_length
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class SiglipForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (SiglipForImageClassification,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-classification": SiglipForImageClassification} if is_torch_available() else {}
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SiglipForImageClassificationModelTester(self)
|
||||
|
||||
@unittest.skip(reason="SiglipForImageClassification does not support inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SiglipForImageClassification does not support inputs_embeds")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
|
Loading…
Reference in New Issue
Block a user