mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Convert image to rgb for clip model (#17101)
Co-authored-by: kuanwee.heng <kuanwee.heng@aaqua.live>
This commit is contained in:
parent
0a2bea4752
commit
6bc6797e04
@ -54,6 +54,8 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
The sequence of means for each channel, to be used when normalizing images.
|
||||
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
|
||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
||||
convert_rgb (`bool`, defaults to `True`):
|
||||
Whether or not to convert `PIL.Image.Image` into `RGB` format
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
@ -68,6 +70,7 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
do_normalize=True,
|
||||
image_mean=None,
|
||||
image_std=None,
|
||||
do_convert_rgb=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -79,6 +82,7 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
|
||||
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -141,7 +145,9 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
if not is_batched:
|
||||
images = [images]
|
||||
|
||||
# transformations (resizing + center cropping + normalization)
|
||||
# transformations (convert rgb + resizing + center cropping + normalization)
|
||||
if self.do_convert_rgb:
|
||||
images = [self.convert_rgb(image) for image in images]
|
||||
if self.do_resize and self.size is not None and self.resample is not None:
|
||||
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
|
||||
if self.do_center_crop and self.crop_size is not None:
|
||||
@ -155,6 +161,20 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def convert_rgb(self, image):
|
||||
"""
|
||||
Converts `image` to RGB format. Note that this will trigger a conversion of `image` to a PIL Image.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
||||
The image to convert.
|
||||
"""
|
||||
self._ensure_format_supported(image)
|
||||
if not isinstance(image, Image.Image):
|
||||
return image
|
||||
|
||||
return image.convert("RGB")
|
||||
|
||||
def center_crop(self, image, size):
|
||||
"""
|
||||
Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
|
||||
|
@ -49,6 +49,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
|
||||
do_normalize=True,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -63,6 +64,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_feat_extract_dict(self):
|
||||
return {
|
||||
@ -73,6 +75,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
def prepare_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
@ -128,6 +131,7 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
|
||||
self.assertTrue(hasattr(feature_extractor, "image_mean"))
|
||||
self.assertTrue(hasattr(feature_extractor, "image_std"))
|
||||
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
|
||||
|
||||
def test_batch_feature(self):
|
||||
pass
|
||||
@ -227,3 +231,64 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
|
||||
self.feature_extract_tester.crop_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, unittest.TestCase):
|
||||
|
||||
feature_extraction_class = CLIPFeatureExtractor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
self.feature_extract_tester = CLIPFeatureExtractionTester(self, num_channels=4)
|
||||
self.expected_encoded_image_num_channels = 3
|
||||
|
||||
@property
|
||||
def feat_extract_dict(self):
|
||||
return self.feature_extract_tester.prepare_feat_extract_dict()
|
||||
|
||||
def test_feat_extract_properties(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
self.assertTrue(hasattr(feature_extractor, "do_resize"))
|
||||
self.assertTrue(hasattr(feature_extractor, "size"))
|
||||
self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
|
||||
self.assertTrue(hasattr(feature_extractor, "center_crop"))
|
||||
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
|
||||
self.assertTrue(hasattr(feature_extractor, "image_mean"))
|
||||
self.assertTrue(hasattr(feature_extractor, "image_std"))
|
||||
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
|
||||
|
||||
def test_batch_feature(self):
|
||||
pass
|
||||
|
||||
def test_call_pil_four_channels(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random PIL images
|
||||
image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
1,
|
||||
self.expected_encoded_image_num_channels,
|
||||
self.feature_extract_tester.crop_size,
|
||||
self.feature_extract_tester.crop_size,
|
||||
),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.expected_encoded_image_num_channels,
|
||||
self.feature_extract_tester.crop_size,
|
||||
self.feature_extract_tester.crop_size,
|
||||
),
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user