[tests] fix slow bart cnn test, faster marian tests (#7888)

This commit is contained in:
Sam Shleifer 2020-10-18 20:18:08 -04:00 committed by GitHub
parent ba8c4d0ac0
commit b86a71ea38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 12 deletions

View File

@ -594,7 +594,9 @@ class BartModelIntegrationTests(unittest.TestCase):
"Bronx on Friday. If convicted, she faces up to four years in prison.",
]
generated_summaries = [tok.batch_decode(hypotheses_batch.tolist())]
generated_summaries = tok.batch_decode(
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
)
assert generated_summaries == EXPECTED

View File

@ -16,7 +16,7 @@
import unittest
from transformers import is_torch_available
from transformers import AutoConfig, AutoTokenizer, MarianConfig, MarianTokenizer, is_torch_available
from transformers.file_utils import cached_property
from transformers.hf_api import HfApi
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
@ -25,14 +25,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
if is_torch_available():
import torch
from transformers import (
AutoConfig,
AutoModelWithLMHead,
AutoTokenizer,
MarianConfig,
MarianMTModel,
MarianTokenizer,
)
from transformers import AutoModelWithLMHead, MarianMTModel
from transformers.convert_marian_to_pytorch import (
ORG_NAME,
convert_hf_name_to_opus_name,
@ -79,10 +72,16 @@ class MarianIntegrationTest(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}"
cls.tokenizer: MarianTokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.eos_token_id = cls.tokenizer.eos_token_id
return cls
@cached_property
def tokenizer(self) -> MarianTokenizer:
return AutoTokenizer.from_pretrained(self.model_name)
@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_token_id
@cached_property
def model(self):
model: MarianMTModel = AutoModelWithLMHead.from_pretrained(self.model_name).to(torch_device)