Decorate test_codegen_sample_max_time as flaky (#22953)

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-04-24 15:27:31 +02:00 committed by GitHub
parent edb6d950cb
commit 3f6a4b5bd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,7 +18,8 @@ import datetime
import unittest
from transformers import CodeGenConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.file_utils import cached_property
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -462,11 +463,19 @@ class CodeGenModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_torch
class CodeGenModelLanguageGenerationTest(unittest.TestCase):
@cached_property
def cached_tokenizer(self):
return AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
@cached_property
def cached_model(self):
return CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
@slow
def test_lm_generate_codegen(self):
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
tokenizer = self.cached_tokenizer
for checkpointing in [True, False]:
model = CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
model = self.cached_model
if checkpointing:
model.gradient_checkpointing_enable()
@ -484,8 +493,8 @@ class CodeGenModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_codegen_sample(self):
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
model = CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
tokenizer = self.cached_tokenizer
model = self.cached_model
model.to(torch_device)
torch.manual_seed(0)
@ -515,10 +524,11 @@ class CodeGenModelLanguageGenerationTest(unittest.TestCase):
all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
) # token_type_ids should change output
@is_flaky(max_attempts=3, description="measure of timing is somehow flaky.")
@slow
def test_codegen_sample_max_time(self):
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
model = CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
tokenizer = self.cached_tokenizer
model = self.cached_model
model.to(torch_device)
torch.manual_seed(0)