mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix marian slow test (#6854)
This commit is contained in:
parent
bbdba0a76d
commit
8af1970e45
@ -38,6 +38,7 @@ if is_torch_available():
|
||||
convert_hf_name_to_opus_name,
|
||||
convert_opus_name_to_hf_name,
|
||||
)
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
from transformers.pipelines import TranslationPipeline
|
||||
|
||||
|
||||
@ -116,18 +117,21 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
||||
expected_ids = [38, 121, 14, 697, 38848, 0]
|
||||
|
||||
model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt).to(torch_device)
|
||||
|
||||
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
|
||||
|
||||
desired_keys = {
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
"labels",
|
||||
}
|
||||
self.assertSetEqual(desired_keys, set(model_inputs.keys()))
|
||||
model_inputs["decoder_input_ids"] = shift_tokens_right(model_inputs.labels, self.tokenizer.pad_token_id)
|
||||
model_inputs["return_dict"] = True
|
||||
model_inputs["use_cache"] = False
|
||||
with torch.no_grad():
|
||||
logits, *enc_features = self.model(**model_inputs)
|
||||
max_indices = logits.argmax(-1)
|
||||
outputs = self.model(**model_inputs)
|
||||
max_indices = outputs.logits.argmax(-1)
|
||||
self.tokenizer.batch_decode(max_indices)
|
||||
|
||||
def test_unk_support(self):
|
||||
|
Loading…
Reference in New Issue
Block a user