diff --git a/.circleci/config.yml b/.circleci/config.yml index e67fdaa0263..f8040e7553f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -91,6 +91,34 @@ jobs: - store_artifacts: path: ~/transformers/reports + run_tests_torch_and_flax: + working_directory: ~/transformers + docker: + - image: circleci/python:3.6 + environment: + OMP_NUM_THREADS: 1 + resource_class: xlarge + parallelism: 1 + steps: + - checkout + - restore_cache: + keys: + - v0.4-torch_and_flax-{{ checksum "setup.py" }} + - v0.4-{{ checksum "setup.py" }} + - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev + - run: pip install --upgrade pip + - run: pip install .[sklearn,flax,torch,testing,sentencepiece,speech] + - run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html + - save_cache: + key: v0.4-{{ checksum "setup.py" }} + paths: + - '~/.cache/pip' + - run: RUN_PT_FLAX_CROSS_TESTS=1 python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_torch_and_flax ./tests/ -m is_pt_flax_cross_test --durations=0 | tee tests_output.txt + - store_artifacts: + path: ~/transformers/tests_output.txt + - store_artifacts: + path: ~/transformers/reports + run_tests_torch: working_directory: ~/transformers docker: @@ -159,9 +187,8 @@ jobs: keys: - v0.4-flax-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} - - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: pip install --upgrade pip - - run: sudo pip install .[flax,sklearn,torch,testing,sentencepiece,speech] + - run: sudo pip install .[flax,testing,sentencepiece] - save_cache: key: v0.4-flax-{{ checksum "setup.py" }} paths: @@ -418,6 +445,7 @@ workflows: - run_examples_torch - run_tests_custom_tokenizers - run_tests_torch_and_tf + - run_tests_torch_and_flax - run_tests_torch - run_tests_tf - run_tests_flax diff --git a/setup.py b/setup.py index 16567d71c0e..c27c66ff8cd 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ _deps = [ "fastapi", "filelock", "flake8>=3.8.3", - "flax>=0.2.2", + "flax>=0.3.2", "fugashi>=1.0", "importlib_metadata", "ipadic>=1.0.0,<2.0", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 576fbe7cd6f..8e0f3773e94 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -10,7 +10,7 @@ deps = { "fastapi": "fastapi", "filelock": "filelock", "flake8": "flake8>=3.8.3", - "flax": "flax>=0.2.2", + "flax": "flax>=0.3.2", "fugashi": "fugashi>=1.0", "importlib_metadata": "importlib_metadata", "ipadic": "ipadic>=1.0.0,<2.0", diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 063aba5553a..ee1dc5277ec 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -80,6 +80,7 @@ def parse_int_from_env(key, default=None): _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) _run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=False) +_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=False) _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) _run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=False) _run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False) @@ -105,6 +106,25 @@ def is_pt_tf_cross_test(test_case): return pytest.mark.is_pt_tf_cross_test()(test_case) +def is_pt_flax_cross_test(test_case): + """ + Decorator marking a test as a test that control interactions between PyTorch and Flax + + PT+FLAX tests are skipped by default and we can run only them by setting RUN_PT_FLAX_CROSS_TESTS environment + variable to a truthy value and selecting the is_pt_flax_cross_test pytest mark. + + """ + if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available(): + return unittest.skip("test is PT+FLAX test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_pt_flax_cross_test()(test_case) + + def is_pipeline_test(test_case): """ Decorator marking a test as a pipeline test. diff --git a/tests/conftest.py b/tests/conftest.py index c49a4d6a3e0..104a1394fdf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,6 +35,9 @@ def pytest_configure(config): config.addinivalue_line( "markers", "is_pt_tf_cross_test: mark test to run only when PT and TF interactions are tested" ) + config.addinivalue_line( + "markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested" + ) def pytest_addoption(parser): diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 5b5bf54bd88..19e900aef40 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -19,7 +19,7 @@ import numpy as np import transformers from transformers import is_flax_available, is_torch_available -from transformers.testing_utils import require_flax, require_torch +from transformers.testing_utils import is_pt_flax_cross_test, require_flax if is_flax_available(): @@ -60,7 +60,6 @@ def random_attention_mask(shape, rng=None): return attn_mask -@require_flax class FlaxModelTesterMixin: model_tester = None all_model_classes = () @@ -69,7 +68,7 @@ class FlaxModelTesterMixin: diff = np.abs((a - b)).max() self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") - @require_torch + @is_pt_flax_cross_test def test_equivalence_flax_pytorch(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -104,6 +103,7 @@ class FlaxModelTesterMixin: for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3) + @require_flax def test_from_pretrained_save_pretrained(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -121,6 +121,7 @@ class FlaxModelTesterMixin: for output_loaded, output in zip(outputs_loaded, outputs): self.assert_almost_equals(output_loaded, output, 5e-3) + @require_flax def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -143,6 +144,7 @@ class FlaxModelTesterMixin: for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) + @require_flax def test_naming_convention(self): for model_class in self.all_model_classes: model_class_name = model_class.__name__ diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 58b4eee3e8c..995b56b00e9 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -24,7 +24,13 @@ from collections import OrderedDict from itertools import takewhile from typing import TYPE_CHECKING, Dict, List, Tuple, Union -from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, is_torch_available +from transformers import ( + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, + is_tf_available, + is_torch_available, +) from transformers.testing_utils import ( get_tests_dir, is_pt_tf_cross_test, @@ -2283,7 +2289,12 @@ class TokenizerTesterMixin: "{} ({}, {})".format(tokenizer.__class__.__name__, pretrained_name, tokenizer.__class__.__name__) ): - returned_tensor = "pt" if is_torch_available() else "tf" + if is_torch_available(): + returned_tensor = "pt" + elif is_tf_available(): + returned_tensor = "tf" + else: + returned_tensor = "jax" if not tokenizer.pad_token or tokenizer.pad_token_id < 0: return diff --git a/tests/test_tokenization_marian.py b/tests/test_tokenization_marian.py index d78d582f3c0..3d9146b11fb 100644 --- a/tests/test_tokenization_marian.py +++ b/tests/test_tokenization_marian.py @@ -21,7 +21,7 @@ from pathlib import Path from shutil import copyfile from transformers import BatchEncoding, MarianTokenizer -from transformers.file_utils import is_sentencepiece_available, is_torch_available +from transformers.file_utils import is_sentencepiece_available, is_tf_available, is_torch_available from transformers.testing_utils import require_sentencepiece @@ -36,7 +36,13 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"} zh_code = ">>zh<<" ORG_NAME = "Helsinki-NLP/" -FRAMEWORK = "pt" if is_torch_available() else "tf" + +if is_torch_available(): + FRAMEWORK = "pt" +elif is_tf_available(): + FRAMEWORK = "tf" +else: + FRAMEWORK = "jax" @require_sentencepiece diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py index 27cdf612cea..710b4ad9fcf 100644 --- a/tests/test_tokenization_t5.py +++ b/tests/test_tokenization_t5.py @@ -17,7 +17,7 @@ import unittest from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast -from transformers.file_utils import cached_property, is_torch_available +from transformers.file_utils import cached_property, is_tf_available, is_torch_available from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers from .test_tokenization_common import TokenizerTesterMixin @@ -25,7 +25,12 @@ from .test_tokenization_common import TokenizerTesterMixin SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") -FRAMEWORK = "pt" if is_torch_available() else "tf" +if is_torch_available(): + FRAMEWORK = "pt" +elif is_tf_available(): + FRAMEWORK = "tf" +else: + FRAMEWORK = "jax" @require_sentencepiece @@ -157,7 +162,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id] batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) self.assertIsInstance(batch, BatchEncoding) - result = list(batch.input_ids.numpy()[0]) + + if FRAMEWORK != "jax": + result = list(batch.input_ids.numpy()[0]) + else: + result = list(batch.input_ids.tolist()[0]) + self.assertListEqual(expected_src_tokens, result) self.assertEqual((2, 9), batch.input_ids.shape)