# Copyright 2021 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import unittest from transformers import ImageGPTConfig from transformers.testing_utils import require_torch, require_vision, run_test_using_subprocess, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): import torch from transformers import ( ImageGPTForCausalImageModeling, ImageGPTForImageClassification, ImageGPTModel, ) if is_vision_available(): from PIL import Image from transformers import ImageGPTImageProcessor class ImageGPTModelTester: def __init__( self, parent, batch_size=14, seq_length=7, is_training=True, use_token_type_ids=True, use_input_mask=True, use_labels=True, use_mc_token_ids=True, vocab_size=99, hidden_size=32, num_hidden_layers=2, num_attention_heads=4, intermediate_size=37, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=16, type_sequence_label_size=2, initializer_range=0.02, num_labels=3, num_choices=4, scope=None, ): self.parent = parent self.batch_size = batch_size self.seq_length = seq_length self.is_training = is_training self.use_token_type_ids = use_token_type_ids self.use_input_mask = use_input_mask self.use_labels = use_labels self.use_mc_token_ids = use_mc_token_ids self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range self.num_labels = num_labels self.num_choices = num_choices self.scope = None def get_large_model_config(self): return ImageGPTConfig.from_pretrained("imagegpt") def prepare_config_and_inputs( self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False ): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size - 1) input_mask = None if self.use_input_mask: input_mask = random_attention_mask([self.batch_size, self.seq_length]) token_type_ids = None if self.use_token_type_ids: token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) mc_token_ids = None if self.use_mc_token_ids: mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length) sequence_labels = None token_labels = None choice_labels = None if self.use_labels: sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) config = self.get_config( gradient_checkpointing=gradient_checkpointing, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, reorder_and_upcast_attn=reorder_and_upcast_attn, ) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) return ( config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, token_labels, choice_labels, ) def get_config( self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False ): return ImageGPTConfig( vocab_size=self.vocab_size, n_embd=self.hidden_size, n_layer=self.num_hidden_layers, n_head=self.num_attention_heads, n_inner=self.intermediate_size, activation_function=self.hidden_act, resid_pdrop=self.hidden_dropout_prob, attn_pdrop=self.attention_probs_dropout_prob, n_positions=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range, use_cache=True, gradient_checkpointing=gradient_checkpointing, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, reorder_and_upcast_attn=reorder_and_upcast_attn, ) def get_pipeline_config(self): config = self.get_config() config.vocab_size = 513 config.max_position_embeddings = 1024 return config def create_and_check_imagegpt_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): model = ImageGPTModel(config=config) model.to(torch_device) model.eval() result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(len(result.past_key_values), config.n_layer) def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): model = ImageGPTForCausalImageModeling(config) model.to(torch_device) model.eval() labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size - 1) result = model(input_ids, token_type_ids=token_type_ids, labels=labels) self.parent.assertEqual(result.loss.shape, ()) # ImageGPTForCausalImageModeling doesn't have tied input- and output embeddings self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size - 1)) def create_and_check_imagegpt_for_image_classification( self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args ): config.num_labels = self.num_labels model = ImageGPTForImageClassification(config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, token_labels, choice_labels, ) = config_and_inputs inputs_dict = { "input_ids": input_ids, "token_type_ids": token_type_ids, "head_mask": head_mask, } return config, inputs_dict @require_torch class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( (ImageGPTForCausalImageModeling, ImageGPTForImageClassification, ImageGPTModel) if is_torch_available() else () ) pipeline_model_mapping = ( {"image-feature-extraction": ImageGPTModel, "image-classification": ImageGPTForImageClassification} if is_torch_available() else {} ) test_missing_keys = False test_torch_exportable = True # as ImageGPTForImageClassification isn't included in any auto mapping, we add labels here def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) if return_labels: if model_class.__name__ == "ImageGPTForImageClassification": inputs_dict["labels"] = torch.zeros( self.model_tester.batch_size, dtype=torch.long, device=torch_device ) return inputs_dict # we overwrite the _check_scores method of GenerationTesterMixin, as ImageGPTForCausalImageModeling doesn't have tied input- and output embeddings def _check_scores(self, batch_size, scores, generated_length, config): expected_shape = (batch_size, config.vocab_size - 1) self.assertIsInstance(scores, tuple) self.assertEqual(len(scores), generated_length) self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) @run_test_using_subprocess def test_beam_search_generate_dict_outputs_use_cache(self): super().test_beam_search_generate_dict_outputs_use_cache() def setUp(self): self.model_tester = ImageGPTModelTester(self) self.config_tester = ConfigTester(self, config_class=ImageGPTConfig, n_embd=37) def test_config(self): self.config_tester.run_common_tests() def test_imagegpt_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_imagegpt_model(*config_and_inputs) def test_imagegpt_causal_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_lm_head_model(*config_and_inputs) def test_imagegpt_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_imagegpt_for_image_classification(*config_and_inputs) @unittest.skip( reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) def test_training_gradient_checkpointing(self): pass @unittest.skip( reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) def test_training_gradient_checkpointing_use_reentrant(self): pass @unittest.skip( reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) def test_training_gradient_checkpointing_use_reentrant_false(self): pass @slow def test_model_from_pretrained(self): model_name = "openai/imagegpt-small" model = ImageGPTModel.from_pretrained(model_name) self.assertIsNotNone(model) def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config) signature = inspect.signature(model.forward) # signature.parameters is an OrderedDict => so arg_names order is deterministic arg_names = [*signature.parameters.keys()] expected_arg_names = ["input_ids"] self.assertListEqual(arg_names[:1], expected_arg_names) @unittest.skip(reason="The model doesn't support left padding") # and it's not used enough to be worth fixing :) def test_left_padding_compatibility(self): pass # We will verify our results on an image of cute cats def prepare_img(): image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") return image @require_torch @require_vision class ImageGPTModelIntegrationTest(unittest.TestCase): @cached_property def default_image_processor(self): return ImageGPTImageProcessor.from_pretrained("openai/imagegpt-small") if is_vision_available() else None @slow def test_inference_causal_lm_head(self): model = ImageGPTForCausalImageModeling.from_pretrained("openai/imagegpt-small").to(torch_device) image_processor = self.default_image_processor image = prepare_img() inputs = image_processor(images=image, return_tensors="pt").to(torch_device) # forward pass with torch.no_grad(): outputs = model(**inputs) # verify the logits expected_shape = torch.Size((1, 1024, 512)) self.assertEqual(outputs.logits.shape, expected_shape) expected_slice = torch.tensor( [[2.3445, 2.6889, 2.7313], [1.0530, 1.2416, 0.5699], [0.2205, 0.7749, 0.3953]] ).to(torch_device) torch.testing.assert_close(outputs.logits[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)