Fix Pipeline CI OOM issue (#24124)

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-06-09 16:49:02 +02:00 committed by GitHub
parent a7501f6fc6
commit d0d1632958
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 1 deletions

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import logging
import os
import sys
@ -507,6 +508,10 @@ class PipelineUtilsTest(unittest.TestCase):
self.check_default_pipeline(task, "pt", set_seed_fn, self.check_models_equal_pt)
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
@slow
@require_tf
def test_load_default_pipelines_tf(self):
@ -522,6 +527,9 @@ class PipelineUtilsTest(unittest.TestCase):
self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
@slow
@require_torch
def test_load_default_pipelines_pt_table_qa(self):
@ -530,6 +538,10 @@ class PipelineUtilsTest(unittest.TestCase):
set_seed_fn = lambda: torch.manual_seed(0) # noqa: E731
self.check_default_pipeline("table-question-answering", "pt", set_seed_fn, self.check_models_equal_pt)
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
@slow
@require_tf
@require_tensorflow_probability
@ -539,6 +551,9 @@ class PipelineUtilsTest(unittest.TestCase):
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
self.check_default_pipeline("table-question-answering", "tf", set_seed_fn, self.check_models_equal_tf)
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equal_fn):
from transformers.pipelines import SUPPORTED_TASKS, pipeline

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from transformers import (
@ -29,7 +30,14 @@ from transformers import (
TFAutoModelForCausalLM,
pipeline,
)
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, slow, torch_device
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
require_tf,
require_torch,
slow,
torch_device,
)
from .test_pipelines_common import ANY
@ -39,6 +47,15 @@ DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
@is_pipeline_test
class ConversationalPipelineTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
if is_torch_available():
import torch
torch.cuda.empty_cache()
model_mapping = dict(
list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items())
if MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING

View File

@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from transformers import MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, FillMaskPipeline, pipeline
from transformers.pipelines import PipelineException
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
@ -33,6 +35,15 @@ class FillMaskPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_MASKED_LM_MAPPING
tf_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
if is_torch_available():
import torch
torch.cuda.empty_cache()
@require_tf
def test_small_model_tf(self):
unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", top_k=2, framework="tf")