From cb555af2c7737a750366a2829cc0145a731c55ee Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 21 Apr 2022 09:09:00 -0400 Subject: [PATCH] Return input_ids in ImageGPT feature extractor (#16872) --- .../models/imagegpt/feature_extraction_imagegpt.py | 7 +++---- tests/imagegpt/test_feature_extraction_imagegpt.py | 12 ++++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/imagegpt/feature_extraction_imagegpt.py b/src/transformers/models/imagegpt/feature_extraction_imagegpt.py index b49d5e521e4..f129f1d4c19 100644 --- a/src/transformers/models/imagegpt/feature_extraction_imagegpt.py +++ b/src/transformers/models/imagegpt/feature_extraction_imagegpt.py @@ -68,7 +68,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix Whether or not to normalize the input to the range between -1 and +1. """ - model_input_names = ["pixel_values"] + model_input_names = ["input_ids"] def __init__(self, clusters, do_resize=True, size=32, resample=Image.BILINEAR, do_normalize=True, **kwargs): super().__init__(**kwargs) @@ -128,8 +128,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: - - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height, - width). + - **input_ids** -- Input IDs to be fed to a model, of shape `(batch_size, height * width)`. """ # Input type checking for clearer error valid_images = False @@ -171,7 +170,7 @@ class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix images = images.reshape(batch_size, -1) # return as BatchFeature - data = {"pixel_values": images} + data = {"input_ids": images} encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) return encoded_inputs diff --git a/tests/imagegpt/test_feature_extraction_imagegpt.py b/tests/imagegpt/test_feature_extraction_imagegpt.py index dd0fdfa89a7..4d1ca087d80 100644 --- a/tests/imagegpt/test_feature_extraction_imagegpt.py +++ b/tests/imagegpt/test_feature_extraction_imagegpt.py @@ -161,17 +161,17 @@ class ImageGPTFeatureExtractorIntegrationTest(unittest.TestCase): # test non-batched encoding = feature_extractor(images[0], return_tensors="pt") - self.assertIsInstance(encoding.pixel_values, torch.LongTensor) - self.assertEqual(encoding.pixel_values.shape, (1, 1024)) + self.assertIsInstance(encoding.input_ids, torch.LongTensor) + self.assertEqual(encoding.input_ids.shape, (1, 1024)) expected_slice = [306, 191, 191] - self.assertEqual(encoding.pixel_values[0, :3].tolist(), expected_slice) + self.assertEqual(encoding.input_ids[0, :3].tolist(), expected_slice) # test batched encoding = feature_extractor(images, return_tensors="pt") - self.assertIsInstance(encoding.pixel_values, torch.LongTensor) - self.assertEqual(encoding.pixel_values.shape, (2, 1024)) + self.assertIsInstance(encoding.input_ids, torch.LongTensor) + self.assertEqual(encoding.input_ids.shape, (2, 1024)) expected_slice = [303, 13, 13] - self.assertEqual(encoding.pixel_values[1, -3:].tolist(), expected_slice) + self.assertEqual(encoding.input_ids[1, -3:].tolist(), expected_slice)