Add a decorator for flaky tests (#19498)

* Add a decorator for flaky tests

* Quality

* Don't break the rest

* Address review comments

* Fix test name

* Fix typo and print to stderr
This commit is contained in:
Sylvain Gugger 2022-10-12 14:00:17 -04:00 committed by GitHub
parent 1973b7716b
commit 209bec4636
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 50 additions and 3 deletions

View File

@ -14,6 +14,7 @@
import collections
import contextlib
import functools
import inspect
import logging
import os
@ -23,12 +24,13 @@ import shutil
import subprocess
import sys
import tempfile
import time
import unittest
from collections.abc import Mapping
from distutils.util import strtobool
from io import StringIO
from pathlib import Path
from typing import Iterator, List, Union
from typing import Iterator, List, Optional, Union
from unittest import mock
import huggingface_hub
@ -1635,3 +1637,36 @@ class RequestCounter:
self.other_request_count += 1
return self.old_request(method=method, **kwargs)
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None):
"""
To decorate flaky tests. They will be retried on failures.
Args:
max_attempts (`int`, *optional*, defaults to 5):
The maximum number of attempts to retry the flaky test.
wait_before_retry (`float`, *optional*):
If provided, will wait that number of seconds before retrying the test.
"""
def decorator(test_func_ref):
@functools.wraps(test_func_ref)
def wrapper(*args, **kwargs):
retry_count = 1
while retry_count < max_attempts:
try:
return test_func_ref(*args, **kwargs)
except Exception as err:
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1
return test_func_ref(*args, **kwargs)
return wrapper
return decorator

View File

@ -20,7 +20,7 @@ import unittest
from huggingface_hub import hf_hub_download
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@ -359,6 +359,10 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, unittest.TestCase):
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length],
)
@is_flaky()
def test_retain_grad_hidden_states_attentions(self):
super().test_retain_grad_hidden_states_attentions()
def prepare_batch(filename="train-batch.pt"):
file = hf_hub_download(repo_id="kashif/tourism-monthly-batch", filename=filename, repo_type="dataset")

View File

@ -21,7 +21,9 @@ from datasets import load_dataset
from transformers import Wav2Vec2Config, is_flax_available
from transformers.testing_utils import (
is_flaky,
is_librosa_available,
is_pt_flax_cross_test,
is_pyctcdecode_available,
require_flax,
require_librosa,
@ -302,6 +304,11 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
outputs = model(np.ones((1, 1024), dtype="f4"))
self.assertIsNotNone(outputs)
@is_pt_flax_cross_test
@is_flaky()
def test_equivalence_pt_to_flax(self):
super().test_equivalence_pt_to_flax()
@require_flax
class FlaxWav2Vec2UtilsTest(unittest.TestCase):

View File

@ -26,7 +26,7 @@ from datasets import load_dataset
from huggingface_hub import snapshot_download
from transformers import Wav2Vec2Config, is_tf_available
from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow
from transformers.testing_utils import is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
from transformers.utils import is_librosa_available, is_pyctcdecode_available
from ...test_configuration_common import ConfigTester
@ -309,6 +309,7 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
@is_flaky()
def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)