mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
make test cases done - 16 left
This commit is contained in:
parent
dc85010daa
commit
86d103da23
@ -46,33 +46,12 @@ else:
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def prepare_florence2_inputs_dict(
|
||||
config,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
decoder_input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.text_config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.text_config.pad_token_id)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
|
||||
class Florence2VisionText2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
seq_length=7,
|
||||
seq_length=13,
|
||||
encoder_seq_length=15,
|
||||
text_config={
|
||||
"vocab_size": 51289,
|
||||
"activation_dropout": 0.1,
|
||||
@ -144,6 +123,7 @@ class Florence2VisionText2TextModelTester:
|
||||
self.num_channels = 3
|
||||
self.image_size = 8
|
||||
self.seq_length = seq_length
|
||||
self.encoder_seq_length = encoder_seq_length
|
||||
|
||||
def get_config(self):
|
||||
return Florence2Config(
|
||||
@ -168,8 +148,11 @@ class Florence2VisionText2TextModelTester:
|
||||
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
config = self.get_config()
|
||||
|
||||
inputs_dict = prepare_florence2_inputs_dict(config, input_ids, pixel_values, decoder_input_ids)
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
@ -189,6 +172,12 @@ class Florence2VisionText2TextModelTester:
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecture (bart) has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
|
||||
)
|
||||
def test_load_save_without_tied_weights(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class Florence2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user