fix loading flax bf16 weights in pt (#14369)

* fix loading flax bf16 weights in pt

* fix clip test

* fix t5 test

* add logging statement

* Update src/transformers/modeling_flax_pytorch_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* switch back to native any

* fix check for bf16 weights

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil 2021-11-11 21:20:49 +05:30 committed by GitHub
parent 7f20bf0d43
commit 3d607df8f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 83 additions and 0 deletions

View File

@ -21,6 +21,7 @@ from typing import Dict, Tuple
import numpy as np import numpy as np
import jax
import jax.numpy as jnp import jax.numpy as jnp
import transformers import transformers
from flax.serialization import from_bytes from flax.serialization import from_bytes
@ -189,6 +190,19 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
) )
raise raise
# check if we have bf16 weights
is_type_bf16 = flatten_dict(jax.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
if any(is_type_bf16):
# convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16
# and bf16 is not fully supported in PT yet.
logger.warning(
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
"before loading those in PyTorch model."
)
flax_state = jax.tree_map(
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
)
flax_state_dict = flatten_dict(flax_state) flax_state_dict = flatten_dict(flax_state)
pt_model_dict = pt_model.state_dict() pt_model_dict = pt_model.state_dict()

View File

@ -227,6 +227,11 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_save_load_to_base_pt(self): def test_save_load_to_base_pt(self):
pass pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes: for model_class_name in self.all_model_classes:
@ -332,6 +337,11 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_save_load_to_base_pt(self): def test_save_load_to_base_pt(self):
pass pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes: for model_class_name in self.all_model_classes:

View File

@ -384,6 +384,35 @@ class FlaxModelTesterMixin:
max_diff = (base_params[key] - base_params_from_head[key]).sum().item() max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_jit_compilation(self): def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -430,6 +430,36 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.
max_diff = (base_params[key] - base_params_from_head[key]).sum().item() max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers