Qwen2-VL: clean-up and add more tests (#33354)

* clean-up on qwen2-vl and add generation tests

* add video tests

* Update tests/models/qwen2_vl/test_processing_qwen2_vl.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix and add better tests

* Update src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* update docs and address comments

* Update docs/source/en/model_doc/qwen2_vl.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update docs/source/en/model_doc/qwen2_vl.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* update

* remove size at all

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay 2024-09-12 18:24:04 +02:00 committed by GitHub
parent 8f8af0fb38
commit 2f611d30d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 297 additions and 106 deletions

View File

@ -229,8 +229,6 @@ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixel
```
#### Multiple Image Inputs
By default, images and video content are directly included in the conversation. When handling multiple images, it's helpful to add labels to the images and videos for better reference. Users can control this behavior with the following settings:

View File

@ -844,6 +844,7 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
@ -1190,9 +1191,12 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
# the hard coded `3` is for temporal, height and width.
if position_ids is None:
# the hard coded `3` is for temporal, height and width.
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
elif position_ids.dim() == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
@ -1777,13 +1781,13 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape

View File

@ -21,25 +21,40 @@
Processor class for Qwen2-VL.
"""
from typing import List, Optional, Union
from typing import List, Union
try:
from typing import Unpack
except ImportError:
from typing_extensions import Unpack
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, VideoInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType, logging
from ...processing_utils import (
ProcessingKwargs,
ProcessorMixin,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
logger = logging.get_logger(__name__)
class Qwen2VLProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
}
class Qwen2VLProcessor(ProcessorMixin):
r"""
Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor.
[`Qwen2VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
[`~Qwen2VLProcessor.__call__`] and [`~Qwen2VLProcessor.decode`] for more information.
Args:
image_processor ([`Qwen2VLImageProcessor`], *optional*):
The image processor is a required input.
@ -62,10 +77,7 @@ class Qwen2VLProcessor(ProcessorMixin):
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
videos: VideoInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: int = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
**kwargs: Unpack[Qwen2VLProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
@ -84,22 +96,8 @@ class Qwen2VLProcessor(ProcessorMixin):
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
@ -117,15 +115,20 @@ class Qwen2VLProcessor(ProcessorMixin):
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Qwen2VLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images=images, videos=None, return_tensors=return_tensors)
image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
image_grid_thw = None
if videos is not None:
videos_inputs = self.image_processor(images=None, videos=videos, return_tensors=return_tensors)
videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["videos_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]
else:
videos_inputs = {}
@ -156,9 +159,8 @@ class Qwen2VLProcessor(ProcessorMixin):
index += 1
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
text_inputs = self.tokenizer(
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
)
_ = output_kwargs["text_kwargs"].pop("padding_side", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})

View File

@ -65,6 +65,7 @@ class Qwen2VLVisionText2TextModelTester:
image_size=28,
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
vision_start_token_id=151652,
image_token_id=151655,
video_token_id=151656,
@ -76,7 +77,7 @@ class Qwen2VLVisionText2TextModelTester:
max_window_layers=3,
model_type="qwen2_vl",
num_attention_heads=4,
num_hidden_layers=3,
num_hidden_layers=4,
num_key_value_heads=2,
rope_theta=10000,
tie_word_embeddings=True,
@ -98,6 +99,7 @@ class Qwen2VLVisionText2TextModelTester:
self.ignore_index = ignore_index
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.vision_start_token_id = vision_start_token_id
self.image_token_id = image_token_id
self.video_token_id = video_token_id
@ -137,6 +139,7 @@ class Qwen2VLVisionText2TextModelTester:
tie_word_embeddings=self.tie_word_embeddings,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
vision_start_token_id=self.vision_start_token_id,
image_token_id=self.image_token_id,
video_token_id=self.video_token_id,
@ -162,6 +165,8 @@ class Qwen2VLVisionText2TextModelTester:
vision_seqlen = pixel_values.shape[0] // self.batch_size // (self.vision_config["spatial_merge_size"] ** 2)
input_ids = ids_tensor([self.batch_size, self.seq_length - 1 + vision_seqlen], self.vocab_size)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[:, torch.arange(vision_seqlen, device=torch_device) + 1] = self.image_token_id
labels = torch.zeros(
(self.batch_size, self.seq_length - 1 + vision_seqlen),
@ -221,6 +226,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
"""
all_model_classes = (Qwen2VLForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Qwen2VLForConditionalGeneration,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
@ -300,6 +306,12 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
def test_model_is_small(self):
pass
@unittest.skip(
reason="Qwen2-VL can't do low-memory generation because position IDs have extra dimension and split function doesn't work for that"
)
def test_beam_search_low_memory(self):
pass
@require_torch
class Qwen2VLIntegrationTest(unittest.TestCase):

View File

@ -0,0 +1,237 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
import pytest
from transformers import AutoProcessor, Qwen2Tokenizer
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from transformers import Qwen2VLImageProcessor, Qwen2VLProcessor
@require_vision
@require_torch
class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Qwen2VLProcessor
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", patch_size=4)
processor.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_save_load_pretrained_default(self):
tokenizer = self.get_tokenizer()
image_processor = self.get_image_processor()
processor = Qwen2VLProcessor(tokenizer=tokenizer, image_processor=image_processor)
processor.save_pretrained(self.tmpdirname)
processor = Qwen2VLProcessor.from_pretrained(self.tmpdirname, use_fast=False)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string())
self.assertIsInstance(processor.tokenizer, Qwen2Tokenizer)
self.assertIsInstance(processor.image_processor, Qwen2VLImageProcessor)
def test_image_processor(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = Qwen2VLProcessor(tokenizer=tokenizer, image_processor=image_processor)
image_input = self.prepare_image_inputs()
input_image_proc = image_processor(image_input, return_tensors="np")
input_processor = processor(images=image_input, text="dummy", return_tensors="np")
for key in input_image_proc.keys():
self.assertAlmostEqual(input_image_proc[key].sum(), input_processor[key].sum(), delta=1e-2)
def test_processor(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = Qwen2VLProcessor(tokenizer=tokenizer, image_processor=image_processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"])
# test if it raises when no input is passed
with pytest.raises(ValueError):
processor()
# test if it raises when no text is passed
with pytest.raises(TypeError):
processor(images=image_input)
def test_model_input_names(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = Qwen2VLProcessor(tokenizer=tokenizer, image_processor=image_processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
video_inputs = self.prepare_video_inputs()
inputs = processor(text=input_str, images=image_input, videos=video_inputs)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
# Qwen2-VL doesn't accept `size` and resized to an optimal size using image_processor attrbutes
# defined at `init`. Therefore, all tests are overwritten and don't actually test if kwargs are passed
# to image processors
def test_image_processor_defaults_preserved_by_image_kwargs(self):
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertEqual(inputs["pixel_values"].shape[0], 800)
def test_kwargs_overrides_default_image_processor_kwargs(self):
image_processor = self.get_component(
"image_processor",
)
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertEqual(inputs["pixel_values"].shape[0], 800)
def test_unstructured_kwargs(self):
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
padding="max_length",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[0], 800)
self.assertEqual(len(inputs["input_ids"][0]), 76)
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[0], 1600)
self.assertEqual(len(inputs["input_ids"][0]), 4)
def test_structured_kwargs_nested(self):
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(inputs["pixel_values"].shape[0], 800)
self.assertEqual(len(inputs["input_ids"][0]), 76)
def test_structured_kwargs_nested_from_dict(self):
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[0], 800)
self.assertEqual(len(inputs["input_ids"][0]), 76)
def test_image_processor_defaults_preserved_by_video_kwargs(self):
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
video_input = self.prepare_video_inputs()
inputs = processor(text=input_str, videos=video_input)
self.assertEqual(inputs["pixel_values_videos"].shape[0], 9600)

View File

@ -23,15 +23,12 @@ try:
from typing import Unpack
except ImportError:
from typing_extensions import Unpack
import unittest
import numpy as np
from transformers import CLIPTokenizerFast, ProcessorMixin
from transformers.models.auto.processing_auto import processor_class_from_name
from transformers.testing_utils import (
check_json_file_has_correct_format,
require_tokenizers,
require_torch,
require_vision,
)
@ -41,8 +38,6 @@ from transformers.utils import is_vision_available
if is_vision_available():
from PIL import Image
from transformers import CLIPImageProcessor
def prepare_image_inputs():
"""This function prepares a list of PIL images"""
@ -53,7 +48,6 @@ def prepare_image_inputs():
@require_torch
@require_vision
@require_torch
class ProcessorTesterMixin:
processor_class = None
@ -91,6 +85,13 @@ class ProcessorTesterMixin:
"""This function prepares a list of PIL images for testing"""
return prepare_image_inputs()
@require_vision
def prepare_video_inputs(self):
"""This function prepares a list of numpy videos."""
video_input = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] * 8
image_inputs = [video_input] * 3 # batch-size=3
return image_inputs
def test_processor_to_json_string(self):
processor = self.get_processor()
obj = json.loads(processor.to_json_string())
@ -125,8 +126,6 @@ class ProcessorTesterMixin:
if not is_kwargs_typed_dict:
self.skipTest(f"{self.processor_class} doesn't have typed kwargs.")
@require_vision
@require_torch
def test_tokenizer_defaults_preserved_by_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
@ -141,8 +140,6 @@ class ProcessorTesterMixin:
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
self.assertEqual(len(inputs["input_ids"][0]), 117)
@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
@ -158,8 +155,6 @@ class ProcessorTesterMixin:
inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
@require_vision
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
@ -176,8 +171,6 @@ class ProcessorTesterMixin:
)
self.assertEqual(len(inputs["input_ids"][0]), 112)
@require_torch
@require_vision
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
@ -193,8 +186,6 @@ class ProcessorTesterMixin:
inputs = processor(text=input_str, images=image_input, size=[224, 224])
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)
@require_torch
@require_vision
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
@ -218,8 +209,6 @@ class ProcessorTesterMixin:
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
@ -244,8 +233,6 @@ class ProcessorTesterMixin:
self.assertEqual(len(inputs["input_ids"][0]), 6)
@require_torch
@require_vision
def test_doubly_passed_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
@ -265,8 +252,6 @@ class ProcessorTesterMixin:
size={"height": 214, "width": 214},
)
@require_torch
@require_vision
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
@ -293,8 +278,6 @@ class ProcessorTesterMixin:
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
@ -318,48 +301,3 @@ class ProcessorTesterMixin:
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
class MyProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "CLIPImageProcessor"
tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
def __init__(self, image_processor=None, tokenizer=None, processor_attr_1=1, processor_attr_2=True):
super().__init__(image_processor, tokenizer)
self.processor_attr_1 = processor_attr_1
self.processor_attr_2 = processor_attr_2
@require_tokenizers
@require_vision
class ProcessorTest(unittest.TestCase):
processor_class = MyProcessor
def prepare_processor_dict(self):
return {"processor_attr_1": 1, "processor_attr_2": False}
def get_processor(self):
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
processor = MyProcessor(image_processor, tokenizer, **self.prepare_processor_dict())
return processor
def test_processor_to_json_string(self):
processor = self.get_processor()
obj = json.loads(processor.to_json_string())
for key, value in self.prepare_processor_dict().items():
self.assertEqual(obj[key], value)
self.assertEqual(getattr(processor, key, None), value)
def test_processor_from_and_save_pretrained(self):
processor_first = self.get_processor()
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = processor_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
processor_second = self.processor_class.from_pretrained(tmpdirname)
self.assertEqual(processor_second.to_dict(), processor_first.to_dict())