Fix marian slow test (#6854)

This commit is contained in:
Sam Shleifer 2020-08-31 16:10:43 -04:00 committed by GitHub
parent bbdba0a76d
commit 8af1970e45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -38,6 +38,7 @@ if is_torch_available():
convert_hf_name_to_opus_name,
convert_opus_name_to_hf_name,
)
from transformers.modeling_bart import shift_tokens_right
from transformers.pipelines import TranslationPipeline
@ -116,18 +117,21 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
expected_ids = [38, 121, 14, 697, 38848, 0]
model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt).to(torch_device)
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
desired_keys = {
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"labels",
}
self.assertSetEqual(desired_keys, set(model_inputs.keys()))
model_inputs["decoder_input_ids"] = shift_tokens_right(model_inputs.labels, self.tokenizer.pad_token_id)
model_inputs["return_dict"] = True
model_inputs["use_cache"] = False
with torch.no_grad():
logits, *enc_features = self.model(**model_inputs)
max_indices = logits.argmax(-1)
outputs = self.model(**model_inputs)
max_indices = outputs.logits.argmax(-1)
self.tokenizer.batch_decode(max_indices)
def test_unk_support(self):