mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Flax testing should not run the full torch test suite (#10725)
* make flax tests pytorch independent * fix typo * finish * improve circle ci * fix return tensors * correct flax test * re-add sentencepiece * last tokenizer fixes * finish maybe now
This commit is contained in:
parent
87d685b8a9
commit
9f8619c6aa
@ -91,6 +91,34 @@ jobs:
|
|||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: ~/transformers/reports
|
path: ~/transformers/reports
|
||||||
|
|
||||||
|
run_tests_torch_and_flax:
|
||||||
|
working_directory: ~/transformers
|
||||||
|
docker:
|
||||||
|
- image: circleci/python:3.6
|
||||||
|
environment:
|
||||||
|
OMP_NUM_THREADS: 1
|
||||||
|
resource_class: xlarge
|
||||||
|
parallelism: 1
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- restore_cache:
|
||||||
|
keys:
|
||||||
|
- v0.4-torch_and_flax-{{ checksum "setup.py" }}
|
||||||
|
- v0.4-{{ checksum "setup.py" }}
|
||||||
|
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
|
||||||
|
- run: pip install --upgrade pip
|
||||||
|
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,speech]
|
||||||
|
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
|
||||||
|
- save_cache:
|
||||||
|
key: v0.4-{{ checksum "setup.py" }}
|
||||||
|
paths:
|
||||||
|
- '~/.cache/pip'
|
||||||
|
- run: RUN_PT_FLAX_CROSS_TESTS=1 python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_torch_and_flax ./tests/ -m is_pt_flax_cross_test --durations=0 | tee tests_output.txt
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/transformers/tests_output.txt
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/transformers/reports
|
||||||
|
|
||||||
run_tests_torch:
|
run_tests_torch:
|
||||||
working_directory: ~/transformers
|
working_directory: ~/transformers
|
||||||
docker:
|
docker:
|
||||||
@ -159,9 +187,8 @@ jobs:
|
|||||||
keys:
|
keys:
|
||||||
- v0.4-flax-{{ checksum "setup.py" }}
|
- v0.4-flax-{{ checksum "setup.py" }}
|
||||||
- v0.4-{{ checksum "setup.py" }}
|
- v0.4-{{ checksum "setup.py" }}
|
||||||
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
|
|
||||||
- run: pip install --upgrade pip
|
- run: pip install --upgrade pip
|
||||||
- run: sudo pip install .[flax,sklearn,torch,testing,sentencepiece,speech]
|
- run: sudo pip install .[flax,testing,sentencepiece]
|
||||||
- save_cache:
|
- save_cache:
|
||||||
key: v0.4-flax-{{ checksum "setup.py" }}
|
key: v0.4-flax-{{ checksum "setup.py" }}
|
||||||
paths:
|
paths:
|
||||||
@ -418,6 +445,7 @@ workflows:
|
|||||||
- run_examples_torch
|
- run_examples_torch
|
||||||
- run_tests_custom_tokenizers
|
- run_tests_custom_tokenizers
|
||||||
- run_tests_torch_and_tf
|
- run_tests_torch_and_tf
|
||||||
|
- run_tests_torch_and_flax
|
||||||
- run_tests_torch
|
- run_tests_torch
|
||||||
- run_tests_tf
|
- run_tests_tf
|
||||||
- run_tests_flax
|
- run_tests_flax
|
||||||
|
2
setup.py
2
setup.py
@ -97,7 +97,7 @@ _deps = [
|
|||||||
"fastapi",
|
"fastapi",
|
||||||
"filelock",
|
"filelock",
|
||||||
"flake8>=3.8.3",
|
"flake8>=3.8.3",
|
||||||
"flax>=0.2.2",
|
"flax>=0.3.2",
|
||||||
"fugashi>=1.0",
|
"fugashi>=1.0",
|
||||||
"importlib_metadata",
|
"importlib_metadata",
|
||||||
"ipadic>=1.0.0,<2.0",
|
"ipadic>=1.0.0,<2.0",
|
||||||
|
@ -10,7 +10,7 @@ deps = {
|
|||||||
"fastapi": "fastapi",
|
"fastapi": "fastapi",
|
||||||
"filelock": "filelock",
|
"filelock": "filelock",
|
||||||
"flake8": "flake8>=3.8.3",
|
"flake8": "flake8>=3.8.3",
|
||||||
"flax": "flax>=0.2.2",
|
"flax": "flax>=0.3.2",
|
||||||
"fugashi": "fugashi>=1.0",
|
"fugashi": "fugashi>=1.0",
|
||||||
"importlib_metadata": "importlib_metadata",
|
"importlib_metadata": "importlib_metadata",
|
||||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||||
|
@ -80,6 +80,7 @@ def parse_int_from_env(key, default=None):
|
|||||||
|
|
||||||
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
||||||
_run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=False)
|
_run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=False)
|
||||||
|
_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=False)
|
||||||
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
|
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
|
||||||
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=False)
|
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=False)
|
||||||
_run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False)
|
_run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False)
|
||||||
@ -105,6 +106,25 @@ def is_pt_tf_cross_test(test_case):
|
|||||||
return pytest.mark.is_pt_tf_cross_test()(test_case)
|
return pytest.mark.is_pt_tf_cross_test()(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_pt_flax_cross_test(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a test that control interactions between PyTorch and Flax
|
||||||
|
|
||||||
|
PT+FLAX tests are skipped by default and we can run only them by setting RUN_PT_FLAX_CROSS_TESTS environment
|
||||||
|
variable to a truthy value and selecting the is_pt_flax_cross_test pytest mark.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available():
|
||||||
|
return unittest.skip("test is PT+FLAX test")(test_case)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import pytest # We don't need a hard dependency on pytest in the main library
|
||||||
|
except ImportError:
|
||||||
|
return test_case
|
||||||
|
else:
|
||||||
|
return pytest.mark.is_pt_flax_cross_test()(test_case)
|
||||||
|
|
||||||
|
|
||||||
def is_pipeline_test(test_case):
|
def is_pipeline_test(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test as a pipeline test.
|
Decorator marking a test as a pipeline test.
|
||||||
|
@ -35,6 +35,9 @@ def pytest_configure(config):
|
|||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
"markers", "is_pt_tf_cross_test: mark test to run only when PT and TF interactions are tested"
|
"markers", "is_pt_tf_cross_test: mark test to run only when PT and TF interactions are tested"
|
||||||
)
|
)
|
||||||
|
config.addinivalue_line(
|
||||||
|
"markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
|
@ -19,7 +19,7 @@ import numpy as np
|
|||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import is_flax_available, is_torch_available
|
from transformers import is_flax_available, is_torch_available
|
||||||
from transformers.testing_utils import require_flax, require_torch
|
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
|
||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
@ -60,7 +60,6 @@ def random_attention_mask(shape, rng=None):
|
|||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
|
||||||
class FlaxModelTesterMixin:
|
class FlaxModelTesterMixin:
|
||||||
model_tester = None
|
model_tester = None
|
||||||
all_model_classes = ()
|
all_model_classes = ()
|
||||||
@ -69,7 +68,7 @@ class FlaxModelTesterMixin:
|
|||||||
diff = np.abs((a - b)).max()
|
diff = np.abs((a - b)).max()
|
||||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||||
|
|
||||||
@require_torch
|
@is_pt_flax_cross_test
|
||||||
def test_equivalence_flax_pytorch(self):
|
def test_equivalence_flax_pytorch(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@ -104,6 +103,7 @@ class FlaxModelTesterMixin:
|
|||||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
|
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
|
||||||
|
|
||||||
|
@require_flax
|
||||||
def test_from_pretrained_save_pretrained(self):
|
def test_from_pretrained_save_pretrained(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@ -121,6 +121,7 @@ class FlaxModelTesterMixin:
|
|||||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||||
self.assert_almost_equals(output_loaded, output, 5e-3)
|
self.assert_almost_equals(output_loaded, output, 5e-3)
|
||||||
|
|
||||||
|
@require_flax
|
||||||
def test_jit_compilation(self):
|
def test_jit_compilation(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@ -143,6 +144,7 @@ class FlaxModelTesterMixin:
|
|||||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||||
self.assertEqual(jitted_output.shape, output.shape)
|
self.assertEqual(jitted_output.shape, output.shape)
|
||||||
|
|
||||||
|
@require_flax
|
||||||
def test_naming_convention(self):
|
def test_naming_convention(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model_class_name = model_class.__name__
|
model_class_name = model_class.__name__
|
||||||
|
@ -24,7 +24,13 @@ from collections import OrderedDict
|
|||||||
from itertools import takewhile
|
from itertools import takewhile
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, is_torch_available
|
from transformers import (
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
PreTrainedTokenizerFast,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
get_tests_dir,
|
get_tests_dir,
|
||||||
is_pt_tf_cross_test,
|
is_pt_tf_cross_test,
|
||||||
@ -2283,7 +2289,12 @@ class TokenizerTesterMixin:
|
|||||||
"{} ({}, {})".format(tokenizer.__class__.__name__, pretrained_name, tokenizer.__class__.__name__)
|
"{} ({}, {})".format(tokenizer.__class__.__name__, pretrained_name, tokenizer.__class__.__name__)
|
||||||
):
|
):
|
||||||
|
|
||||||
returned_tensor = "pt" if is_torch_available() else "tf"
|
if is_torch_available():
|
||||||
|
returned_tensor = "pt"
|
||||||
|
elif is_tf_available():
|
||||||
|
returned_tensor = "tf"
|
||||||
|
else:
|
||||||
|
returned_tensor = "jax"
|
||||||
|
|
||||||
if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
|
if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
|
||||||
return
|
return
|
||||||
|
@ -21,7 +21,7 @@ from pathlib import Path
|
|||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
|
|
||||||
from transformers import BatchEncoding, MarianTokenizer
|
from transformers import BatchEncoding, MarianTokenizer
|
||||||
from transformers.file_utils import is_sentencepiece_available, is_torch_available
|
from transformers.file_utils import is_sentencepiece_available, is_tf_available, is_torch_available
|
||||||
from transformers.testing_utils import require_sentencepiece
|
from transformers.testing_utils import require_sentencepiece
|
||||||
|
|
||||||
|
|
||||||
@ -36,7 +36,13 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
|
|||||||
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
|
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
|
||||||
zh_code = ">>zh<<"
|
zh_code = ">>zh<<"
|
||||||
ORG_NAME = "Helsinki-NLP/"
|
ORG_NAME = "Helsinki-NLP/"
|
||||||
FRAMEWORK = "pt" if is_torch_available() else "tf"
|
|
||||||
|
if is_torch_available():
|
||||||
|
FRAMEWORK = "pt"
|
||||||
|
elif is_tf_available():
|
||||||
|
FRAMEWORK = "tf"
|
||||||
|
else:
|
||||||
|
FRAMEWORK = "jax"
|
||||||
|
|
||||||
|
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
|
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
|
||||||
from transformers.file_utils import cached_property, is_torch_available
|
from transformers.file_utils import cached_property, is_tf_available, is_torch_available
|
||||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
|
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
@ -25,7 +25,12 @@ from .test_tokenization_common import TokenizerTesterMixin
|
|||||||
|
|
||||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||||
|
|
||||||
FRAMEWORK = "pt" if is_torch_available() else "tf"
|
if is_torch_available():
|
||||||
|
FRAMEWORK = "pt"
|
||||||
|
elif is_tf_available():
|
||||||
|
FRAMEWORK = "tf"
|
||||||
|
else:
|
||||||
|
FRAMEWORK = "jax"
|
||||||
|
|
||||||
|
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@ -157,7 +162,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
|
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
|
||||||
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
|
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
|
||||||
self.assertIsInstance(batch, BatchEncoding)
|
self.assertIsInstance(batch, BatchEncoding)
|
||||||
|
|
||||||
|
if FRAMEWORK != "jax":
|
||||||
result = list(batch.input_ids.numpy()[0])
|
result = list(batch.input_ids.numpy()[0])
|
||||||
|
else:
|
||||||
|
result = list(batch.input_ids.tolist()[0])
|
||||||
|
|
||||||
self.assertListEqual(expected_src_tokens, result)
|
self.assertListEqual(expected_src_tokens, result)
|
||||||
|
|
||||||
self.assertEqual((2, 9), batch.input_ids.shape)
|
self.assertEqual((2, 9), batch.input_ids.shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user