[tests] remove flax-pt equivalence and cross tests (#36283)

This commit is contained in:
Joao Gante 2025-02-19 15:13:27 +00:00 committed by GitHub
parent fa8cdccd91
commit 99adc74462
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
39 changed files with 33 additions and 3103 deletions

View File

@ -28,7 +28,6 @@ COMMON_ENV_VARIABLES = {
"TRANSFORMERS_IS_CI": True,
"PYTEST_TIMEOUT": 120,
"RUN_PIPELINE_TESTS": False,
"RUN_PT_FLAX_CROSS_TESTS": False,
}
# Disable the use of {"s": None} as the output is way too long, causing the navigation on CircleCI impractical
COMMON_PYTEST_OPTIONS = {"max-worker-restart": 0, "dist": "loadfile", "vvv": None, "rsfE":None}
@ -176,14 +175,6 @@ class CircleCIJob:
# JOBS
torch_and_flax_job = CircleCIJob(
"torch_and_flax",
additional_env={"RUN_PT_FLAX_CROSS_TESTS": True},
docker_image=[{"image":"huggingface/transformers-torch-jax-light"}],
marker="is_pt_flax_cross_test",
pytest_options={"rA": None, "durations": 0},
)
torch_job = CircleCIJob(
"torch",
docker_image=[{"image": "huggingface/transformers-torch-light"}],
@ -343,7 +334,7 @@ doc_test_job = CircleCIJob(
pytest_num_workers=1,
)
REGULAR_TESTS = [torch_and_flax_job, torch_job, tf_job, flax_job, hub_job, onnx_job, tokenization_job, processor_job, generate_job, non_model_job] # fmt: skip
REGULAR_TESTS = [torch_job, tf_job, flax_job, hub_job, onnx_job, tokenization_job, processor_job, generate_job, non_model_job] # fmt: skip
EXAMPLES_TESTS = [examples_torch_job, examples_tensorflow_job]
PIPELINE_TESTS = [pipelines_torch_job, pipelines_tf_job]
REPO_UTIL_TESTS = [repo_utils_job]

View File

@ -26,7 +26,7 @@ jobs:
strategy:
matrix:
file: ["quality", "consistency", "custom-tokenizers", "torch-light", "tf-light", "exotic-models", "torch-tf-light", "torch-jax-light", "jax-light", "examples-torch", "examples-tf"]
file: ["quality", "consistency", "custom-tokenizers", "torch-light", "tf-light", "exotic-models", "torch-tf-light", "jax-light", "examples-torch", "examples-tf"]
continue-on-error: true
steps:
@ -34,11 +34,11 @@ jobs:
name: Set tag
run: |
if ${{contains(github.event.head_commit.message, '[build-ci-image]')}}; then
echo "TAG=huggingface/transformers-${{ matrix.file }}:dev" >> "$GITHUB_ENV"
echo "TAG=huggingface/transformers-${{ matrix.file }}:dev" >> "$GITHUB_ENV"
echo "setting it to DEV!"
else
echo "TAG=huggingface/transformers-${{ matrix.file }}" >> "$GITHUB_ENV"
fi
-
name: Set up Docker Buildx

View File

@ -343,7 +343,6 @@ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/t
Like the slow tests, there are other environment variables available which are not enabled by default during testing:
- `RUN_CUSTOM_TOKENIZERS`: Enables tests for custom tokenizers.
- `RUN_PT_FLAX_CROSS_TESTS`: Enables tests for PyTorch + Flax integration.
More environment variables and additional information can be found in the [testing_utils.py](https://github.com/huggingface/transformers/blob/main/src/transformers/testing_utils.py).

View File

@ -84,9 +84,6 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_configure(config):
config.addinivalue_line(
"markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested"
)
config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested")
config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment")
config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate")

View File

@ -283,7 +283,6 @@ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/t
Wie bei den langsamen Tests gibt es auch andere Umgebungsvariablen, die standardmäßig beim Testen nicht gesetzt sind:
* `RUN_CUSTOM_TOKENIZERS`: Aktiviert Tests für benutzerdefinierte Tokenizer.
* `RUN_PT_FLAX_CROSS_TESTS`: Aktiviert Tests für die Integration von PyTorch + Flax.
Weitere Umgebungsvariablen und zusätzliche Informationen finden Sie in der [testing_utils.py](src/transformers/testing_utils.py).

View File

@ -282,7 +282,6 @@ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/t
느린 테스트와 마찬가지로, 다음과 같이 테스트 중에 기본적으로 활성화되지 않는 다른 환경 변수도 있습니다:
- `RUN_CUSTOM_TOKENIZERS`: 사용자 정의 토크나이저 테스트를 활성화합니다.
- `RUN_PT_FLAX_CROSS_TESTS`: PyTorch + Flax 통합 테스트를 활성화합니다.
더 많은 환경 변수와 추가 정보는 [testing_utils.py](src/transformers/testing_utils.py)에서 찾을 수 있습니다.

View File

@ -281,7 +281,6 @@ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/t
和时间较长的测试一样,还有其他环境变量在测试过程中,在默认情况下是未启用的:
- `RUN_CUSTOM_TOKENIZERS`: 启用自定义分词器的测试。
- `RUN_PT_FLAX_CROSS_TESTS`: 启用 PyTorch + Flax 整合的测试。
更多环境变量和额外信息可以在 [testing_utils.py](src/transformers/testing_utils.py) 中找到。

View File

@ -473,7 +473,6 @@ setup(
extras["tests_torch"] = deps_list()
extras["tests_tf"] = deps_list()
extras["tests_flax"] = deps_list()
extras["tests_torch_and_flax"] = deps_list()
extras["tests_hub"] = deps_list()
extras["tests_pipelines_torch"] = deps_list()
extras["tests_pipelines_tf"] = deps_list()

View File

@ -230,10 +230,8 @@ def parse_int_from_env(key, default=None):
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=True)
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
_run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False)
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)
@ -250,25 +248,6 @@ def get_device_count():
return num_devices
def is_pt_flax_cross_test(test_case):
"""
Decorator marking a test as a test that control interactions between PyTorch and Flax
PT+FLAX tests are skipped by default and we can run only them by setting RUN_PT_FLAX_CROSS_TESTS environment
variable to a truthy value and selecting the is_pt_flax_cross_test pytest mark.
"""
if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available():
return unittest.skip(reason="test is PT+FLAX test")(test_case)
else:
try:
import pytest # We don't need a hard dependency on pytest in the main library
except ImportError:
return test_case
else:
return pytest.mark.is_pt_flax_cross_test()(test_case)
def is_staging_test(test_case):
"""
Decorator marking a test as a staging test.

View File

@ -624,15 +624,6 @@ class BigBirdModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
# overwrite from common in order to skip the check on `attentions`
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
# an effort was done to return `attention_probs` (yet to be verified).
if name.startswith("outputs.attentions"):
return
else:
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
@require_torch
@slow

View File

@ -212,12 +212,3 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
# overwrite from common in order to skip the check on `attentions`
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
# an effort was done to return `attention_probs` (yet to be verified).
if name.startswith("outputs.attentions"):
return
else:
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)

View File

@ -25,11 +25,8 @@ import requests
from parameterized import parameterized
from pytest import mark
import transformers
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from transformers.testing_utils import (
is_flax_available,
is_pt_flax_cross_test,
require_flash_attn,
require_torch,
require_torch_gpu,
@ -82,15 +79,6 @@ if is_vision_available():
from transformers import CLIPProcessor
if is_flax_available():
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
class CLIPVisionModelTester:
def __init__(
self,
@ -883,126 +871,6 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
text_config = CLIPTextConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
# overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# load PyTorch class
pt_model = model_class(config).eval()
pt_model.to(torch_device)
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
self.skipTest(reason="No Flax model exists for this class")
fx_model_class = getattr(transformers, fx_model_class_name)
# load Flax class
fx_model = fx_model_class(config, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
# convert inputs to Flax
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
# overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# load corresponding PyTorch class
pt_model = model_class(config).eval()
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
self.skipTest(reason="No Flax model exists for this class")
fx_model_class = getattr(transformers, fx_model_class_name)
# load Flax class
fx_model = fx_model_class(config, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
pt_model.to(torch_device)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
@slow
def test_model_from_pretrained(self):
model_name = "openai/clip-vit-base-patch32"

View File

@ -4,21 +4,15 @@ import unittest
import numpy as np
import transformers
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.models.clip.modeling_flax_clip import (
FlaxCLIPModel,
FlaxCLIPTextModel,
@ -26,9 +20,6 @@ if is_flax_available():
FlaxCLIPVisionModel,
)
if is_torch_available():
import torch
class FlaxCLIPVisionModelTester:
def __init__(
@ -223,21 +214,6 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_save_load_to_base(self):
pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
pass
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
@ -333,21 +309,6 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_save_load_to_base(self):
pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
pass
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
@ -472,92 +433,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
outputs = model(input_ids=np.ones((1, 1)), pixel_values=np.ones((1, 3, 224, 224)))
self.assertIsNotNone(outputs)
# overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
# overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
fx_model = model_class(config, dtype=jnp.float32)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
# overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test
def test_from_pretrained_save_pretrained(self):

View File

@ -22,11 +22,8 @@ import unittest
import numpy as np
import requests
import transformers
from transformers import CLIPSegConfig, CLIPSegProcessor, CLIPSegTextConfig, CLIPSegVisionConfig
from transformers.testing_utils import (
is_flax_available,
is_pt_flax_cross_test,
require_torch,
require_vision,
slow,
@ -57,15 +54,6 @@ if is_vision_available():
from PIL import Image
if is_flax_available():
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
class CLIPSegVisionModelTester:
def __init__(
self,
@ -635,123 +623,6 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
text_config = CLIPSegTextConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
# overwrite from common since FlaxCLIPSegModel returns nested output
# which is not supported in the common test
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
self.skipTest(reason="No Flax model exists for this class")
fx_model_class = getattr(transformers, fx_model_class_name)
# load Flax class
fx_model = fx_model_class(config, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
# convert inputs to Flax
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
# overwrite from common since FlaxCLIPSegModel returns nested output
# which is not supported in the common test
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# load corresponding PyTorch class
pt_model = model_class(config).eval()
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
self.skipTest(reason="No Flax model exists for this class")
fx_model_class = getattr(transformers, fx_model_class_name)
# load Flax class
fx_model = fx_model_class(config, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
def test_training(self):
if not self.model_tester.is_training:
self.skipTest(reason="Training test is skipped as the model was not trained")

View File

@ -22,7 +22,7 @@ from datasets import load_dataset
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import Data2VecAudioConfig, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_soundfile, require_torch, slow, torch_device
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init
@ -442,16 +442,6 @@ class Data2VecAudioModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
def test_model_get_set_embeddings(self):
pass
@is_pt_flax_cross_test
# non-robust architecture does not exist in Flax
def test_equivalence_flax_to_pt(self):
pass
@is_pt_flax_cross_test
# non-robust architecture does not exist in Flax
def test_equivalence_pt_to_flax(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True

View File

@ -19,8 +19,8 @@ import unittest
import numpy as np
from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
from transformers import is_flax_available
from transformers.testing_utils import require_flax, slow
from ...test_modeling_flax_common import ids_tensor
from ..bart.test_modeling_flax_bart import FlaxBartStandaloneDecoderModelTester
@ -38,15 +38,6 @@ if is_flax_available():
FlaxEncoderDecoderModel,
FlaxGPT2LMHeadModel,
)
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
if is_torch_available():
import torch
from transformers import EncoderDecoderModel
@require_flax
@ -291,68 +282,6 @@ class FlaxEncoderDecoderMixin:
generated_sequences = generated_output.sequences
self.assertEqual(generated_sequences.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = EncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = EncoderDecoderModel(encoder_decoder_config)
fx_model = FlaxEncoderDecoderModel(encoder_decoder_config)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = EncoderDecoderModel(encoder_decoder_config)
fx_model = FlaxEncoderDecoderModel(encoder_decoder_config)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def test_encoder_decoder_model_from_pretrained_configs(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
@ -385,40 +314,6 @@ class FlaxEncoderDecoderMixin:
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@is_pt_flax_cross_test
def test_pt_flax_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
config = config_inputs_dict.pop("config")
decoder_config = config_inputs_dict.pop("decoder_config")
inputs_dict = config_inputs_dict
# `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"]
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = np.concatenate(
[np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1
)
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
decoder_config.use_cache = False
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
# check without `enc_to_dec_proj` projection
decoder_config.hidden_size = config.hidden_size
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
# check `enc_to_dec_proj` work as expected
decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()

View File

@ -13,14 +13,12 @@
# limitations under the License.
import tempfile
import unittest
import numpy as np
import transformers
from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
from transformers import GPT2Config, GPT2Tokenizer, is_flax_available
from transformers.testing_utils import require_flax, slow
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
@ -29,15 +27,8 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
if is_torch_available():
import torch
class FlaxGPT2ModelTester:
def __init__(
@ -255,105 +246,6 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertListEqual(output_string, expected_string)
# overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
batch_size, seq_length = pt_inputs["input_ids"].shape
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
pt_model = pt_model_class(config).eval()
fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)
# overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
fx_model = model_class(config, dtype=jnp.float32)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
batch_size, seq_length = pt_inputs["input_ids"].shape
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
# make sure weights are tied in PyTorch
pt_model.tie_weights()
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:

View File

@ -13,14 +13,12 @@
# limitations under the License.
import tempfile
import unittest
import numpy as np
import transformers
from transformers import GPT2Tokenizer, GPTNeoConfig, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
from transformers import GPT2Tokenizer, GPTNeoConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
@ -29,15 +27,8 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.models.gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel
if is_torch_available():
import torch
class FlaxGPTNeoModelTester:
def __init__(
@ -224,105 +215,6 @@ class FlaxGPTNeoModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertListEqual(output_string, expected_string)
# overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
batch_size, seq_length = pt_inputs["input_ids"].shape
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
pt_model = pt_model_class(config).eval()
fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)
# overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
fx_model = model_class(config, dtype=jnp.float32)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
batch_size, seq_length = pt_inputs["input_ids"].shape
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
# make sure weights are tied in PyTorch
pt_model.tie_weights()
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:

View File

@ -13,14 +13,12 @@
# limitations under the License.
import tempfile
import unittest
import numpy as np
import transformers
from transformers import GPT2Tokenizer, GPTJConfig, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, tooslow
from transformers import GPT2Tokenizer, GPTJConfig, is_flax_available
from transformers.testing_utils import require_flax, tooslow
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
@ -29,15 +27,8 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.models.gptj.modeling_flax_gptj import FlaxGPTJForCausalLM, FlaxGPTJModel
if is_torch_available():
import torch
class FlaxGPTJModelTester:
def __init__(
@ -221,105 +212,6 @@ class FlaxGPTJModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertListEqual(output_string, expected_string)
# overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
batch_size, seq_length = pt_inputs["input_ids"].shape
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
pt_model = pt_model_class(config).eval()
fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)
# overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
fx_model = model_class(config, dtype=jnp.float32)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
batch_size, seq_length = pt_inputs["input_ids"].shape
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
# make sure weights are tied in PyTorch
pt_model.tie_weights()
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
@tooslow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:

View File

@ -17,11 +17,9 @@ import unittest
import numpy as np
import transformers
from transformers import is_flax_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
is_pt_flax_cross_test,
require_flax,
require_sentencepiece,
require_tokenizers,
@ -46,7 +44,6 @@ if is_flax_available():
from flax.traverse_util import flatten_dict
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING, AutoTokenizer, LongT5Config
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
from transformers.models.longt5.modeling_flax_longt5 import (
FlaxLongT5ForConditionalGeneration,
FlaxLongT5Model,
@ -467,95 +464,6 @@ class FlaxLongT5ModelTest(FlaxModelTesterMixin, unittest.TestCase):
[self.model_tester.num_attention_heads, block_len, 3 * block_len],
)
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
# save pt model
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = flatten_dict(unfreeze(head_model.params))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
class FlaxLongT5TGlobalModelTest(FlaxLongT5ModelTest):
def setUp(self):

View File

@ -19,7 +19,7 @@ import unittest
import numpy as np
from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
from transformers.testing_utils import require_flax, slow
from ...test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
from ..bart.test_modeling_flax_bart import FlaxBartStandaloneDecoderModelTester
@ -43,14 +43,8 @@ if is_flax_available():
SpeechEncoderDecoderConfig,
)
from transformers.modeling_flax_outputs import FlaxBaseModelOutput
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
if is_torch_available():
import torch
from transformers import SpeechEncoderDecoderModel
@ -406,68 +400,6 @@ class FlaxEncoderDecoderMixin:
for grad, grad_frozen in zip(grads, grads_frozen):
self.assertTrue((grad == grad_frozen).all())
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = SpeechEncoderDecoderModel(encoder_decoder_config)
fx_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = SpeechEncoderDecoderModel(encoder_decoder_config)
fx_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def test_encoder_decoder_model_from_pretrained_configs(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
@ -504,46 +436,6 @@ class FlaxEncoderDecoderMixin:
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@is_pt_flax_cross_test
def test_pt_flax_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
config = config_inputs_dict.pop("config")
decoder_config = config_inputs_dict.pop("decoder_config")
inputs_dict = config_inputs_dict
# `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"]
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = np.concatenate(
[np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1
)
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
decoder_config.use_cache = False
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
# check without `enc_to_dec_proj` projection
decoder_config.hidden_size = config.hidden_size
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
# check `enc_to_dec_proj` work as expected
decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
# check `add_adapter` works as expected
config.add_adapter = True
self.assertTrue(config.add_adapter)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
@ -625,71 +517,6 @@ class FlaxWav2Vec2GPT2ModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
"encoder_hidden_states": encoder_hidden_states,
}
@slow
def test_flaxwav2vec2gpt2_pt_flax_equivalence(self):
pt_model = SpeechEncoderDecoderModel.from_pretrained("jsnfly/wav2vec2-large-xlsr-53-german-gpt2")
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained(
"jsnfly/wav2vec2-large-xlsr-53-german-gpt2", from_pt=True
)
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
batch_size = 13
input_values = floats_tensor([batch_size, 512], scale=1.0)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs_dict = {
"inputs": input_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
pt_logits = pt_outputs.logits
pt_outputs = pt_outputs.to_tuple()
fx_outputs = fx_model(**inputs_dict)
fx_logits = fx_outputs.logits
fx_outputs = fx_outputs.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
fx_logits_loaded = fx_outputs_loaded.logits
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
pt_logits_loaded = pt_outputs_loaded.logits
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
@require_flax
class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
@ -742,71 +569,6 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
"encoder_hidden_states": encoder_hidden_states,
}
@slow
def test_flaxwav2vec2bart_pt_flax_equivalence(self):
pt_model = SpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained(
"patrickvonplaten/wav2vec2-2-bart-large", from_pt=True
)
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
batch_size = 13
input_values = floats_tensor([batch_size, 512], scale=1.0)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs_dict = {
"inputs": input_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
pt_logits = pt_outputs.logits
pt_outputs = pt_outputs.to_tuple()
fx_outputs = fx_model(**inputs_dict)
fx_logits = fx_outputs.logits
fx_outputs = fx_outputs.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
fx_logits_loaded = fx_outputs_loaded.logits
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
pt_logits_loaded = pt_outputs_loaded.logits
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
@require_flax
class FlaxWav2Vec2BertModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
@ -858,66 +620,3 @@ class FlaxWav2Vec2BertModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
"decoder_attention_mask": decoder_attention_mask,
"encoder_hidden_states": encoder_hidden_states,
}
@slow
def test_flaxwav2vec2bert_pt_flax_equivalence(self):
pt_model = SpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large")
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large", from_pt=True)
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
batch_size = 13
input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs_dict = {
"inputs": input_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
pt_logits = pt_outputs.logits
pt_outputs = pt_outputs.to_tuple()
fx_outputs = fx_model(**inputs_dict)
fx_logits = fx_outputs.logits
fx_outputs = fx_outputs.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
fx_logits_loaded = fx_outputs_loaded.logits
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
pt_logits_loaded = pt_outputs_loaded.logits
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)

View File

@ -17,15 +17,8 @@ import unittest
import numpy as np
import transformers
from transformers import is_flax_available
from transformers.testing_utils import (
is_pt_flax_cross_test,
require_flax,
require_sentencepiece,
require_tokenizers,
slow,
)
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
@ -47,7 +40,6 @@ if is_flax_available():
from flax.traverse_util import flatten_dict
from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
from transformers.models.t5.modeling_flax_t5 import (
FlaxT5EncoderModel,
FlaxT5ForConditionalGeneration,
@ -373,95 +365,6 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, unittest.TestCase):
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
# save pt model
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = flatten_dict(unfreeze(head_model.params))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
class FlaxT5EncoderOnlyModelTester:
def __init__(
@ -663,95 +566,6 @@ class FlaxT5EncoderOnlyModelTest(FlaxModelTesterMixin, unittest.TestCase):
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
# save pt model
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = flatten_dict(unfreeze(head_model.params))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@require_sentencepiece
@require_tokenizers

View File

@ -19,8 +19,8 @@ import unittest
import numpy as np
from transformers import is_flax_available, is_torch_available, is_vision_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_vision, slow, torch_device
from transformers import is_flax_available, is_vision_available
from transformers.testing_utils import require_flax, require_vision, slow
from ...test_modeling_flax_common import floats_tensor, ids_tensor
from ..gpt2.test_modeling_flax_gpt2 import FlaxGPT2ModelTester
@ -35,15 +35,7 @@ if is_flax_available():
FlaxViTModel,
VisionEncoderDecoderConfig,
)
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
if is_torch_available():
import torch
from transformers import VisionEncoderDecoderModel
if is_vision_available():
from PIL import Image
@ -235,68 +227,6 @@ class FlaxEncoderDecoderMixin:
generated_sequences = generated_output.sequences
self.assertEqual(generated_sequences.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxVisionEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = VisionEncoderDecoderModel(encoder_decoder_config)
fx_model = FlaxVisionEncoderDecoderModel(encoder_decoder_config)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = VisionEncoderDecoderModel(encoder_decoder_config)
fx_model = FlaxVisionEncoderDecoderModel(encoder_decoder_config)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def test_encoder_decoder_model_from_pretrained_configs(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained_configs(**config_inputs_dict)
@ -325,39 +255,6 @@ class FlaxEncoderDecoderMixin:
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@is_pt_flax_cross_test
def test_pt_flax_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
config = config_inputs_dict.pop("config")
decoder_config = config_inputs_dict.pop("decoder_config")
inputs_dict = config_inputs_dict
# `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"]
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = np.concatenate(
[np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1
)
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
decoder_config.use_cache = False
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
# check without `enc_to_dec_proj` projection
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
# check `enc_to_dec_proj` work as expected
decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()

View File

@ -20,15 +20,8 @@ import unittest
import numpy as np
from transformers.testing_utils import (
is_pt_flax_cross_test,
require_flax,
require_torch,
require_vision,
slow,
torch_device,
)
from transformers.utils import is_flax_available, is_torch_available, is_vision_available
from transformers.testing_utils import require_flax, require_torch, require_vision, slow
from transformers.utils import is_flax_available, is_vision_available
from ...test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
from ..bert.test_modeling_flax_bert import FlaxBertModelTester
@ -45,17 +38,8 @@ if is_flax_available():
VisionTextDualEncoderConfig,
VisionTextDualEncoderProcessor,
)
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
if is_torch_available():
import torch
from transformers import VisionTextDualEncoderModel
if is_vision_available():
from PIL import Image
@ -154,68 +138,6 @@ class VisionTextDualEncoderMixin:
(text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]),
)
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = VisionTextDualEncoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
pt_model = VisionTextDualEncoderModel(config)
fx_model = FlaxVisionTextDualEncoderModel(config)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict):
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
pt_model = VisionTextDualEncoderModel(config)
fx_model = FlaxVisionTextDualEncoderModel(config)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def test_model_from_pretrained_configs(self):
inputs_dict = self.prepare_config_and_inputs()
self.check_model_from_pretrained_configs(**inputs_dict)
@ -232,17 +154,6 @@ class VisionTextDualEncoderMixin:
inputs_dict = self.prepare_config_and_inputs()
self.check_vision_text_output_attention(**inputs_dict)
@is_pt_flax_cross_test
def test_pt_flax_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
vision_config = config_inputs_dict.pop("vision_config")
text_config = config_inputs_dict.pop("text_config")
inputs_dict = config_inputs_dict
self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict)
self.check_equivalence_flax_to_pt(vision_config, text_config, inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2, inputs = self.get_pretrained_model_and_inputs()

View File

@ -20,8 +20,8 @@ import unittest
import numpy as np
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_vision, slow, torch_device
from transformers.utils import is_flax_available, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from ..bert.test_modeling_bert import BertModelTester
@ -44,12 +44,6 @@ if is_torch_available():
ViTModel,
)
if is_flax_available():
from transformers import FlaxVisionTextDualEncoderModel
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
if is_vision_available():
from PIL import Image
@ -172,69 +166,6 @@ class VisionTextDualEncoderMixin:
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mask, pixel_values, **kwargs):
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
pt_inputs = inputs_dict
flax_inputs = {k: v.numpy(force=True) for k, v in pt_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**flax_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = VisionTextDualEncoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
pt_model = VisionTextDualEncoderModel(config)
fx_model = FlaxVisionTextDualEncoderModel(config)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict)
def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict):
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
pt_model = VisionTextDualEncoderModel(config)
fx_model = FlaxVisionTextDualEncoderModel(config)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict)
def test_vision_text_dual_encoder_model(self):
inputs_dict = self.prepare_config_and_inputs()
self.check_vision_text_dual_encoder_model(**inputs_dict)
@ -255,17 +186,6 @@ class VisionTextDualEncoderMixin:
inputs_dict = self.prepare_config_and_inputs()
self.check_vision_text_output_attention(**inputs_dict)
@is_pt_flax_cross_test
def test_pt_flax_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
vision_config = config_inputs_dict.pop("vision_config")
text_config = config_inputs_dict.pop("text_config")
inputs_dict = config_inputs_dict
self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict)
self.check_equivalence_flax_to_pt(vision_config, text_config, inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2, inputs = self.get_pretrained_model_and_inputs()
@ -429,10 +349,6 @@ class DeiTRobertaModelTest(VisionTextDualEncoderMixin, unittest.TestCase):
"text_choice_labels": choice_labels,
}
@unittest.skip(reason="DeiT is not available in Flax")
def test_pt_flax_equivalence(self):
pass
@require_torch
class CLIPVisionBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase):

View File

@ -24,9 +24,7 @@ from datasets import load_dataset
from transformers import Wav2Vec2Config, is_flax_available
from transformers.testing_utils import (
CaptureLogger,
is_flaky,
is_librosa_available,
is_pt_flax_cross_test,
is_pyctcdecode_available,
require_flax,
require_librosa,
@ -350,11 +348,6 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
outputs = model(np.ones((1, 1024), dtype="f4"))
self.assertIsNotNone(outputs)
@is_pt_flax_cross_test
@is_flaky()
def test_equivalence_pt_to_flax(self):
super().test_equivalence_pt_to_flax()
@require_flax
class FlaxWav2Vec2UtilsTest(unittest.TestCase):

View File

@ -31,7 +31,6 @@ from transformers.testing_utils import (
CaptureLogger,
cleanup,
is_flaky,
is_pt_flax_cross_test,
is_pyctcdecode_available,
is_torchaudio_available,
require_flash_attn,
@ -569,16 +568,6 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
def test_model_get_set_embeddings(self):
pass
@is_pt_flax_cross_test
@unittest.skip(reason="Non-rubst architecture does not exist in Flax")
def test_equivalence_flax_to_pt(self):
pass
@is_pt_flax_cross_test
@unittest.skip(reason="Non-rubst architecture does not exist in Flax")
def test_equivalence_pt_to_flax(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True

View File

@ -21,7 +21,6 @@ from datasets import load_dataset
from transformers import Wav2Vec2BertConfig, is_torch_available
from transformers.testing_utils import (
is_pt_flax_cross_test,
require_torch,
require_torch_accelerator,
require_torch_fp16,
@ -559,18 +558,6 @@ class Wav2Vec2BertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
def test_model_get_set_embeddings(self):
pass
# Ignore copy
@unittest.skip(reason="non-robust architecture does not exist in Flax")
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
pass
# Ignore copy
@unittest.skip(reason="non-robust architecture does not exist in Flax")
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True

View File

@ -24,7 +24,6 @@ from datasets import load_dataset
from transformers import Wav2Vec2ConformerConfig, is_torch_available
from transformers.testing_utils import (
is_flaky,
is_pt_flax_cross_test,
require_torch,
require_torch_accelerator,
require_torch_fp16,
@ -535,16 +534,6 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest
def test_model_get_set_embeddings(self):
pass
@is_pt_flax_cross_test
@unittest.skip(reason="Non-robust architecture does not exist in Flax")
def test_equivalence_flax_to_pt(self):
pass
@is_pt_flax_cross_test
@unittest.skip(reason="Non-robust architecture does not exist in Flax")
def test_equivalence_pt_to_flax(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True

View File

@ -17,9 +17,8 @@ import inspect
import tempfile
import unittest
import transformers
from transformers import WhisperConfig, is_flax_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
from transformers.testing_utils import require_flax, slow
from transformers.utils import cached_property
from transformers.utils.import_utils import is_datasets_available
@ -45,7 +44,6 @@ if is_flax_available():
WhisperFeatureExtractor,
WhisperProcessor,
)
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init
@ -245,99 +243,6 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
# We override with a slightly higher tol value, as test recently became flaky
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
# overwrite because of `input_features`
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
for model_class in self.all_model_classes:
if model_class.__name__ == base_class.__name__:
continue
model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite because of `input_features`
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
for model_class in self.all_model_classes:
if model_class.__name__ == base_class.__name__:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
# save pt model
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite because of `input_features`
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
for model_class in self.all_model_classes:
if model_class.__name__ == base_class.__name__:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite because of `input_features`
def test_save_load_from_base(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@ -899,18 +804,3 @@ class WhisperEncoderModelTest(FlaxModelTesterMixin, unittest.TestCase):
# WhisperEncoder does not have any base model
def test_save_load_from_base(self):
pass
# WhisperEncoder does not have any base model
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
pass
# WhisperEncoder does not have any base model
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
pass
# WhisperEncoder does not have any base model
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
pass

View File

@ -32,7 +32,6 @@ import transformers
from transformers import WhisperConfig
from transformers.testing_utils import (
is_flaky,
is_pt_flax_cross_test,
require_flash_attn,
require_non_xpu,
require_torch,
@ -44,7 +43,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils import cached_property, is_flax_available, is_torch_available, is_torchaudio_available
from transformers.utils import cached_property, is_torch_available, is_torchaudio_available
from transformers.utils.import_utils import is_datasets_available
from ...generation.test_utils import GenerationTesterMixin
@ -155,15 +154,6 @@ if is_torchaudio_available():
import torchaudio
if is_flax_available():
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
def prepare_whisper_inputs_dict(
config,
input_features,
@ -1069,161 +1059,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertTrue(models_equal)
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
# We override with a slightly higher tol value, as test recently became flaky
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
self.skipTest(reason="No Flax model exists for this class")
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
fx_model_class = getattr(transformers, fx_model_class_name)
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
# load Flax class
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}
# convert inputs to Flax
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
self.skipTest(reason="No Flax model exists for this class")
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
fx_model_class = getattr(transformers, fx_model_class_name)
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
# load Flax class
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}
# convert inputs to Flax
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
def test_mask_feature_prob(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.mask_feature_prob = 0.2
@ -3622,157 +3457,6 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
def test_resize_tokens_embeddings(self):
pass
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
self.skipTest(reason="Flax model does not exist")
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
fx_model_class = getattr(transformers, fx_model_class_name)
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
# load Flax class
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}
# convert inputs to Flax
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
self.skipTest("Flax model does not exist")
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
fx_model_class = getattr(transformers, fx_model_class_name)
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
# load Flax class
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}
# convert inputs to Flax
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
class WhisperStandaloneDecoderModelTester:
def __init__(

View File

@ -14,12 +14,10 @@
# limitations under the License.
import tempfile
import unittest
import transformers
from transformers import XGLMConfig, XGLMTokenizer, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_sentencepiece, slow
from transformers import XGLMConfig, XGLMTokenizer, is_flax_available
from transformers.testing_utils import require_flax, require_sentencepiece, slow
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
@ -29,17 +27,9 @@ if is_flax_available():
import jax.numpy as jnp
import numpy as np
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.models.xglm.modeling_flax_xglm import FlaxXGLMForCausalLM, FlaxXGLMModel
if is_torch_available():
import torch
@require_flax
class FlaxXGLMModelTester:
def __init__(
@ -220,108 +210,6 @@ class FlaxXGLMModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertListEqual(output_string, expected_string)
# overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
batch_size, seq_length = pt_inputs["input_ids"].shape
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
pt_model = pt_model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)
# overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
pt_model.config.use_cache = False
fx_model = model_class(config, dtype=jnp.float32)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
batch_size, seq_length = pt_inputs["input_ids"].shape
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
# make sure weights are tied in PyTorch
pt_model.tie_weights()
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:

View File

@ -32,7 +32,6 @@ from packaging import version
from parameterized import parameterized
from pytest import mark
import transformers
from transformers import (
AutoModel,
AutoModelForCausalLM,
@ -75,7 +74,6 @@ from transformers.models.auto.modeling_auto import (
from transformers.testing_utils import (
CaptureLogger,
is_flaky,
is_pt_flax_cross_test,
require_accelerate,
require_bitsandbytes,
require_deepspeed,
@ -100,14 +98,12 @@ from transformers.utils import (
GENERATION_CONFIG_NAME,
SAFE_WEIGHTS_NAME,
is_accelerate_available,
is_flax_available,
is_tf_available,
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
is_torch_fx_available,
is_torch_sdpa_available,
)
from transformers.utils.generic import ContextManagers, ModelOutput
from transformers.utils.generic import ContextManagers
if is_accelerate_available():
@ -126,19 +122,6 @@ if is_torch_available():
from transformers.modeling_utils import load_state_dict, no_init_weights
from transformers.pytorch_utils import id_tensor_storage
if is_tf_available():
pass
if is_flax_available():
import jax.numpy as jnp
from tests.utils.test_modeling_flax_utils import check_models_equal
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
if is_torch_fx_available():
from transformers.utils.fx import _FX_SUPPORTED_MODELS_WITH_KV_CACHE, symbolic_trace
@ -2552,249 +2535,6 @@ class ModelTesterMixin:
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
Currently unused, but it could make debugging easier and faster.
names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
self.assertEqual(type(name), str)
if attributes is not None:
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
# Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
if isinstance(fx_outputs, ModelOutput):
self.assertTrue(
isinstance(pt_outputs, ModelOutput),
f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is",
)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch")
# convert to the case of `tuple`
# appending each key to the current (string) `name`
attributes = tuple([f"{name}.{k}" for k in fx_keys])
self.check_pt_flax_outputs(
fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
)
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
elif type(fx_outputs) in [tuple, list]:
self.assertEqual(
type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch"
)
self.assertEqual(
len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch"
)
if attributes is not None:
# case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
self.assertEqual(
len(attributes),
len(fx_outputs),
f"{name}: The tuple `attributes` should have the same length as `fx_outputs`",
)
else:
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])
for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)
elif isinstance(fx_outputs, jnp.ndarray):
self.assertTrue(
isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is"
)
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
fx_outputs = np.array(fx_outputs)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
self.assertEqual(
fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch"
)
# deal with NumPy's scalars to make replacing nan values by 0 work.
if np.isscalar(fx_outputs):
fx_outputs = np.array([fx_outputs])
pt_outputs = np.array([pt_outputs])
fx_nans = np.isnan(fx_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[fx_nans] = 0
fx_outputs[fx_nans] = 0
pt_outputs[pt_nans] = 0
fx_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
self.assertLessEqual(
max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})."
)
else:
raise ValueError(
"`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got"
f" {type(fx_outputs)} instead."
)
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
self.skipTest(reason="No Flax model exists for this class")
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
fx_model_class = getattr(transformers, fx_model_class_name)
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
# load Flax class
fx_model = fx_model_class(config, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}
# convert inputs to Flax
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name):
self.skipTest(reason="No Flax model exists for this class")
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
fx_model_class = getattr(transformers, fx_model_class_name)
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
# load Flax class
fx_model = fx_model_class(config, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}
# convert inputs to Flax
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(
tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
)
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -4413,29 +4153,6 @@ class ModelTesterMixin:
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
@is_pt_flax_cross_test
def test_flax_from_pt_safetensors(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
flax_model_class_name = "Flax" + model_class.__name__ # Add the "Flax at the beginning
if not hasattr(transformers, flax_model_class_name):
self.skipTest(reason="transformers does not have this model in Flax version yet")
flax_model_class = getattr(transformers, flax_model_class_name)
pt_model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname, safe_serialization=True)
flax_model_1 = flax_model_class.from_pretrained(tmpdirname, from_pt=True)
pt_model.save_pretrained(tmpdirname, safe_serialization=False)
flax_model_2 = flax_model_class.from_pretrained(tmpdirname, from_pt=True)
# Check models are equal
self.assertTrue(check_models_equal(flax_model_1, flax_model_2))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test

View File

@ -21,13 +21,10 @@ from typing import List, Tuple
import numpy as np
import transformers
from transformers import is_flax_available, is_torch_available
from transformers.cache_utils import DynamicCache
from transformers import is_flax_available
from transformers.models.auto import get_values
from transformers.testing_utils import CaptureLogger, is_pt_flax_cross_test, require_flax, torch_device
from transformers.testing_utils import CaptureLogger, require_flax
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
from transformers.utils.generic import ModelOutput
if is_flax_available():
@ -47,17 +44,10 @@ if is_flax_available():
FlaxAutoModelForSequenceClassification,
FlaxBertModel,
)
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.modeling_flax_utils import FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
if is_torch_available():
import torch
def ids_tensor(shape, vocab_size, rng=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
@ -184,216 +174,6 @@ class FlaxModelTesterMixin:
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
# (Copied from tests.test_modeling_common.ModelTesterMixin.check_pt_flax_outputs)
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
Currently unused, but it could make debugging easier and faster.
names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
self.assertEqual(type(name), str)
if attributes is not None:
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
# Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
if isinstance(fx_outputs, ModelOutput):
self.assertTrue(
isinstance(pt_outputs, ModelOutput),
f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is",
)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch")
# convert to the case of `tuple`
# appending each key to the current (string) `name`
attributes = tuple([f"{name}.{k}" for k in fx_keys])
self.check_pt_flax_outputs(
fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
)
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
elif type(fx_outputs) in [tuple, list]:
self.assertEqual(
type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch"
)
self.assertEqual(
len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch"
)
if attributes is not None:
# case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
self.assertEqual(
len(attributes),
len(fx_outputs),
f"{name}: The tuple `attributes` should have the same length as `fx_outputs`",
)
else:
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])
for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
if isinstance(pt_output, DynamicCache):
pt_output = pt_output.to_legacy_cache()
self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)
elif isinstance(fx_outputs, jnp.ndarray):
self.assertTrue(
isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is"
)
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
fx_outputs = np.array(fx_outputs)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
self.assertEqual(
fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch"
)
# deal with NumPy's scalars to make replacing nan values by 0 work.
if np.isscalar(fx_outputs):
fx_outputs = np.array([fx_outputs])
pt_outputs = np.array([pt_outputs])
fx_nans = np.isnan(fx_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[fx_nans] = 0
fx_outputs[fx_nans] = 0
pt_outputs[pt_nans] = 0
fx_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
self.assertLessEqual(
max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})."
)
else:
raise ValueError(
"`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got"
f" {type(fx_outputs)} instead."
)
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
# It might be better to put this inside the for loop below (because we modify the config there).
# But logically, it is fine.
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**prepared_inputs_dict)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model = model_class(config, dtype=jnp.float32)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**prepared_inputs_dict)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(
tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
)
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -474,92 +254,6 @@ class FlaxModelTesterMixin:
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = get_params(model.params)
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
# save pt model
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = get_params(base_model.params)
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = get_params(base_model.params)
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -1119,14 +813,6 @@ class FlaxModelTesterMixin:
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
@is_pt_flax_cross_test
def test_from_sharded_pt(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only")
for key, ref_val in flatten_dict(ref_model.params).items():
val = flatten_dict(model.params)[key]
assert np.allclose(np.array(val), np.array(ref_val))
def test_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -31,13 +31,11 @@ from datasets import Dataset
from transformers import is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import ( # noqa: F401
from transformers.testing_utils import (
CaptureLogger,
_tf_gpu_memory_limit,
require_tf,
require_tf2onnx,
slow,
torch_device,
)
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
from transformers.utils.generic import ModelOutput
@ -73,20 +71,6 @@ if is_tf_available():
tf.config.experimental.enable_tensor_float_32_execution(False)
if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU")
for gpu in gpus:
# Restrict TensorFlow to only allocate x GB of memory on the GPUs
try:
tf.config.set_logical_device_configuration(
gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
)
logical_gpus = tf.config.list_logical_devices("GPU")
print("Logical GPUs", logical_gpus)
except RuntimeError as e:
# Virtual devices must be set before GPUs have been initialized
print(e)
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)

View File

@ -18,16 +18,14 @@ import unittest
import numpy as np
from huggingface_hub import HfFolder, snapshot_download
from transformers import BertConfig, BertModel, is_flax_available, is_torch_available
from transformers import BertConfig, is_flax_available
from transformers.testing_utils import (
TOKEN,
CaptureLogger,
TemporaryHubRepo,
is_pt_flax_cross_test,
is_staging_test,
require_flax,
require_safetensors,
require_torch,
)
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging
@ -42,9 +40,6 @@ if is_flax_available():
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
if is_torch_available():
import torch
@require_flax
@is_staging_test
@ -205,23 +200,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
self.assertTrue(check_models_equal(model, new_model))
@require_flax
@require_torch
@is_pt_flax_cross_test
def test_safetensors_save_and_load_pt_to_flax(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
pt_model.save_pretrained(tmp_dir)
# Check we have a model.safetensors file
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
new_model = FlaxBertModel.from_pretrained(tmp_dir)
# Check models are equal
self.assertTrue(check_models_equal(model, new_model))
@require_safetensors
def test_safetensors_load_from_hub(self):
"""
@ -248,58 +226,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
@is_pt_flax_cross_test
def test_safetensors_load_from_hub_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-msgpack")
# Can load from the PyTorch-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
@require_torch
@is_pt_flax_cross_test
def test_safetensors_load_from_hub_from_safetensors_pt_bf16(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
import torch
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
model.to(torch.bfloat16)
with tempfile.TemporaryDirectory() as tmp:
model.save_pretrained(tmp)
flax_model = FlaxBertModel.from_pretrained(tmp)
# Can load from the PyTorch-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
@is_pt_flax_cross_test
def test_safetensors_load_from_local_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp)
flax_model = FlaxBertModel.from_pretrained(location)
# Can load from the PyTorch-formatted checkpoint
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
safetensors_model = FlaxBertModel.from_pretrained(location)
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
"""
@ -328,19 +254,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
self.assertTrue(check_models_equal(model, new_model))
@require_safetensors
@require_torch
@is_pt_flax_cross_test
def test_safetensors_flax_from_torch(self):
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
new_model = FlaxBertModel.from_pretrained(tmp_dir)
self.assertTrue(check_models_equal(hub_model, new_model))
@require_safetensors
def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_local(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@ -370,27 +283,3 @@ class FlaxModelUtilsTest(unittest.TestCase):
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
in cl.out
)
@require_torch
@require_safetensors
@is_pt_flax_cross_test
def test_from_pt_bf16(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
model.to(torch.bfloat16)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=False)
logger = logging.get_logger("transformers.modeling_flax_utils")
with CaptureLogger(logger) as cl:
new_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
self.assertTrue(
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
in cl.out
)
flat_params_1 = flatten_dict(new_model.params)
for value in flat_params_1.values():
self.assertEqual(value.dtype, "bfloat16")

View File

@ -24,7 +24,7 @@ from math import isnan
from transformers import is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow
from transformers.testing_utils import require_tf, slow
from ..test_modeling_tf_common import ids_tensor
@ -48,20 +48,6 @@ if is_tf_available():
)
from transformers.modeling_tf_utils import keras
if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU")
for gpu in gpus:
# Restrict TensorFlow to only allocate x GB of memory on the GPUs
try:
tf.config.set_logical_device_configuration(
gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
)
logical_gpus = tf.config.list_logical_devices("GPU")
print("Logical GPUs", logical_gpus)
except RuntimeError as e:
# Virtual devices must be set before GPUs have been initialized
print(e)
@require_tf
class TFCoreModelTesterMixin:

View File

@ -33,7 +33,6 @@ from transformers.testing_utils import ( # noqa: F401
USER,
CaptureLogger,
TemporaryHubRepo,
_tf_gpu_memory_limit,
is_staging_test,
require_safetensors,
require_tf,
@ -68,20 +67,6 @@ if is_tf_available():
tf.config.experimental.enable_tensor_float_32_execution(False)
if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU")
for gpu in gpus:
# Restrict TensorFlow to only allocate x GB of memory on the GPUs
try:
tf.config.set_logical_device_configuration(
gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
)
logical_gpus = tf.config.list_logical_devices("GPU")
print("Logical GPUs", logical_gpus)
except RuntimeError as e:
# Virtual devices must be set before GPUs have been initialized
print(e)
@require_tf
class TFModelUtilsTest(unittest.TestCase):

View File

@ -1148,7 +1148,6 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
JOB_TO_TEST_FILE = {
"tests_torch_and_flax": r"tests/models/.*/test_modeling_(?:flax|(?!tf)).*",
"tests_tf": r"tests/models/.*/test_modeling_tf_.*",
"tests_torch": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*",
"tests_generate": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*",