mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
fix loading flax bf16 weights in pt (#14369)
* fix loading flax bf16 weights in pt * fix clip test * fix t5 test * add logging statement * Update src/transformers/modeling_flax_pytorch_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * switch back to native any * fix check for bf16 weights Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
7f20bf0d43
commit
3d607df8f4
@ -21,6 +21,7 @@ from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import transformers
|
||||
from flax.serialization import from_bytes
|
||||
@ -189,6 +190,19 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
||||
)
|
||||
raise
|
||||
|
||||
# check if we have bf16 weights
|
||||
is_type_bf16 = flatten_dict(jax.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
|
||||
if any(is_type_bf16):
|
||||
# convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16
|
||||
# and bf16 is not fully supported in PT yet.
|
||||
logger.warning(
|
||||
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
|
||||
"before loading those in PyTorch model."
|
||||
)
|
||||
flax_state = jax.tree_map(
|
||||
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
|
||||
)
|
||||
|
||||
flax_state_dict = flatten_dict(flax_state)
|
||||
pt_model_dict = pt_model.state_dict()
|
||||
|
||||
|
@ -227,6 +227,11 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
def test_save_load_to_base_pt(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
@ -332,6 +337,11 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
def test_save_load_to_base_pt(self):
|
||||
pass
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
|
@ -384,6 +384,35 @@ class FlaxModelTesterMixin:
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -430,6 +430,36 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
|
Loading…
Reference in New Issue
Block a user