mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fix some tests
This commit is contained in:
parent
e4674c8228
commit
80f13786c7
@ -321,7 +321,7 @@ class Florence2Config(PretrainedConfig):
|
||||
if text_config is not None:
|
||||
self.text_config = Florence2LanguageConfig(**text_config)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(is_encoder_decoder=True, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["Florence2Config"]
|
||||
|
@ -2826,7 +2826,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
image_features, inputs_embeds
|
||||
)
|
||||
|
||||
if inputs_embeds is not None:
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
|
@ -1260,7 +1260,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
|
||||
image_features, inputs_embeds
|
||||
)
|
||||
|
||||
if inputs_embeds is not None:
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
|
@ -33,6 +33,7 @@ from transformers.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
@ -46,6 +47,27 @@ if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def prepare_florence2_inputs_dict(
|
||||
config,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
decoder_input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.text_config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.text_config.pad_token_id)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
|
||||
class Florence2VisionText2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
@ -110,7 +132,9 @@ class Florence2VisionText2TextModelTester:
|
||||
self.parent = parent
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
self.bos_token_id = text_config["bos_token_id"]
|
||||
self.pad_token_id = text_config["pad_token_id"]
|
||||
self.eos_token_id = text_config["eos_token_id"]
|
||||
|
||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||
self.vocab_size = text_config["vocab_size"]
|
||||
@ -136,23 +160,19 @@ class Florence2VisionText2TextModelTester:
|
||||
self.image_size,
|
||||
]
|
||||
)
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
||||
3,
|
||||
)
|
||||
input_ids[:, -1] = self.eos_token_id # Eos Token
|
||||
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
inputs_dict = prepare_florence2_inputs_dict(config, input_ids, pixel_values, decoder_input_ids)
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
||||
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
inputs_dict = {
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_florence2_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
|
||||
@ -163,18 +183,19 @@ class Florence2VisionText2TextModelTester:
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
pixel_values=pixel_values.to(torch.bfloat16),
|
||||
pixel_values=pixel_values.to(torch.float16),
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
|
||||
@require_torch
|
||||
class Florence2ForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class Florence2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `Florence2ForConditionalGeneration`.
|
||||
"""
|
||||
|
||||
additional_model_inputs = ["pixel_values"]
|
||||
all_model_classes = (Florence2ForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Florence2ForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-to-text": Florence2ForConditionalGeneration} if is_torch_available() else {}
|
||||
|
Loading…
Reference in New Issue
Block a user