VLMs: enable generation tests - last batch (#34484)

* add tests for 3 more vlms

* fix fuyu back

* skip test
This commit is contained in:
Raushan Turganbay 2024-11-21 11:00:22 +01:00 committed by GitHub
parent 40821a2478
commit 28fb02fc05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 129 additions and 9 deletions

View File

@ -346,7 +346,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
): ):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
if past_key_values: if past_key_values is not None:
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
@ -355,7 +355,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
@ -377,3 +377,12 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
} }
) )
return model_inputs return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past

View File

@ -91,6 +91,10 @@ class Pix2StructTextConfig(PretrainedConfig):
"hidden_size": "hidden_size", "hidden_size": "hidden_size",
"num_attention_heads": "num_heads", "num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers", "num_hidden_layers": "num_layers",
"decoder_attention_heads": "num_heads",
"encoder_attention_heads": "num_heads",
"encoder_layers": "num_layers",
"decoder_layers": "num_layers",
} }
def __init__( def __init__(
@ -354,6 +358,8 @@ class Pix2StructConfig(PretrainedConfig):
vision_config = {} vision_config = {}
logger.info("vision_config is None. Initializing the Pix2StructVisionConfig with default values.") logger.info("vision_config is None. Initializing the Pix2StructVisionConfig with default values.")
text_config["is_encoder_decoder"] = is_encoder_decoder
text_config["tie_word_embeddings"] = tie_word_embeddings
self.text_config = Pix2StructTextConfig(**text_config) self.text_config = Pix2StructTextConfig(**text_config)
self.vision_config = Pix2StructVisionConfig(**vision_config) self.vision_config = Pix2StructVisionConfig(**vision_config)

View File

@ -1382,19 +1382,22 @@ class GenerationTesterMixin:
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate() config, inputs_dict = self.prepare_config_and_inputs_for_generate()
text_config = config.get_text_config()
# We want to test only encoder-decoder models # We want to test only encoder-decoder models
if not config.is_encoder_decoder: if not text_config.is_encoder_decoder:
continue continue
model = model_class(config).to(torch_device) model = model_class(config).to(torch_device)
head_masking = { head_masking = {
"head_mask": torch.zeros(config.encoder_layers, config.encoder_attention_heads, device=torch_device), "head_mask": torch.zeros(
text_config.encoder_layers, text_config.encoder_attention_heads, device=torch_device
),
"decoder_head_mask": torch.zeros( "decoder_head_mask": torch.zeros(
config.decoder_layers, config.decoder_attention_heads, device=torch_device text_config.decoder_layers, text_config.decoder_attention_heads, device=torch_device
), ),
"cross_attn_head_mask": torch.zeros( "cross_attn_head_mask": torch.zeros(
config.decoder_layers, config.decoder_attention_heads, device=torch_device text_config.decoder_layers, text_config.decoder_attention_heads, device=torch_device
), ),
} }

View File

@ -17,12 +17,15 @@
import io import io
import unittest import unittest
import pytest
import requests import requests
from parameterized import parameterized
from transformers import FuyuConfig, is_torch_available, is_vision_available from transformers import FuyuConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from transformers.utils import cached_property from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin from ...test_pipeline_mixin import PipelineTesterMixin
@ -263,8 +266,9 @@ class FuyuModelTester:
@require_torch @require_torch
class FuyuModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (FuyuForCausalLM,) if is_torch_available() else () all_model_classes = (FuyuForCausalLM,) if is_torch_available() else ()
all_generative_model_classes = (FuyuForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{"text-generation": FuyuForCausalLM, "image-text-to-text": FuyuForCausalLM} if is_torch_available() else {} {"text-generation": FuyuForCausalLM, "image-text-to-text": FuyuForCausalLM} if is_torch_available() else {}
) )
@ -296,6 +300,16 @@ class FuyuModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
@pytest.mark.generate
@parameterized.expand([("random",), ("same",)])
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
def test_assisted_decoding_matches_greedy_search(self):
pass
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
def test_assisted_decoding_sample(self):
pass
# TODO: Fix me (once this model gets more usage) # TODO: Fix me (once this model gets more usage)
@unittest.skip(reason="Does not work on the tiny model.") @unittest.skip(reason="Does not work on the tiny model.")
def test_disk_offload_bin(self): def test_disk_offload_bin(self):

View File

@ -21,7 +21,9 @@ import tempfile
import unittest import unittest
import numpy as np import numpy as np
import pytest
import requests import requests
from parameterized import parameterized
from transformers import AutoModelForImageTextToText, AutoProcessor, Kosmos2Config from transformers import AutoModelForImageTextToText, AutoProcessor, Kosmos2Config
from transformers.models.kosmos2.configuration_kosmos2 import Kosmos2TextConfig, Kosmos2VisionConfig from transformers.models.kosmos2.configuration_kosmos2 import Kosmos2TextConfig, Kosmos2VisionConfig
@ -37,6 +39,7 @@ from transformers.utils import (
is_vision_available, is_vision_available,
) )
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ( from ...test_modeling_common import (
ModelTesterMixin, ModelTesterMixin,
@ -205,6 +208,7 @@ class Kosmos2ModelTester:
self.text_model_tester = Kosmos2TextModelTester(parent, **text_kwargs) self.text_model_tester = Kosmos2TextModelTester(parent, **text_kwargs)
self.vision_model_tester = Kosmos2VisionModelTester(parent, **vision_kwargs) self.vision_model_tester = Kosmos2VisionModelTester(parent, **vision_kwargs)
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
self.seq_length = self.text_model_tester.seq_length
self.latent_query_num = latent_query_num self.latent_query_num = latent_query_num
self.is_training = is_training self.is_training = is_training
@ -253,7 +257,7 @@ class Kosmos2ModelTester:
@require_torch @require_torch
class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Kosmos2Model, Kosmos2ForConditionalGeneration) if is_torch_available() else () all_model_classes = (Kosmos2Model, Kosmos2ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (Kosmos2ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (Kosmos2ForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
@ -451,6 +455,68 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape) # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
@pytest.mark.generate
@parameterized.expand([("greedy", 1), ("beam search", 2)])
@unittest.skip(
"KOSMOS-2 doesn't support inputs embeds. The test isn't skipped by checking input args because KOSMOS-2 has `generate()` overwritten"
)
def test_generate_from_inputs_embeds(self):
pass
@pytest.mark.generate
def test_left_padding_compatibility(self):
# Overwrite because Kosmos-2 need to padd pixel values and pad image-attn-mask
def _prepare_model_kwargs(input_ids, attention_mask, pad_size, signature):
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
model_kwargs["cache_position"] = cache_position
if "image_embeds_position_mask" in signature:
image_embeds_position_mask = torch.zeros_like(input_ids)
image_embeds_position_mask[:, (pad_size + 1) : pad_size + 1 + self.model_tester.latent_query_num] = 1
model_kwargs["image_embeds_position_mask"] = image_embeds_position_mask
return model_kwargs
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
input_ids = inputs_dict["input_ids"]
pixel_values = inputs_dict["pixel_values"]
attention_mask = inputs_dict.get("attention_mask")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()
# no cache as some models require special cache classes to be init outside forward
model.generation_config.use_cache = False
# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, pad_size=0, signature=signature)
next_logits_wo_padding = model(**model_kwargs, pixel_values=pixel_values).logits[:, -1, :]
# With left-padding (length 32)
# can hardcode pad_token to be 0 as we'll do attn masking anyway
pad_token_id = (
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
)
pad_size = (input_ids.shape[0], 32)
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
model_kwargs = _prepare_model_kwargs(
padded_input_ids, padded_attention_mask, pad_size=32, signature=signature
)
next_logits_with_padding = model(**model_kwargs, pixel_values=pixel_values).logits[:, -1, :]
# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-3))
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
model_name = "microsoft/kosmos-2-patch14-224" model_name = "microsoft/kosmos-2-patch14-224"

View File

@ -27,6 +27,7 @@ from transformers import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisio
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import is_torch_available, is_vision_available
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ( from ...test_modeling_common import (
ModelTesterMixin, ModelTesterMixin,
@ -388,6 +389,7 @@ class Pix2StructModelTester:
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
self.seq_length = self.text_model_tester.seq_length # need seq_length for common tests self.seq_length = self.text_model_tester.seq_length # need seq_length for common tests
self.is_training = is_training self.is_training = is_training
self.max_patches = self.vision_model_tester.max_patches
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs() text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
@ -417,7 +419,7 @@ class Pix2StructModelTester:
@require_torch @require_torch
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class Pix2StructModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else () all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {} all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {}
pipeline_model_mapping = ( pipeline_model_mapping = (
@ -751,6 +753,26 @@ class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
text_config = Pix2StructTextConfig.from_pretrained(tmp_dir_name) text_config = Pix2StructTextConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict()) self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
# overwrite because # pix2struct seq length depends on image inputs
seq_length = self.model_tester.max_patches
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[layer_attentions.shape for layer_attentions in attentions],
[encoder_expected_shape] * len(attentions),
)
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
# overwrite because # pix2struct seq length depends on image inputs
seq_length = self.model_tester.max_patches
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
[encoder_expected_shape] * len(hidden_states),
)
# We will verify our results on an image of a stop sign # We will verify our results on an image of a stop sign
def prepare_img(): def prepare_img():