skip gptj slow generate tests for now (#13809)

This commit is contained in:
Suraj Patil 2021-10-01 01:14:33 +05:30 committed by GitHub
parent 41436d3dfb
commit 8bbb53e20b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -396,8 +396,9 @@ class GPTJModelTest(unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
@slow
@tooslow
def test_batch_generation(self):
# Marked as @tooslow due to GPU OOM
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16)
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")
@ -464,8 +465,9 @@ class GPTJModelTest(unittest.TestCase):
@require_torch
class GPTJModelLanguageGenerationTest(unittest.TestCase):
@slow
@tooslow
def test_lm_generate_gptj(self):
# Marked as @tooslow due to GPU OOM
for checkpointing in [True, False]:
model = GPTJForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16