Fix PyTorch RAG tests GPU OOM (#16881)

* add torch.cuda.empty_cache in some PT RAG tests

* torch.cuda.empty_cache in tearDownModule()

* tearDown()

* add gc.collect()

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-04-25 17:33:56 +02:00 committed by GitHub
parent 3e47d19cfc
commit 32adbb26d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,6 +14,7 @@
# limitations under the License.
import gc
import json
import os
import shutil
@ -195,6 +196,10 @@ class RagTestMixin:
def tearDown(self):
shutil.rmtree(self.tmpdirname)
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def get_retriever(self, config):
dataset = Dataset.from_dict(
{
@ -677,6 +682,12 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
@require_tokenizers
@require_torch_non_multi_gpu
class RagModelIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
@cached_property
def sequence_model(self):
return (
@ -1024,6 +1035,12 @@ class RagModelIntegrationTests(unittest.TestCase):
@require_torch
@require_retrieval
class RagModelSaveLoadTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def get_rag_config(self):
question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")