mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00

* Chameleon model integration Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Leonid Shamis <leonid.shamis@gmail.com> * fix 7B, again. mask away image tokens * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * remove pretrained_config_map * make fixup passing up to utils/check_config_docstrings.py; vqgan moved to the modeling file * remove tokenizer (use llama's); remove codechameleon tests * a few copied from statements and minor changes * copied from in ChameleonModel * some copies in ChameleonForCausalLM * a few more copies * VQModel moved to ChameleonModel (as opposed to being in the processor) * ChameleonProcessor ready * Fix chameleon weights convert * update conversion script * clean-up processing * update modeling a bit * update * update (throws error...) * correct conversion ready * fix tests * fix docs * docs * ve swin norm * fix device for vocab map * add normalization * update * update script with rope rotations * final fix on model conversion * add slow tests * more info in docs * fix repo consistency tests * fix repo tests * fix-copies * hope this will make CI happy * fix for 30b model * Update docs/source/en/index.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/chameleon.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/chameleon/modeling_chameleon.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/chameleon.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/chameleon.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/chameleon.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/chameleon.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/auto/configuration_auto.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/chameleon/image_processing_chameleon.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/chameleon/image_processing_chameleon.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/chameleon/image_processing_chameleon.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/chameleon/image_processing_chameleon.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/chameleon/modeling_chameleon.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/chameleon/processing_chameleon.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/chameleon/processing_chameleon.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/chameleon/test_modeling_chameleon.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/chameleon/test_modeling_chameleon.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/chameleon/test_modeling_chameleon.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * address comments * remove assertion in conversion script * add image processor test * not copied * port changes for qk layernorm * fix-copies * read token decorator for tests * [run-slow] chameleon * one more read-token * address some comments * qk norm changes * tests and repo check * moved rope permutations to conversion, YAY! * fix past kv check * docs * layernorm done! * let's be consistent in naming * fix slow tests * weird thing with slow CI, but let's see * once more try * remove past-kv as tuple following llama * ignore * style --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Co-authored-by: ArthurZucker <arthur.zucker@gmail.com> Co-authored-by: jacobkahn <jacobkahn1@gmail.com> Co-authored-by: Leonid Shamis <leonid.shamis@gmail.com> Co-authored-by: Leonid Shamis <lshamis@meta.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Joao Gante <joao@huggingface.co> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
206 lines
8.8 KiB
Python
206 lines
8.8 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 HuggingFace Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from transformers.testing_utils import require_torch, require_vision
|
|
from transformers.utils import is_torch_available, is_vision_available
|
|
|
|
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
if is_vision_available():
|
|
from PIL import Image
|
|
|
|
from transformers import ChameleonImageProcessor
|
|
|
|
|
|
class ChameleonImageProcessingTester(unittest.TestCase):
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=7,
|
|
num_channels=3,
|
|
image_size=18,
|
|
min_resolution=30,
|
|
max_resolution=200,
|
|
do_resize=True,
|
|
size=None,
|
|
do_center_crop=True,
|
|
crop_size=None,
|
|
do_normalize=True,
|
|
image_mean=[1.0, 1.0, 1.0],
|
|
image_std=[1.0, 1.0, 1.0],
|
|
do_convert_rgb=True,
|
|
):
|
|
size = size if size is not None else {"shortest_edge": 18}
|
|
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.num_channels = num_channels
|
|
self.image_size = image_size
|
|
self.min_resolution = min_resolution
|
|
self.max_resolution = max_resolution
|
|
self.do_resize = do_resize
|
|
self.size = size
|
|
self.do_center_crop = do_center_crop
|
|
self.crop_size = crop_size
|
|
self.do_normalize = do_normalize
|
|
self.image_mean = image_mean
|
|
self.image_std = image_std
|
|
self.do_convert_rgb = do_convert_rgb
|
|
|
|
def prepare_image_processor_dict(self):
|
|
return {
|
|
"do_resize": self.do_resize,
|
|
"size": self.size,
|
|
"do_center_crop": self.do_center_crop,
|
|
"crop_size": self.crop_size,
|
|
"do_normalize": self.do_normalize,
|
|
"image_mean": self.image_mean,
|
|
"image_std": self.image_std,
|
|
"do_convert_rgb": self.do_convert_rgb,
|
|
}
|
|
|
|
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.expected_output_image_shape
|
|
def expected_output_image_shape(self, images):
|
|
return self.num_channels, self.crop_size["height"], self.crop_size["width"]
|
|
|
|
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs
|
|
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
|
return prepare_image_inputs(
|
|
batch_size=self.batch_size,
|
|
num_channels=self.num_channels,
|
|
min_resolution=self.min_resolution,
|
|
max_resolution=self.max_resolution,
|
|
equal_resolution=equal_resolution,
|
|
numpify=numpify,
|
|
torchify=torchify,
|
|
)
|
|
|
|
|
|
@require_torch
|
|
@require_vision
|
|
class ChameleonImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|
image_processing_class = ChameleonImageProcessor if is_vision_available() else None
|
|
|
|
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->Chameleon
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.image_processor_tester = ChameleonImageProcessingTester(self)
|
|
|
|
@property
|
|
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict
|
|
def image_processor_dict(self):
|
|
return self.image_processor_tester.prepare_image_processor_dict()
|
|
|
|
def test_image_processor_properties(self):
|
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
|
self.assertTrue(hasattr(image_processing, "do_resize"))
|
|
self.assertTrue(hasattr(image_processing, "size"))
|
|
self.assertTrue(hasattr(image_processing, "do_center_crop"))
|
|
self.assertTrue(hasattr(image_processing, "center_crop"))
|
|
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
|
self.assertTrue(hasattr(image_processing, "image_mean"))
|
|
self.assertTrue(hasattr(image_processing, "image_std"))
|
|
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
|
|
|
def test_image_processor_from_dict_with_kwargs(self):
|
|
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
|
self.assertEqual(image_processor.size, {"shortest_edge": 18})
|
|
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
|
|
|
|
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
|
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
|
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
|
|
|
def test_call_pil(self):
|
|
# Initialize image_processing
|
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
|
# create random PIL images
|
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
|
|
for image in image_inputs:
|
|
self.assertIsInstance(image, Image.Image)
|
|
|
|
# Test not batched input
|
|
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
|
expected_output_image_shape = (1, 3, 18, 18)
|
|
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
|
|
|
# Test batched
|
|
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
|
expected_output_image_shape = (7, 3, 18, 18)
|
|
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
|
|
|
def test_call_numpy(self):
|
|
# Initialize image_processing
|
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
|
# create random numpy tensors
|
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
|
|
for image in image_inputs:
|
|
self.assertIsInstance(image, np.ndarray)
|
|
|
|
# Test not batched input
|
|
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
|
expected_output_image_shape = (1, 3, 18, 18)
|
|
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
|
|
|
# Test batched
|
|
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
|
expected_output_image_shape = (7, 3, 18, 18)
|
|
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
|
|
|
def test_call_pytorch(self):
|
|
# Initialize image_processing
|
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
|
# create random PyTorch tensors
|
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
|
|
|
|
for image in image_inputs:
|
|
self.assertIsInstance(image, torch.Tensor)
|
|
|
|
# Test not batched input
|
|
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
|
expected_output_image_shape = (1, 3, 18, 18)
|
|
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
|
|
|
# Test batched
|
|
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
|
expected_output_image_shape = (7, 3, 18, 18)
|
|
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
|
|
|
def test_nested_input(self):
|
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
|
|
|
|
# Test batched as a list of images
|
|
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
|
expected_output_image_shape = (7, 3, 18, 18)
|
|
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
|
|
|
# Test batched as a nested list of images, where each sublist is one batch
|
|
image_inputs_nested = [image_inputs[:3], image_inputs[3:]]
|
|
encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values
|
|
expected_output_image_shape = (7, 3, 18, 18)
|
|
self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)
|
|
|
|
# Image processor should return same pixel values, independently of input format
|
|
self.assertTrue((encoded_images_nested == encoded_images).all())
|