mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
skip gptj slow generate tests for now (#13809)
This commit is contained in:
parent
41436d3dfb
commit
8bbb53e20b
@ -396,8 +396,9 @@ class GPTJModelTest(unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
|
||||
|
||||
@slow
|
||||
@tooslow
|
||||
def test_batch_generation(self):
|
||||
# Marked as @tooslow due to GPU OOM
|
||||
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16)
|
||||
model.to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")
|
||||
@ -464,8 +465,9 @@ class GPTJModelTest(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
class GPTJModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
@tooslow
|
||||
def test_lm_generate_gptj(self):
|
||||
# Marked as @tooslow due to GPU OOM
|
||||
for checkpointing in [True, False]:
|
||||
model = GPTJForCausalLM.from_pretrained(
|
||||
"EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16
|
||||
|
Loading…
Reference in New Issue
Block a user