fix loading flax bf16 weights in pt (#14369)

* fix loading flax bf16 weights in pt

* fix clip test

* fix t5 test

* add logging statement

* Update src/transformers/modeling_flax_pytorch_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* switch back to native any

* fix check for bf16 weights

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil 2021-11-11 21:20:49 +05:30 committed by GitHub
parent 7f20bf0d43
commit 3d607df8f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 83 additions and 0 deletions

View File

@ -21,6 +21,7 @@ from typing import Dict, Tuple
import numpy as np
import jax
import jax.numpy as jnp
import transformers
from flax.serialization import from_bytes
@ -189,6 +190,19 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
)
raise
# check if we have bf16 weights
is_type_bf16 = flatten_dict(jax.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
if any(is_type_bf16):
# convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16
# and bf16 is not fully supported in PT yet.
logger.warning(
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
"before loading those in PyTorch model."
)
flax_state = jax.tree_map(
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
)
flax_state_dict = flatten_dict(flax_state)
pt_model_dict = pt_model.state_dict()

View File

@ -227,6 +227,11 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_save_load_to_base_pt(self):
pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
pass
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
@ -332,6 +337,11 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_save_load_to_base_pt(self):
pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
pass
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:

View File

@ -384,6 +384,35 @@ class FlaxModelTesterMixin:
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -430,6 +430,36 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@require_sentencepiece
@require_tokenizers