mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix Pipeline CI OOM issue (#24124)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
a7501f6fc6
commit
d0d1632958
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@ -507,6 +508,10 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
|
||||
self.check_default_pipeline(task, "pt", set_seed_fn, self.check_models_equal_pt)
|
||||
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
def test_load_default_pipelines_tf(self):
|
||||
@ -522,6 +527,9 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
|
||||
self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
|
||||
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_load_default_pipelines_pt_table_qa(self):
|
||||
@ -530,6 +538,10 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
set_seed_fn = lambda: torch.manual_seed(0) # noqa: E731
|
||||
self.check_default_pipeline("table-question-answering", "pt", set_seed_fn, self.check_models_equal_pt)
|
||||
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
@require_tensorflow_probability
|
||||
@ -539,6 +551,9 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
|
||||
self.check_default_pipeline("table-question-answering", "tf", set_seed_fn, self.check_models_equal_tf)
|
||||
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
|
||||
def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equal_fn):
|
||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
@ -29,7 +30,14 @@ from transformers import (
|
||||
TFAutoModelForCausalLM,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
is_torch_available,
|
||||
require_tf,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY
|
||||
|
||||
@ -39,6 +47,15 @@ DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
||||
|
||||
@is_pipeline_test
|
||||
class ConversationalPipelineTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model_mapping = dict(
|
||||
list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items())
|
||||
if MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
|
@ -12,12 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, FillMaskPipeline, pipeline
|
||||
from transformers.pipelines import PipelineException
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
is_torch_available,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
require_torch,
|
||||
@ -33,6 +35,15 @@ class FillMaskPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_MASKED_LM_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", top_k=2, framework="tf")
|
||||
|
Loading…
Reference in New Issue
Block a user