mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[tests] fix slow bart cnn test, faster marian tests (#7888)
This commit is contained in:
parent
ba8c4d0ac0
commit
b86a71ea38
@ -594,7 +594,9 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
"Bronx on Friday. If convicted, she faces up to four years in prison.",
|
||||
]
|
||||
|
||||
generated_summaries = [tok.batch_decode(hypotheses_batch.tolist())]
|
||||
generated_summaries = tok.batch_decode(
|
||||
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
|
||||
)
|
||||
assert generated_summaries == EXPECTED
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers import AutoConfig, AutoTokenizer, MarianConfig, MarianTokenizer, is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.hf_api import HfApi
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
@ -25,14 +25,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelWithLMHead,
|
||||
AutoTokenizer,
|
||||
MarianConfig,
|
||||
MarianMTModel,
|
||||
MarianTokenizer,
|
||||
)
|
||||
from transformers import AutoModelWithLMHead, MarianMTModel
|
||||
from transformers.convert_marian_to_pytorch import (
|
||||
ORG_NAME,
|
||||
convert_hf_name_to_opus_name,
|
||||
@ -79,10 +72,16 @@ class MarianIntegrationTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}"
|
||||
cls.tokenizer: MarianTokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
cls.eos_token_id = cls.tokenizer.eos_token_id
|
||||
return cls
|
||||
|
||||
@cached_property
|
||||
def tokenizer(self) -> MarianTokenizer:
|
||||
return AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
model: MarianMTModel = AutoModelWithLMHead.from_pretrained(self.model_name).to(torch_device)
|
||||
|
Loading…
Reference in New Issue
Block a user