mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Bart] Refactor - fix issues, consistency with the library, naming (#8900)
* remove make on the fly linear embedding * start refactor * big first refactor * save intermediate * save intermediat * correct mask issue * save tests * refactor padding masks * make all tests pass * further refactor * make pegasus test pass * fix bool if * fix leftover tests * continue * bart renaming * delete torchscript test hack * fix imports in tests * correct shift * fix docs and repo cons * re-add fix for FSTM * typo in test * fix typo * fix another typo * continue * hot fix 2 for tf * small fixes * refactor types linting * continue * finish refactor * fix import in tests * better bart names * further refactor and add test * delete hack * apply sylvains and lysandres commens * small perf improv * further perf improv * improv perf * fix typo * make style * small perf improv
This commit is contained in:
parent
75627148ee
commit
06971ac4f9
@ -105,8 +105,6 @@ BartModel
|
||||
.. autoclass:: transformers.BartModel
|
||||
:members: forward
|
||||
|
||||
.. autofunction:: transformers.models.bart.modeling_bart._prepare_bart_decoder_inputs
|
||||
|
||||
|
||||
BartForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -378,6 +378,7 @@ if is_torch_available():
|
||||
BartForQuestionAnswering,
|
||||
BartForSequenceClassification,
|
||||
BartModel,
|
||||
BartPretrainedModel,
|
||||
PretrainedBartModel,
|
||||
)
|
||||
from .models.bert import (
|
||||
|
@ -31,6 +31,7 @@ if is_torch_available():
|
||||
BartForQuestionAnswering,
|
||||
BartForSequenceClassification,
|
||||
BartModel,
|
||||
BartPretrainedModel,
|
||||
PretrainedBartModel,
|
||||
)
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -577,9 +577,9 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
encoder_padding_mask = invert_mask(encoder_padding_mask)
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_ids, use_cache=use_cache)
|
||||
positions = self.embed_positions(input_ids, use_cache=(use_cache and decoder_cached_states is not None))
|
||||
|
||||
if use_cache:
|
||||
if use_cache and decoder_cached_states is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
positions = positions[:, -1:]
|
||||
|
||||
@ -964,7 +964,7 @@ class TFBartModel(TFPretrainedBartModel):
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if not inputs["use_cache"]:
|
||||
if not use_cache or past_key_values is None:
|
||||
inputs["decoder_input_ids"], decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs(
|
||||
inputs["input_ids"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
@ -1154,6 +1154,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
||||
assert (
|
||||
decoder_cached_states
|
||||
), f"decoder cached states must be truthy. got {decoder_cached_states} from the 2nd element of past"
|
||||
|
||||
assert isinstance(
|
||||
encoder_outputs, TFBaseModelOutput
|
||||
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
|
||||
|
@ -813,9 +813,6 @@ class T5Stack(T5PreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.embed_tokens = new_embeddings
|
||||
|
||||
|
@ -450,6 +450,15 @@ class BartModel:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class BartPretrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class PretrainedBartModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@ -49,37 +50,68 @@ if is_torch_available():
|
||||
pipeline,
|
||||
)
|
||||
from transformers.models.bart.modeling_bart import (
|
||||
SinusoidalPositionalEmbedding,
|
||||
_prepare_bart_decoder_inputs,
|
||||
invert_mask,
|
||||
BartDecoder,
|
||||
BartEncoder,
|
||||
BartSinusoidalPositionalEmbedding,
|
||||
shift_tokens_right,
|
||||
)
|
||||
|
||||
|
||||
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
|
||||
|
||||
def prepare_bart_inputs_dict(
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModelTester:
|
||||
class BartModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=4,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=20,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = 13
|
||||
self.seq_length = 7
|
||||
self.is_training = True
|
||||
self.use_labels = False
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 16
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 4
|
||||
self.hidden_act = "gelu"
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.attention_probs_dropout_prob = 0.1
|
||||
self.max_position_embeddings = 20
|
||||
self.eos_token_id = 2
|
||||
self.pad_token_id = 1
|
||||
self.bos_token_id = 0
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
torch.manual_seed(0)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
@ -111,21 +143,67 @@ class ModelTester:
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
|
||||
inputs_dict["decoder_attention_mask"] = inputs_dict["attention_mask"]
|
||||
inputs_dict["use_cache"] = False
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||
model = BartModel(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
def prepare_bart_inputs_dict(
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
# first forward pass
|
||||
outputs = model(input_ids, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
|
||||
|
||||
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
||||
model = BartModel(config=config).to(torch_device).eval()
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
encoder_last_hidden_state = outputs.encoder_last_hidden_state
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
encoder = model.get_encoder()
|
||||
encoder.save_pretrained(tmpdirname)
|
||||
encoder = BartEncoder.from_pretrained(tmpdirname).to(torch_device)
|
||||
|
||||
encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
|
||||
0
|
||||
]
|
||||
|
||||
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
decoder = model.get_decoder()
|
||||
decoder.save_pretrained(tmpdirname)
|
||||
decoder = BartDecoder.from_pretrained(tmpdirname).to(torch_device)
|
||||
|
||||
last_hidden_state_2 = decoder(
|
||||
input_ids=inputs_dict["decoder_input_ids"],
|
||||
attention_mask=inputs_dict["decoder_attention_mask"],
|
||||
encoder_hidden_states=encoder_last_hidden_state,
|
||||
encoder_attention_mask=inputs_dict["attention_mask"],
|
||||
)[0]
|
||||
|
||||
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -142,7 +220,7 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ModelTester(self)
|
||||
self.model_tester = BartModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BartConfig)
|
||||
|
||||
def test_config(self):
|
||||
@ -169,23 +247,25 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
config.use_cache = False
|
||||
inputs_dict["input_ids"][:, -2:] = config.pad_token_id
|
||||
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
|
||||
config, inputs_dict["input_ids"]
|
||||
)
|
||||
model = BartModel(config).to(torch_device).eval()
|
||||
|
||||
model = BartModel(config).to(torch_device).eval()
|
||||
decoder_features_with_created_mask = model(**inputs_dict)[0]
|
||||
|
||||
decoder_input_ids = shift_tokens_right(inputs_dict["input_ids"], config.pad_token_id)
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
decoder_attention_mask[:, 0] = decoder_attention_mask[:, 1]
|
||||
|
||||
decoder_features_with_passed_mask = model(
|
||||
decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict
|
||||
decoder_attention_mask=decoder_attention_mask, decoder_input_ids=decoder_input_ids, **inputs_dict
|
||||
)[0]
|
||||
assert_tensors_close(decoder_features_with_passed_mask, decoder_features_with_created_mask)
|
||||
useless_mask = torch.zeros_like(decoder_attn_mask)
|
||||
useless_mask = torch.zeros_like(decoder_attention_mask)
|
||||
decoder_features = model(decoder_attention_mask=useless_mask, **inputs_dict)[0]
|
||||
self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions
|
||||
self.assertEqual(
|
||||
decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model)
|
||||
)
|
||||
if decoder_attn_mask.min().item() < -1e3: # some tokens were masked
|
||||
if decoder_attention_mask.min().item() == 0: # some tokens were masked
|
||||
self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item())
|
||||
|
||||
# Test different encoder attention masks
|
||||
@ -204,13 +284,43 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||||
self.assertEqual(info["missing_keys"], [])
|
||||
|
||||
@unittest.skip("Passing inputs_embeds not implemented for Bart.")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
|
||||
def test_resize_embeddings_untied(self):
|
||||
pass
|
||||
def test_encoder_decoder_model_standalone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
|
||||
|
||||
# BartForSequenceClassification does not support inputs_embeds
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in (BartModel, BartForConditionalGeneration, BartForQuestionAnswering):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs["input_ids"]
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||
del inputs["input_ids"]
|
||||
inputs.pop("decoder_input_ids", None)
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
if not self.is_encoder_decoder:
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
else:
|
||||
inputs["inputs_embeds"] = wte(encoder_input_ids)
|
||||
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@ -386,20 +496,6 @@ class BartHeadTests(unittest.TestCase):
|
||||
model = BartForConditionalGeneration(config).eval().to(torch_device)
|
||||
model(**model.dummy_inputs)
|
||||
|
||||
def test_prepare_bart_decoder_inputs(self):
|
||||
config, *_ = self._get_config_and_data()
|
||||
input_ids = _long_tensor(([4, 4, 2]))
|
||||
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
|
||||
ignore = float("-inf")
|
||||
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
|
||||
config, input_ids, decoder_input_ids
|
||||
)
|
||||
expected_causal_mask = torch.tensor(
|
||||
[[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad
|
||||
).to(input_ids.device)
|
||||
self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size())
|
||||
self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all())
|
||||
|
||||
def test_resize_tokens_embeddings_more(self):
|
||||
config, input_ids, _ = self._get_config_and_data()
|
||||
|
||||
@ -470,14 +566,14 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
@slow
|
||||
def test_bart_base_mask_filling(self):
|
||||
def test_base_mask_filling(self):
|
||||
pbase = pipeline(task="fill-mask", model="facebook/bart-base")
|
||||
src_text = [" I went to the <mask>."]
|
||||
results = [x["token_str"] for x in pbase(src_text)]
|
||||
assert "Ġbathroom" in results
|
||||
|
||||
@slow
|
||||
def test_bart_large_mask_filling(self):
|
||||
def test_large_mask_filling(self):
|
||||
plarge = pipeline(task="fill-mask", model="facebook/bart-large")
|
||||
src_text = [" I went to the <mask>."]
|
||||
results = [x["token_str"] for x in plarge(src_text)]
|
||||
@ -608,7 +704,7 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
class TestBartSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
desired_weights = [
|
||||
[0, 0, 0, 0, 0],
|
||||
[0.84147096, 0.82177866, 0.80180490, 0.78165019, 0.76140374],
|
||||
@ -616,38 +712,30 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
]
|
||||
|
||||
def test_positional_emb_cache_logic(self):
|
||||
pad = 1
|
||||
input_ids = torch.tensor([[4, 10]], dtype=torch.long, device=torch_device)
|
||||
emb1 = SinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6, padding_idx=pad).to(torch_device)
|
||||
no_cache = emb1(input_ids, use_cache=False)
|
||||
yes_cache = emb1(input_ids, use_cache=True)
|
||||
self.assertEqual((1, 1, 6), yes_cache.shape) # extra dim to allow broadcasting, feel free to delete!
|
||||
self.assertListEqual(no_cache[-1].tolist(), yes_cache[0][0].tolist())
|
||||
emb1 = BartSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6, padding_idx=1).to(torch_device)
|
||||
no_cache = emb1((4, 10), past_key_values_length=0)
|
||||
yes_cache = emb1((4, 10), past_key_values_length=2)
|
||||
|
||||
self.assertTrue(no_cache.shape == yes_cache.shape == (10, 6))
|
||||
self.assertListEqual(no_cache[2:].tolist(), yes_cache[:-2].tolist())
|
||||
|
||||
def test_odd_embed_dim(self):
|
||||
# odd embedding_dim is allowed
|
||||
SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device)
|
||||
BartSinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device)
|
||||
|
||||
# odd num_positions is allowed
|
||||
SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=0).to(torch_device)
|
||||
BartSinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=0).to(torch_device)
|
||||
|
||||
def test_positional_emb_weights_against_marian(self):
|
||||
pad = 1
|
||||
emb1 = SinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512, padding_idx=pad).to(torch_device)
|
||||
emb1 = BartSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512, padding_idx=pad).to(
|
||||
torch_device
|
||||
)
|
||||
weights = emb1.weight.data[:3, :5].tolist()
|
||||
for i, (expected_weight, actual_weight) in enumerate(zip(self.desired_weights, weights)):
|
||||
for j in range(5):
|
||||
self.assertAlmostEqual(expected_weight[j], actual_weight[j], places=3)
|
||||
|
||||
# test that forward pass is just a lookup, there is no ignore padding logic
|
||||
input_ids = torch.tensor([[4, 10, pad, pad, pad]], dtype=torch.long, device=torch_device)
|
||||
no_cache_pad_zero = emb1(input_ids)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
torch.tensor(self.desired_weights, device=torch_device), no_cache_pad_zero[:3, :5], atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
def test_child_config_equivalence(self):
|
||||
"""Test that configs associated with children of BartForConditionalGeneration are identical."""
|
||||
child_classes = [BlenderbotConfig, MBartConfig, MarianConfig, PegasusConfig]
|
||||
|
@ -104,9 +104,6 @@ class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase):
|
||||
self.model_tester = BlenderbotModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_initialization_module(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = BlenderbotForConditionalGeneration(config).model
|
||||
|
@ -302,6 +302,8 @@ class ModelTesterMixin:
|
||||
# Question Answering model returns start_logits and end_logits
|
||||
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||
if "past_key_values" in outputs:
|
||||
correct_outlen += 1 # past_key_values have been returned
|
||||
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
@ -386,7 +388,7 @@ class ModelTesterMixin:
|
||||
|
||||
try:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # TODO: this should be deleted after bug #7474 is solved
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
input_ids = inputs["input_ids"]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
decoder_input_ids = inputs["decoder_input_ids"]
|
||||
@ -1020,7 +1022,6 @@ class ModelTesterMixin:
|
||||
)
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
|
@ -37,7 +37,7 @@ from transformers.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_bart import ModelTester as BartModelTester
|
||||
from .test_modeling_bart import BartModelTester
|
||||
from .test_modeling_dpr import DPRModelTester
|
||||
from .test_modeling_t5 import T5ModelTester
|
||||
|
||||
|
@ -344,13 +344,6 @@ class TFModelTesterMixin:
|
||||
tf_hidden_states[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
|
||||
# Debug info (remove when fixed)
|
||||
if max_diff >= 4e-2:
|
||||
print("===")
|
||||
print(model_class)
|
||||
print(config)
|
||||
print(inputs_dict)
|
||||
print(pt_inputs_dict)
|
||||
self.assertLessEqual(max_diff, 4e-2)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
|
@ -28,6 +28,8 @@ PATH_TO_DOC = "docs/source/model_doc"
|
||||
# Update this list for models that are not tested with a comment explaining the reason it should not be.
|
||||
# Being in this list is an exception and should **not** be the rule.
|
||||
IGNORE_NON_TESTED = [
|
||||
"BartDecoder", # Building part of bigger (tested) model.
|
||||
"BartEncoder", # Building part of bigger (tested) model.
|
||||
"BertLMHeadModel", # Needs to be setup as decoder.
|
||||
"DPREncoder", # Building part of bigger (tested) model.
|
||||
"DPRSpanPredictor", # Building part of bigger (tested) model.
|
||||
@ -58,9 +60,11 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
# Update this list for models that are not documented with a comment explaining the reason it should not be.
|
||||
# Being in this list is an exception and should **not** be the rule.
|
||||
IGNORE_NON_DOCUMENTED = [
|
||||
"BartDecoder", # Building part of bigger (documented) model.
|
||||
"BartEncoder", # Building part of bigger (documented) model.
|
||||
"DPREncoder", # Building part of bigger (documented) model.
|
||||
"DPRSpanPredictor", # Building part of bigger (documented) model.
|
||||
"T5Stack", # Building part of bigger (tested) model.
|
||||
"T5Stack", # Building part of bigger (documented) model.
|
||||
"TFDPREncoder", # Building part of bigger (documented) model.
|
||||
"TFDPRSpanPredictor", # Building part of bigger (documented) model.
|
||||
]
|
||||
@ -78,6 +82,8 @@ MODEL_NAME_TO_DOC_FILE = {
|
||||
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
|
||||
# should **not** be the rule.
|
||||
IGNORE_NON_AUTO_CONFIGURED = [
|
||||
"BartDecoder",
|
||||
"BartEncoder",
|
||||
"DPRContextEncoder",
|
||||
"DPREncoder",
|
||||
"DPRReader",
|
||||
|
Loading…
Reference in New Issue
Block a user