mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix VisionEncoderDecoder
Positional Arg (#29497)
* 🐛 Fix vision encoder decoder positional arg * ✅ Add test for VisionEncoderDecoder with LayoutLMv3 encoder --------- Co-authored-by: Nick DeGroot <1966472+nickthegroot@users.noreply.github.com>
This commit is contained in:
parent
ddf177ee4a
commit
b338a6c3b8
@ -573,7 +573,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
pixel_values,
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
|
@ -38,6 +38,7 @@ from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_
|
||||
from ..bart.test_modeling_bart import BartModelTester
|
||||
from ..bert.test_modeling_bert import BertModelTester
|
||||
from ..deit.test_modeling_deit import DeiTModelTester
|
||||
from ..layoutlmv3.test_modeling_layoutlmv3 import LayoutLMv3ModelTester
|
||||
from ..swin.test_modeling_swin import SwinModelTester
|
||||
from ..trocr.test_modeling_trocr import TrOCRStandaloneDecoderModelTester
|
||||
from ..vit.test_modeling_vit import ViTModelTester
|
||||
@ -52,6 +53,7 @@ if is_torch_available():
|
||||
BartForCausalLM,
|
||||
BertLMHeadModel,
|
||||
DeiTModel,
|
||||
LayoutLMv3Model,
|
||||
SwinModel,
|
||||
TrOCRForCausalLM,
|
||||
VisionEncoderDecoderConfig,
|
||||
@ -680,6 +682,128 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class LayoutLMv32TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = LayoutLMv3Model(config).eval()
|
||||
decoder_model = TrOCRForCausalLM(decoder_config).eval()
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = LayoutLMv3ModelTester(self, batch_size=13, image_size=4, patch_size=2)
|
||||
model_tester_decoder = TrOCRStandaloneDecoderModelTester(
|
||||
self, batch_size=13, d_model=32, max_position_embeddings=512
|
||||
)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
bbox,
|
||||
pixel_values,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
) = encoder_config_and_inputs
|
||||
(decoder_config, decoder_input_ids, decoder_attention_mask, _) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
# disable cache for now
|
||||
decoder_config.use_cache = False
|
||||
return {
|
||||
"config": config,
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": input_ids,
|
||||
"bbox": bbox,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"labels": decoder_input_ids,
|
||||
}
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||
|
||||
# LayoutLMv3's sequence length equals the number of text tokens + number of patches + 1 (we add 1 for the CLS token)
|
||||
text_seq_length = input_ids.shape[-1]
|
||||
image_seq_length = (encoder_model.config.input_size // encoder_model.config.patch_size) ** 2 + 1
|
||||
seq_len = text_seq_length + image_seq_length
|
||||
|
||||
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
|
||||
num_decoder_layers = (
|
||||
decoder_config.num_decoder_layers
|
||||
if hasattr(decoder_config, "num_decoder_layers")
|
||||
else decoder_config.num_hidden_layers
|
||||
)
|
||||
self.assertEqual(len(decoder_attentions), num_decoder_layers)
|
||||
|
||||
self.assertEqual(
|
||||
decoder_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
cross_attentions = outputs_encoder_decoder["cross_attentions"]
|
||||
self.assertEqual(len(cross_attentions), num_decoder_layers)
|
||||
|
||||
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
|
||||
self.assertEqual(
|
||||
cross_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len),
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
|
||||
# Generate until max length
|
||||
if hasattr(enc_dec_model.config, "eos_token_id"):
|
||||
enc_dec_model.config.eos_token_id = None
|
||||
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
||||
enc_dec_model.config.decoder.eos_token_id = None
|
||||
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
|
||||
enc_dec_model.generation_config.eos_token_id = None
|
||||
enc_dec_model.to(torch_device)
|
||||
|
||||
generated_output = enc_dec_model.generate(
|
||||
pixel_values=pixel_values,
|
||||
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
@unittest.skip("There are no published pretrained TrOCR checkpoints for now")
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
class TrOCRModelIntegrationTest(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user