mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Decorate test_codegen_sample_max_time
as flaky (#22953)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
edb6d950cb
commit
3f6a4b5bd7
@ -18,7 +18,8 @@ import datetime
|
||||
import unittest
|
||||
|
||||
from transformers import CodeGenConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -462,11 +463,19 @@ class CodeGenModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
@require_torch
|
||||
class CodeGenModelLanguageGenerationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def cached_tokenizer(self):
|
||||
return AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
|
||||
|
||||
@cached_property
|
||||
def cached_model(self):
|
||||
return CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
|
||||
|
||||
@slow
|
||||
def test_lm_generate_codegen(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
|
||||
tokenizer = self.cached_tokenizer
|
||||
for checkpointing in [True, False]:
|
||||
model = CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
|
||||
model = self.cached_model
|
||||
|
||||
if checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
@ -484,8 +493,8 @@ class CodeGenModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_codegen_sample(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
|
||||
model = CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
|
||||
tokenizer = self.cached_tokenizer
|
||||
model = self.cached_model
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
@ -515,10 +524,11 @@ class CodeGenModelLanguageGenerationTest(unittest.TestCase):
|
||||
all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
|
||||
) # token_type_ids should change output
|
||||
|
||||
@is_flaky(max_attempts=3, description="measure of timing is somehow flaky.")
|
||||
@slow
|
||||
def test_codegen_sample_max_time(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
|
||||
model = CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
|
||||
tokenizer = self.cached_tokenizer
|
||||
model = self.cached_model
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
Loading…
Reference in New Issue
Block a user