mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix PyTorch RAG tests GPU OOM (#16881)
* add torch.cuda.empty_cache in some PT RAG tests * torch.cuda.empty_cache in tearDownModule() * tearDown() * add gc.collect() Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
3e47d19cfc
commit
32adbb26d6
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
@ -195,6 +196,10 @@ class RagTestMixin:
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_retriever(self, config):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
@ -677,6 +682,12 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
|
||||
@require_tokenizers
|
||||
@require_torch_non_multi_gpu
|
||||
class RagModelIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@cached_property
|
||||
def sequence_model(self):
|
||||
return (
|
||||
@ -1024,6 +1035,12 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
@require_torch
|
||||
@require_retrieval
|
||||
class RagModelSaveLoadTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_rag_config(self):
|
||||
question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
||||
generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")
|
||||
|
Loading…
Reference in New Issue
Block a user