mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[tests] remove flax-pt equivalence and cross tests (#36283)
This commit is contained in:
parent
fa8cdccd91
commit
99adc74462
@ -28,7 +28,6 @@ COMMON_ENV_VARIABLES = {
|
||||
"TRANSFORMERS_IS_CI": True,
|
||||
"PYTEST_TIMEOUT": 120,
|
||||
"RUN_PIPELINE_TESTS": False,
|
||||
"RUN_PT_FLAX_CROSS_TESTS": False,
|
||||
}
|
||||
# Disable the use of {"s": None} as the output is way too long, causing the navigation on CircleCI impractical
|
||||
COMMON_PYTEST_OPTIONS = {"max-worker-restart": 0, "dist": "loadfile", "vvv": None, "rsfE":None}
|
||||
@ -176,14 +175,6 @@ class CircleCIJob:
|
||||
|
||||
|
||||
# JOBS
|
||||
torch_and_flax_job = CircleCIJob(
|
||||
"torch_and_flax",
|
||||
additional_env={"RUN_PT_FLAX_CROSS_TESTS": True},
|
||||
docker_image=[{"image":"huggingface/transformers-torch-jax-light"}],
|
||||
marker="is_pt_flax_cross_test",
|
||||
pytest_options={"rA": None, "durations": 0},
|
||||
)
|
||||
|
||||
torch_job = CircleCIJob(
|
||||
"torch",
|
||||
docker_image=[{"image": "huggingface/transformers-torch-light"}],
|
||||
@ -343,7 +334,7 @@ doc_test_job = CircleCIJob(
|
||||
pytest_num_workers=1,
|
||||
)
|
||||
|
||||
REGULAR_TESTS = [torch_and_flax_job, torch_job, tf_job, flax_job, hub_job, onnx_job, tokenization_job, processor_job, generate_job, non_model_job] # fmt: skip
|
||||
REGULAR_TESTS = [torch_job, tf_job, flax_job, hub_job, onnx_job, tokenization_job, processor_job, generate_job, non_model_job] # fmt: skip
|
||||
EXAMPLES_TESTS = [examples_torch_job, examples_tensorflow_job]
|
||||
PIPELINE_TESTS = [pipelines_torch_job, pipelines_tf_job]
|
||||
REPO_UTIL_TESTS = [repo_utils_job]
|
||||
|
6
.github/workflows/build-ci-docker-images.yml
vendored
6
.github/workflows/build-ci-docker-images.yml
vendored
@ -26,7 +26,7 @@ jobs:
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
file: ["quality", "consistency", "custom-tokenizers", "torch-light", "tf-light", "exotic-models", "torch-tf-light", "torch-jax-light", "jax-light", "examples-torch", "examples-tf"]
|
||||
file: ["quality", "consistency", "custom-tokenizers", "torch-light", "tf-light", "exotic-models", "torch-tf-light", "jax-light", "examples-torch", "examples-tf"]
|
||||
continue-on-error: true
|
||||
|
||||
steps:
|
||||
@ -34,11 +34,11 @@ jobs:
|
||||
name: Set tag
|
||||
run: |
|
||||
if ${{contains(github.event.head_commit.message, '[build-ci-image]')}}; then
|
||||
echo "TAG=huggingface/transformers-${{ matrix.file }}:dev" >> "$GITHUB_ENV"
|
||||
echo "TAG=huggingface/transformers-${{ matrix.file }}:dev" >> "$GITHUB_ENV"
|
||||
echo "setting it to DEV!"
|
||||
else
|
||||
echo "TAG=huggingface/transformers-${{ matrix.file }}" >> "$GITHUB_ENV"
|
||||
|
||||
|
||||
fi
|
||||
-
|
||||
name: Set up Docker Buildx
|
||||
|
@ -343,7 +343,6 @@ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/t
|
||||
|
||||
Like the slow tests, there are other environment variables available which are not enabled by default during testing:
|
||||
- `RUN_CUSTOM_TOKENIZERS`: Enables tests for custom tokenizers.
|
||||
- `RUN_PT_FLAX_CROSS_TESTS`: Enables tests for PyTorch + Flax integration.
|
||||
|
||||
More environment variables and additional information can be found in the [testing_utils.py](https://github.com/huggingface/transformers/blob/main/src/transformers/testing_utils.py).
|
||||
|
||||
|
@ -84,9 +84,6 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line(
|
||||
"markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested"
|
||||
)
|
||||
config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested")
|
||||
config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment")
|
||||
config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate")
|
||||
|
@ -283,7 +283,6 @@ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/t
|
||||
Wie bei den langsamen Tests gibt es auch andere Umgebungsvariablen, die standardmäßig beim Testen nicht gesetzt sind:
|
||||
|
||||
* `RUN_CUSTOM_TOKENIZERS`: Aktiviert Tests für benutzerdefinierte Tokenizer.
|
||||
* `RUN_PT_FLAX_CROSS_TESTS`: Aktiviert Tests für die Integration von PyTorch + Flax.
|
||||
|
||||
Weitere Umgebungsvariablen und zusätzliche Informationen finden Sie in der [testing_utils.py](src/transformers/testing_utils.py).
|
||||
|
||||
|
@ -282,7 +282,6 @@ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/t
|
||||
|
||||
느린 테스트와 마찬가지로, 다음과 같이 테스트 중에 기본적으로 활성화되지 않는 다른 환경 변수도 있습니다:
|
||||
- `RUN_CUSTOM_TOKENIZERS`: 사용자 정의 토크나이저 테스트를 활성화합니다.
|
||||
- `RUN_PT_FLAX_CROSS_TESTS`: PyTorch + Flax 통합 테스트를 활성화합니다.
|
||||
|
||||
더 많은 환경 변수와 추가 정보는 [testing_utils.py](src/transformers/testing_utils.py)에서 찾을 수 있습니다.
|
||||
|
||||
|
@ -281,7 +281,6 @@ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/t
|
||||
|
||||
和时间较长的测试一样,还有其他环境变量在测试过程中,在默认情况下是未启用的:
|
||||
- `RUN_CUSTOM_TOKENIZERS`: 启用自定义分词器的测试。
|
||||
- `RUN_PT_FLAX_CROSS_TESTS`: 启用 PyTorch + Flax 整合的测试。
|
||||
|
||||
更多环境变量和额外信息可以在 [testing_utils.py](src/transformers/testing_utils.py) 中找到。
|
||||
|
||||
|
1
setup.py
1
setup.py
@ -473,7 +473,6 @@ setup(
|
||||
extras["tests_torch"] = deps_list()
|
||||
extras["tests_tf"] = deps_list()
|
||||
extras["tests_flax"] = deps_list()
|
||||
extras["tests_torch_and_flax"] = deps_list()
|
||||
extras["tests_hub"] = deps_list()
|
||||
extras["tests_pipelines_torch"] = deps_list()
|
||||
extras["tests_pipelines_tf"] = deps_list()
|
||||
|
@ -230,10 +230,8 @@ def parse_int_from_env(key, default=None):
|
||||
|
||||
|
||||
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
||||
_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=True)
|
||||
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
|
||||
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
|
||||
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
|
||||
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
|
||||
_run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False)
|
||||
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)
|
||||
@ -250,25 +248,6 @@ def get_device_count():
|
||||
return num_devices
|
||||
|
||||
|
||||
def is_pt_flax_cross_test(test_case):
|
||||
"""
|
||||
Decorator marking a test as a test that control interactions between PyTorch and Flax
|
||||
|
||||
PT+FLAX tests are skipped by default and we can run only them by setting RUN_PT_FLAX_CROSS_TESTS environment
|
||||
variable to a truthy value and selecting the is_pt_flax_cross_test pytest mark.
|
||||
|
||||
"""
|
||||
if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available():
|
||||
return unittest.skip(reason="test is PT+FLAX test")(test_case)
|
||||
else:
|
||||
try:
|
||||
import pytest # We don't need a hard dependency on pytest in the main library
|
||||
except ImportError:
|
||||
return test_case
|
||||
else:
|
||||
return pytest.mark.is_pt_flax_cross_test()(test_case)
|
||||
|
||||
|
||||
def is_staging_test(test_case):
|
||||
"""
|
||||
Decorator marking a test as a staging test.
|
||||
|
@ -624,15 +624,6 @@ class BigBirdModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
# overwrite from common in order to skip the check on `attentions`
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
|
||||
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
|
||||
# an effort was done to return `attention_probs` (yet to be verified).
|
||||
if name.startswith("outputs.attentions"):
|
||||
return
|
||||
else:
|
||||
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
|
@ -212,12 +212,3 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
# overwrite from common in order to skip the check on `attentions`
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
|
||||
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
|
||||
# an effort was done to return `attention_probs` (yet to be verified).
|
||||
if name.startswith("outputs.attentions"):
|
||||
return
|
||||
else:
|
||||
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
@ -25,11 +25,8 @@ import requests
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
import transformers
|
||||
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
||||
from transformers.testing_utils import (
|
||||
is_flax_available,
|
||||
is_pt_flax_cross_test,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
@ -82,15 +79,6 @@ if is_vision_available():
|
||||
from transformers import CLIPProcessor
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
@ -883,126 +871,6 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
text_config = CLIPTextConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||
|
||||
# overwrite from common since FlaxCLIPModel returns nested output
|
||||
# which is not supported in the common test
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
pt_model.to(torch_device)
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple()
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# overwrite from common since FlaxCLIPModel returns nested output
|
||||
# which is not supported in the common test
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# load corresponding PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
pt_model.to(torch_device)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
pt_model_loaded.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "openai/clip-vit-base-patch32"
|
||||
|
@ -4,21 +4,15 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
from transformers.models.clip.modeling_flax_clip import (
|
||||
FlaxCLIPModel,
|
||||
FlaxCLIPTextModel,
|
||||
@ -26,9 +20,6 @@ if is_flax_available():
|
||||
FlaxCLIPVisionModel,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class FlaxCLIPVisionModelTester:
|
||||
def __init__(
|
||||
@ -223,21 +214,6 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
def test_save_load_to_base(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
@ -333,21 +309,6 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
def test_save_load_to_base(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
@ -472,92 +433,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
outputs = model(input_ids=np.ones((1, 1)), pixel_values=np.ones((1, 3, 224, 224)))
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
# overwrite from common since FlaxCLIPModel returns nested output
|
||||
# which is not supported in the common test
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||
|
||||
# overwrite from common since FlaxCLIPModel returns nested output
|
||||
# which is not supported in the common test
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
# overwrite from common since FlaxCLIPModel returns nested output
|
||||
# which is not supported in the common test
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
|
@ -22,11 +22,8 @@ import unittest
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
import transformers
|
||||
from transformers import CLIPSegConfig, CLIPSegProcessor, CLIPSegTextConfig, CLIPSegVisionConfig
|
||||
from transformers.testing_utils import (
|
||||
is_flax_available,
|
||||
is_pt_flax_cross_test,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
@ -57,15 +54,6 @@ if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
|
||||
class CLIPSegVisionModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
@ -635,123 +623,6 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
text_config = CLIPSegTextConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||
|
||||
# overwrite from common since FlaxCLIPSegModel returns nested output
|
||||
# which is not supported in the common test
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple()
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||
|
||||
# overwrite from common since FlaxCLIPSegModel returns nested output
|
||||
# which is not supported in the common test
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# load corresponding PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
def test_training(self):
|
||||
if not self.model_tester.is_training:
|
||||
self.skipTest(reason="Training test is skipped as the model was not trained")
|
||||
|
@ -22,7 +22,7 @@ from datasets import load_dataset
|
||||
|
||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from transformers import Data2VecAudioConfig, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_soundfile, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
@ -442,16 +442,6 @@ class Data2VecAudioModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
# non-robust architecture does not exist in Flax
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
pass
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
# non-robust architecture does not exist in Flax
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
pass
|
||||
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
|
@ -19,8 +19,8 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
|
||||
from transformers import is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ...test_modeling_flax_common import ids_tensor
|
||||
from ..bart.test_modeling_flax_bart import FlaxBartStandaloneDecoderModelTester
|
||||
@ -38,15 +38,6 @@ if is_flax_available():
|
||||
FlaxEncoderDecoderModel,
|
||||
FlaxGPT2LMHeadModel,
|
||||
)
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import EncoderDecoderModel
|
||||
|
||||
|
||||
@require_flax
|
||||
@ -291,68 +282,6 @@ class FlaxEncoderDecoderMixin:
|
||||
generated_sequences = generated_output.sequences
|
||||
self.assertEqual(generated_sequences.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = EncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
|
||||
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
pt_model = EncoderDecoderModel(encoder_decoder_config)
|
||||
fx_model = FlaxEncoderDecoderModel(encoder_decoder_config)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
|
||||
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
pt_model = EncoderDecoderModel(encoder_decoder_config)
|
||||
fx_model = FlaxEncoderDecoderModel(encoder_decoder_config)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def test_encoder_decoder_model_from_pretrained_configs(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
|
||||
@ -385,40 +314,6 @@ class FlaxEncoderDecoderMixin:
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_pt_flax_equivalence(self):
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
config = config_inputs_dict.pop("config")
|
||||
decoder_config = config_inputs_dict.pop("decoder_config")
|
||||
|
||||
inputs_dict = config_inputs_dict
|
||||
# `encoder_hidden_states` is not used in model call/forward
|
||||
del inputs_dict["encoder_hidden_states"]
|
||||
|
||||
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||
inputs_dict["decoder_attention_mask"] = np.concatenate(
|
||||
[np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1
|
||||
)
|
||||
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
decoder_config.use_cache = False
|
||||
|
||||
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
|
||||
|
||||
# check without `enc_to_dec_proj` projection
|
||||
decoder_config.hidden_size = config.hidden_size
|
||||
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
|
||||
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
# check `enc_to_dec_proj` work as expected
|
||||
decoder_config.hidden_size = decoder_config.hidden_size * 2
|
||||
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
|
||||
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2 = self.get_pretrained_model()
|
||||
|
@ -13,14 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
from transformers import GPT2Config, GPT2Tokenizer, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
@ -29,15 +27,8 @@ if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class FlaxGPT2ModelTester:
|
||||
def __init__(
|
||||
@ -255,105 +246,6 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertListEqual(output_string, expected_string)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
batch_size, seq_length = pt_inputs["input_ids"].shape
|
||||
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
|
||||
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
|
||||
pt_model = pt_model_class(config).eval()
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
batch_size, seq_length = pt_inputs["input_ids"].shape
|
||||
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
|
||||
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
|
@ -13,14 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import GPT2Tokenizer, GPTNeoConfig, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
from transformers import GPT2Tokenizer, GPTNeoConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
@ -29,15 +27,8 @@ if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
from transformers.models.gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class FlaxGPTNeoModelTester:
|
||||
def __init__(
|
||||
@ -224,105 +215,6 @@ class FlaxGPTNeoModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertListEqual(output_string, expected_string)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
batch_size, seq_length = pt_inputs["input_ids"].shape
|
||||
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
|
||||
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
|
||||
pt_model = pt_model_class(config).eval()
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
batch_size, seq_length = pt_inputs["input_ids"].shape
|
||||
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
|
||||
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
|
@ -13,14 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import GPT2Tokenizer, GPTJConfig, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, tooslow
|
||||
from transformers import GPT2Tokenizer, GPTJConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, tooslow
|
||||
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
@ -29,15 +27,8 @@ if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
from transformers.models.gptj.modeling_flax_gptj import FlaxGPTJForCausalLM, FlaxGPTJModel
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class FlaxGPTJModelTester:
|
||||
def __init__(
|
||||
@ -221,105 +212,6 @@ class FlaxGPTJModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertListEqual(output_string, expected_string)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
batch_size, seq_length = pt_inputs["input_ids"].shape
|
||||
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
|
||||
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
|
||||
pt_model = pt_model_class(config).eval()
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
batch_size, seq_length = pt_inputs["input_ids"].shape
|
||||
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
|
||||
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
@tooslow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
|
@ -17,11 +17,9 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import is_flax_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
is_pt_flax_cross_test,
|
||||
require_flax,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
@ -46,7 +44,6 @@ if is_flax_available():
|
||||
from flax.traverse_util import flatten_dict
|
||||
|
||||
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING, AutoTokenizer, LongT5Config
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
from transformers.models.longt5.modeling_flax_longt5 import (
|
||||
FlaxLongT5ForConditionalGeneration,
|
||||
FlaxLongT5Model,
|
||||
@ -467,95 +464,6 @@ class FlaxLongT5ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
[self.model_tester.num_attention_heads, block_len, 3 * block_len],
|
||||
)
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# save pt model
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params))
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
|
||||
class FlaxLongT5TGlobalModelTest(FlaxLongT5ModelTest):
|
||||
def setUp(self):
|
||||
|
@ -19,7 +19,7 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ...test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from ..bart.test_modeling_flax_bart import FlaxBartStandaloneDecoderModelTester
|
||||
@ -43,14 +43,8 @@ if is_flax_available():
|
||||
SpeechEncoderDecoderConfig,
|
||||
)
|
||||
from transformers.modeling_flax_outputs import FlaxBaseModelOutput
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import SpeechEncoderDecoderModel
|
||||
|
||||
|
||||
@ -406,68 +400,6 @@ class FlaxEncoderDecoderMixin:
|
||||
for grad, grad_frozen in zip(grads, grads_frozen):
|
||||
self.assertTrue((grad == grad_frozen).all())
|
||||
|
||||
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
|
||||
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
pt_model = SpeechEncoderDecoderModel(encoder_decoder_config)
|
||||
fx_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
|
||||
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
pt_model = SpeechEncoderDecoderModel(encoder_decoder_config)
|
||||
fx_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def test_encoder_decoder_model_from_pretrained_configs(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
|
||||
@ -504,46 +436,6 @@ class FlaxEncoderDecoderMixin:
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_pt_flax_equivalence(self):
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
config = config_inputs_dict.pop("config")
|
||||
decoder_config = config_inputs_dict.pop("decoder_config")
|
||||
|
||||
inputs_dict = config_inputs_dict
|
||||
# `encoder_hidden_states` is not used in model call/forward
|
||||
del inputs_dict["encoder_hidden_states"]
|
||||
|
||||
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||
inputs_dict["decoder_attention_mask"] = np.concatenate(
|
||||
[np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1
|
||||
)
|
||||
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
decoder_config.use_cache = False
|
||||
|
||||
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
|
||||
|
||||
# check without `enc_to_dec_proj` projection
|
||||
decoder_config.hidden_size = config.hidden_size
|
||||
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
|
||||
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
# check `enc_to_dec_proj` work as expected
|
||||
decoder_config.hidden_size = decoder_config.hidden_size * 2
|
||||
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
|
||||
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
# check `add_adapter` works as expected
|
||||
config.add_adapter = True
|
||||
self.assertTrue(config.add_adapter)
|
||||
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2 = self.get_pretrained_model()
|
||||
@ -625,71 +517,6 @@ class FlaxWav2Vec2GPT2ModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
@slow
|
||||
def test_flaxwav2vec2gpt2_pt_flax_equivalence(self):
|
||||
pt_model = SpeechEncoderDecoderModel.from_pretrained("jsnfly/wav2vec2-large-xlsr-53-german-gpt2")
|
||||
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained(
|
||||
"jsnfly/wav2vec2-large-xlsr-53-german-gpt2", from_pt=True
|
||||
)
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
batch_size = 13
|
||||
input_values = floats_tensor([batch_size, 512], scale=1.0)
|
||||
attention_mask = random_attention_mask([batch_size, 512])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs_dict = {
|
||||
"inputs": input_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
pt_logits = pt_outputs.logits
|
||||
pt_outputs = pt_outputs.to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict)
|
||||
fx_logits = fx_outputs.logits
|
||||
fx_outputs = fx_outputs.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
||||
fx_logits_loaded = fx_outputs_loaded.logits
|
||||
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
pt_logits_loaded = pt_outputs_loaded.logits
|
||||
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
@ -742,71 +569,6 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
@slow
|
||||
def test_flaxwav2vec2bart_pt_flax_equivalence(self):
|
||||
pt_model = SpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
|
||||
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained(
|
||||
"patrickvonplaten/wav2vec2-2-bart-large", from_pt=True
|
||||
)
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
batch_size = 13
|
||||
input_values = floats_tensor([batch_size, 512], scale=1.0)
|
||||
attention_mask = random_attention_mask([batch_size, 512])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs_dict = {
|
||||
"inputs": input_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
pt_logits = pt_outputs.logits
|
||||
pt_outputs = pt_outputs.to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict)
|
||||
fx_logits = fx_outputs.logits
|
||||
fx_outputs = fx_outputs.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
||||
fx_logits_loaded = fx_outputs_loaded.logits
|
||||
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
pt_logits_loaded = pt_outputs_loaded.logits
|
||||
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxWav2Vec2BertModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
@ -858,66 +620,3 @@ class FlaxWav2Vec2BertModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
@slow
|
||||
def test_flaxwav2vec2bert_pt_flax_equivalence(self):
|
||||
pt_model = SpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large")
|
||||
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large", from_pt=True)
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
batch_size = 13
|
||||
input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size)
|
||||
attention_mask = random_attention_mask([batch_size, 512])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs_dict = {
|
||||
"inputs": input_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
pt_logits = pt_outputs.logits
|
||||
pt_outputs = pt_outputs.to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict)
|
||||
fx_logits = fx_outputs.logits
|
||||
fx_outputs = fx_outputs.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
||||
fx_logits_loaded = fx_outputs_loaded.logits
|
||||
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
pt_logits_loaded = pt_outputs_loaded.logits
|
||||
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||
|
@ -17,15 +17,8 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import is_flax_available
|
||||
from transformers.testing_utils import (
|
||||
is_pt_flax_cross_test,
|
||||
require_flax,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
slow,
|
||||
)
|
||||
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
@ -47,7 +40,6 @@ if is_flax_available():
|
||||
from flax.traverse_util import flatten_dict
|
||||
|
||||
from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
from transformers.models.t5.modeling_flax_t5 import (
|
||||
FlaxT5EncoderModel,
|
||||
FlaxT5ForConditionalGeneration,
|
||||
@ -373,95 +365,6 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# save pt model
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params))
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
|
||||
class FlaxT5EncoderOnlyModelTester:
|
||||
def __init__(
|
||||
@ -663,95 +566,6 @@ class FlaxT5EncoderOnlyModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# save pt model
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params))
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
|
@ -19,8 +19,8 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_flax_available, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_vision, slow, torch_device
|
||||
from transformers import is_flax_available, is_vision_available
|
||||
from transformers.testing_utils import require_flax, require_vision, slow
|
||||
|
||||
from ...test_modeling_flax_common import floats_tensor, ids_tensor
|
||||
from ..gpt2.test_modeling_flax_gpt2 import FlaxGPT2ModelTester
|
||||
@ -35,15 +35,7 @@ if is_flax_available():
|
||||
FlaxViTModel,
|
||||
VisionEncoderDecoderConfig,
|
||||
)
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
@ -235,68 +227,6 @@ class FlaxEncoderDecoderMixin:
|
||||
generated_sequences = generated_output.sequences
|
||||
self.assertEqual(generated_sequences.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxVisionEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
|
||||
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
pt_model = VisionEncoderDecoderModel(encoder_decoder_config)
|
||||
fx_model = FlaxVisionEncoderDecoderModel(encoder_decoder_config)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
|
||||
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
pt_model = VisionEncoderDecoderModel(encoder_decoder_config)
|
||||
fx_model = FlaxVisionEncoderDecoderModel(encoder_decoder_config)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def test_encoder_decoder_model_from_pretrained_configs(self):
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_from_pretrained_configs(**config_inputs_dict)
|
||||
@ -325,39 +255,6 @@ class FlaxEncoderDecoderMixin:
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_pt_flax_equivalence(self):
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
config = config_inputs_dict.pop("config")
|
||||
decoder_config = config_inputs_dict.pop("decoder_config")
|
||||
|
||||
inputs_dict = config_inputs_dict
|
||||
# `encoder_hidden_states` is not used in model call/forward
|
||||
del inputs_dict["encoder_hidden_states"]
|
||||
|
||||
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||
inputs_dict["decoder_attention_mask"] = np.concatenate(
|
||||
[np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1
|
||||
)
|
||||
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
decoder_config.use_cache = False
|
||||
|
||||
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
|
||||
|
||||
# check without `enc_to_dec_proj` projection
|
||||
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
|
||||
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
# check `enc_to_dec_proj` work as expected
|
||||
decoder_config.hidden_size = decoder_config.hidden_size * 2
|
||||
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
|
||||
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2 = self.get_pretrained_model()
|
||||
|
@ -20,15 +20,8 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import (
|
||||
is_pt_flax_cross_test,
|
||||
require_flax,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_flax_available, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_flax, require_torch, require_vision, slow
|
||||
from transformers.utils import is_flax_available, is_vision_available
|
||||
|
||||
from ...test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from ..bert.test_modeling_flax_bert import FlaxBertModelTester
|
||||
@ -45,17 +38,8 @@ if is_flax_available():
|
||||
VisionTextDualEncoderConfig,
|
||||
VisionTextDualEncoderProcessor,
|
||||
)
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import VisionTextDualEncoderModel
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
@ -154,68 +138,6 @@ class VisionTextDualEncoderMixin:
|
||||
(text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = VisionTextDualEncoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
fx_model = FlaxVisionTextDualEncoderModel(config)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict):
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
fx_model = FlaxVisionTextDualEncoderModel(config)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def test_model_from_pretrained_configs(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_model_from_pretrained_configs(**inputs_dict)
|
||||
@ -232,17 +154,6 @@ class VisionTextDualEncoderMixin:
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_vision_text_output_attention(**inputs_dict)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_pt_flax_equivalence(self):
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
vision_config = config_inputs_dict.pop("vision_config")
|
||||
text_config = config_inputs_dict.pop("text_config")
|
||||
|
||||
inputs_dict = config_inputs_dict
|
||||
|
||||
self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(vision_config, text_config, inputs_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2, inputs = self.get_pretrained_model_and_inputs()
|
||||
|
@ -20,8 +20,8 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import is_flax_available, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from ..bert.test_modeling_bert import BertModelTester
|
||||
@ -44,12 +44,6 @@ if is_torch_available():
|
||||
ViTModel,
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from transformers import FlaxVisionTextDualEncoderModel
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
@ -172,69 +166,6 @@ class VisionTextDualEncoderMixin:
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mask, pixel_values, **kwargs):
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
|
||||
pt_inputs = inputs_dict
|
||||
flax_inputs = {k: v.numpy(force=True) for k, v in pt_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**flax_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = VisionTextDualEncoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
fx_model = FlaxVisionTextDualEncoderModel(config)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict)
|
||||
|
||||
def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict):
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
fx_model = FlaxVisionTextDualEncoderModel(config)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict)
|
||||
|
||||
def test_vision_text_dual_encoder_model(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_vision_text_dual_encoder_model(**inputs_dict)
|
||||
@ -255,17 +186,6 @@ class VisionTextDualEncoderMixin:
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_vision_text_output_attention(**inputs_dict)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_pt_flax_equivalence(self):
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
vision_config = config_inputs_dict.pop("vision_config")
|
||||
text_config = config_inputs_dict.pop("text_config")
|
||||
|
||||
inputs_dict = config_inputs_dict
|
||||
|
||||
self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(vision_config, text_config, inputs_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2, inputs = self.get_pretrained_model_and_inputs()
|
||||
@ -429,10 +349,6 @@ class DeiTRobertaModelTest(VisionTextDualEncoderMixin, unittest.TestCase):
|
||||
"text_choice_labels": choice_labels,
|
||||
}
|
||||
|
||||
@unittest.skip(reason="DeiT is not available in Flax")
|
||||
def test_pt_flax_equivalence(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class CLIPVisionBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase):
|
||||
|
@ -24,9 +24,7 @@ from datasets import load_dataset
|
||||
from transformers import Wav2Vec2Config, is_flax_available
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_flaky,
|
||||
is_librosa_available,
|
||||
is_pt_flax_cross_test,
|
||||
is_pyctcdecode_available,
|
||||
require_flax,
|
||||
require_librosa,
|
||||
@ -350,11 +348,6 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
outputs = model(np.ones((1, 1024), dtype="f4"))
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
@is_flaky()
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
super().test_equivalence_pt_to_flax()
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxWav2Vec2UtilsTest(unittest.TestCase):
|
||||
|
@ -31,7 +31,6 @@ from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
cleanup,
|
||||
is_flaky,
|
||||
is_pt_flax_cross_test,
|
||||
is_pyctcdecode_available,
|
||||
is_torchaudio_available,
|
||||
require_flash_attn,
|
||||
@ -569,16 +568,6 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
@unittest.skip(reason="Non-rubst architecture does not exist in Flax")
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
pass
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
@unittest.skip(reason="Non-rubst architecture does not exist in Flax")
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
pass
|
||||
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
|
@ -21,7 +21,6 @@ from datasets import load_dataset
|
||||
|
||||
from transformers import Wav2Vec2BertConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
is_pt_flax_cross_test,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_fp16,
|
||||
@ -559,18 +558,6 @@ class Wav2Vec2BertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
# Ignore copy
|
||||
@unittest.skip(reason="non-robust architecture does not exist in Flax")
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
pass
|
||||
|
||||
# Ignore copy
|
||||
@unittest.skip(reason="non-robust architecture does not exist in Flax")
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
pass
|
||||
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
|
@ -24,7 +24,6 @@ from datasets import load_dataset
|
||||
from transformers import Wav2Vec2ConformerConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
is_pt_flax_cross_test,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_fp16,
|
||||
@ -535,16 +534,6 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
@unittest.skip(reason="Non-robust architecture does not exist in Flax")
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
pass
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
@unittest.skip(reason="Non-robust architecture does not exist in Flax")
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
pass
|
||||
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
|
@ -17,9 +17,8 @@ import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import transformers
|
||||
from transformers import WhisperConfig, is_flax_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
from transformers.utils import cached_property
|
||||
from transformers.utils.import_utils import is_datasets_available
|
||||
|
||||
@ -45,7 +44,6 @@ if is_flax_available():
|
||||
WhisperFeatureExtractor,
|
||||
WhisperProcessor,
|
||||
)
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init
|
||||
|
||||
|
||||
@ -245,99 +243,6 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
|
||||
# We override with a slightly higher tol value, as test recently became flaky
|
||||
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
# overwrite because of `input_features`
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ == base_class.__name__:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite because of `input_features`
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ == base_class.__name__:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# save pt model
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite because of `input_features`
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ == base_class.__name__:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite because of `input_features`
|
||||
def test_save_load_from_base(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -899,18 +804,3 @@ class WhisperEncoderModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
# WhisperEncoder does not have any base model
|
||||
def test_save_load_from_base(self):
|
||||
pass
|
||||
|
||||
# WhisperEncoder does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
pass
|
||||
|
||||
# WhisperEncoder does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
pass
|
||||
|
||||
# WhisperEncoder does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
pass
|
||||
|
@ -32,7 +32,6 @@ import transformers
|
||||
from transformers import WhisperConfig
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
is_pt_flax_cross_test,
|
||||
require_flash_attn,
|
||||
require_non_xpu,
|
||||
require_torch,
|
||||
@ -44,7 +43,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property, is_flax_available, is_torch_available, is_torchaudio_available
|
||||
from transformers.utils import cached_property, is_torch_available, is_torchaudio_available
|
||||
from transformers.utils.import_utils import is_datasets_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@ -155,15 +154,6 @@ if is_torchaudio_available():
|
||||
import torchaudio
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
|
||||
def prepare_whisper_inputs_dict(
|
||||
config,
|
||||
input_features,
|
||||
@ -1069,161 +1059,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
|
||||
# We override with a slightly higher tol value, as test recently became flaky
|
||||
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
|
||||
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||
}
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
|
||||
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||
}
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||
|
||||
def test_mask_feature_prob(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.mask_feature_prob = 0.2
|
||||
@ -3622,157 +3457,6 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="Flax model does not exist")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
|
||||
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||
}
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest("Flax model does not exist")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
|
||||
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||
}
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||
|
||||
|
||||
class WhisperStandaloneDecoderModelTester:
|
||||
def __init__(
|
||||
|
@ -14,12 +14,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import transformers
|
||||
from transformers import XGLMConfig, XGLMTokenizer, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_sentencepiece, slow
|
||||
from transformers import XGLMConfig, XGLMTokenizer, is_flax_available
|
||||
from transformers.testing_utils import require_flax, require_sentencepiece, slow
|
||||
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
@ -29,17 +27,9 @@ if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
from transformers.models.xglm.modeling_flax_xglm import FlaxXGLMForCausalLM, FlaxXGLMModel
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxXGLMModelTester:
|
||||
def __init__(
|
||||
@ -220,108 +210,6 @@ class FlaxXGLMModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertListEqual(output_string, expected_string)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
batch_size, seq_length = pt_inputs["input_ids"].shape
|
||||
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
|
||||
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
|
||||
pt_model = pt_model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model.config.use_cache = False
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
batch_size, seq_length = pt_inputs["input_ids"].shape
|
||||
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
|
||||
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
|
@ -32,7 +32,6 @@ from packaging import version
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
@ -75,7 +74,6 @@ from transformers.models.auto.modeling_auto import (
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_flaky,
|
||||
is_pt_flax_cross_test,
|
||||
require_accelerate,
|
||||
require_bitsandbytes,
|
||||
require_deepspeed,
|
||||
@ -100,14 +98,12 @@ from transformers.utils import (
|
||||
GENERATION_CONFIG_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
is_accelerate_available,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_bf16_available_on_device,
|
||||
is_torch_fp16_available_on_device,
|
||||
is_torch_fx_available,
|
||||
is_torch_sdpa_available,
|
||||
)
|
||||
from transformers.utils.generic import ContextManagers, ModelOutput
|
||||
from transformers.utils.generic import ContextManagers
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@ -126,19 +122,6 @@ if is_torch_available():
|
||||
from transformers.modeling_utils import load_state_dict, no_init_weights
|
||||
from transformers.pytorch_utils import id_tensor_storage
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
pass
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
from tests.utils.test_modeling_flax_utils import check_models_equal
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
if is_torch_fx_available():
|
||||
from transformers.utils.fx import _FX_SUPPORTED_MODELS_WITH_KV_CACHE, symbolic_trace
|
||||
|
||||
@ -2552,249 +2535,6 @@ class ModelTesterMixin:
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
|
||||
"""
|
||||
Args:
|
||||
model_class: The class of the model that is currently testing. For example, ..., etc.
|
||||
Currently unused, but it could make debugging easier and faster.
|
||||
|
||||
names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
|
||||
Currently unused, but in the future, we could use this information to make the error message clearer
|
||||
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
|
||||
"""
|
||||
self.assertEqual(type(name), str)
|
||||
if attributes is not None:
|
||||
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
|
||||
|
||||
# Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
|
||||
if isinstance(fx_outputs, ModelOutput):
|
||||
self.assertTrue(
|
||||
isinstance(pt_outputs, ModelOutput),
|
||||
f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is",
|
||||
)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch")
|
||||
|
||||
# convert to the case of `tuple`
|
||||
# appending each key to the current (string) `name`
|
||||
attributes = tuple([f"{name}.{k}" for k in fx_keys])
|
||||
self.check_pt_flax_outputs(
|
||||
fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
|
||||
)
|
||||
|
||||
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
|
||||
elif type(fx_outputs) in [tuple, list]:
|
||||
self.assertEqual(
|
||||
type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch"
|
||||
)
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
|
||||
if attributes is not None:
|
||||
# case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
|
||||
self.assertEqual(
|
||||
len(attributes),
|
||||
len(fx_outputs),
|
||||
f"{name}: The tuple `attributes` should have the same length as `fx_outputs`",
|
||||
)
|
||||
else:
|
||||
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
|
||||
attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])
|
||||
|
||||
for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
|
||||
self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)
|
||||
|
||||
elif isinstance(fx_outputs, jnp.ndarray):
|
||||
self.assertTrue(
|
||||
isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is"
|
||||
)
|
||||
|
||||
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
|
||||
fx_outputs = np.array(fx_outputs)
|
||||
pt_outputs = pt_outputs.detach().to("cpu").numpy()
|
||||
|
||||
self.assertEqual(
|
||||
fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch"
|
||||
)
|
||||
|
||||
# deal with NumPy's scalars to make replacing nan values by 0 work.
|
||||
if np.isscalar(fx_outputs):
|
||||
fx_outputs = np.array([fx_outputs])
|
||||
pt_outputs = np.array([pt_outputs])
|
||||
|
||||
fx_nans = np.isnan(fx_outputs)
|
||||
pt_nans = np.isnan(pt_outputs)
|
||||
|
||||
pt_outputs[fx_nans] = 0
|
||||
fx_outputs[fx_nans] = 0
|
||||
pt_outputs[pt_nans] = 0
|
||||
fx_outputs[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
|
||||
self.assertLessEqual(
|
||||
max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got"
|
||||
f" {type(fx_outputs)} instead."
|
||||
)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||
}
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
||||
# load Flax class
|
||||
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||
|
||||
# make sure only flax inputs are forward that actually exist in function args
|
||||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
# prepare inputs
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# remove function args that don't exist in Flax
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||
}
|
||||
|
||||
# convert inputs to Flax
|
||||
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = model_class.from_pretrained(
|
||||
tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
|
||||
)
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@ -4413,29 +4153,6 @@ class ModelTesterMixin:
|
||||
tol = torch.finfo(torch.float16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_flax_from_pt_safetensors(self):
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
flax_model_class_name = "Flax" + model_class.__name__ # Add the "Flax at the beginning
|
||||
if not hasattr(transformers, flax_model_class_name):
|
||||
self.skipTest(reason="transformers does not have this model in Flax version yet")
|
||||
|
||||
flax_model_class = getattr(transformers, flax_model_class_name)
|
||||
|
||||
pt_model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname, safe_serialization=True)
|
||||
flax_model_1 = flax_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
pt_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
flax_model_2 = flax_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
# Check models are equal
|
||||
self.assertTrue(check_models_equal(flax_model_1, flax_model_2))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
|
@ -21,13 +21,10 @@ from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers import is_flax_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import CaptureLogger, is_pt_flax_cross_test, require_flax, torch_device
|
||||
from transformers.testing_utils import CaptureLogger, require_flax
|
||||
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
|
||||
from transformers.utils.generic import ModelOutput
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@ -47,17 +44,10 @@ if is_flax_available():
|
||||
FlaxAutoModelForSequenceClassification,
|
||||
FlaxBertModel,
|
||||
)
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
from transformers.modeling_flax_utils import FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def ids_tensor(shape, vocab_size, rng=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
@ -184,216 +174,6 @@ class FlaxModelTesterMixin:
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
# (Copied from tests.test_modeling_common.ModelTesterMixin.check_pt_flax_outputs)
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
|
||||
"""
|
||||
Args:
|
||||
model_class: The class of the model that is currently testing. For example, ..., etc.
|
||||
Currently unused, but it could make debugging easier and faster.
|
||||
|
||||
names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
|
||||
Currently unused, but in the future, we could use this information to make the error message clearer
|
||||
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
|
||||
"""
|
||||
self.assertEqual(type(name), str)
|
||||
if attributes is not None:
|
||||
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
|
||||
|
||||
# Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
|
||||
if isinstance(fx_outputs, ModelOutput):
|
||||
self.assertTrue(
|
||||
isinstance(pt_outputs, ModelOutput),
|
||||
f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is",
|
||||
)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch")
|
||||
|
||||
# convert to the case of `tuple`
|
||||
# appending each key to the current (string) `name`
|
||||
attributes = tuple([f"{name}.{k}" for k in fx_keys])
|
||||
self.check_pt_flax_outputs(
|
||||
fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
|
||||
)
|
||||
|
||||
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
|
||||
elif type(fx_outputs) in [tuple, list]:
|
||||
self.assertEqual(
|
||||
type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch"
|
||||
)
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
|
||||
if attributes is not None:
|
||||
# case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
|
||||
self.assertEqual(
|
||||
len(attributes),
|
||||
len(fx_outputs),
|
||||
f"{name}: The tuple `attributes` should have the same length as `fx_outputs`",
|
||||
)
|
||||
else:
|
||||
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
|
||||
attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])
|
||||
|
||||
for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
|
||||
if isinstance(pt_output, DynamicCache):
|
||||
pt_output = pt_output.to_legacy_cache()
|
||||
self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)
|
||||
|
||||
elif isinstance(fx_outputs, jnp.ndarray):
|
||||
self.assertTrue(
|
||||
isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is"
|
||||
)
|
||||
|
||||
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
|
||||
fx_outputs = np.array(fx_outputs)
|
||||
pt_outputs = pt_outputs.detach().to("cpu").numpy()
|
||||
|
||||
self.assertEqual(
|
||||
fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch"
|
||||
)
|
||||
|
||||
# deal with NumPy's scalars to make replacing nan values by 0 work.
|
||||
if np.isscalar(fx_outputs):
|
||||
fx_outputs = np.array([fx_outputs])
|
||||
pt_outputs = np.array([pt_outputs])
|
||||
|
||||
fx_nans = np.isnan(fx_outputs)
|
||||
pt_nans = np.isnan(pt_outputs)
|
||||
|
||||
pt_outputs[fx_nans] = 0
|
||||
fx_outputs[fx_nans] = 0
|
||||
pt_outputs[pt_nans] = 0
|
||||
fx_outputs[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
|
||||
self.assertLessEqual(
|
||||
max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got"
|
||||
f" {type(fx_outputs)} instead."
|
||||
)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
# It might be better to put this inside the for loop below (because we modify the config there).
|
||||
# But logically, it is fine.
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = pt_model_class.from_pretrained(
|
||||
tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
|
||||
)
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@ -474,92 +254,6 @@ class FlaxModelTesterMixin:
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = get_params(model.params)
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# save pt model
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = get_params(base_model.params)
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = get_params(base_model.params)
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@ -1119,14 +813,6 @@ class FlaxModelTesterMixin:
|
||||
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
|
||||
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_from_sharded_pt(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
|
||||
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only")
|
||||
for key, ref_val in flatten_dict(ref_model.params).items():
|
||||
val = flatten_dict(model.params)[key]
|
||||
assert np.allclose(np.array(val), np.array(ref_val))
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -31,13 +31,11 @@ from datasets import Dataset
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import ( # noqa: F401
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
_tf_gpu_memory_limit,
|
||||
require_tf,
|
||||
require_tf2onnx,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
|
||||
from transformers.utils.generic import ModelOutput
|
||||
@ -73,20 +71,6 @@ if is_tf_available():
|
||||
|
||||
tf.config.experimental.enable_tensor_float_32_execution(False)
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
for gpu in gpus:
|
||||
# Restrict TensorFlow to only allocate x GB of memory on the GPUs
|
||||
try:
|
||||
tf.config.set_logical_device_configuration(
|
||||
gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
|
||||
)
|
||||
logical_gpus = tf.config.list_logical_devices("GPU")
|
||||
print("Logical GPUs", logical_gpus)
|
||||
except RuntimeError as e:
|
||||
# Virtual devices must be set before GPUs have been initialized
|
||||
print(e)
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
|
@ -18,16 +18,14 @@ import unittest
|
||||
import numpy as np
|
||||
from huggingface_hub import HfFolder, snapshot_download
|
||||
|
||||
from transformers import BertConfig, BertModel, is_flax_available, is_torch_available
|
||||
from transformers import BertConfig, is_flax_available
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
CaptureLogger,
|
||||
TemporaryHubRepo,
|
||||
is_pt_flax_cross_test,
|
||||
is_staging_test,
|
||||
require_flax,
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
)
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging
|
||||
|
||||
@ -42,9 +40,6 @@ if is_flax_available():
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@require_flax
|
||||
@is_staging_test
|
||||
@ -205,23 +200,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(check_models_equal(model, new_model))
|
||||
|
||||
@require_flax
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_save_and_load_pt_to_flax(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
|
||||
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pt_model.save_pretrained(tmp_dir)
|
||||
|
||||
# Check we have a model.safetensors file
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
self.assertTrue(check_models_equal(model, new_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub(self):
|
||||
"""
|
||||
@ -248,58 +226,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-msgpack")
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt_bf16(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
import torch
|
||||
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
model.to(torch.bfloat16)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
model.save_pretrained(tmp)
|
||||
flax_model = FlaxBertModel.from_pretrained(tmp)
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_local_from_safetensors_pt(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp)
|
||||
flax_model = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
|
||||
safetensors_model = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
|
||||
"""
|
||||
@ -328,19 +254,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(check_models_equal(model, new_model))
|
||||
|
||||
@require_safetensors
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_flax_from_torch(self):
|
||||
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue(check_models_equal(hub_model, new_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_local(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -370,27 +283,3 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
|
||||
in cl.out
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_from_pt_bf16(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
model.to(torch.bfloat16)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
|
||||
logger = logging.get_logger("transformers.modeling_flax_utils")
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
new_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||
|
||||
self.assertTrue(
|
||||
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
|
||||
in cl.out
|
||||
)
|
||||
|
||||
flat_params_1 = flatten_dict(new_model.params)
|
||||
for value in flat_params_1.values():
|
||||
self.assertEqual(value.dtype, "bfloat16")
|
||||
|
@ -24,7 +24,7 @@ from math import isnan
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_modeling_tf_common import ids_tensor
|
||||
|
||||
@ -48,20 +48,6 @@ if is_tf_available():
|
||||
)
|
||||
from transformers.modeling_tf_utils import keras
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
for gpu in gpus:
|
||||
# Restrict TensorFlow to only allocate x GB of memory on the GPUs
|
||||
try:
|
||||
tf.config.set_logical_device_configuration(
|
||||
gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
|
||||
)
|
||||
logical_gpus = tf.config.list_logical_devices("GPU")
|
||||
print("Logical GPUs", logical_gpus)
|
||||
except RuntimeError as e:
|
||||
# Virtual devices must be set before GPUs have been initialized
|
||||
print(e)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFCoreModelTesterMixin:
|
||||
|
@ -33,7 +33,6 @@ from transformers.testing_utils import ( # noqa: F401
|
||||
USER,
|
||||
CaptureLogger,
|
||||
TemporaryHubRepo,
|
||||
_tf_gpu_memory_limit,
|
||||
is_staging_test,
|
||||
require_safetensors,
|
||||
require_tf,
|
||||
@ -68,20 +67,6 @@ if is_tf_available():
|
||||
|
||||
tf.config.experimental.enable_tensor_float_32_execution(False)
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
for gpu in gpus:
|
||||
# Restrict TensorFlow to only allocate x GB of memory on the GPUs
|
||||
try:
|
||||
tf.config.set_logical_device_configuration(
|
||||
gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
|
||||
)
|
||||
logical_gpus = tf.config.list_logical_devices("GPU")
|
||||
print("Logical GPUs", logical_gpus)
|
||||
except RuntimeError as e:
|
||||
# Virtual devices must be set before GPUs have been initialized
|
||||
print(e)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFModelUtilsTest(unittest.TestCase):
|
||||
|
@ -1148,7 +1148,6 @@ def parse_commit_message(commit_message: str) -> Dict[str, bool]:
|
||||
|
||||
|
||||
JOB_TO_TEST_FILE = {
|
||||
"tests_torch_and_flax": r"tests/models/.*/test_modeling_(?:flax|(?!tf)).*",
|
||||
"tests_tf": r"tests/models/.*/test_modeling_tf_.*",
|
||||
"tests_torch": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*",
|
||||
"tests_generate": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*",
|
||||
|
Loading…
Reference in New Issue
Block a user