mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Load sharded pt to flax (#18419)
* initial commit * add small test * add cross pt tf flag to test * fix quality * style * update test with new repo * fix failing test * update * fix wrong param ordering * style * update based on review * update related to recent new caching mechanism * quality * Update based on review Co-authored-by: sgugger <sylvain.gugger@gmail.com> * quality and style * Update src/transformers/modeling_flax_utils.py Co-authored-by: sgugger <sylvain.gugger@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
c8b6ae858d
commit
bce36ee065
@ -38,7 +38,9 @@ logger = logging.get_logger(__name__)
|
||||
#####################
|
||||
|
||||
|
||||
def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False):
|
||||
def load_pytorch_checkpoint_in_flax_state_dict(
|
||||
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
|
||||
):
|
||||
"""Load pytorch checkpoints in a flax model"""
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
@ -50,14 +52,17 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa
|
||||
)
|
||||
raise
|
||||
|
||||
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
||||
logger.info(f"Loading PyTorch weights from {pt_path}")
|
||||
if not is_sharded:
|
||||
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
||||
logger.info(f"Loading PyTorch weights from {pt_path}")
|
||||
|
||||
pt_state_dict = torch.load(pt_path, map_location="cpu")
|
||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
||||
|
||||
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
|
||||
pt_state_dict = torch.load(pt_path, map_location="cpu")
|
||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
||||
|
||||
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
|
||||
else:
|
||||
# model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files
|
||||
flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)
|
||||
return flax_state_dict
|
||||
|
||||
|
||||
@ -156,6 +161,61 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||
return unflatten_dict(flax_state_dict)
|
||||
|
||||
|
||||
############################
|
||||
# Sharded Pytorch => Flax #
|
||||
############################
|
||||
|
||||
|
||||
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
|
||||
import torch
|
||||
|
||||
# Load the index
|
||||
flax_state_dict = {}
|
||||
for shard_file in shard_filenames:
|
||||
# load using msgpack utils
|
||||
pt_state_dict = torch.load(shard_file)
|
||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||
|
||||
model_prefix = flax_model.base_model_prefix
|
||||
random_flax_state_dict = flatten_dict(flax_model.params)
|
||||
|
||||
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
|
||||
model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
||||
)
|
||||
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
|
||||
model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
||||
)
|
||||
# Need to change some parameters name to match Flax names
|
||||
for pt_key, pt_tensor in pt_state_dict.items():
|
||||
|
||||
pt_tuple_key = tuple(pt_key.split("."))
|
||||
|
||||
# remove base model prefix if necessary
|
||||
has_base_model_prefix = pt_tuple_key[0] == model_prefix
|
||||
if load_model_with_head_into_base_model and has_base_model_prefix:
|
||||
pt_tuple_key = pt_tuple_key[1:]
|
||||
|
||||
# Correctly rename weight parameters
|
||||
flax_key, flax_tensor = rename_key_and_reshape_tensor(
|
||||
pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
|
||||
)
|
||||
# add model prefix if necessary
|
||||
require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
|
||||
if load_base_model_into_model_with_head and require_base_model_prefix:
|
||||
flax_key = (model_prefix,) + flax_key
|
||||
|
||||
if flax_key in random_flax_state_dict:
|
||||
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
|
||||
raise ValueError(
|
||||
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
||||
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
||||
)
|
||||
|
||||
# also add unexpected weight so that warning is thrown
|
||||
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
||||
return unflatten_dict(flax_state_dict)
|
||||
|
||||
|
||||
#####################
|
||||
# Flax => PyTorch #
|
||||
#####################
|
||||
|
@ -40,6 +40,7 @@ from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_d
|
||||
from .utils import (
|
||||
FLAX_WEIGHTS_INDEX_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
PushToHubMixin,
|
||||
add_code_sample_docstrings,
|
||||
@ -650,6 +651,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
|
||||
# Load from a sharded pytorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
||||
# Load from a Flax checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
||||
@ -700,6 +705,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
)
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
|
||||
elif resolved_archive_file is None and from_pt:
|
||||
resolved_archive_file = cached_file(
|
||||
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||
)
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
if resolved_archive_file is None:
|
||||
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
||||
# message.
|
||||
@ -714,6 +726,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
|
||||
" load this model from those weights."
|
||||
)
|
||||
elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||
f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use"
|
||||
" `from_pt=True` to load this model from those weights."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||
@ -761,7 +779,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
|
||||
|
||||
if from_pt:
|
||||
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
|
||||
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
|
||||
else:
|
||||
|
||||
if is_sharded:
|
||||
|
@ -1099,6 +1099,14 @@ class FlaxModelTesterMixin:
|
||||
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
|
||||
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_from_sharded_pt(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
|
||||
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only")
|
||||
for key, ref_val in flatten_dict(ref_model.params).items():
|
||||
val = flatten_dict(model.params)[key]
|
||||
assert np.allclose(np.array(val), np.array(ref_val))
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user