transformers/tests/models/siglip/test_modeling_siglip.py
Raushan Turganbay 21d5025826
Attn implementation for composite models (#32238)
* first try

* codestyle

* idefics2 is happy

* [run-slow] llava, llava_next, video_llava, vipllava, llava_next_video, idefics, idefics2, kosmos2, fuyu, blip, blip_2, instructblip, instructblipvideo, paligemma

* fix-copies

* [run-slow] llava, llava_next, video_llava, vipllava, llava_next_video, idefics, idefics2, kosmos2, fuyu, blip, blip_2, instructblip, instructblipvideo

* blip-2 needs to init vision from config

* when was this removed O_o

* minor fix

* tests

* this way?

* tests

* model-agnostic code

* codestyle

* add tests for idefics

* modify general test for VLMs

* no generation test for vlm yet!

* no generation test here also

* wanr in VIT-SDPA if output attn

* add more tests

* user can pass dict as attn impl

* repo consistency

* update

* muicgen

* no prints

* forgot speech enc-dec and clip

* how many composite models we have?

* musicgen meelody is same as mudicgen

* +siglip

* fix tests + add some more

* remove idefics custom overriden code

* make idefics2 automappable

* nits

* skip tests

* doctests

* Update src/transformers/models/idefics2/configuration_idefics2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/clip/test_modeling_clip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/idefics2/test_modeling_idefics2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/idefics2/test_modeling_idefics2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/configuration_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* major update, no need for automap

* clean up

* add FA2 test

* more tests

* style

* skip tests

* why did these started failing now?

* no attributes for FA2 needed

* one tiny test

* address comment about FA2 false warning

* style

* add new models and resolve conflicts

* fix copies

* let it be this way for now, come back tomorrow to review

* some more fixes

* update

* more updates

* update

* fix copies

* style and tests

* another big update

* fix tests

* fix tests

* update

* another update

* fix tests

* fix copies

* fix tests

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2024-10-22 06:54:44 +02:00

1042 lines
42 KiB
Python

# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch SigLIP model."""
import inspect
import os
import tempfile
import unittest
from typing import Tuple
import numpy as np
import requests
from parameterized import parameterized
from pytest import mark
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from transformers.testing_utils import (
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_sdpa,
require_vision,
slow,
torch_device,
)
from transformers.utils import (
is_torch_available,
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
is_torch_sdpa_available,
is_vision_available,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
is_flaky,
random_attention_mask,
)
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from torch import nn
from transformers import SiglipForImageClassification, SiglipModel, SiglipTextModel, SiglipVisionModel
if is_torch_sdpa_available():
from torch.nn.attention import SDPBackend, sdpa_kernel
if is_vision_available():
from PIL import Image
from transformers import SiglipProcessor
class SiglipModelTesterMixin(ModelTesterMixin):
def test_sdpa_can_dispatch_composite_models(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# Load the model with SDPA
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
# Load model with eager attention
model_eager = model_class.from_pretrained(
tmpdirname,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
# SigLip has one shared cls attr for all models, so we assign both submodels heer
vision_attn = text_attn = "sdpa" if model._supports_sdpa else "eager"
if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "text_model"):
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model_sdpa.text_model.config._attn_implementation == text_attn)
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
self.assertTrue(model_eager.text_model.config._attn_implementation == "eager")
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
def test_eager_matches_sdpa_inference(
self,
torch_dtype: str,
use_attention_mask_options: Tuple[bool, ...] = (True, False),
logit_keys: Tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
):
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Convert to torch dtype
dtypes = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
torch_dtype = dtypes[torch_dtype]
atols = {
torch.float32: 1e-5,
torch.bfloat16: 3e-2,
torch.float16: 5e-3,
}
rtols = {
torch.float32: 1e-4,
torch.bfloat16: 3e-2,
torch.float16: 5e-3,
}
atol = atols[torch_dtype]
rtol = rtols[torch_dtype]
def get_mean_reldiff(msg, current_case, x, ref, atol, rtol):
return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# Load the model with SDPA
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
# Load model with eager attention
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time,
# but it would be nicer to have an efficient way to use parameterized.expand
cases = [
(use_mask, output_attentions, sdpa_backend, batch_size)
for use_mask in use_attention_mask_options
for output_attentions in [True, False]
for sdpa_backend in [
SDPBackend.MATH,
[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH],
[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
]
for batch_size in [1, 5]
]
fail_cases = []
for use_mask, output_attentions, sdpa_backend, batch_size in cases:
processed_inputs = inputs_dict.copy()
# convert to torch_dtype
if "pixel_values" in processed_inputs:
processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(torch_dtype)
# slice for different batch sizes
for key in ["pixel_values", "input_ids", "attention_mask"]:
if key in processed_inputs:
processed_inputs[key] = processed_inputs[key][:batch_size]
# set attention mask with left padding
if not use_mask:
processed_inputs.pop("attention_mask", None)
else:
dummy_attention_mask = processed_inputs["attention_mask"]
dummy_attention_mask[:] = 1
dummy_attention_mask[:, :1] = 0
processed_inputs["attention_mask"] = dummy_attention_mask
processed_inputs["output_attentions"] = output_attentions
processed_inputs["output_hidden_states"] = True
current_case = (
f"padding_side=left, use_mask={use_mask}, batch_size={batch_size}, sdpa_backend={sdpa_backend}"
)
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
with torch.no_grad():
try:
with sdpa_kernel(sdpa_backend):
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
except Exception as e:
fail_cases.append(f"{current_case}: {e}")
continue
for key in logit_keys:
eager_logits = outputs_eager[key]
sdpa_logits = outputs_sdpa[key]
if use_mask:
eager_logits = eager_logits[:, 1:]
sdpa_logits = sdpa_logits[:, 1:]
is_close = torch.allclose(eager_logits, sdpa_logits, atol=atol, rtol=rtol)
if not is_close:
fail_cases.append(get_mean_reldiff(key, current_case, sdpa_logits, eager_logits, atol, rtol))
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
class SiglipVisionModelTester:
def __init__(
self,
parent,
batch_size=12,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
initializer_range=0.02,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.initializer_range = initializer_range
self.scope = scope
# in ViT, the seq length equals the number of patches
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches
# Copied from tests.models.clip.test_modeling_clip.CLIPVisionModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
return SiglipVisionConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, pixel_values):
model = SiglipVisionModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
# Copied from tests.models.clip.test_modeling_clip.CLIPVisionModelTester.prepare_config_and_inputs_for_common
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class SiglipVisionModelTest(SiglipModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as SIGLIP does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (SiglipVisionModel,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
# MP works but offload doesn't work when the MultiheadAttention is offloaded
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
# in the dispatch_model function
test_cpu_offload = False
test_disk_offload_safetensors = False
test_disk_offload_bin = False
def setUp(self):
self.model_tester = SiglipVisionModelTester(self)
self.config_tester = ConfigTester(
self, config_class=SiglipVisionConfig, has_text_modality=False, hidden_size=37
)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="SIGLIP does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="SiglipVisionModel does not support standalone training")
def test_training(self):
pass
@unittest.skip(reason="SiglipVisionModel does not support standalone training")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SiglipVisionModel does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SiglipVisionModel does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="SiglipVisionModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="SiglipVisionModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_to_base(self):
pass
@unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation")
def test_initialization(self):
pass
@slow
def test_model_from_pretrained(self):
model_name = "google/siglip-base-patch16-224"
model = SiglipVisionModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("pooler_output", "last_hidden_state"),
use_attention_mask_options=(False,),
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
class SiglipTextModelTester:
def __init__(
self,
parent,
batch_size=12,
seq_length=7,
is_training=True,
use_input_mask=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=512,
initializer_range=0.02,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.scope = scope
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
if input_mask is not None:
batch_size, seq_length = input_mask.shape
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
input_mask[batch_idx, :start_index] = 1
input_mask[batch_idx, start_index:] = 0
config = self.get_config()
return config, input_ids, input_mask
def get_config(self):
return SiglipTextConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, input_ids, input_mask):
model = SiglipTextModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTester.prepare_config_and_inputs_for_common
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class SiglipTextModelTest(SiglipModelTesterMixin, unittest.TestCase):
all_model_classes = (SiglipTextModel,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_head_masking = False
model_split_percents = [0.5, 0.8, 0.9]
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.setUp with CLIP->Siglip
def setUp(self):
self.model_tester = SiglipTextModelTester(self)
self.config_tester = ConfigTester(self, config_class=SiglipTextConfig, hidden_size=37)
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_config
def test_config(self):
self.config_tester.run_common_tests()
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_model
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="SiglipTextModel does not support standalone training")
def test_training(self):
pass
@unittest.skip(reason="SiglipTextModel does not support standalone training")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SiglipTextModel does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SiglipTextModel does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Siglip does not use inputs_embeds")
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_inputs_embeds
def test_inputs_embeds(self):
pass
@unittest.skip(reason="SiglipTextModel has no base class and is not available in MODEL_MAPPING")
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_save_load_fast_init_from_base
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="SiglipTextModel has no base class and is not available in MODEL_MAPPING")
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_save_load_fast_init_to_base
def test_save_load_fast_init_to_base(self):
pass
@unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation")
def test_initialization(self):
pass
@slow
def test_model_from_pretrained(self):
model_name = "google/siglip-base-patch16-224"
model = SiglipTextModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("pooler_output", "last_hidden_state"),
use_attention_mask_options=(False, True),
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
class SiglipModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
if text_kwargs is None:
text_kwargs = {}
if vision_kwargs is None:
vision_kwargs = {}
self.parent = parent
self.text_model_tester = SiglipTextModelTester(parent, **text_kwargs)
self.vision_model_tester = SiglipVisionModelTester(parent, **vision_kwargs)
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
self.is_training = is_training
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, input_ids, attention_mask, pixel_values
def get_config(self):
return SiglipConfig.from_text_vision_configs(
self.text_model_tester.get_config(),
self.vision_model_tester.get_config(),
)
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
model = SiglipModel(config).to(torch_device).eval()
with torch.no_grad():
result = model(input_ids, pixel_values, attention_mask)
self.parent.assertEqual(
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
)
self.parent.assertEqual(
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"return_loss": False,
}
return config, inputs_dict
@require_torch
class SiglipModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (SiglipModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": SiglipModel} if is_torch_available() else {}
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
# MP works but offload doesn't work when the MultiheadAttention is offloaded
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
# in the dispatch_model function
test_cpu_offload = False
test_disk_offload_safetensors = False
test_disk_offload_bin = False
_is_composite = True
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.setUp with CLIP->Siglip
def setUp(self):
self.model_tester = SiglipModelTester(self)
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_model
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="Hidden_states is tested in individual model tests")
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_hidden_states_output
def test_hidden_states_output(self):
pass
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_inputs_embeds
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Retain_grad is tested in individual model tests")
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_retain_grad_hidden_states_attentions
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="SiglipModel does not have input/output embeddings")
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_model_get_set_embeddings
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation")
def test_initialization(self):
pass
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest._create_and_check_torchscript with CLIP->Siglip
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
self.skipTest(reason="test_torchscript is set to False")
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
try:
input_ids = inputs_dict["input_ids"]
pixel_values = inputs_dict["pixel_values"] # Siglip needs pixel_values
traced_model = torch.jit.trace(model, (input_ids, pixel_values))
except RuntimeError:
self.fail("Couldn't trace module.")
with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
model.to(torch_device)
model.eval()
loaded_model.to(torch_device)
loaded_model.eval()
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()
non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]
loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
self.assertTrue(found_buffer)
model_buffers.pop(i)
models_equal = True
for layer_name, p1 in model_state_dict.items():
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_load_vision_text_config with CLIP->Siglip
def test_load_vision_text_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Save SiglipConfig and check if we can load SiglipVisionConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
vision_config = SiglipVisionConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
# Save SiglipConfig and check if we can load SiglipTextConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
text_config = SiglipTextConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
@slow
def test_model_from_pretrained(self):
model_name = "google/siglip-base-patch16-224"
model = SiglipModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
dummy_input_ids = inputs_dict["input_ids"]
outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
outputs_fa = model_fa(
pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
)
self.assertTrue(
torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
f"Image logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
)
self.assertTrue(
torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
)
# Test with attention mask
dummy_attention_mask = inputs_dict["attention_mask"]
if dummy_attention_mask is not None:
dummy_attention_mask[:, 1:] = 1
dummy_attention_mask[:, :1] = 0
outputs = model(
pixel_values=dummy_pixel_values,
input_ids=dummy_input_ids,
attention_mask=dummy_attention_mask,
output_hidden_states=True,
)
outputs_fa = model_fa(
pixel_values=dummy_pixel_values,
input_ids=dummy_input_ids,
attention_mask=dummy_attention_mask,
output_hidden_states=True,
)
self.assertTrue(
torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
f"Logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
)
self.assertTrue(
torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
f"Logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
)
# check with inference + dropout
model.train()
_ = model_fa(
pixel_values=dummy_pixel_values,
input_ids=dummy_input_ids,
attention_mask=dummy_attention_mask,
output_hidden_states=True,
)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("SigLIP does not support right padding")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
use_attention_mask_options=(False, True),
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
class SiglipForImageClassificationModelTester(SiglipModelTester):
def __init__(self, parent):
super().__init__(parent)
self.batch_size = self.vision_model_tester.batch_size
self.num_hidden_layers = self.vision_model_tester.num_hidden_layers
self.hidden_size = self.vision_model_tester.hidden_size
self.seq_length = self.vision_model_tester.seq_length
def prepare_config_and_inputs(self):
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class SiglipForImageClassificationModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (SiglipForImageClassification,) if is_torch_available() else ()
pipeline_model_mapping = {"image-classification": SiglipForImageClassification} if is_torch_available() else {}
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
# MP works but offload doesn't work when the MultiheadAttention is offloaded
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
# in the dispatch_model function
test_cpu_offload = False
test_disk_offload_safetensors = False
test_disk_offload_bin = False
_is_composite = True
def setUp(self):
self.model_tester = SiglipForImageClassificationModelTester(self)
@unittest.skip(reason="SiglipForImageClassification does not support inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="SiglipForImageClassification does not support inputs_embeds")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation")
def test_initialization(self):
pass
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype, logit_keys=("logits",), use_attention_mask_options=(False,)
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return image
@require_vision
@require_torch
class SiglipModelIntegrationTest(unittest.TestCase):
@slow
def test_inference(self):
model_name = "google/siglip-base-patch16-224"
model = SiglipModel.from_pretrained(model_name).to(torch_device)
processor = SiglipProcessor.from_pretrained(model_name)
image = prepare_img()
inputs = processor(
text=["a photo of 2 cats", "a photo of 2 dogs"], images=image, padding="max_length", return_tensors="pt"
).to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
logits_per_text = outputs.logits_per_text
# verify the logits
self.assertEqual(
logits_per_image.shape,
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
)
self.assertEqual(
logits_per_text.shape,
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
)
expected_logits = torch.tensor([[-0.7567, -10.3354]], device=torch_device)
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
# verify the probs
probs = torch.sigmoid(logits_per_image) # these are the probabilities
expected_probs = torch.tensor([[3.1937e-01, 3.2463e-05]], device=torch_device)
self.assertTrue(torch.allclose(probs, expected_probs, atol=1e-3))
@slow
def test_inference_interpolate_pos_encoding(self):
model_name = "google/siglip-base-patch16-224"
model = SiglipModel.from_pretrained(model_name).to(torch_device)
# 640 x 480 image
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
processor = SiglipProcessor.from_pretrained(model_name, do_resize=False, size={"height": 480, "width": 640})
inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
# verify the shape
# patch size = 16
# batch size 1, (640/16) * (480/16) = 1200 patches, 768 hidden size
expected_shape = torch.Size((1, 1200, 768))
self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)