Convert image to rgb for clip model (#17101)

Co-authored-by: kuanwee.heng <kuanwee.heng@aaqua.live>
This commit is contained in:
Heng Kuan Wee 2022-05-11 20:09:54 +08:00 committed by GitHub
parent 0a2bea4752
commit 6bc6797e04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 86 additions and 1 deletions

View File

@ -54,6 +54,8 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
convert_rgb (`bool`, defaults to `True`):
Whether or not to convert `PIL.Image.Image` into `RGB` format
"""
model_input_names = ["pixel_values"]
@ -68,6 +70,7 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
do_normalize=True,
image_mean=None,
image_std=None,
do_convert_rgb=True,
**kwargs
):
super().__init__(**kwargs)
@ -79,6 +82,7 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
self.do_convert_rgb = do_convert_rgb
def __call__(
self,
@ -141,7 +145,9 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
if not is_batched:
images = [images]
# transformations (resizing + center cropping + normalization)
# transformations (convert rgb + resizing + center cropping + normalization)
if self.do_convert_rgb:
images = [self.convert_rgb(image) for image in images]
if self.do_resize and self.size is not None and self.resample is not None:
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
if self.do_center_crop and self.crop_size is not None:
@ -155,6 +161,20 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
return encoded_inputs
def convert_rgb(self, image):
"""
Converts `image` to RGB format. Note that this will trigger a conversion of `image` to a PIL Image.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to convert.
"""
self._ensure_format_supported(image)
if not isinstance(image, Image.Image):
return image
return image.convert("RGB")
def center_crop(self, image, size):
"""
Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the

View File

@ -49,6 +49,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
do_normalize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
do_convert_rgb=True,
):
self.parent = parent
self.batch_size = batch_size
@ -63,6 +64,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
def prepare_feat_extract_dict(self):
return {
@ -73,6 +75,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_convert_rgb": self.do_convert_rgb,
}
def prepare_inputs(self, equal_resolution=False, numpify=False, torchify=False):
@ -128,6 +131,7 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
def test_batch_feature(self):
pass
@ -227,3 +231,64 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.feature_extract_tester.crop_size,
),
)
@require_torch
@require_vision
class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, unittest.TestCase):
feature_extraction_class = CLIPFeatureExtractor if is_vision_available() else None
def setUp(self):
self.feature_extract_tester = CLIPFeatureExtractionTester(self, num_channels=4)
self.expected_encoded_image_num_channels = 3
@property
def feat_extract_dict(self):
return self.feature_extract_tester.prepare_feat_extract_dict()
def test_feat_extract_properties(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
self.assertTrue(hasattr(feature_extractor, "center_crop"))
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
def test_batch_feature(self):
pass
def test_call_pil_four_channels(self):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# create random PIL images
image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
1,
self.expected_encoded_image_num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
),
)
# Test batched
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
self.feature_extract_tester.batch_size,
self.expected_encoded_image_num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
),
)