diff --git a/.circleci/create_circleci_config.py b/.circleci/create_circleci_config.py index 00c213c86eb..156765d30f3 100644 --- a/.circleci/create_circleci_config.py +++ b/.circleci/create_circleci_config.py @@ -28,7 +28,6 @@ COMMON_ENV_VARIABLES = { "TRANSFORMERS_IS_CI": True, "PYTEST_TIMEOUT": 120, "RUN_PIPELINE_TESTS": False, - "RUN_PT_FLAX_CROSS_TESTS": False, } # Disable the use of {"s": None} as the output is way too long, causing the navigation on CircleCI impractical COMMON_PYTEST_OPTIONS = {"max-worker-restart": 0, "dist": "loadfile", "vvv": None, "rsfE":None} @@ -176,14 +175,6 @@ class CircleCIJob: # JOBS -torch_and_flax_job = CircleCIJob( - "torch_and_flax", - additional_env={"RUN_PT_FLAX_CROSS_TESTS": True}, - docker_image=[{"image":"huggingface/transformers-torch-jax-light"}], - marker="is_pt_flax_cross_test", - pytest_options={"rA": None, "durations": 0}, -) - torch_job = CircleCIJob( "torch", docker_image=[{"image": "huggingface/transformers-torch-light"}], @@ -343,7 +334,7 @@ doc_test_job = CircleCIJob( pytest_num_workers=1, ) -REGULAR_TESTS = [torch_and_flax_job, torch_job, tf_job, flax_job, hub_job, onnx_job, tokenization_job, processor_job, generate_job, non_model_job] # fmt: skip +REGULAR_TESTS = [torch_job, tf_job, flax_job, hub_job, onnx_job, tokenization_job, processor_job, generate_job, non_model_job] # fmt: skip EXAMPLES_TESTS = [examples_torch_job, examples_tensorflow_job] PIPELINE_TESTS = [pipelines_torch_job, pipelines_tf_job] REPO_UTIL_TESTS = [repo_utils_job] diff --git a/.github/workflows/build-ci-docker-images.yml b/.github/workflows/build-ci-docker-images.yml index 9d947684ee8..5606668531d 100644 --- a/.github/workflows/build-ci-docker-images.yml +++ b/.github/workflows/build-ci-docker-images.yml @@ -26,7 +26,7 @@ jobs: strategy: matrix: - file: ["quality", "consistency", "custom-tokenizers", "torch-light", "tf-light", "exotic-models", "torch-tf-light", "torch-jax-light", "jax-light", "examples-torch", "examples-tf"] + file: ["quality", "consistency", "custom-tokenizers", "torch-light", "tf-light", "exotic-models", "torch-tf-light", "jax-light", "examples-torch", "examples-tf"] continue-on-error: true steps: @@ -34,11 +34,11 @@ jobs: name: Set tag run: | if ${{contains(github.event.head_commit.message, '[build-ci-image]')}}; then - echo "TAG=huggingface/transformers-${{ matrix.file }}:dev" >> "$GITHUB_ENV" + echo "TAG=huggingface/transformers-${{ matrix.file }}:dev" >> "$GITHUB_ENV" echo "setting it to DEV!" else echo "TAG=huggingface/transformers-${{ matrix.file }}" >> "$GITHUB_ENV" - + fi - name: Set up Docker Buildx diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index dee40cfbe41..c6047b0e1cc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -343,7 +343,6 @@ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/t Like the slow tests, there are other environment variables available which are not enabled by default during testing: - `RUN_CUSTOM_TOKENIZERS`: Enables tests for custom tokenizers. -- `RUN_PT_FLAX_CROSS_TESTS`: Enables tests for PyTorch + Flax integration. More environment variables and additional information can be found in the [testing_utils.py](https://github.com/huggingface/transformers/blob/main/src/transformers/testing_utils.py). diff --git a/conftest.py b/conftest.py index 6c4b1f3d2bf..ef520380443 100644 --- a/conftest.py +++ b/conftest.py @@ -84,9 +84,6 @@ warnings.simplefilter(action="ignore", category=FutureWarning) def pytest_configure(config): - config.addinivalue_line( - "markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested" - ) config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested") config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment") config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate") diff --git a/docs/source/de/contributing.md b/docs/source/de/contributing.md index 08ef100571c..61ee8c3fc4e 100644 --- a/docs/source/de/contributing.md +++ b/docs/source/de/contributing.md @@ -283,7 +283,6 @@ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/t Wie bei den langsamen Tests gibt es auch andere Umgebungsvariablen, die standardmäßig beim Testen nicht gesetzt sind: * `RUN_CUSTOM_TOKENIZERS`: Aktiviert Tests für benutzerdefinierte Tokenizer. -* `RUN_PT_FLAX_CROSS_TESTS`: Aktiviert Tests für die Integration von PyTorch + Flax. Weitere Umgebungsvariablen und zusätzliche Informationen finden Sie in der [testing_utils.py](src/transformers/testing_utils.py). diff --git a/docs/source/ko/contributing.md b/docs/source/ko/contributing.md index 78e2f2a2169..f1c0a84ef32 100644 --- a/docs/source/ko/contributing.md +++ b/docs/source/ko/contributing.md @@ -282,7 +282,6 @@ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/t 느린 테스트와 마찬가지로, 다음과 같이 테스트 중에 기본적으로 활성화되지 않는 다른 환경 변수도 있습니다: - `RUN_CUSTOM_TOKENIZERS`: 사용자 정의 토크나이저 테스트를 활성화합니다. -- `RUN_PT_FLAX_CROSS_TESTS`: PyTorch + Flax 통합 테스트를 활성화합니다. 더 많은 환경 변수와 추가 정보는 [testing_utils.py](src/transformers/testing_utils.py)에서 찾을 수 있습니다. diff --git a/docs/source/zh/contributing.md b/docs/source/zh/contributing.md index 215dd653ad8..045b58af086 100644 --- a/docs/source/zh/contributing.md +++ b/docs/source/zh/contributing.md @@ -281,7 +281,6 @@ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/t 和时间较长的测试一样,还有其他环境变量在测试过程中,在默认情况下是未启用的: - `RUN_CUSTOM_TOKENIZERS`: 启用自定义分词器的测试。 -- `RUN_PT_FLAX_CROSS_TESTS`: 启用 PyTorch + Flax 整合的测试。 更多环境变量和额外信息可以在 [testing_utils.py](src/transformers/testing_utils.py) 中找到。 diff --git a/setup.py b/setup.py index 53dac683160..9c207fb8161 100644 --- a/setup.py +++ b/setup.py @@ -473,7 +473,6 @@ setup( extras["tests_torch"] = deps_list() extras["tests_tf"] = deps_list() extras["tests_flax"] = deps_list() -extras["tests_torch_and_flax"] = deps_list() extras["tests_hub"] = deps_list() extras["tests_pipelines_torch"] = deps_list() extras["tests_pipelines_tf"] = deps_list() diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 6f7808a9b1d..1d575ad4a3a 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -230,10 +230,8 @@ def parse_int_from_env(key, default=None): _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) -_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=True) _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) _run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False) -_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None) _run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True) _run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False) _run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False) @@ -250,25 +248,6 @@ def get_device_count(): return num_devices -def is_pt_flax_cross_test(test_case): - """ - Decorator marking a test as a test that control interactions between PyTorch and Flax - - PT+FLAX tests are skipped by default and we can run only them by setting RUN_PT_FLAX_CROSS_TESTS environment - variable to a truthy value and selecting the is_pt_flax_cross_test pytest mark. - - """ - if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available(): - return unittest.skip(reason="test is PT+FLAX test")(test_case) - else: - try: - import pytest # We don't need a hard dependency on pytest in the main library - except ImportError: - return test_case - else: - return pytest.mark.is_pt_flax_cross_test()(test_case) - - def is_staging_test(test_case): """ Decorator marking a test as a staging test. diff --git a/tests/models/big_bird/test_modeling_big_bird.py b/tests/models/big_bird/test_modeling_big_bird.py index 6aca3cbc410..ffad4459ec9 100644 --- a/tests/models/big_bird/test_modeling_big_bird.py +++ b/tests/models/big_bird/test_modeling_big_bird.py @@ -624,15 +624,6 @@ class BigBirdModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) def test_training_gradient_checkpointing_use_reentrant_false(self): pass - # overwrite from common in order to skip the check on `attentions` - def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): - # `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version, - # an effort was done to return `attention_probs` (yet to be verified). - if name.startswith("outputs.attentions"): - return - else: - super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes) - @require_torch @slow diff --git a/tests/models/big_bird/test_modeling_flax_big_bird.py b/tests/models/big_bird/test_modeling_flax_big_bird.py index 8beb12b8c6c..fe1790bf75c 100644 --- a/tests/models/big_bird/test_modeling_flax_big_bird.py +++ b/tests/models/big_bird/test_modeling_flax_big_bird.py @@ -212,12 +212,3 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): self.assertEqual(len(outputs), len(jitted_outputs)) for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) - - # overwrite from common in order to skip the check on `attentions` - def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): - # `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version, - # an effort was done to return `attention_probs` (yet to be verified). - if name.startswith("outputs.attentions"): - return - else: - super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes) diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index 75ee9a189ad..63723bfe3b9 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -25,11 +25,8 @@ import requests from parameterized import parameterized from pytest import mark -import transformers from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig from transformers.testing_utils import ( - is_flax_available, - is_pt_flax_cross_test, require_flash_attn, require_torch, require_torch_gpu, @@ -82,15 +79,6 @@ if is_vision_available(): from transformers import CLIPProcessor -if is_flax_available(): - import jax.numpy as jnp - - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) - - class CLIPVisionModelTester: def __init__( self, @@ -883,126 +871,6 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase text_config = CLIPTextConfig.from_pretrained(tmp_dir_name) self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict()) - # overwrite from common since FlaxCLIPModel returns nested output - # which is not supported in the common test - @is_pt_flax_cross_test - def test_equivalence_pt_to_flax(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # load PyTorch class - pt_model = model_class(config).eval() - pt_model.to(torch_device) - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - - fx_model_class_name = "Flax" + model_class.__name__ - - if not hasattr(transformers, fx_model_class_name): - self.skipTest(reason="No Flax model exists for this class") - - fx_model_class = getattr(transformers, fx_model_class_name) - - # load Flax class - fx_model = fx_model_class(config, dtype=jnp.float32) - # make sure only flax inputs are forward that actually exist in function args - fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() - - # prepare inputs - pt_inputs = self._prepare_for_class(inputs_dict, model_class) - - # remove function args that don't exist in Flax - pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - # convert inputs to Flax - fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} - fx_outputs = fx_model(**fx_inputs).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2) - - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple() - self.assertEqual( - len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" - ) - for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2) - - # overwrite from common since FlaxCLIPModel returns nested output - # which is not supported in the common test - @is_pt_flax_cross_test - def test_equivalence_flax_to_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # load corresponding PyTorch class - pt_model = model_class(config).eval() - - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - - fx_model_class_name = "Flax" + model_class.__name__ - - if not hasattr(transformers, fx_model_class_name): - self.skipTest(reason="No Flax model exists for this class") - - fx_model_class = getattr(transformers, fx_model_class_name) - - # load Flax class - fx_model = fx_model_class(config, dtype=jnp.float32) - # make sure only flax inputs are forward that actually exist in function args - fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - pt_model.to(torch_device) - - # make sure weights are tied in PyTorch - pt_model.tie_weights() - - # prepare inputs - pt_inputs = self._prepare_for_class(inputs_dict, model_class) - - # remove function args that don't exist in Flax - pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} - - fx_outputs = fx_model(**fx_inputs).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - - for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2) - - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) - pt_model_loaded.to(torch_device) - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() - - self.assertEqual( - len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" - ) - for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2) - @slow def test_model_from_pretrained(self): model_name = "openai/clip-vit-base-patch32" diff --git a/tests/models/clip/test_modeling_flax_clip.py b/tests/models/clip/test_modeling_flax_clip.py index c1d05081ca5..d499f4bf7dc 100644 --- a/tests/models/clip/test_modeling_flax_clip.py +++ b/tests/models/clip/test_modeling_flax_clip.py @@ -4,21 +4,15 @@ import unittest import numpy as np -import transformers -from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig, is_flax_available, is_torch_available -from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow +from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig, is_flax_available +from transformers.testing_utils import require_flax, slow from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask if is_flax_available(): import jax - import jax.numpy as jnp - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) from transformers.models.clip.modeling_flax_clip import ( FlaxCLIPModel, FlaxCLIPTextModel, @@ -26,9 +20,6 @@ if is_flax_available(): FlaxCLIPVisionModel, ) -if is_torch_available(): - import torch - class FlaxCLIPVisionModelTester: def __init__( @@ -223,21 +214,6 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase): def test_save_load_to_base(self): pass - # FlaxCLIPVisionModel does not have any base model - @is_pt_flax_cross_test - def test_save_load_from_base_pt(self): - pass - - # FlaxCLIPVisionModel does not have any base model - @is_pt_flax_cross_test - def test_save_load_to_base_pt(self): - pass - - # FlaxCLIPVisionModel does not have any base model - @is_pt_flax_cross_test - def test_save_load_bf16_to_base_pt(self): - pass - @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: @@ -333,21 +309,6 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase): def test_save_load_to_base(self): pass - # FlaxCLIPVisionModel does not have any base model - @is_pt_flax_cross_test - def test_save_load_from_base_pt(self): - pass - - # FlaxCLIPVisionModel does not have any base model - @is_pt_flax_cross_test - def test_save_load_to_base_pt(self): - pass - - # FlaxCLIPVisionModel does not have any base model - @is_pt_flax_cross_test - def test_save_load_bf16_to_base_pt(self): - pass - @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: @@ -472,92 +433,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase): outputs = model(input_ids=np.ones((1, 1)), pixel_values=np.ones((1, 3, 224, 224))) self.assertIsNotNone(outputs) - # overwrite from common since FlaxCLIPModel returns nested output - # which is not supported in the common test - @is_pt_flax_cross_test - def test_equivalence_pt_to_flax(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - pt_model = pt_model_class(config).eval() - fx_model = model_class(config, dtype=jnp.float32) - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) - - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple() - self.assertEqual( - len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" - ) - for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) - - # overwrite from common since FlaxCLIPModel returns nested output - # which is not supported in the common test - @is_pt_flax_cross_test - def test_equivalence_flax_to_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - pt_model = pt_model_class(config).eval() - fx_model = model_class(config, dtype=jnp.float32) - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - - # make sure weights are tied in PyTorch - pt_model.tie_weights() - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) - - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() - - self.assertEqual( - len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" - ) - for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) - # overwrite from common since FlaxCLIPModel returns nested output # which is not supported in the common test def test_from_pretrained_save_pretrained(self): diff --git a/tests/models/clipseg/test_modeling_clipseg.py b/tests/models/clipseg/test_modeling_clipseg.py index 4b712f19900..a17b2b6a4fe 100644 --- a/tests/models/clipseg/test_modeling_clipseg.py +++ b/tests/models/clipseg/test_modeling_clipseg.py @@ -22,11 +22,8 @@ import unittest import numpy as np import requests -import transformers from transformers import CLIPSegConfig, CLIPSegProcessor, CLIPSegTextConfig, CLIPSegVisionConfig from transformers.testing_utils import ( - is_flax_available, - is_pt_flax_cross_test, require_torch, require_vision, slow, @@ -57,15 +54,6 @@ if is_vision_available(): from PIL import Image -if is_flax_available(): - import jax.numpy as jnp - - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) - - class CLIPSegVisionModelTester: def __init__( self, @@ -635,123 +623,6 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) text_config = CLIPSegTextConfig.from_pretrained(tmp_dir_name) self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict()) - # overwrite from common since FlaxCLIPSegModel returns nested output - # which is not supported in the common test - @is_pt_flax_cross_test - def test_equivalence_pt_to_flax(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # load PyTorch class - pt_model = model_class(config).eval() - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - - fx_model_class_name = "Flax" + model_class.__name__ - - if not hasattr(transformers, fx_model_class_name): - self.skipTest(reason="No Flax model exists for this class") - - fx_model_class = getattr(transformers, fx_model_class_name) - - # load Flax class - fx_model = fx_model_class(config, dtype=jnp.float32) - # make sure only flax inputs are forward that actually exist in function args - fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() - - # prepare inputs - pt_inputs = self._prepare_for_class(inputs_dict, model_class) - - # remove function args that don't exist in Flax - pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - # convert inputs to Flax - fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} - fx_outputs = fx_model(**fx_inputs).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) - - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple() - self.assertEqual( - len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" - ) - for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) - - # overwrite from common since FlaxCLIPSegModel returns nested output - # which is not supported in the common test - @is_pt_flax_cross_test - def test_equivalence_flax_to_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # load corresponding PyTorch class - pt_model = model_class(config).eval() - - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - - fx_model_class_name = "Flax" + model_class.__name__ - - if not hasattr(transformers, fx_model_class_name): - self.skipTest(reason="No Flax model exists for this class") - - fx_model_class = getattr(transformers, fx_model_class_name) - - # load Flax class - fx_model = fx_model_class(config, dtype=jnp.float32) - # make sure only flax inputs are forward that actually exist in function args - fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - - # make sure weights are tied in PyTorch - pt_model.tie_weights() - - # prepare inputs - pt_inputs = self._prepare_for_class(inputs_dict, model_class) - - # remove function args that don't exist in Flax - pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} - - fx_outputs = fx_model(**fx_inputs).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - - for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) - - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() - - self.assertEqual( - len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" - ) - for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) - def test_training(self): if not self.model_tester.is_training: self.skipTest(reason="Training test is skipped as the model was not trained") diff --git a/tests/models/data2vec/test_modeling_data2vec_audio.py b/tests/models/data2vec/test_modeling_data2vec_audio.py index d4312828685..4b7e24a23ed 100644 --- a/tests/models/data2vec/test_modeling_data2vec_audio.py +++ b/tests/models/data2vec/test_modeling_data2vec_audio.py @@ -22,7 +22,7 @@ from datasets import load_dataset from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from transformers import Data2VecAudioConfig, is_torch_available -from transformers.testing_utils import is_pt_flax_cross_test, require_soundfile, require_torch, slow, torch_device +from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init @@ -442,16 +442,6 @@ class Data2VecAudioModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes def test_model_get_set_embeddings(self): pass - @is_pt_flax_cross_test - # non-robust architecture does not exist in Flax - def test_equivalence_flax_to_pt(self): - pass - - @is_pt_flax_cross_test - # non-robust architecture does not exist in Flax - def test_equivalence_pt_to_flax(self): - pass - def test_retain_grad_hidden_states_attentions(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.output_hidden_states = True diff --git a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py index 35434a280e9..36298182002 100644 --- a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py @@ -19,8 +19,8 @@ import unittest import numpy as np -from transformers import is_flax_available, is_torch_available -from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device +from transformers import is_flax_available +from transformers.testing_utils import require_flax, slow from ...test_modeling_flax_common import ids_tensor from ..bart.test_modeling_flax_bart import FlaxBartStandaloneDecoderModelTester @@ -38,15 +38,6 @@ if is_flax_available(): FlaxEncoderDecoderModel, FlaxGPT2LMHeadModel, ) - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) - -if is_torch_available(): - import torch - - from transformers import EncoderDecoderModel @require_flax @@ -291,68 +282,6 @@ class FlaxEncoderDecoderMixin: generated_sequences = generated_output.sequences self.assertEqual(generated_sequences.shape, (input_ids.shape[0],) + (decoder_config.max_length,)) - def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): - pt_model.to(torch_device) - pt_model.eval() - - # prepare inputs - flax_inputs = inputs_dict - pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()} - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5) - - # PT -> Flax - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = FlaxEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5) - - # Flax -> PT - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = EncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True) - - pt_model_loaded.to(torch_device) - pt_model_loaded.eval() - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() - - self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded): - self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5) - - def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict): - encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) - - pt_model = EncoderDecoderModel(encoder_decoder_config) - fx_model = FlaxEncoderDecoderModel(encoder_decoder_config) - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict) - - def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict): - encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) - - pt_model = EncoderDecoderModel(encoder_decoder_config) - fx_model = FlaxEncoderDecoderModel(encoder_decoder_config) - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - - self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict) - def test_encoder_decoder_model_from_pretrained_configs(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict) @@ -385,40 +314,6 @@ class FlaxEncoderDecoderMixin: diff = np.abs((a - b)).max() self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") - @is_pt_flax_cross_test - def test_pt_flax_equivalence(self): - config_inputs_dict = self.prepare_config_and_inputs() - config = config_inputs_dict.pop("config") - decoder_config = config_inputs_dict.pop("decoder_config") - - inputs_dict = config_inputs_dict - # `encoder_hidden_states` is not used in model call/forward - del inputs_dict["encoder_hidden_states"] - - # Avoid the case where a sequence has no place to attend (after combined with the causal attention mask) - batch_size = inputs_dict["decoder_attention_mask"].shape[0] - inputs_dict["decoder_attention_mask"] = np.concatenate( - [np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1 - ) - - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - decoder_config.use_cache = False - - self.assertTrue(decoder_config.cross_attention_hidden_size is None) - - # check without `enc_to_dec_proj` projection - decoder_config.hidden_size = config.hidden_size - self.assertTrue(config.hidden_size == decoder_config.hidden_size) - self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict) - self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict) - - # check `enc_to_dec_proj` work as expected - decoder_config.hidden_size = decoder_config.hidden_size * 2 - self.assertTrue(config.hidden_size != decoder_config.hidden_size) - self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict) - self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict) - @slow def test_real_model_save_load_from_pretrained(self): model_2 = self.get_pretrained_model() diff --git a/tests/models/gpt2/test_modeling_flax_gpt2.py b/tests/models/gpt2/test_modeling_flax_gpt2.py index 3b80cd52bde..5beb43d90e6 100644 --- a/tests/models/gpt2/test_modeling_flax_gpt2.py +++ b/tests/models/gpt2/test_modeling_flax_gpt2.py @@ -13,14 +13,12 @@ # limitations under the License. -import tempfile import unittest import numpy as np -import transformers -from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_available -from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow +from transformers import GPT2Config, GPT2Tokenizer, is_flax_available +from transformers.testing_utils import require_flax, slow from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask @@ -29,15 +27,8 @@ if is_flax_available(): import jax import jax.numpy as jnp - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model -if is_torch_available(): - import torch - class FlaxGPT2ModelTester: def __init__( @@ -255,105 +246,6 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase): self.assertListEqual(output_string, expected_string) - # overwrite from common since `attention_mask` in combination - # with `causal_mask` behaves slighly differently - @is_pt_flax_cross_test - def test_equivalence_pt_to_flax(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - batch_size, seq_length = pt_inputs["input_ids"].shape - rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,)) - for batch_idx, start_index in enumerate(rnd_start_indices): - pt_inputs["attention_mask"][batch_idx, :start_index] = 0 - pt_inputs["attention_mask"][batch_idx, start_index:] = 1 - prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 - prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 - pt_model = pt_model_class(config).eval() - fx_model = model_class(config, dtype=jnp.float32) - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) - - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple() - self.assertEqual( - len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" - ) - for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2) - - # overwrite from common since `attention_mask` in combination - # with `causal_mask` behaves slighly differently - @is_pt_flax_cross_test - def test_equivalence_flax_to_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - pt_model = pt_model_class(config).eval() - fx_model = model_class(config, dtype=jnp.float32) - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - batch_size, seq_length = pt_inputs["input_ids"].shape - rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,)) - for batch_idx, start_index in enumerate(rnd_start_indices): - pt_inputs["attention_mask"][batch_idx, :start_index] = 0 - pt_inputs["attention_mask"][batch_idx, start_index:] = 1 - prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 - prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 - - # make sure weights are tied in PyTorch - pt_model.tie_weights() - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) - - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() - - self.assertEqual( - len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" - ) - for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded): - self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) - @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: diff --git a/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py b/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py index 6875a46299f..170a261b412 100644 --- a/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py +++ b/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py @@ -13,14 +13,12 @@ # limitations under the License. -import tempfile import unittest import numpy as np -import transformers -from transformers import GPT2Tokenizer, GPTNeoConfig, is_flax_available, is_torch_available -from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow +from transformers import GPT2Tokenizer, GPTNeoConfig, is_flax_available +from transformers.testing_utils import require_flax, slow from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask @@ -29,15 +27,8 @@ if is_flax_available(): import jax import jax.numpy as jnp - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) from transformers.models.gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel -if is_torch_available(): - import torch - class FlaxGPTNeoModelTester: def __init__( @@ -224,105 +215,6 @@ class FlaxGPTNeoModelTest(FlaxModelTesterMixin, unittest.TestCase): self.assertListEqual(output_string, expected_string) - # overwrite from common since `attention_mask` in combination - # with `causal_mask` behaves slighly differently - @is_pt_flax_cross_test - def test_equivalence_pt_to_flax(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - batch_size, seq_length = pt_inputs["input_ids"].shape - rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,)) - for batch_idx, start_index in enumerate(rnd_start_indices): - pt_inputs["attention_mask"][batch_idx, :start_index] = 0 - pt_inputs["attention_mask"][batch_idx, start_index:] = 1 - prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 - prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 - pt_model = pt_model_class(config).eval() - fx_model = model_class(config, dtype=jnp.float32) - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) - - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple() - self.assertEqual( - len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" - ) - for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2) - - # overwrite from common since `attention_mask` in combination - # with `causal_mask` behaves slighly differently - @is_pt_flax_cross_test - def test_equivalence_flax_to_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - pt_model = pt_model_class(config).eval() - fx_model = model_class(config, dtype=jnp.float32) - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - batch_size, seq_length = pt_inputs["input_ids"].shape - rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,)) - for batch_idx, start_index in enumerate(rnd_start_indices): - pt_inputs["attention_mask"][batch_idx, :start_index] = 0 - pt_inputs["attention_mask"][batch_idx, start_index:] = 1 - prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 - prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 - - # make sure weights are tied in PyTorch - pt_model.tie_weights() - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) - - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() - - self.assertEqual( - len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" - ) - for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded): - self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) - @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: diff --git a/tests/models/gptj/test_modeling_flax_gptj.py b/tests/models/gptj/test_modeling_flax_gptj.py index 09f2aa99d7e..305a86ece1e 100644 --- a/tests/models/gptj/test_modeling_flax_gptj.py +++ b/tests/models/gptj/test_modeling_flax_gptj.py @@ -13,14 +13,12 @@ # limitations under the License. -import tempfile import unittest import numpy as np -import transformers -from transformers import GPT2Tokenizer, GPTJConfig, is_flax_available, is_torch_available -from transformers.testing_utils import is_pt_flax_cross_test, require_flax, tooslow +from transformers import GPT2Tokenizer, GPTJConfig, is_flax_available +from transformers.testing_utils import require_flax, tooslow from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask @@ -29,15 +27,8 @@ if is_flax_available(): import jax import jax.numpy as jnp - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) from transformers.models.gptj.modeling_flax_gptj import FlaxGPTJForCausalLM, FlaxGPTJModel -if is_torch_available(): - import torch - class FlaxGPTJModelTester: def __init__( @@ -221,105 +212,6 @@ class FlaxGPTJModelTest(FlaxModelTesterMixin, unittest.TestCase): self.assertListEqual(output_string, expected_string) - # overwrite from common since `attention_mask` in combination - # with `causal_mask` behaves slighly differently - @is_pt_flax_cross_test - def test_equivalence_pt_to_flax(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - batch_size, seq_length = pt_inputs["input_ids"].shape - rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,)) - for batch_idx, start_index in enumerate(rnd_start_indices): - pt_inputs["attention_mask"][batch_idx, :start_index] = 0 - pt_inputs["attention_mask"][batch_idx, start_index:] = 1 - prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 - prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 - pt_model = pt_model_class(config).eval() - fx_model = model_class(config, dtype=jnp.float32) - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) - - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple() - self.assertEqual( - len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" - ) - for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2) - - # overwrite from common since `attention_mask` in combination - # with `causal_mask` behaves slighly differently - @is_pt_flax_cross_test - def test_equivalence_flax_to_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - pt_model = pt_model_class(config).eval() - fx_model = model_class(config, dtype=jnp.float32) - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - batch_size, seq_length = pt_inputs["input_ids"].shape - rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,)) - for batch_idx, start_index in enumerate(rnd_start_indices): - pt_inputs["attention_mask"][batch_idx, :start_index] = 0 - pt_inputs["attention_mask"][batch_idx, start_index:] = 1 - prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 - prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 - - # make sure weights are tied in PyTorch - pt_model.tie_weights() - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) - - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() - - self.assertEqual( - len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" - ) - for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded): - self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) - @tooslow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: diff --git a/tests/models/longt5/test_modeling_flax_longt5.py b/tests/models/longt5/test_modeling_flax_longt5.py index fa8673ec439..ec5432f9efd 100644 --- a/tests/models/longt5/test_modeling_flax_longt5.py +++ b/tests/models/longt5/test_modeling_flax_longt5.py @@ -17,11 +17,9 @@ import unittest import numpy as np -import transformers from transformers import is_flax_available from transformers.models.auto import get_values from transformers.testing_utils import ( - is_pt_flax_cross_test, require_flax, require_sentencepiece, require_tokenizers, @@ -46,7 +44,6 @@ if is_flax_available(): from flax.traverse_util import flatten_dict from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING, AutoTokenizer, LongT5Config - from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model from transformers.models.longt5.modeling_flax_longt5 import ( FlaxLongT5ForConditionalGeneration, FlaxLongT5Model, @@ -467,95 +464,6 @@ class FlaxLongT5ModelTest(FlaxModelTesterMixin, unittest.TestCase): [self.model_tester.num_attention_heads, block_len, 3 * block_len], ) - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_from_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = base_class(config) - base_params = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - # save pt model - pt_model.save_pretrained(tmpdirname) - head_model = model_class.from_pretrained(tmpdirname, from_pt=True) - - base_param_from_head = flatten_dict(unfreeze(head_model.params)) - - for key in base_param_from_head.keys(): - max_diff = (base_params[key] - base_param_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_to_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = model_class(config) - base_params_from_head = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - - base_params = flatten_dict(unfreeze(base_model.params)) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_bf16_to_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = model_class(config) - model.params = model.to_bf16(model.params) - base_params_from_head = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - - base_params = flatten_dict(unfreeze(base_model.params)) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - class FlaxLongT5TGlobalModelTest(FlaxLongT5ModelTest): def setUp(self): diff --git a/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index 8f210a07d27..5348315c7c8 100644 --- a/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -19,7 +19,7 @@ import unittest import numpy as np from transformers import is_flax_available, is_torch_available -from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device +from transformers.testing_utils import require_flax, slow from ...test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask from ..bart.test_modeling_flax_bart import FlaxBartStandaloneDecoderModelTester @@ -43,14 +43,8 @@ if is_flax_available(): SpeechEncoderDecoderConfig, ) from transformers.modeling_flax_outputs import FlaxBaseModelOutput - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) if is_torch_available(): - import torch - from transformers import SpeechEncoderDecoderModel @@ -406,68 +400,6 @@ class FlaxEncoderDecoderMixin: for grad, grad_frozen in zip(grads, grads_frozen): self.assertTrue((grad == grad_frozen).all()) - def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): - pt_model.to(torch_device) - pt_model.eval() - - # prepare inputs - flax_inputs = inputs_dict - pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()} - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5) - - # PT -> Flax - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5) - - # Flax -> PT - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True) - - pt_model_loaded.to(torch_device) - pt_model_loaded.eval() - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() - - self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded): - self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5) - - def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict): - encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) - - pt_model = SpeechEncoderDecoderModel(encoder_decoder_config) - fx_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config) - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict) - - def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict): - encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) - - pt_model = SpeechEncoderDecoderModel(encoder_decoder_config) - fx_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config) - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - - self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict) - def test_encoder_decoder_model_from_pretrained_configs(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict) @@ -504,46 +436,6 @@ class FlaxEncoderDecoderMixin: diff = np.abs((a - b)).max() self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") - @is_pt_flax_cross_test - def test_pt_flax_equivalence(self): - config_inputs_dict = self.prepare_config_and_inputs() - config = config_inputs_dict.pop("config") - decoder_config = config_inputs_dict.pop("decoder_config") - - inputs_dict = config_inputs_dict - # `encoder_hidden_states` is not used in model call/forward - del inputs_dict["encoder_hidden_states"] - - # Avoid the case where a sequence has no place to attend (after combined with the causal attention mask) - batch_size = inputs_dict["decoder_attention_mask"].shape[0] - inputs_dict["decoder_attention_mask"] = np.concatenate( - [np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1 - ) - - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - decoder_config.use_cache = False - - self.assertTrue(decoder_config.cross_attention_hidden_size is None) - - # check without `enc_to_dec_proj` projection - decoder_config.hidden_size = config.hidden_size - self.assertTrue(config.hidden_size == decoder_config.hidden_size) - self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict) - self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict) - - # check `enc_to_dec_proj` work as expected - decoder_config.hidden_size = decoder_config.hidden_size * 2 - self.assertTrue(config.hidden_size != decoder_config.hidden_size) - self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict) - self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict) - - # check `add_adapter` works as expected - config.add_adapter = True - self.assertTrue(config.add_adapter) - self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict) - self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict) - @slow def test_real_model_save_load_from_pretrained(self): model_2 = self.get_pretrained_model() @@ -625,71 +517,6 @@ class FlaxWav2Vec2GPT2ModelTest(FlaxEncoderDecoderMixin, unittest.TestCase): "encoder_hidden_states": encoder_hidden_states, } - @slow - def test_flaxwav2vec2gpt2_pt_flax_equivalence(self): - pt_model = SpeechEncoderDecoderModel.from_pretrained("jsnfly/wav2vec2-large-xlsr-53-german-gpt2") - fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained( - "jsnfly/wav2vec2-large-xlsr-53-german-gpt2", from_pt=True - ) - - pt_model.to(torch_device) - pt_model.eval() - - # prepare inputs - batch_size = 13 - input_values = floats_tensor([batch_size, 512], scale=1.0) - attention_mask = random_attention_mask([batch_size, 512]) - decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size) - decoder_attention_mask = random_attention_mask([batch_size, 4]) - inputs_dict = { - "inputs": input_values, - "attention_mask": attention_mask, - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": decoder_attention_mask, - } - - flax_inputs = inputs_dict - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - pt_logits = pt_outputs.logits - pt_outputs = pt_outputs.to_tuple() - - fx_outputs = fx_model(**inputs_dict) - fx_logits = fx_outputs.logits - fx_outputs = fx_outputs.to_tuple() - - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2) - - # PT -> Flax - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**inputs_dict) - fx_logits_loaded = fx_outputs_loaded.logits - fx_outputs_loaded = fx_outputs_loaded.to_tuple() - self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2) - - # Flax -> PT - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True) - - pt_model_loaded.to(torch_device) - pt_model_loaded.eval() - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs) - pt_logits_loaded = pt_outputs_loaded.logits - pt_outputs_loaded = pt_outputs_loaded.to_tuple() - - self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") - self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2) - @require_flax class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase): @@ -742,71 +569,6 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase): "encoder_hidden_states": encoder_hidden_states, } - @slow - def test_flaxwav2vec2bart_pt_flax_equivalence(self): - pt_model = SpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large") - fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained( - "patrickvonplaten/wav2vec2-2-bart-large", from_pt=True - ) - - pt_model.to(torch_device) - pt_model.eval() - - # prepare inputs - batch_size = 13 - input_values = floats_tensor([batch_size, 512], scale=1.0) - attention_mask = random_attention_mask([batch_size, 512]) - decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size) - decoder_attention_mask = random_attention_mask([batch_size, 4]) - inputs_dict = { - "inputs": input_values, - "attention_mask": attention_mask, - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": decoder_attention_mask, - } - - flax_inputs = inputs_dict - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - pt_logits = pt_outputs.logits - pt_outputs = pt_outputs.to_tuple() - - fx_outputs = fx_model(**inputs_dict) - fx_logits = fx_outputs.logits - fx_outputs = fx_outputs.to_tuple() - - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2) - - # PT -> Flax - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**inputs_dict) - fx_logits_loaded = fx_outputs_loaded.logits - fx_outputs_loaded = fx_outputs_loaded.to_tuple() - self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2) - - # Flax -> PT - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True) - - pt_model_loaded.to(torch_device) - pt_model_loaded.eval() - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs) - pt_logits_loaded = pt_outputs_loaded.logits - pt_outputs_loaded = pt_outputs_loaded.to_tuple() - - self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") - self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2) - @require_flax class FlaxWav2Vec2BertModelTest(FlaxEncoderDecoderMixin, unittest.TestCase): @@ -858,66 +620,3 @@ class FlaxWav2Vec2BertModelTest(FlaxEncoderDecoderMixin, unittest.TestCase): "decoder_attention_mask": decoder_attention_mask, "encoder_hidden_states": encoder_hidden_states, } - - @slow - def test_flaxwav2vec2bert_pt_flax_equivalence(self): - pt_model = SpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large") - fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large", from_pt=True) - - pt_model.to(torch_device) - pt_model.eval() - - # prepare inputs - batch_size = 13 - input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size) - attention_mask = random_attention_mask([batch_size, 512]) - decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size) - decoder_attention_mask = random_attention_mask([batch_size, 4]) - inputs_dict = { - "inputs": input_values, - "attention_mask": attention_mask, - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": decoder_attention_mask, - } - - flax_inputs = inputs_dict - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - pt_logits = pt_outputs.logits - pt_outputs = pt_outputs.to_tuple() - - fx_outputs = fx_model(**inputs_dict) - fx_logits = fx_outputs.logits - fx_outputs = fx_outputs.to_tuple() - - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2) - - # PT -> Flax - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**inputs_dict) - fx_logits_loaded = fx_outputs_loaded.logits - fx_outputs_loaded = fx_outputs_loaded.to_tuple() - self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2) - - # Flax -> PT - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True) - - pt_model_loaded.to(torch_device) - pt_model_loaded.eval() - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs) - pt_logits_loaded = pt_outputs_loaded.logits - pt_outputs_loaded = pt_outputs_loaded.to_tuple() - - self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") - self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2) diff --git a/tests/models/t5/test_modeling_flax_t5.py b/tests/models/t5/test_modeling_flax_t5.py index 963bf91716d..dc372c57694 100644 --- a/tests/models/t5/test_modeling_flax_t5.py +++ b/tests/models/t5/test_modeling_flax_t5.py @@ -17,15 +17,8 @@ import unittest import numpy as np -import transformers from transformers import is_flax_available -from transformers.testing_utils import ( - is_pt_flax_cross_test, - require_flax, - require_sentencepiece, - require_tokenizers, - slow, -) +from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow from ...test_configuration_common import ConfigTester from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor @@ -47,7 +40,6 @@ if is_flax_available(): from flax.traverse_util import flatten_dict from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer - from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model from transformers.models.t5.modeling_flax_t5 import ( FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, @@ -373,95 +365,6 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, unittest.TestCase): max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_from_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = base_class(config) - base_params = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - # save pt model - pt_model.save_pretrained(tmpdirname) - head_model = model_class.from_pretrained(tmpdirname, from_pt=True) - - base_param_from_head = flatten_dict(unfreeze(head_model.params)) - - for key in base_param_from_head.keys(): - max_diff = (base_params[key] - base_param_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_to_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = model_class(config) - base_params_from_head = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - - base_params = flatten_dict(unfreeze(base_model.params)) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_bf16_to_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = model_class(config) - model.params = model.to_bf16(model.params) - base_params_from_head = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - - base_params = flatten_dict(unfreeze(base_model.params)) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - class FlaxT5EncoderOnlyModelTester: def __init__( @@ -663,95 +566,6 @@ class FlaxT5EncoderOnlyModelTest(FlaxModelTesterMixin, unittest.TestCase): max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_from_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = base_class(config) - base_params = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - # save pt model - pt_model.save_pretrained(tmpdirname) - head_model = model_class.from_pretrained(tmpdirname, from_pt=True) - - base_param_from_head = flatten_dict(unfreeze(head_model.params)) - - for key in base_param_from_head.keys(): - max_diff = (base_params[key] - base_param_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_to_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = model_class(config) - base_params_from_head = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - - base_params = flatten_dict(unfreeze(base_model.params)) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_bf16_to_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = model_class(config) - model.params = model.to_bf16(model.params) - base_params_from_head = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - - base_params = flatten_dict(unfreeze(base_model.params)) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - @require_sentencepiece @require_tokenizers diff --git a/tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py index fabef4b8c6d..bec79869ae3 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py @@ -19,8 +19,8 @@ import unittest import numpy as np -from transformers import is_flax_available, is_torch_available, is_vision_available -from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_vision, slow, torch_device +from transformers import is_flax_available, is_vision_available +from transformers.testing_utils import require_flax, require_vision, slow from ...test_modeling_flax_common import floats_tensor, ids_tensor from ..gpt2.test_modeling_flax_gpt2 import FlaxGPT2ModelTester @@ -35,15 +35,7 @@ if is_flax_available(): FlaxViTModel, VisionEncoderDecoderConfig, ) - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) -if is_torch_available(): - import torch - - from transformers import VisionEncoderDecoderModel if is_vision_available(): from PIL import Image @@ -235,68 +227,6 @@ class FlaxEncoderDecoderMixin: generated_sequences = generated_output.sequences self.assertEqual(generated_sequences.shape, (pixel_values.shape[0],) + (decoder_config.max_length,)) - def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): - pt_model.to(torch_device) - pt_model.eval() - - # prepare inputs - flax_inputs = inputs_dict - pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()} - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5) - - # PT -> Flax - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = FlaxVisionEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5) - - # Flax -> PT - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True) - - pt_model_loaded.to(torch_device) - pt_model_loaded.eval() - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() - - self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded): - self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5) - - def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict): - encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) - - pt_model = VisionEncoderDecoderModel(encoder_decoder_config) - fx_model = FlaxVisionEncoderDecoderModel(encoder_decoder_config) - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict) - - def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict): - encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) - - pt_model = VisionEncoderDecoderModel(encoder_decoder_config) - fx_model = FlaxVisionEncoderDecoderModel(encoder_decoder_config) - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - - self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict) - def test_encoder_decoder_model_from_pretrained_configs(self): config_inputs_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_from_pretrained_configs(**config_inputs_dict) @@ -325,39 +255,6 @@ class FlaxEncoderDecoderMixin: diff = np.abs((a - b)).max() self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") - @is_pt_flax_cross_test - def test_pt_flax_equivalence(self): - config_inputs_dict = self.prepare_config_and_inputs() - config = config_inputs_dict.pop("config") - decoder_config = config_inputs_dict.pop("decoder_config") - - inputs_dict = config_inputs_dict - # `encoder_hidden_states` is not used in model call/forward - del inputs_dict["encoder_hidden_states"] - - # Avoid the case where a sequence has no place to attend (after combined with the causal attention mask) - batch_size = inputs_dict["decoder_attention_mask"].shape[0] - inputs_dict["decoder_attention_mask"] = np.concatenate( - [np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1 - ) - - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - decoder_config.use_cache = False - - self.assertTrue(decoder_config.cross_attention_hidden_size is None) - - # check without `enc_to_dec_proj` projection - self.assertTrue(config.hidden_size == decoder_config.hidden_size) - self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict) - self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict) - - # check `enc_to_dec_proj` work as expected - decoder_config.hidden_size = decoder_config.hidden_size * 2 - self.assertTrue(config.hidden_size != decoder_config.hidden_size) - self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict) - self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict) - @slow def test_real_model_save_load_from_pretrained(self): model_2 = self.get_pretrained_model() diff --git a/tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py b/tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py index e1e8eb4076c..115cdf444fe 100644 --- a/tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py +++ b/tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py @@ -20,15 +20,8 @@ import unittest import numpy as np -from transformers.testing_utils import ( - is_pt_flax_cross_test, - require_flax, - require_torch, - require_vision, - slow, - torch_device, -) -from transformers.utils import is_flax_available, is_torch_available, is_vision_available +from transformers.testing_utils import require_flax, require_torch, require_vision, slow +from transformers.utils import is_flax_available, is_vision_available from ...test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask from ..bert.test_modeling_flax_bert import FlaxBertModelTester @@ -45,17 +38,8 @@ if is_flax_available(): VisionTextDualEncoderConfig, VisionTextDualEncoderProcessor, ) - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) -if is_torch_available(): - import torch - - from transformers import VisionTextDualEncoderModel - if is_vision_available(): from PIL import Image @@ -154,68 +138,6 @@ class VisionTextDualEncoderMixin: (text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]), ) - def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): - pt_model.to(torch_device) - pt_model.eval() - - # prepare inputs - flax_inputs = inputs_dict - pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()} - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2) - - # PT -> Flax - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2) - - # Flax -> PT - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = VisionTextDualEncoderModel.from_pretrained(tmpdirname, from_flax=True) - - pt_model_loaded.to(torch_device) - pt_model_loaded.eval() - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() - - self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]): - self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2) - - def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict): - config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config) - - pt_model = VisionTextDualEncoderModel(config) - fx_model = FlaxVisionTextDualEncoderModel(config) - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict) - - def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict): - config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config) - - pt_model = VisionTextDualEncoderModel(config) - fx_model = FlaxVisionTextDualEncoderModel(config) - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - - self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict) - def test_model_from_pretrained_configs(self): inputs_dict = self.prepare_config_and_inputs() self.check_model_from_pretrained_configs(**inputs_dict) @@ -232,17 +154,6 @@ class VisionTextDualEncoderMixin: inputs_dict = self.prepare_config_and_inputs() self.check_vision_text_output_attention(**inputs_dict) - @is_pt_flax_cross_test - def test_pt_flax_equivalence(self): - config_inputs_dict = self.prepare_config_and_inputs() - vision_config = config_inputs_dict.pop("vision_config") - text_config = config_inputs_dict.pop("text_config") - - inputs_dict = config_inputs_dict - - self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict) - self.check_equivalence_flax_to_pt(vision_config, text_config, inputs_dict) - @slow def test_real_model_save_load_from_pretrained(self): model_2, inputs = self.get_pretrained_model_and_inputs() diff --git a/tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py b/tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py index ab4adeb5d46..ab10f3e93fa 100644 --- a/tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py +++ b/tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py @@ -20,8 +20,8 @@ import unittest import numpy as np -from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_vision, slow, torch_device -from transformers.utils import is_flax_available, is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.utils import is_torch_available, is_vision_available from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from ..bert.test_modeling_bert import BertModelTester @@ -44,12 +44,6 @@ if is_torch_available(): ViTModel, ) -if is_flax_available(): - from transformers import FlaxVisionTextDualEncoderModel - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) if is_vision_available(): from PIL import Image @@ -172,69 +166,6 @@ class VisionTextDualEncoderMixin: diff = np.abs((a - b)).max() self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") - def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mask, pixel_values, **kwargs): - pt_model.to(torch_device) - pt_model.eval() - - # prepare inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values} - pt_inputs = inputs_dict - flax_inputs = {k: v.numpy(force=True) for k, v in pt_inputs.items()} - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**flax_inputs).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2) - - # PT -> Flax - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple() - self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2) - - # Flax -> PT - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = VisionTextDualEncoderModel.from_pretrained(tmpdirname, from_flax=True) - - pt_model_loaded.to(torch_device) - pt_model_loaded.eval() - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() - - self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]): - self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2) - - def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict): - config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config) - - pt_model = VisionTextDualEncoderModel(config) - fx_model = FlaxVisionTextDualEncoderModel(config) - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict) - - def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict): - config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config) - - pt_model = VisionTextDualEncoderModel(config) - fx_model = FlaxVisionTextDualEncoderModel(config) - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - - self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict) - def test_vision_text_dual_encoder_model(self): inputs_dict = self.prepare_config_and_inputs() self.check_vision_text_dual_encoder_model(**inputs_dict) @@ -255,17 +186,6 @@ class VisionTextDualEncoderMixin: inputs_dict = self.prepare_config_and_inputs() self.check_vision_text_output_attention(**inputs_dict) - @is_pt_flax_cross_test - def test_pt_flax_equivalence(self): - config_inputs_dict = self.prepare_config_and_inputs() - vision_config = config_inputs_dict.pop("vision_config") - text_config = config_inputs_dict.pop("text_config") - - inputs_dict = config_inputs_dict - - self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict) - self.check_equivalence_flax_to_pt(vision_config, text_config, inputs_dict) - @slow def test_real_model_save_load_from_pretrained(self): model_2, inputs = self.get_pretrained_model_and_inputs() @@ -429,10 +349,6 @@ class DeiTRobertaModelTest(VisionTextDualEncoderMixin, unittest.TestCase): "text_choice_labels": choice_labels, } - @unittest.skip(reason="DeiT is not available in Flax") - def test_pt_flax_equivalence(self): - pass - @require_torch class CLIPVisionBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase): diff --git a/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py index b91d66654de..e888cc5ff3c 100644 --- a/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py @@ -24,9 +24,7 @@ from datasets import load_dataset from transformers import Wav2Vec2Config, is_flax_available from transformers.testing_utils import ( CaptureLogger, - is_flaky, is_librosa_available, - is_pt_flax_cross_test, is_pyctcdecode_available, require_flax, require_librosa, @@ -350,11 +348,6 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase): outputs = model(np.ones((1, 1024), dtype="f4")) self.assertIsNotNone(outputs) - @is_pt_flax_cross_test - @is_flaky() - def test_equivalence_pt_to_flax(self): - super().test_equivalence_pt_to_flax() - @require_flax class FlaxWav2Vec2UtilsTest(unittest.TestCase): diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index 10ca9a22e43..e71c2d677a7 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -31,7 +31,6 @@ from transformers.testing_utils import ( CaptureLogger, cleanup, is_flaky, - is_pt_flax_cross_test, is_pyctcdecode_available, is_torchaudio_available, require_flash_attn, @@ -569,16 +568,6 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase def test_model_get_set_embeddings(self): pass - @is_pt_flax_cross_test - @unittest.skip(reason="Non-rubst architecture does not exist in Flax") - def test_equivalence_flax_to_pt(self): - pass - - @is_pt_flax_cross_test - @unittest.skip(reason="Non-rubst architecture does not exist in Flax") - def test_equivalence_pt_to_flax(self): - pass - def test_retain_grad_hidden_states_attentions(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.output_hidden_states = True diff --git a/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py b/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py index 0fbc18165eb..fc563df8f19 100644 --- a/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py +++ b/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py @@ -21,7 +21,6 @@ from datasets import load_dataset from transformers import Wav2Vec2BertConfig, is_torch_available from transformers.testing_utils import ( - is_pt_flax_cross_test, require_torch, require_torch_accelerator, require_torch_fp16, @@ -559,18 +558,6 @@ class Wav2Vec2BertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test def test_model_get_set_embeddings(self): pass - # Ignore copy - @unittest.skip(reason="non-robust architecture does not exist in Flax") - @is_pt_flax_cross_test - def test_equivalence_flax_to_pt(self): - pass - - # Ignore copy - @unittest.skip(reason="non-robust architecture does not exist in Flax") - @is_pt_flax_cross_test - def test_equivalence_pt_to_flax(self): - pass - def test_retain_grad_hidden_states_attentions(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.output_hidden_states = True diff --git a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py index 2f1e5a8e341..0c406bfbc82 100644 --- a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py +++ b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py @@ -24,7 +24,6 @@ from datasets import load_dataset from transformers import Wav2Vec2ConformerConfig, is_torch_available from transformers.testing_utils import ( is_flaky, - is_pt_flax_cross_test, require_torch, require_torch_accelerator, require_torch_fp16, @@ -535,16 +534,6 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest def test_model_get_set_embeddings(self): pass - @is_pt_flax_cross_test - @unittest.skip(reason="Non-robust architecture does not exist in Flax") - def test_equivalence_flax_to_pt(self): - pass - - @is_pt_flax_cross_test - @unittest.skip(reason="Non-robust architecture does not exist in Flax") - def test_equivalence_pt_to_flax(self): - pass - def test_retain_grad_hidden_states_attentions(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.output_hidden_states = True diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index f018d0d4198..274582be233 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -17,9 +17,8 @@ import inspect import tempfile import unittest -import transformers from transformers import WhisperConfig, is_flax_available -from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow +from transformers.testing_utils import require_flax, slow from transformers.utils import cached_property from transformers.utils.import_utils import is_datasets_available @@ -45,7 +44,6 @@ if is_flax_available(): WhisperFeatureExtractor, WhisperProcessor, ) - from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init @@ -245,99 +243,6 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase): for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) - def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None): - # We override with a slightly higher tol value, as test recently became flaky - super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes) - - # overwrite because of `input_features` - @is_pt_flax_cross_test - def test_save_load_bf16_to_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) - - for model_class in self.all_model_classes: - if model_class.__name__ == base_class.__name__: - continue - - model = model_class(config) - model.params = model.to_bf16(model.params) - base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - - base_params = flatten_dict(unfreeze(base_model.params)) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite because of `input_features` - @is_pt_flax_cross_test - def test_save_load_from_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) - - for model_class in self.all_model_classes: - if model_class.__name__ == base_class.__name__: - continue - - model = base_class(config) - base_params = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - # save pt model - pt_model.save_pretrained(tmpdirname) - head_model = model_class.from_pretrained(tmpdirname, from_pt=True) - - base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix])) - - for key in base_param_from_head.keys(): - max_diff = (base_params[key] - base_param_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite because of `input_features` - @is_pt_flax_cross_test - def test_save_load_to_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) - - for model_class in self.all_model_classes: - if model_class.__name__ == base_class.__name__: - continue - - model = model_class(config) - base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - - base_params = flatten_dict(unfreeze(base_model.params)) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - # overwrite because of `input_features` def test_save_load_from_base(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -899,18 +804,3 @@ class WhisperEncoderModelTest(FlaxModelTesterMixin, unittest.TestCase): # WhisperEncoder does not have any base model def test_save_load_from_base(self): pass - - # WhisperEncoder does not have any base model - @is_pt_flax_cross_test - def test_save_load_from_base_pt(self): - pass - - # WhisperEncoder does not have any base model - @is_pt_flax_cross_test - def test_save_load_to_base_pt(self): - pass - - # WhisperEncoder does not have any base model - @is_pt_flax_cross_test - def test_save_load_bf16_to_base_pt(self): - pass diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index c5c1ba971d3..1a38b5b225f 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -32,7 +32,6 @@ import transformers from transformers import WhisperConfig from transformers.testing_utils import ( is_flaky, - is_pt_flax_cross_test, require_flash_attn, require_non_xpu, require_torch, @@ -44,7 +43,7 @@ from transformers.testing_utils import ( slow, torch_device, ) -from transformers.utils import cached_property, is_flax_available, is_torch_available, is_torchaudio_available +from transformers.utils import cached_property, is_torch_available, is_torchaudio_available from transformers.utils.import_utils import is_datasets_available from ...generation.test_utils import GenerationTesterMixin @@ -155,15 +154,6 @@ if is_torchaudio_available(): import torchaudio -if is_flax_available(): - import jax.numpy as jnp - - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) - - def prepare_whisper_inputs_dict( config, input_features, @@ -1069,161 +1059,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi self.assertTrue(models_equal) - def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None): - # We override with a slightly higher tol value, as test recently became flaky - super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes) - - @is_pt_flax_cross_test - def test_equivalence_pt_to_flax(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - init_shape = (1,) + inputs_dict["input_features"].shape[1:] - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - fx_model_class_name = "Flax" + model_class.__name__ - - if not hasattr(transformers, fx_model_class_name): - self.skipTest(reason="No Flax model exists for this class") - - # Output all for aggressive testing - config.output_hidden_states = True - config.output_attentions = self.has_attentions - - fx_model_class = getattr(transformers, fx_model_class_name) - - # load PyTorch class - pt_model = model_class(config).eval() - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - - # load Flax class - fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32) - - # make sure only flax inputs are forward that actually exist in function args - fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() - - # prepare inputs - pt_inputs = self._prepare_for_class(inputs_dict, model_class) - - # remove function args that don't exist in Flax - pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} - - # send pytorch inputs to the correct device - pt_inputs = { - k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() - } - - # convert inputs to Flax - fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - # send pytorch model to the correct device - pt_model.to(torch_device) - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - fx_outputs = fx_model(**fx_inputs) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) - - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**fx_inputs) - - fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class) - - @is_pt_flax_cross_test - def test_equivalence_flax_to_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - init_shape = (1,) + inputs_dict["input_features"].shape[1:] - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - fx_model_class_name = "Flax" + model_class.__name__ - - if not hasattr(transformers, fx_model_class_name): - self.skipTest(reason="No Flax model exists for this class") - - # Output all for aggressive testing - config.output_hidden_states = True - config.output_attentions = self.has_attentions - - fx_model_class = getattr(transformers, fx_model_class_name) - - # load PyTorch class - pt_model = model_class(config).eval() - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - - # load Flax class - fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32) - - # make sure only flax inputs are forward that actually exist in function args - fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() - - # prepare inputs - pt_inputs = self._prepare_for_class(inputs_dict, model_class) - - # remove function args that don't exist in Flax - pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} - - # send pytorch inputs to the correct device - pt_inputs = { - k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() - } - - # convert inputs to Flax - fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - - # make sure weights are tied in PyTorch - pt_model.tie_weights() - - # send pytorch model to the correct device - pt_model.to(torch_device) - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - fx_outputs = fx_model(**fx_inputs) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) - - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) - - # send pytorch model to the correct device - pt_model_loaded.to(torch_device) - pt_model_loaded.eval() - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) - def test_mask_feature_prob(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.mask_feature_prob = 0.2 @@ -3622,157 +3457,6 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. def test_resize_tokens_embeddings(self): pass - @is_pt_flax_cross_test - def test_equivalence_pt_to_flax(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - init_shape = (1,) + inputs_dict["input_features"].shape[1:] - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - fx_model_class_name = "Flax" + model_class.__name__ - - if not hasattr(transformers, fx_model_class_name): - self.skipTest(reason="Flax model does not exist") - - # Output all for aggressive testing - config.output_hidden_states = True - config.output_attentions = self.has_attentions - - fx_model_class = getattr(transformers, fx_model_class_name) - - # load PyTorch class - pt_model = model_class(config).eval() - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - - # load Flax class - fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32) - - # make sure only flax inputs are forward that actually exist in function args - fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() - - # prepare inputs - pt_inputs = self._prepare_for_class(inputs_dict, model_class) - - # remove function args that don't exist in Flax - pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} - - # send pytorch inputs to the correct device - pt_inputs = { - k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() - } - - # convert inputs to Flax - fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - # send pytorch model to the correct device - pt_model.to(torch_device) - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - fx_outputs = fx_model(**fx_inputs) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) - - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**fx_inputs) - - fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class) - - @is_pt_flax_cross_test - def test_equivalence_flax_to_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - init_shape = (1,) + inputs_dict["input_features"].shape[1:] - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - fx_model_class_name = "Flax" + model_class.__name__ - - if not hasattr(transformers, fx_model_class_name): - self.skipTest("Flax model does not exist") - - # Output all for aggressive testing - config.output_hidden_states = True - config.output_attentions = self.has_attentions - - fx_model_class = getattr(transformers, fx_model_class_name) - - # load PyTorch class - pt_model = model_class(config).eval() - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - - # load Flax class - fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32) - - # make sure only flax inputs are forward that actually exist in function args - fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() - - # prepare inputs - pt_inputs = self._prepare_for_class(inputs_dict, model_class) - - # remove function args that don't exist in Flax - pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} - - # send pytorch inputs to the correct device - pt_inputs = { - k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() - } - - # convert inputs to Flax - fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - - # make sure weights are tied in PyTorch - pt_model.tie_weights() - - # send pytorch model to the correct device - pt_model.to(torch_device) - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - fx_outputs = fx_model(**fx_inputs) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) - - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) - - # send pytorch model to the correct device - pt_model_loaded.to(torch_device) - pt_model_loaded.eval() - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) - class WhisperStandaloneDecoderModelTester: def __init__( diff --git a/tests/models/xglm/test_modeling_flax_xglm.py b/tests/models/xglm/test_modeling_flax_xglm.py index 8dcdb8ae073..0eaf5d46af9 100644 --- a/tests/models/xglm/test_modeling_flax_xglm.py +++ b/tests/models/xglm/test_modeling_flax_xglm.py @@ -14,12 +14,10 @@ # limitations under the License. -import tempfile import unittest -import transformers -from transformers import XGLMConfig, XGLMTokenizer, is_flax_available, is_torch_available -from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_sentencepiece, slow +from transformers import XGLMConfig, XGLMTokenizer, is_flax_available +from transformers.testing_utils import require_flax, require_sentencepiece, slow from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask @@ -29,17 +27,9 @@ if is_flax_available(): import jax.numpy as jnp import numpy as np - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) from transformers.models.xglm.modeling_flax_xglm import FlaxXGLMForCausalLM, FlaxXGLMModel -if is_torch_available(): - import torch - - @require_flax class FlaxXGLMModelTester: def __init__( @@ -220,108 +210,6 @@ class FlaxXGLMModelTest(FlaxModelTesterMixin, unittest.TestCase): self.assertListEqual(output_string, expected_string) - # overwrite from common since `attention_mask` in combination - # with `causal_mask` behaves slighly differently - @is_pt_flax_cross_test - def test_equivalence_pt_to_flax(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - batch_size, seq_length = pt_inputs["input_ids"].shape - rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,)) - for batch_idx, start_index in enumerate(rnd_start_indices): - pt_inputs["attention_mask"][batch_idx, :start_index] = 0 - pt_inputs["attention_mask"][batch_idx, start_index:] = 1 - prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 - prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 - pt_model = pt_model_class(config).eval() - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - fx_model = model_class(config, dtype=jnp.float32) - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) - - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple() - self.assertEqual( - len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" - ) - for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2) - - # overwrite from common since `attention_mask` in combination - # with `causal_mask` behaves slighly differently - @is_pt_flax_cross_test - def test_equivalence_flax_to_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - pt_model = pt_model_class(config).eval() - pt_model.config.use_cache = False - fx_model = model_class(config, dtype=jnp.float32) - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - batch_size, seq_length = pt_inputs["input_ids"].shape - rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,)) - for batch_idx, start_index in enumerate(rnd_start_indices): - pt_inputs["attention_mask"][batch_idx, :start_index] = 0 - pt_inputs["attention_mask"][batch_idx, start_index:] = 1 - prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 - prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 - - # make sure weights are tied in PyTorch - pt_model.tie_weights() - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() - - fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) - - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() - - self.assertEqual( - len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" - ) - for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded): - self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) - @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f1f3eaeebda..8fc7a409a45 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -32,7 +32,6 @@ from packaging import version from parameterized import parameterized from pytest import mark -import transformers from transformers import ( AutoModel, AutoModelForCausalLM, @@ -75,7 +74,6 @@ from transformers.models.auto.modeling_auto import ( from transformers.testing_utils import ( CaptureLogger, is_flaky, - is_pt_flax_cross_test, require_accelerate, require_bitsandbytes, require_deepspeed, @@ -100,14 +98,12 @@ from transformers.utils import ( GENERATION_CONFIG_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available, - is_flax_available, - is_tf_available, is_torch_bf16_available_on_device, is_torch_fp16_available_on_device, is_torch_fx_available, is_torch_sdpa_available, ) -from transformers.utils.generic import ContextManagers, ModelOutput +from transformers.utils.generic import ContextManagers if is_accelerate_available(): @@ -126,19 +122,6 @@ if is_torch_available(): from transformers.modeling_utils import load_state_dict, no_init_weights from transformers.pytorch_utils import id_tensor_storage - -if is_tf_available(): - pass - -if is_flax_available(): - import jax.numpy as jnp - - from tests.utils.test_modeling_flax_utils import check_models_equal - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) - if is_torch_fx_available(): from transformers.utils.fx import _FX_SUPPORTED_MODELS_WITH_KV_CACHE, symbolic_trace @@ -2552,249 +2535,6 @@ class ModelTesterMixin: diff = np.abs((a - b)).max() self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") - def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None): - """ - Args: - model_class: The class of the model that is currently testing. For example, ..., etc. - Currently unused, but it could make debugging easier and faster. - - names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs. - Currently unused, but in the future, we could use this information to make the error message clearer - by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax. - """ - self.assertEqual(type(name), str) - if attributes is not None: - self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`") - - # Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`). - if isinstance(fx_outputs, ModelOutput): - self.assertTrue( - isinstance(pt_outputs, ModelOutput), - f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is", - ) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch") - - # convert to the case of `tuple` - # appending each key to the current (string) `name` - attributes = tuple([f"{name}.{k}" for k in fx_keys]) - self.check_pt_flax_outputs( - fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes - ) - - # Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.) - elif type(fx_outputs) in [tuple, list]: - self.assertEqual( - type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch" - ) - self.assertEqual( - len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch" - ) - - if attributes is not None: - # case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`) - self.assertEqual( - len(attributes), - len(fx_outputs), - f"{name}: The tuple `attributes` should have the same length as `fx_outputs`", - ) - else: - # case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name` - attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))]) - - for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes): - self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr) - - elif isinstance(fx_outputs, jnp.ndarray): - self.assertTrue( - isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is" - ) - - # Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`. - fx_outputs = np.array(fx_outputs) - pt_outputs = pt_outputs.detach().to("cpu").numpy() - - self.assertEqual( - fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch" - ) - - # deal with NumPy's scalars to make replacing nan values by 0 work. - if np.isscalar(fx_outputs): - fx_outputs = np.array([fx_outputs]) - pt_outputs = np.array([pt_outputs]) - - fx_nans = np.isnan(fx_outputs) - pt_nans = np.isnan(pt_outputs) - - pt_outputs[fx_nans] = 0 - fx_outputs[fx_nans] = 0 - pt_outputs[pt_nans] = 0 - fx_outputs[pt_nans] = 0 - - max_diff = np.amax(np.abs(fx_outputs - pt_outputs)) - self.assertLessEqual( - max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})." - ) - else: - raise ValueError( - "`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got" - f" {type(fx_outputs)} instead." - ) - - @is_pt_flax_cross_test - def test_equivalence_pt_to_flax(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - fx_model_class_name = "Flax" + model_class.__name__ - - if not hasattr(transformers, fx_model_class_name): - self.skipTest(reason="No Flax model exists for this class") - - # Output all for aggressive testing - config.output_hidden_states = True - config.output_attentions = self.has_attentions - - fx_model_class = getattr(transformers, fx_model_class_name) - - # load PyTorch class - pt_model = model_class(config).eval() - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - - # load Flax class - fx_model = fx_model_class(config, dtype=jnp.float32) - - # make sure only flax inputs are forward that actually exist in function args - fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() - - # prepare inputs - pt_inputs = self._prepare_for_class(inputs_dict, model_class) - - # remove function args that don't exist in Flax - pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} - - # send pytorch inputs to the correct device - pt_inputs = { - k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() - } - - # convert inputs to Flax - fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - # send pytorch model to the correct device - pt_model.to(torch_device) - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - fx_outputs = fx_model(**fx_inputs) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) - - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**fx_inputs) - - fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class) - - @is_pt_flax_cross_test - def test_equivalence_flax_to_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - fx_model_class_name = "Flax" + model_class.__name__ - - if not hasattr(transformers, fx_model_class_name): - self.skipTest(reason="No Flax model exists for this class") - - # Output all for aggressive testing - config.output_hidden_states = True - config.output_attentions = self.has_attentions - - fx_model_class = getattr(transformers, fx_model_class_name) - - # load PyTorch class - pt_model = model_class(config).eval() - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - - # load Flax class - fx_model = fx_model_class(config, dtype=jnp.float32) - - # make sure only flax inputs are forward that actually exist in function args - fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() - - # prepare inputs - pt_inputs = self._prepare_for_class(inputs_dict, model_class) - - # remove function args that don't exist in Flax - pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} - - # send pytorch inputs to the correct device - pt_inputs = { - k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() - } - - # convert inputs to Flax - fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - - # make sure weights are tied in PyTorch - pt_model.tie_weights() - - # send pytorch model to the correct device - pt_model.to(torch_device) - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - fx_outputs = fx_model(**fx_inputs) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) - - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = model_class.from_pretrained( - tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation - ) - - # send pytorch model to the correct device - pt_model_loaded.to(torch_device) - pt_model_loaded.eval() - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) - def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -4413,29 +4153,6 @@ class ModelTesterMixin: tol = torch.finfo(torch.float16).eps torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) - @is_pt_flax_cross_test - def test_flax_from_pt_safetensors(self): - for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - flax_model_class_name = "Flax" + model_class.__name__ # Add the "Flax at the beginning - if not hasattr(transformers, flax_model_class_name): - self.skipTest(reason="transformers does not have this model in Flax version yet") - - flax_model_class = getattr(transformers, flax_model_class_name) - - pt_model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname, safe_serialization=True) - flax_model_1 = flax_model_class.from_pretrained(tmpdirname, from_pt=True) - - pt_model.save_pretrained(tmpdirname, safe_serialization=False) - flax_model_2 = flax_model_class.from_pretrained(tmpdirname, from_pt=True) - - # Check models are equal - self.assertTrue(check_models_equal(flax_model_1, flax_model_2)) - @require_flash_attn @require_torch_gpu @mark.flash_attn_test diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index e6e3a860772..ab126357f58 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -21,13 +21,10 @@ from typing import List, Tuple import numpy as np -import transformers -from transformers import is_flax_available, is_torch_available -from transformers.cache_utils import DynamicCache +from transformers import is_flax_available from transformers.models.auto import get_values -from transformers.testing_utils import CaptureLogger, is_pt_flax_cross_test, require_flax, torch_device +from transformers.testing_utils import CaptureLogger, require_flax from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging -from transformers.utils.generic import ModelOutput if is_flax_available(): @@ -47,17 +44,10 @@ if is_flax_available(): FlaxAutoModelForSequenceClassification, FlaxBertModel, ) - from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, - load_flax_weights_in_pytorch_model, - ) from transformers.modeling_flax_utils import FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 -if is_torch_available(): - import torch - def ids_tensor(shape, vocab_size, rng=None): """Creates a random int32 tensor of the shape within the vocab size.""" @@ -184,216 +174,6 @@ class FlaxModelTesterMixin: dict_inputs = self._prepare_for_class(inputs_dict, model_class) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) - # (Copied from tests.test_modeling_common.ModelTesterMixin.check_pt_flax_outputs) - def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None): - """ - Args: - model_class: The class of the model that is currently testing. For example, ..., etc. - Currently unused, but it could make debugging easier and faster. - - names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs. - Currently unused, but in the future, we could use this information to make the error message clearer - by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax. - """ - self.assertEqual(type(name), str) - if attributes is not None: - self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`") - - # Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`). - if isinstance(fx_outputs, ModelOutput): - self.assertTrue( - isinstance(pt_outputs, ModelOutput), - f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is", - ) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch") - - # convert to the case of `tuple` - # appending each key to the current (string) `name` - attributes = tuple([f"{name}.{k}" for k in fx_keys]) - self.check_pt_flax_outputs( - fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes - ) - - # Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.) - elif type(fx_outputs) in [tuple, list]: - self.assertEqual( - type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch" - ) - self.assertEqual( - len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch" - ) - - if attributes is not None: - # case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`) - self.assertEqual( - len(attributes), - len(fx_outputs), - f"{name}: The tuple `attributes` should have the same length as `fx_outputs`", - ) - else: - # case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name` - attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))]) - - for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes): - if isinstance(pt_output, DynamicCache): - pt_output = pt_output.to_legacy_cache() - self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr) - - elif isinstance(fx_outputs, jnp.ndarray): - self.assertTrue( - isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is" - ) - - # Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`. - fx_outputs = np.array(fx_outputs) - pt_outputs = pt_outputs.detach().to("cpu").numpy() - - self.assertEqual( - fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch" - ) - - # deal with NumPy's scalars to make replacing nan values by 0 work. - if np.isscalar(fx_outputs): - fx_outputs = np.array([fx_outputs]) - pt_outputs = np.array([pt_outputs]) - - fx_nans = np.isnan(fx_outputs) - pt_nans = np.isnan(pt_outputs) - - pt_outputs[fx_nans] = 0 - fx_outputs[fx_nans] = 0 - pt_outputs[pt_nans] = 0 - fx_outputs[pt_nans] = 0 - - max_diff = np.amax(np.abs(fx_outputs - pt_outputs)) - self.assertLessEqual( - max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})." - ) - else: - raise ValueError( - "`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got" - f" {type(fx_outputs)} instead." - ) - - @is_pt_flax_cross_test - def test_equivalence_pt_to_flax(self): - # It might be better to put this inside the for loop below (because we modify the config there). - # But logically, it is fine. - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # Output all for aggressive testing - config.output_hidden_states = True - config.output_attentions = self.has_attentions - - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - pt_model = pt_model_class(config).eval() - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - fx_model = model_class(config, dtype=jnp.float32) - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - # send pytorch model to the correct device - pt_model.to(torch_device) - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - fx_outputs = fx_model(**prepared_inputs_dict) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) - - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict) - - fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class) - - @is_pt_flax_cross_test - def test_equivalence_flax_to_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # Output all for aggressive testing - config.output_hidden_states = True - config.output_attentions = self.has_attentions - - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - pt_model = pt_model_class(config).eval() - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - fx_model = model_class(config, dtype=jnp.float32) - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - - # make sure weights are tied in PyTorch - pt_model.tie_weights() - - # send pytorch model to the correct device - pt_model.to(torch_device) - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - fx_outputs = fx_model(**prepared_inputs_dict) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) - - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = pt_model_class.from_pretrained( - tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation - ) - - # send pytorch model to the correct device - pt_model_loaded.to(torch_device) - pt_model_loaded.eval() - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) - def test_from_pretrained_save_pretrained(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -474,92 +254,6 @@ class FlaxModelTesterMixin: max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - @is_pt_flax_cross_test - def test_save_load_from_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = base_class(config) - base_params = get_params(model.params) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - # save pt model - pt_model.save_pretrained(tmpdirname) - head_model = model_class.from_pretrained(tmpdirname, from_pt=True) - - base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix) - - for key in base_param_from_head.keys(): - max_diff = (base_params[key] - base_param_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - @is_pt_flax_cross_test - def test_save_load_to_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = model_class(config) - base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - - base_params = get_params(base_model.params) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - @is_pt_flax_cross_test - def test_save_load_bf16_to_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = model_class(config) - model.params = model.to_bf16(model.params) - base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - - base_params = get_params(base_model.params) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -1119,14 +813,6 @@ class FlaxModelTesterMixin: for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()): self.assertTrue(np.allclose(np.array(p1), np.array(p2))) - @is_pt_flax_cross_test - def test_from_sharded_pt(self): - model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) - ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only") - for key, ref_val in flatten_dict(ref_model.params).items(): - val = flatten_dict(model.params)[key] - assert np.allclose(np.array(val), np.array(ref_val)) - def test_gradient_checkpointing(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 1981859048b..fb6860bcc31 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -31,13 +31,11 @@ from datasets import Dataset from transformers import is_tf_available from transformers.models.auto import get_values -from transformers.testing_utils import ( # noqa: F401 +from transformers.testing_utils import ( CaptureLogger, - _tf_gpu_memory_limit, require_tf, require_tf2onnx, slow, - torch_device, ) from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging from transformers.utils.generic import ModelOutput @@ -73,20 +71,6 @@ if is_tf_available(): tf.config.experimental.enable_tensor_float_32_execution(False) - if _tf_gpu_memory_limit is not None: - gpus = tf.config.list_physical_devices("GPU") - for gpu in gpus: - # Restrict TensorFlow to only allocate x GB of memory on the GPUs - try: - tf.config.set_logical_device_configuration( - gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)] - ) - logical_gpus = tf.config.list_logical_devices("GPU") - print("Logical GPUs", logical_gpus) - except RuntimeError as e: - # Virtual devices must be set before GPUs have been initialized - print(e) - def _config_zero_init(config): configs_no_init = copy.deepcopy(config) diff --git a/tests/utils/test_modeling_flax_utils.py b/tests/utils/test_modeling_flax_utils.py index 7f66944446a..7a2c516132b 100644 --- a/tests/utils/test_modeling_flax_utils.py +++ b/tests/utils/test_modeling_flax_utils.py @@ -18,16 +18,14 @@ import unittest import numpy as np from huggingface_hub import HfFolder, snapshot_download -from transformers import BertConfig, BertModel, is_flax_available, is_torch_available +from transformers import BertConfig, is_flax_available from transformers.testing_utils import ( TOKEN, CaptureLogger, TemporaryHubRepo, - is_pt_flax_cross_test, is_staging_test, require_flax, require_safetensors, - require_torch, ) from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging @@ -42,9 +40,6 @@ if is_flax_available(): os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 -if is_torch_available(): - import torch - @require_flax @is_staging_test @@ -205,23 +200,6 @@ class FlaxModelUtilsTest(unittest.TestCase): self.assertTrue(check_models_equal(model, new_model)) - @require_flax - @require_torch - @is_pt_flax_cross_test - def test_safetensors_save_and_load_pt_to_flax(self): - model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True) - pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") - with tempfile.TemporaryDirectory() as tmp_dir: - pt_model.save_pretrained(tmp_dir) - - # Check we have a model.safetensors file - self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME))) - - new_model = FlaxBertModel.from_pretrained(tmp_dir) - - # Check models are equal - self.assertTrue(check_models_equal(model, new_model)) - @require_safetensors def test_safetensors_load_from_hub(self): """ @@ -248,58 +226,6 @@ class FlaxModelUtilsTest(unittest.TestCase): self.assertTrue(check_models_equal(flax_model, safetensors_model)) - @require_safetensors - @is_pt_flax_cross_test - def test_safetensors_load_from_hub_from_safetensors_pt(self): - """ - This test checks that we can load safetensors from a checkpoint that only has those on the Hub. - saved in the "pt" format. - """ - flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-msgpack") - - # Can load from the PyTorch-formatted checkpoint - safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors") - self.assertTrue(check_models_equal(flax_model, safetensors_model)) - - @require_safetensors - @require_torch - @is_pt_flax_cross_test - def test_safetensors_load_from_hub_from_safetensors_pt_bf16(self): - """ - This test checks that we can load safetensors from a checkpoint that only has those on the Hub. - saved in the "pt" format. - """ - import torch - - model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors") - model.to(torch.bfloat16) - - with tempfile.TemporaryDirectory() as tmp: - model.save_pretrained(tmp) - flax_model = FlaxBertModel.from_pretrained(tmp) - - # Can load from the PyTorch-formatted checkpoint - safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16") - self.assertTrue(check_models_equal(flax_model, safetensors_model)) - - @require_safetensors - @is_pt_flax_cross_test - def test_safetensors_load_from_local_from_safetensors_pt(self): - """ - This test checks that we can load safetensors from a checkpoint that only has those on the Hub. - saved in the "pt" format. - """ - with tempfile.TemporaryDirectory() as tmp: - location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp) - flax_model = FlaxBertModel.from_pretrained(location) - - # Can load from the PyTorch-formatted checkpoint - with tempfile.TemporaryDirectory() as tmp: - location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp) - safetensors_model = FlaxBertModel.from_pretrained(location) - - self.assertTrue(check_models_equal(flax_model, safetensors_model)) - @require_safetensors def test_safetensors_load_from_hub_msgpack_before_safetensors(self): """ @@ -328,19 +254,6 @@ class FlaxModelUtilsTest(unittest.TestCase): self.assertTrue(check_models_equal(model, new_model)) - @require_safetensors - @require_torch - @is_pt_flax_cross_test - def test_safetensors_flax_from_torch(self): - hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") - model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") - - with tempfile.TemporaryDirectory() as tmp_dir: - model.save_pretrained(tmp_dir, safe_serialization=True) - new_model = FlaxBertModel.from_pretrained(tmp_dir) - - self.assertTrue(check_models_equal(hub_model, new_model)) - @require_safetensors def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_local(self): with tempfile.TemporaryDirectory() as tmp_dir: @@ -370,27 +283,3 @@ class FlaxModelUtilsTest(unittest.TestCase): "Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint" in cl.out ) - - @require_torch - @require_safetensors - @is_pt_flax_cross_test - def test_from_pt_bf16(self): - model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") - model.to(torch.bfloat16) - - with tempfile.TemporaryDirectory() as tmp_dir: - model.save_pretrained(tmp_dir, safe_serialization=False) - - logger = logging.get_logger("transformers.modeling_flax_utils") - - with CaptureLogger(logger) as cl: - new_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16") - - self.assertTrue( - "Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint" - in cl.out - ) - - flat_params_1 = flatten_dict(new_model.params) - for value in flat_params_1.values(): - self.assertEqual(value.dtype, "bfloat16") diff --git a/tests/utils/test_modeling_tf_core.py b/tests/utils/test_modeling_tf_core.py index 53f3edede7a..c1cfc469697 100644 --- a/tests/utils/test_modeling_tf_core.py +++ b/tests/utils/test_modeling_tf_core.py @@ -24,7 +24,7 @@ from math import isnan from transformers import is_tf_available from transformers.models.auto import get_values -from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow +from transformers.testing_utils import require_tf, slow from ..test_modeling_tf_common import ids_tensor @@ -48,20 +48,6 @@ if is_tf_available(): ) from transformers.modeling_tf_utils import keras - if _tf_gpu_memory_limit is not None: - gpus = tf.config.list_physical_devices("GPU") - for gpu in gpus: - # Restrict TensorFlow to only allocate x GB of memory on the GPUs - try: - tf.config.set_logical_device_configuration( - gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)] - ) - logical_gpus = tf.config.list_logical_devices("GPU") - print("Logical GPUs", logical_gpus) - except RuntimeError as e: - # Virtual devices must be set before GPUs have been initialized - print(e) - @require_tf class TFCoreModelTesterMixin: diff --git a/tests/utils/test_modeling_tf_utils.py b/tests/utils/test_modeling_tf_utils.py index abd728c43ae..116af748e6f 100644 --- a/tests/utils/test_modeling_tf_utils.py +++ b/tests/utils/test_modeling_tf_utils.py @@ -33,7 +33,6 @@ from transformers.testing_utils import ( # noqa: F401 USER, CaptureLogger, TemporaryHubRepo, - _tf_gpu_memory_limit, is_staging_test, require_safetensors, require_tf, @@ -68,20 +67,6 @@ if is_tf_available(): tf.config.experimental.enable_tensor_float_32_execution(False) - if _tf_gpu_memory_limit is not None: - gpus = tf.config.list_physical_devices("GPU") - for gpu in gpus: - # Restrict TensorFlow to only allocate x GB of memory on the GPUs - try: - tf.config.set_logical_device_configuration( - gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)] - ) - logical_gpus = tf.config.list_logical_devices("GPU") - print("Logical GPUs", logical_gpus) - except RuntimeError as e: - # Virtual devices must be set before GPUs have been initialized - print(e) - @require_tf class TFModelUtilsTest(unittest.TestCase): diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 7fb24cd540f..3543b18f9c2 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -1148,7 +1148,6 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]: JOB_TO_TEST_FILE = { - "tests_torch_and_flax": r"tests/models/.*/test_modeling_(?:flax|(?!tf)).*", "tests_tf": r"tests/models/.*/test_modeling_tf_.*", "tests_torch": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*", "tests_generate": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*",