mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Convert image to rgb for clip model (#17101)
Co-authored-by: kuanwee.heng <kuanwee.heng@aaqua.live>
This commit is contained in:
parent
0a2bea4752
commit
6bc6797e04
@ -54,6 +54,8 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
The sequence of means for each channel, to be used when normalizing images.
|
The sequence of means for each channel, to be used when normalizing images.
|
||||||
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
|
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
The sequence of standard deviations for each channel, to be used when normalizing images.
|
||||||
|
convert_rgb (`bool`, defaults to `True`):
|
||||||
|
Whether or not to convert `PIL.Image.Image` into `RGB` format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
model_input_names = ["pixel_values"]
|
||||||
@ -68,6 +70,7 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=None,
|
image_mean=None,
|
||||||
image_std=None,
|
image_std=None,
|
||||||
|
do_convert_rgb=True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -79,6 +82,7 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
self.do_normalize = do_normalize
|
self.do_normalize = do_normalize
|
||||||
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
|
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
|
||||||
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
|
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
|
||||||
|
self.do_convert_rgb = do_convert_rgb
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -141,7 +145,9 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
if not is_batched:
|
if not is_batched:
|
||||||
images = [images]
|
images = [images]
|
||||||
|
|
||||||
# transformations (resizing + center cropping + normalization)
|
# transformations (convert rgb + resizing + center cropping + normalization)
|
||||||
|
if self.do_convert_rgb:
|
||||||
|
images = [self.convert_rgb(image) for image in images]
|
||||||
if self.do_resize and self.size is not None and self.resample is not None:
|
if self.do_resize and self.size is not None and self.resample is not None:
|
||||||
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
|
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
|
||||||
if self.do_center_crop and self.crop_size is not None:
|
if self.do_center_crop and self.crop_size is not None:
|
||||||
@ -155,6 +161,20 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
|
|
||||||
return encoded_inputs
|
return encoded_inputs
|
||||||
|
|
||||||
|
def convert_rgb(self, image):
|
||||||
|
"""
|
||||||
|
Converts `image` to RGB format. Note that this will trigger a conversion of `image` to a PIL Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
||||||
|
The image to convert.
|
||||||
|
"""
|
||||||
|
self._ensure_format_supported(image)
|
||||||
|
if not isinstance(image, Image.Image):
|
||||||
|
return image
|
||||||
|
|
||||||
|
return image.convert("RGB")
|
||||||
|
|
||||||
def center_crop(self, image, size):
|
def center_crop(self, image, size):
|
||||||
"""
|
"""
|
||||||
Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
|
Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
|
||||||
|
@ -49,6 +49,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
|
|||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||||
|
do_convert_rgb=True,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@ -63,6 +64,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
|
|||||||
self.do_normalize = do_normalize
|
self.do_normalize = do_normalize
|
||||||
self.image_mean = image_mean
|
self.image_mean = image_mean
|
||||||
self.image_std = image_std
|
self.image_std = image_std
|
||||||
|
self.do_convert_rgb = do_convert_rgb
|
||||||
|
|
||||||
def prepare_feat_extract_dict(self):
|
def prepare_feat_extract_dict(self):
|
||||||
return {
|
return {
|
||||||
@ -73,6 +75,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
|
|||||||
"do_normalize": self.do_normalize,
|
"do_normalize": self.do_normalize,
|
||||||
"image_mean": self.image_mean,
|
"image_mean": self.image_mean,
|
||||||
"image_std": self.image_std,
|
"image_std": self.image_std,
|
||||||
|
"do_convert_rgb": self.do_convert_rgb,
|
||||||
}
|
}
|
||||||
|
|
||||||
def prepare_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
def prepare_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||||
@ -128,6 +131,7 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
|||||||
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
|
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
|
||||||
self.assertTrue(hasattr(feature_extractor, "image_mean"))
|
self.assertTrue(hasattr(feature_extractor, "image_mean"))
|
||||||
self.assertTrue(hasattr(feature_extractor, "image_std"))
|
self.assertTrue(hasattr(feature_extractor, "image_std"))
|
||||||
|
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
|
||||||
|
|
||||||
def test_batch_feature(self):
|
def test_batch_feature(self):
|
||||||
pass
|
pass
|
||||||
@ -227,3 +231,64 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
|||||||
self.feature_extract_tester.crop_size,
|
self.feature_extract_tester.crop_size,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, unittest.TestCase):
|
||||||
|
|
||||||
|
feature_extraction_class = CLIPFeatureExtractor if is_vision_available() else None
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.feature_extract_tester = CLIPFeatureExtractionTester(self, num_channels=4)
|
||||||
|
self.expected_encoded_image_num_channels = 3
|
||||||
|
|
||||||
|
@property
|
||||||
|
def feat_extract_dict(self):
|
||||||
|
return self.feature_extract_tester.prepare_feat_extract_dict()
|
||||||
|
|
||||||
|
def test_feat_extract_properties(self):
|
||||||
|
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||||
|
self.assertTrue(hasattr(feature_extractor, "do_resize"))
|
||||||
|
self.assertTrue(hasattr(feature_extractor, "size"))
|
||||||
|
self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
|
||||||
|
self.assertTrue(hasattr(feature_extractor, "center_crop"))
|
||||||
|
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
|
||||||
|
self.assertTrue(hasattr(feature_extractor, "image_mean"))
|
||||||
|
self.assertTrue(hasattr(feature_extractor, "image_std"))
|
||||||
|
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
|
||||||
|
|
||||||
|
def test_batch_feature(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_call_pil_four_channels(self):
|
||||||
|
# Initialize feature_extractor
|
||||||
|
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||||
|
# create random PIL images
|
||||||
|
image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False)
|
||||||
|
for image in image_inputs:
|
||||||
|
self.assertIsInstance(image, Image.Image)
|
||||||
|
|
||||||
|
# Test not batched input
|
||||||
|
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||||
|
self.assertEqual(
|
||||||
|
encoded_images.shape,
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
self.expected_encoded_image_num_channels,
|
||||||
|
self.feature_extract_tester.crop_size,
|
||||||
|
self.feature_extract_tester.crop_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test batched
|
||||||
|
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||||
|
self.assertEqual(
|
||||||
|
encoded_images.shape,
|
||||||
|
(
|
||||||
|
self.feature_extract_tester.batch_size,
|
||||||
|
self.expected_encoded_image_num_channels,
|
||||||
|
self.feature_extract_tester.crop_size,
|
||||||
|
self.feature_extract_tester.crop_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user