make test cases done - 16 left

This commit is contained in:
Duc-Viet Hoang 2025-05-20 19:16:53 +07:00
parent dc85010daa
commit 86d103da23

View File

@ -46,33 +46,12 @@ else:
if is_vision_available():
from PIL import Image
def prepare_florence2_inputs_dict(
config,
input_ids,
pixel_values,
decoder_input_ids=None,
attention_mask=None,
decoder_attention_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.text_config.pad_token_id)
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.text_config.pad_token_id)
return {
"input_ids": input_ids,
"pixel_values": pixel_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": attention_mask,
}
class Florence2VisionText2TextModelTester:
def __init__(
self,
parent,
seq_length=7,
seq_length=13,
encoder_seq_length=15,
text_config={
"vocab_size": 51289,
"activation_dropout": 0.1,
@ -144,6 +123,7 @@ class Florence2VisionText2TextModelTester:
self.num_channels = 3
self.image_size = 8
self.seq_length = seq_length
self.encoder_seq_length = encoder_seq_length
def get_config(self):
return Florence2Config(
@ -168,8 +148,11 @@ class Florence2VisionText2TextModelTester:
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = self.get_config()
inputs_dict = prepare_florence2_inputs_dict(config, input_ids, pixel_values, decoder_input_ids)
inputs_dict = {
"input_ids": input_ids,
"pixel_values": pixel_values,
"decoder_input_ids": decoder_input_ids,
}
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
@ -189,6 +172,12 @@ class Florence2VisionText2TextModelTester:
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())
@unittest.skip(
reason="This architecture (bart) has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
@require_torch
class Florence2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):