diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index e8f93e75267..c822f11e981 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -21,6 +21,7 @@ from typing import Dict, Tuple import numpy as np +import jax import jax.numpy as jnp import transformers from flax.serialization import from_bytes @@ -189,6 +190,19 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): ) raise + # check if we have bf16 weights + is_type_bf16 = flatten_dict(jax.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() + if any(is_type_bf16): + # convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16 + # and bf16 is not fully supported in PT yet. + logger.warning( + "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " + "before loading those in PyTorch model." + ) + flax_state = jax.tree_map( + lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state + ) + flax_state_dict = flatten_dict(flax_state) pt_model_dict = pt_model.state_dict() diff --git a/tests/test_modeling_flax_clip.py b/tests/test_modeling_flax_clip.py index fabc5fd25a5..00f5c47a090 100644 --- a/tests/test_modeling_flax_clip.py +++ b/tests/test_modeling_flax_clip.py @@ -227,6 +227,11 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase): def test_save_load_to_base_pt(self): pass + # FlaxCLIPVisionModel does not have any base model + @is_pt_flax_cross_test + def test_save_load_bf16_to_base_pt(self): + pass + @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: @@ -332,6 +337,11 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase): def test_save_load_to_base_pt(self): pass + # FlaxCLIPVisionModel does not have any base model + @is_pt_flax_cross_test + def test_save_load_bf16_to_base_pt(self): + pass + @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 26888a605f6..ca138a33708 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -384,6 +384,35 @@ class FlaxModelTesterMixin: max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + @is_pt_flax_cross_test + def test_save_load_bf16_to_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + model.params = model.to_bf16(model.params) + base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_flax_t5.py b/tests/test_modeling_flax_t5.py index 424ec0b420b..8c59c849199 100644 --- a/tests/test_modeling_flax_t5.py +++ b/tests/test_modeling_flax_t5.py @@ -430,6 +430,36 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest. max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_bf16_to_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + model.params = model.to_bf16(model.params) + base_params_from_head = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + @require_sentencepiece @require_tokenizers