[Tests] Fix slow opt tests (#17282)

* fix opt tests

* remove unused tok

* make style

* make flake8 happy

* Update tests/models/opt/test_modeling_opt.py
This commit is contained in:
Patrick von Platen 2022-05-16 23:24:20 +02:00 committed by GitHub
parent f6a6388972
commit e705e1267c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,7 +23,6 @@ import timeout_decorator # noqa
from transformers import OPTConfig, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.utils import cached_property
from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -270,10 +269,6 @@ def _long_tensor(tok_lst):
@require_sentencepiece
@require_tokenizers
class OPTModelIntegrationTests(unittest.TestCase):
@cached_property
def default_tokenizer(self):
return GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer")
@slow
def test_inference_no_head(self):
model = OPTModel.from_pretrained("facebook/opt-350m").to(torch_device)
@ -284,7 +279,7 @@ class OPTModelIntegrationTests(unittest.TestCase):
expected_shape = torch.Size((1, 11, 512))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], device=torch_device
[[-0.2873, -1.9218, -0.3033], [-1.2710, -0.1338, -0.1902], [0.4095, 0.1214, -1.3121]], device=torch_device
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3))
@ -307,7 +302,6 @@ class OPTEmbeddingsTest(unittest.TestCase):
model = OPTForCausalLM.from_pretrained(self.path_model)
model = model.eval()
tokenizer = GPT2Tokenizer.from_pretrained(self.path_model)
tokenizer.add_special_tokens({"pad_token": "<pad>"})
prompts = [
"Today is a beautiful day and I want to",
@ -315,8 +309,9 @@ class OPTEmbeddingsTest(unittest.TestCase):
"Paris is the capital of France and",
"Computers and mobile phones have taken",
]
input_ids = tokenizer(prompts, return_tensors="pt", padding=True).input_ids
logits = model(input_ids)[0].mean(dim=-1)
# verify that prompt without BOS token is identical to Metaseq -> add_special_tokens=False
inputs = tokenizer(prompts, return_tensors="pt", padding=True, add_special_tokens=False)
logits = model(inputs.input_ids, attention_mask=inputs.attention_mask)[0].mean(dim=-1)
# logits_meta = torch.load(self.path_logits_meta)
logits_meta = torch.Tensor(
[
@ -326,7 +321,6 @@ class OPTEmbeddingsTest(unittest.TestCase):
[6.4783, -1.9913, -10.7926, -2.3336, 1.5092, -0.9974, -6.8213, 1.3477, 1.3477],
]
)
assert torch.allclose(logits, logits_meta, atol=1e-4)