transformers/tests/pipelines/test_pipelines_zero_shot_image_classification.py
NielsRogge 3b742ea84c
Add SigLIP (#26522)
* Add first draft

* Use appropriate gelu function

* More improvements

* More improvements

* More improvements

* Convert checkpoint

* More improvements

* Improve docs, remove print statements

* More improvements

* Add link

* remove unused masking function

* begin tokenizer

* do_lower_case

* debug

* set split_special_tokens=True

* Remove script

* Fix style

* Fix rebase

* Use same design as CLIP

* Add fast tokenizer

* Add SiglipTokenizer to init, remove extra_ids

* Improve conversion script

* Use smaller inputs in conversion script

* Update conversion script

* More improvements

* Add processor to conversion script

* Add tests

* Remove print statements

* Add tokenizer tests

* Fix more tests

* More improvements related to weight initialization

* More improvements

* Make more tests pass

* More improvements

* More improvements

* Add copied from

* Add canonicalize_text

* Enable fast tokenizer tests

* More improvements

* Fix most slow tokenizer tests

* Address comments

* Fix style

* Remove script

* Address some comments

* Add copied from to tests

* Add more copied from

* Add more copied from

* Add more copied from

* Remove is_flax_available

* More updates

* Address comment

* Remove SiglipTokenizerFast for now

* Add caching

* Remove umt5 test

* Add canonicalize_text inside _tokenize, thanks Arthur

* Fix image processor tests

* Skip tests which are not applicable

* Skip test_initialization

* More improvements

* Compare pixel values

* Fix doc tests, add integration test

* Add do_normalize

* Remove causal mask and leverage ignore copy

* Fix attention_mask

* Fix remaining tests

* Fix dummies

* Rename temperature and bias

* Address comments

* Add copied from to tokenizer tests

* Add SiglipVisionModel to auto mapping

* Add copied from to image processor tests

* Improve doc

* Remove SiglipVisionModel from index

* Address comments

* Improve docs

* Simplify config

* Add first draft

* Make it like mistral

* More improvements

* Fix attention_mask

* Fix output_attentions

* Add note in docs

* Convert multilingual model

* Convert large checkpoint

* Convert more checkpoints

* Add pipeline support, correct image_mean and image_std

* Use padding=max_length by default

* Make processor like llava

* Add code snippet

* Convert more checkpoints

* Set keep_punctuation_string=None as in OpenCLIP

* Set normalized=False for special tokens

* Fix doc test

* Update integration test

* Add figure

* Update organization

* Happy new year

* Use AutoModel everywhere

---------

Co-authored-by: patil-suraj <surajp815@gmail.com>
2024-01-08 18:17:16 +01:00

278 lines
10 KiB
Python

# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_vision_available
from transformers.pipelines import pipeline
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_tf,
require_torch,
require_vision,
slow,
)
from .test_pipelines_common import ANY
if is_vision_available():
from PIL import Image
else:
class Image:
@staticmethod
def open(*args, **kwargs):
pass
@is_pipeline_test
@require_vision
class ZeroShotImageClassificationPipelineTests(unittest.TestCase):
# Deactivating auto tests since we don't have a good MODEL_FOR_XX mapping,
# and only CLIP would be there for now.
# model_mapping = {CLIPConfig: CLIPModel}
# def get_test_pipeline(self, model, tokenizer, processor):
# if tokenizer is None:
# # Side effect of no Fast Tokenizer class for these model, so skipping
# # But the slow tokenizer test should still run as they're quite small
# self.skipTest("No tokenizer available")
# return
# # return None, None
# image_classifier = ZeroShotImageClassificationPipeline(
# model=model, tokenizer=tokenizer, feature_extractor=processor
# )
# # test with a raw waveform
# image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
# image2 = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
# return image_classifier, [image, image2]
# def run_pipeline_test(self, pipe, examples):
# image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
# outputs = pipe(image, candidate_labels=["A", "B"])
# self.assertEqual(outputs, {"text": ANY(str)})
# # Batching
# outputs = pipe([image] * 3, batch_size=2, candidate_labels=["A", "B"])
@require_torch
def test_small_model_pt(self):
image_classifier = pipeline(
model="hf-internal-testing/tiny-random-clip-zero-shot-image-classification",
)
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
output = image_classifier(image, candidate_labels=["a", "b", "c"])
# The floating scores are so close, we enter floating error approximation and the order is not guaranteed across
# python and torch versions.
self.assertIn(
nested_simplify(output),
[
[{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "b"}, {"score": 0.333, "label": "c"}],
[{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}, {"score": 0.333, "label": "b"}],
[{"score": 0.333, "label": "b"}, {"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}],
],
)
output = image_classifier([image] * 5, candidate_labels=["A", "B", "C"], batch_size=2)
self.assertEqual(
nested_simplify(output),
# Pipeline outputs are supposed to be deterministic and
# So we could in theory have real values "A", "B", "C" instead
# of ANY(str).
# However it seems that in this particular case, the floating
# scores are so close, we enter floating error approximation
# and the order is not guaranteed anymore with batching.
[
[
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
],
[
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
],
[
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
],
[
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
],
[
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
],
],
)
@require_tf
def test_small_model_tf(self):
image_classifier = pipeline(
model="hf-internal-testing/tiny-random-clip-zero-shot-image-classification", framework="tf"
)
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
output = image_classifier(image, candidate_labels=["a", "b", "c"])
self.assertEqual(
nested_simplify(output),
[{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "b"}, {"score": 0.333, "label": "c"}],
)
output = image_classifier([image] * 5, candidate_labels=["A", "B", "C"], batch_size=2)
self.assertEqual(
nested_simplify(output),
# Pipeline outputs are supposed to be deterministic and
# So we could in theory have real values "A", "B", "C" instead
# of ANY(str).
# However it seems that in this particular case, the floating
# scores are so close, we enter floating error approximation
# and the order is not guaranteed anymore with batching.
[
[
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
],
[
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
],
[
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
],
[
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
],
[
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
{"score": 0.333, "label": ANY(str)},
],
],
)
@slow
@require_torch
def test_large_model_pt(self):
image_classifier = pipeline(
task="zero-shot-image-classification",
model="openai/clip-vit-base-patch32",
)
# This is an image of 2 cats with remotes and no planes
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
output = image_classifier(image, candidate_labels=["cat", "plane", "remote"])
self.assertEqual(
nested_simplify(output),
[
{"score": 0.511, "label": "remote"},
{"score": 0.485, "label": "cat"},
{"score": 0.004, "label": "plane"},
],
)
output = image_classifier([image] * 5, candidate_labels=["cat", "plane", "remote"], batch_size=2)
self.assertEqual(
nested_simplify(output),
[
[
{"score": 0.511, "label": "remote"},
{"score": 0.485, "label": "cat"},
{"score": 0.004, "label": "plane"},
],
]
* 5,
)
@slow
@require_tf
def test_large_model_tf(self):
image_classifier = pipeline(
task="zero-shot-image-classification", model="openai/clip-vit-base-patch32", framework="tf"
)
# This is an image of 2 cats with remotes and no planes
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
output = image_classifier(image, candidate_labels=["cat", "plane", "remote"])
self.assertEqual(
nested_simplify(output),
[
{"score": 0.511, "label": "remote"},
{"score": 0.485, "label": "cat"},
{"score": 0.004, "label": "plane"},
],
)
output = image_classifier([image] * 5, candidate_labels=["cat", "plane", "remote"], batch_size=2)
self.assertEqual(
nested_simplify(output),
[
[
{"score": 0.511, "label": "remote"},
{"score": 0.485, "label": "cat"},
{"score": 0.004, "label": "plane"},
],
]
* 5,
)
@slow
@require_torch
def test_siglip_model_pt(self):
image_classifier = pipeline(
task="zero-shot-image-classification",
model="google/siglip-base-patch16-224",
)
# This is an image of 2 cats with remotes and no planes
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
output = image_classifier(image, candidate_labels=["2 cats", "a plane", "a remote"])
self.assertEqual(
nested_simplify(output),
[
{"score": 0.198, "label": "2 cats"},
{"score": 0.0, "label": "a remote"},
{"score": 0.0, "label": "a plane"},
],
)
output = image_classifier([image] * 5, candidate_labels=["2 cats", "a plane", "a remote"], batch_size=2)
self.assertEqual(
nested_simplify(output),
[
[
{"score": 0.198, "label": "2 cats"},
{"score": 0.0, "label": "a remote"},
{"score": 0.0, "label": "a plane"},
]
]
* 5,
)