mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
make test cases done - 16 left
This commit is contained in:
parent
dc85010daa
commit
86d103da23
@ -46,33 +46,12 @@ else:
|
|||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
def prepare_florence2_inputs_dict(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
pixel_values,
|
|
||||||
decoder_input_ids=None,
|
|
||||||
attention_mask=None,
|
|
||||||
decoder_attention_mask=None,
|
|
||||||
):
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = input_ids.ne(config.text_config.pad_token_id)
|
|
||||||
if decoder_attention_mask is None:
|
|
||||||
decoder_attention_mask = decoder_input_ids.ne(config.text_config.pad_token_id)
|
|
||||||
return {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"pixel_values": pixel_values,
|
|
||||||
"decoder_input_ids": decoder_input_ids,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"decoder_attention_mask": attention_mask,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class Florence2VisionText2TextModelTester:
|
class Florence2VisionText2TextModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
seq_length=7,
|
seq_length=13,
|
||||||
|
encoder_seq_length=15,
|
||||||
text_config={
|
text_config={
|
||||||
"vocab_size": 51289,
|
"vocab_size": 51289,
|
||||||
"activation_dropout": 0.1,
|
"activation_dropout": 0.1,
|
||||||
@ -144,6 +123,7 @@ class Florence2VisionText2TextModelTester:
|
|||||||
self.num_channels = 3
|
self.num_channels = 3
|
||||||
self.image_size = 8
|
self.image_size = 8
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
|
self.encoder_seq_length = encoder_seq_length
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return Florence2Config(
|
return Florence2Config(
|
||||||
@ -168,8 +148,11 @@ class Florence2VisionText2TextModelTester:
|
|||||||
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
|
|
||||||
inputs_dict = prepare_florence2_inputs_dict(config, input_ids, pixel_values, decoder_input_ids)
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
@ -189,6 +172,12 @@ class Florence2VisionText2TextModelTester:
|
|||||||
)["logits"]
|
)["logits"]
|
||||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="This architecture (bart) has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
|
||||||
|
)
|
||||||
|
def test_load_save_without_tied_weights(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Florence2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
class Florence2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user