Fix TF Rag OOM issue (#24122)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-06-09 15:03:11 +02:00 committed by GitHub
parent f2b918356c
commit 707023d155
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import gc
import json
import os
import shutil
@ -550,6 +551,11 @@ class TFRagDPRBartTest(TFRagTestMixin, unittest.TestCase):
@require_sentencepiece
@require_tokenizers
class TFRagModelIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
@cached_property
def token_model(self):
return TFRagTokenForGeneration.from_pretrained_question_encoder_generator(