mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Fix marian slow test (#6854)
This commit is contained in:
parent
bbdba0a76d
commit
8af1970e45
@ -38,6 +38,7 @@ if is_torch_available():
|
|||||||
convert_hf_name_to_opus_name,
|
convert_hf_name_to_opus_name,
|
||||||
convert_opus_name_to_hf_name,
|
convert_opus_name_to_hf_name,
|
||||||
)
|
)
|
||||||
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
from transformers.pipelines import TranslationPipeline
|
from transformers.pipelines import TranslationPipeline
|
||||||
|
|
||||||
|
|
||||||
@ -116,18 +117,21 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
|||||||
expected_ids = [38, 121, 14, 697, 38848, 0]
|
expected_ids = [38, 121, 14, 697, 38848, 0]
|
||||||
|
|
||||||
model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt).to(torch_device)
|
model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt).to(torch_device)
|
||||||
|
|
||||||
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
|
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
|
||||||
|
|
||||||
desired_keys = {
|
desired_keys = {
|
||||||
"input_ids",
|
"input_ids",
|
||||||
"attention_mask",
|
"attention_mask",
|
||||||
"decoder_input_ids",
|
"labels",
|
||||||
"decoder_attention_mask",
|
|
||||||
}
|
}
|
||||||
self.assertSetEqual(desired_keys, set(model_inputs.keys()))
|
self.assertSetEqual(desired_keys, set(model_inputs.keys()))
|
||||||
|
model_inputs["decoder_input_ids"] = shift_tokens_right(model_inputs.labels, self.tokenizer.pad_token_id)
|
||||||
|
model_inputs["return_dict"] = True
|
||||||
|
model_inputs["use_cache"] = False
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits, *enc_features = self.model(**model_inputs)
|
outputs = self.model(**model_inputs)
|
||||||
max_indices = logits.argmax(-1)
|
max_indices = outputs.logits.argmax(-1)
|
||||||
self.tokenizer.batch_decode(max_indices)
|
self.tokenizer.batch_decode(max_indices)
|
||||||
|
|
||||||
def test_unk_support(self):
|
def test_unk_support(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user