mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: Add GPTNeoX integration test (#22346)
This commit is contained in:
parent
b79607656b
commit
0fa46524ac
@ -17,8 +17,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import GPTNeoXConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -232,3 +232,28 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
@unittest.skip(reason="Feed forward chunking is not implemented")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class GPTNeoXLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_codegen(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
|
||||
for checkpointing in [True, False]:
|
||||
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped")
|
||||
|
||||
if checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
else:
|
||||
model.gradient_checkpointing_disable()
|
||||
model.to(torch_device)
|
||||
|
||||
inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device)
|
||||
expected_output = (
|
||||
"My favorite food is the chicken and rice.\n\nI love to cook and bake. I love to cook and bake"
|
||||
)
|
||||
|
||||
output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20)
|
||||
output_str = tokenizer.batch_decode(output_ids)[0]
|
||||
|
||||
self.assertEqual(output_str, expected_output)
|
||||
|
Loading…
Reference in New Issue
Block a user