Avoid GPT-2 daily CI job OOM (in TF tests) (#24106)

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-06-08 18:21:09 +02:00 committed by GitHub
parent 9322c24476
commit 2e2088f24b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -15,6 +15,7 @@
import datetime
import gc
import math
import unittest
@ -500,6 +501,12 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
self.model_tester = GPT2ModelTester(self)
self.config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37)
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def test_config(self):
self.config_tester.run_common_tests()
@ -683,6 +690,12 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
@require_torch
class GPT2ModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def _test_lm_generate_gpt2_helper(
self,
gradient_checkpointing=False,