mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Fix TF Rag OOM issue (#24122)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
f2b918356c
commit
707023d155
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
@ -550,6 +551,11 @@ class TFRagDPRBartTest(TFRagTestMixin, unittest.TestCase):
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class TFRagModelIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
|
||||
@cached_property
|
||||
def token_model(self):
|
||||
return TFRagTokenForGeneration.from_pretrained_question_encoder_generator(
|
||||
|
Loading…
Reference in New Issue
Block a user