avoid calling gc.collect and cuda.empty_cache (#34514)

* update

* update

* update

* update

* update

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2024-10-31 16:36:13 +01:00 committed by GitHub
parent dca93ca076
commit ab98f0b0a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 77 additions and 94 deletions

View File

@ -16,6 +16,7 @@ import collections
import contextlib import contextlib
import doctest import doctest
import functools import functools
import gc
import importlib import importlib
import inspect import inspect
import logging import logging
@ -2679,3 +2680,10 @@ def compare_pipeline_output_to_hub_spec(output, hub_spec):
if unexpected_keys: if unexpected_keys:
error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}") error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}")
raise KeyError("\n".join(error)) raise KeyError("\n".join(error))
@require_torch
def cleanup(device: str, gc_collect=False):
if gc_collect:
gc.collect()
backend_empty_cache(device)

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import itertools import itertools
import os import os
import random import random
@ -24,7 +23,13 @@ import numpy as np
from datasets import Audio, load_dataset from datasets import Audio, load_dataset
from transformers import ClvpFeatureExtractor from transformers import ClvpFeatureExtractor
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, slow from transformers.testing_utils import (
check_json_file_has_correct_format,
cleanup,
require_torch,
slow,
torch_device,
)
from transformers.utils.import_utils import is_torch_available from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
@ -116,8 +121,7 @@ class ClvpFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device)
torch.cuda.empty_cache()
# Copied from transformers.tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_feat_extract_from_and_save_pretrained # Copied from transformers.tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_feat_extract_from_and_save_pretrained
def test_feat_extract_from_and_save_pretrained(self): def test_feat_extract_from_and_save_pretrained(self):

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch Clvp model.""" """Testing suite for the PyTorch Clvp model."""
import gc
import tempfile import tempfile
import unittest import unittest
@ -23,6 +22,7 @@ import numpy as np
from transformers import ClvpConfig, ClvpDecoderConfig, ClvpEncoderConfig from transformers import ClvpConfig, ClvpDecoderConfig, ClvpEncoderConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
require_torch, require_torch,
slow, slow,
torch_device, torch_device,
@ -174,8 +174,7 @@ class ClvpEncoderTest(ModelTesterMixin, unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device)
torch.cuda.empty_cache()
def test_config(self): def test_config(self):
self.encoder_config_tester.run_common_tests() self.encoder_config_tester.run_common_tests()
@ -294,8 +293,7 @@ class ClvpDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device)
torch.cuda.empty_cache()
def test_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
@ -421,8 +419,7 @@ class ClvpModelForConditionalGenerationTest(ModelTesterMixin, unittest.TestCase)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device)
torch.cuda.empty_cache()
def test_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
@ -571,8 +568,7 @@ class ClvpIntegrationTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
def test_conditional_encoder(self): def test_conditional_encoder(self):
with torch.no_grad(): with torch.no_grad():

View File

@ -13,11 +13,10 @@
# limitations under the License. # limitations under the License.
import gc
import unittest import unittest
from transformers import CTRLConfig, is_torch_available from transformers import CTRLConfig, is_torch_available
from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device from transformers.testing_utils import cleanup, require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
@ -235,8 +234,7 @@ class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device)
backend_empty_cache(torch_device)
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
@ -261,8 +259,7 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device, gc_collect=True)
backend_empty_cache(torch_device)
@slow @slow
def test_lm_generate_ctrl(self): def test_lm_generate_ctrl(self):

View File

@ -15,7 +15,6 @@
import datetime import datetime
import gc
import math import math
import unittest import unittest
@ -23,7 +22,7 @@ import pytest
from transformers import GPT2Config, is_torch_available from transformers import GPT2Config, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
backend_empty_cache, cleanup,
require_flash_attn, require_flash_attn,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
@ -542,8 +541,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device)
backend_empty_cache(torch_device)
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
@ -753,8 +751,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device, gc_collect=True)
backend_empty_cache(torch_device)
def _test_lm_generate_gpt2_helper( def _test_lm_generate_gpt2_helper(
self, self,

View File

@ -18,7 +18,7 @@ import unittest
from parameterized import parameterized from parameterized import parameterized
from transformers import GPTBigCodeConfig, is_torch_available from transformers import GPTBigCodeConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import cleanup, require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
@ -422,9 +422,9 @@ class GPTBigCodeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
self.config_tester = ConfigTester(self, config_class=GPTBigCodeConfig, n_embd=37) self.config_tester = ConfigTester(self, config_class=GPTBigCodeConfig, n_embd=37)
def tearDown(self): def tearDown(self):
import gc super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device)
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()

View File

@ -15,7 +15,6 @@
"""Testing suite for the PyTorch Idefics2 model.""" """Testing suite for the PyTorch Idefics2 model."""
import copy import copy
import gc
import tempfile import tempfile
import unittest import unittest
from io import BytesIO from io import BytesIO
@ -31,6 +30,7 @@ from transformers import (
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
require_bitsandbytes, require_bitsandbytes,
require_flash_attn, require_flash_attn,
require_torch, require_torch,
@ -583,8 +583,7 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
) )
def tearDown(self): def tearDown(self):
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
@slow @slow
@require_torch_multi_gpu @require_torch_multi_gpu

View File

@ -15,7 +15,6 @@
"""Testing suite for the PyTorch Idefics3 model.""" """Testing suite for the PyTorch Idefics3 model."""
import copy import copy
import gc
import unittest import unittest
from io import BytesIO from io import BytesIO
@ -26,7 +25,7 @@ from transformers import (
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device from transformers.testing_utils import cleanup, require_bitsandbytes, require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
@ -497,8 +496,7 @@ class Idefics3ForConditionalGenerationIntegrationTest(unittest.TestCase):
) )
def tearDown(self): def tearDown(self):
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
@slow @slow
@unittest.skip("multi-gpu tests are disabled for now") @unittest.skip("multi-gpu tests are disabled for now")

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch LLaMA model.""" """Testing suite for the PyTorch LLaMA model."""
import gc
import tempfile import tempfile
import unittest import unittest
@ -25,7 +24,7 @@ from parameterized import parameterized
from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed
from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
backend_empty_cache, cleanup,
require_flash_attn, require_flash_attn,
require_read_token, require_read_token,
require_torch, require_torch,
@ -891,8 +890,7 @@ class LlamaIntegrationTest(unittest.TestCase):
@require_torch_accelerator @require_torch_accelerator
class Mask4DTestHard(unittest.TestCase): class Mask4DTestHard(unittest.TestCase):
def tearDown(self): def tearDown(self):
gc.collect() cleanup(torch_device, gc_collect=True)
backend_empty_cache(torch_device)
def setUp(self): def setUp(self):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch Llava model.""" """Testing suite for the PyTorch Llava model."""
import gc
import unittest import unittest
import requests import requests
@ -28,6 +27,7 @@ from transformers import (
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
require_bitsandbytes, require_bitsandbytes,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
@ -307,8 +307,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
self.processor = AutoProcessor.from_pretrained("llava-hf/bakLlava-v1-hf") self.processor = AutoProcessor.from_pretrained("llava-hf/bakLlava-v1-hf")
def tearDown(self): def tearDown(self):
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
@slow @slow
@require_bitsandbytes @require_bitsandbytes

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch Llava-NeXT model.""" """Testing suite for the PyTorch Llava-NeXT model."""
import gc
import unittest import unittest
import requests import requests
@ -28,6 +27,7 @@ from transformers import (
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
require_bitsandbytes, require_bitsandbytes,
require_torch, require_torch,
slow, slow,
@ -370,8 +370,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
self.prompt = "[INST] <image>\nWhat is shown in this image? [/INST]" self.prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
def tearDown(self): def tearDown(self):
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
@slow @slow
@require_bitsandbytes @require_bitsandbytes

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch Llava-NeXT-Video model.""" """Testing suite for the PyTorch Llava-NeXT-Video model."""
import gc
import unittest import unittest
import numpy as np import numpy as np
@ -29,6 +28,7 @@ from transformers import (
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
require_bitsandbytes, require_bitsandbytes,
require_torch, require_torch,
slow, slow,
@ -400,8 +400,7 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
self.prompt_video = "USER: <video>\nWhy is this video funny? ASSISTANT:" self.prompt_video = "USER: <video>\nWhy is this video funny? ASSISTANT:"
def tearDown(self): def tearDown(self):
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
@slow @slow
@require_bitsandbytes @require_bitsandbytes

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch Llava-NeXT model.""" """Testing suite for the PyTorch Llava-NeXT model."""
import gc
import unittest import unittest
import numpy as np import numpy as np
@ -29,6 +28,7 @@ from transformers import (
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
require_bitsandbytes, require_bitsandbytes,
require_torch, require_torch,
slow, slow,
@ -336,8 +336,7 @@ class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase):
self.prompt_video = "user\n<video>\nWhat do you see in this video?<|im_end|>\n<|im_start|>assistant\n" self.prompt_video = "user\n<video>\nWhat do you see in this video?<|im_end|>\n<|im_start|>assistant\n"
def tearDown(self): def tearDown(self):
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
@slow @slow
@require_bitsandbytes @require_bitsandbytes

View File

@ -23,6 +23,7 @@ from packaging import version
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
backend_empty_cache, backend_empty_cache,
cleanup,
require_bitsandbytes, require_bitsandbytes,
require_flash_attn, require_flash_attn,
require_read_token, require_read_token,
@ -436,8 +437,7 @@ class MistralIntegrationTest(unittest.TestCase):
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
def tearDown(self): def tearDown(self):
torch.cuda.empty_cache() cleanup(torch_device, gc_collect=True)
gc.collect()
@slow @slow
def test_model_7b_logits(self): def test_model_7b_logits(self):
@ -656,8 +656,7 @@ class Mask4DTestHard(unittest.TestCase):
_model = None _model = None
def tearDown(self): def tearDown(self):
gc.collect() cleanup(torch_device, gc_collect=True)
backend_empty_cache(torch_device)
@property @property
def model(self): def model(self):

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch Mllama model.""" """Testing suite for the PyTorch Mllama model."""
import gc
import unittest import unittest
import requests import requests
@ -30,6 +29,7 @@ from transformers import (
) )
from transformers.models.mllama.configuration_mllama import MllamaTextConfig from transformers.models.mllama.configuration_mllama import MllamaTextConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
is_flaky, is_flaky,
require_bitsandbytes, require_bitsandbytes,
require_read_token, require_read_token,
@ -396,8 +396,7 @@ class MllamaForConditionalGenerationIntegrationTest(unittest.TestCase):
self.instruct_model_checkpoint = "meta-llama/Llama-3.2-11B-Vision-Instruct" self.instruct_model_checkpoint = "meta-llama/Llama-3.2-11B-Vision-Instruct"
def tearDown(self): def tearDown(self):
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
@slow @slow
@require_torch_gpu @require_torch_gpu

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch PaliGemma model.""" """Testing suite for the PyTorch PaliGemma model."""
import gc
import unittest import unittest
import requests import requests
@ -28,6 +27,7 @@ from transformers import (
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
require_read_token, require_read_token,
require_torch, require_torch,
require_torch_sdpa, require_torch_sdpa,
@ -365,8 +365,7 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
self.processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-pt-224") self.processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-pt-224")
def tearDown(self): def tearDown(self):
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
def test_small_model_integration_test(self): def test_small_model_integration_test(self):
# Let' s make sure we test the preprocessing to replace what is used # Let' s make sure we test the preprocessing to replace what is used

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch Qwen2Audio model.""" """Testing suite for the PyTorch Qwen2Audio model."""
import gc
import tempfile import tempfile
import unittest import unittest
from io import BytesIO from io import BytesIO
@ -29,6 +28,7 @@ from transformers import (
is_torch_available, is_torch_available,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
require_torch, require_torch,
require_torch_sdpa, require_torch_sdpa,
slow, slow,
@ -222,8 +222,7 @@ class Qwen2AudioForConditionalGenerationIntegrationTest(unittest.TestCase):
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct") self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct")
def tearDown(self): def tearDown(self):
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
@slow @slow
def test_small_model_integration_test_single(self): def test_small_model_integration_test_single(self):

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import gc
import json import json
import os import os
import shutil import shutil
@ -29,6 +28,7 @@ from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_
from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
get_tests_dir, get_tests_dir,
require_sentencepiece, require_sentencepiece,
require_tokenizers, require_tokenizers,
@ -196,8 +196,7 @@ class RagTestMixin:
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device)
torch.cuda.empty_cache()
def get_retriever(self, config): def get_retriever(self, config):
dataset = Dataset.from_dict( dataset = Dataset.from_dict(
@ -684,8 +683,7 @@ class RagModelIntegrationTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
@cached_property @cached_property
def sequence_model(self): def sequence_model(self):
@ -1043,8 +1041,7 @@ class RagModelSaveLoadTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
def get_rag_config(self): def get_rag_config(self):
question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base") question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

View File

@ -14,13 +14,12 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch SAM model.""" """Testing suite for the PyTorch SAM model."""
import gc
import unittest import unittest
import requests import requests
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device from transformers.testing_utils import cleanup, require_torch, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
@ -469,8 +468,7 @@ class SamModelIntegrationTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device, gc_collect=True)
backend_empty_cache(torch_device)
def test_inference_mask_generation_no_point(self): def test_inference_mask_generation_no_point(self):
model = SamModel.from_pretrained("facebook/sam-vit-base") model = SamModel.from_pretrained("facebook/sam-vit-base")

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import inspect import inspect
import random import random
import unittest import unittest
@ -21,7 +20,7 @@ from datasets import Audio, load_dataset
from transformers import UnivNetConfig, UnivNetFeatureExtractor from transformers import UnivNetConfig, UnivNetFeatureExtractor
from transformers.testing_utils import ( from transformers.testing_utils import (
backend_empty_cache, cleanup,
is_torch_available, is_torch_available,
require_torch, require_torch,
require_torch_accelerator, require_torch_accelerator,
@ -211,8 +210,7 @@ class UnivNetModelTest(ModelTesterMixin, unittest.TestCase):
class UnivNetModelIntegrationTests(unittest.TestCase): class UnivNetModelIntegrationTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() cleanup(torch_device, gc_collect=True)
backend_empty_cache(torch_device)
def _load_datasamples(self, num_samples, sampling_rate=24000): def _load_datasamples(self, num_samples, sampling_rate=24000):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch VideoLlava model.""" """Testing suite for the PyTorch VideoLlava model."""
import gc
import unittest import unittest
import numpy as np import numpy as np
@ -29,6 +28,7 @@ from transformers import (
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
require_bitsandbytes, require_bitsandbytes,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
@ -437,8 +437,7 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
self.processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf") self.processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
def tearDown(self): def tearDown(self):
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
@slow @slow
@require_bitsandbytes @require_bitsandbytes

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch VipLlava model.""" """Testing suite for the PyTorch VipLlava model."""
import gc
import unittest import unittest
import requests import requests
@ -26,7 +25,14 @@ from transformers import (
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device from transformers.testing_utils import (
cleanup,
require_bitsandbytes,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
@ -290,8 +296,7 @@ class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
self.processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf") self.processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf")
def tearDown(self): def tearDown(self):
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
@slow @slow
@require_bitsandbytes @require_bitsandbytes

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch Wav2Vec2 model.""" """Testing suite for the PyTorch Wav2Vec2 model."""
import gc
import math import math
import multiprocessing import multiprocessing
import os import os
@ -30,7 +29,7 @@ from pytest import mark
from transformers import Wav2Vec2Config, is_torch_available from transformers import Wav2Vec2Config, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
CaptureLogger, CaptureLogger,
backend_empty_cache, cleanup,
is_pt_flax_cross_test, is_pt_flax_cross_test,
is_pyctcdecode_available, is_pyctcdecode_available,
is_torchaudio_available, is_torchaudio_available,
@ -1460,8 +1459,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device, gc_collect=True)
backend_empty_cache(torch_device)
def _load_datasamples(self, num_samples): def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

View File

@ -14,12 +14,12 @@
# limitations under the License. # limitations under the License.
import datetime import datetime
import gc
import math import math
import unittest import unittest
from transformers import XGLMConfig, is_torch_available from transformers import XGLMConfig, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup,
require_torch, require_torch,
require_torch_accelerator, require_torch_accelerator,
require_torch_fp16, require_torch_fp16,
@ -343,8 +343,7 @@ class XGLMModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() cleanup(torch_device, gc_collect=True)
torch.cuda.empty_cache()
def _test_lm_generate_xglm_helper( def _test_lm_generate_xglm_helper(
self, self,