mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[gpu slow tests] fix mbart-large-enro gpu tests (#4472)
This commit is contained in:
parent
48c3a70b4e
commit
956c4c4eb4
@ -231,7 +231,7 @@ class BartTranslationTests(unittest.TestCase):
|
||||
"""Only load the model if needed."""
|
||||
if self._model is None:
|
||||
model = BartForConditionalGeneration.from_pretrained("mbart-large-en-ro")
|
||||
self._model = model
|
||||
self._model = model.to(torch_device)
|
||||
return self._model
|
||||
|
||||
@slow
|
||||
@ -257,10 +257,7 @@ class BartTranslationTests(unittest.TestCase):
|
||||
)
|
||||
}
|
||||
translated_tokens = model.generate(input_ids=inputs["input_ids"].to(torch_device), num_beams=5,)
|
||||
decoded = [
|
||||
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
for g in translated_tokens
|
||||
]
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
self.assertEqual(expected_translation_romanian, decoded[0])
|
||||
|
||||
def test_mbart_enro_config(self):
|
||||
@ -576,11 +573,13 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
|
||||
dct = tok.batch_encode_plus([PGE_ARTICLE], max_length=1024, pad_to_max_length=True, return_tensors="pt",)
|
||||
dct = tok.batch_encode_plus([PGE_ARTICLE], max_length=1024, pad_to_max_length=True, return_tensors="pt",).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
hypotheses_batch = model.generate(
|
||||
input_ids=dct["input_ids"].to(torch_device),
|
||||
attention_mask=dct["attention_mask"].to(torch_device),
|
||||
input_ids=dct["input_ids"],
|
||||
attention_mask=dct["attention_mask"],
|
||||
num_beams=2,
|
||||
max_length=62,
|
||||
min_length=11,
|
||||
@ -590,9 +589,7 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
decoder_start_token_id=model.config.eos_token_id,
|
||||
)
|
||||
|
||||
decoded = [
|
||||
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
|
||||
]
|
||||
decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True,)
|
||||
self.assertEqual(EXPECTED_SUMMARY, decoded[0])
|
||||
|
||||
def test_xsum_config_generation_params(self):
|
||||
|
Loading…
Reference in New Issue
Block a user