mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
CLI: convert sharded PT models (#17959)
* sharded conversion; add flag to control max hidden error * better hidden name matching * Add test: load TF from PT shards * fix test (PT data must be local)
This commit is contained in:
parent
f25457b273
commit
91e1f24ef3
@ -34,7 +34,7 @@ from .. import (
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
from ..utils import logging
|
||||
from ..utils import TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
@ -48,7 +48,6 @@ if is_torch_available():
|
||||
|
||||
|
||||
MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors
|
||||
TF_WEIGHTS_NAME = "tf_model.h5"
|
||||
|
||||
|
||||
def convert_command_factory(args: Namespace):
|
||||
@ -58,7 +57,13 @@ def convert_command_factory(args: Namespace):
|
||||
Returns: ServeCommand
|
||||
"""
|
||||
return PTtoTFCommand(
|
||||
args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push, args.extra_commit_description
|
||||
args.model_name,
|
||||
args.local_dir,
|
||||
args.max_hidden_error,
|
||||
args.new_weights,
|
||||
args.no_pr,
|
||||
args.push,
|
||||
args.extra_commit_description,
|
||||
)
|
||||
|
||||
|
||||
@ -90,6 +95,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
default="",
|
||||
help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--max-hidden-error",
|
||||
type=float,
|
||||
default=MAX_ERROR,
|
||||
help=(
|
||||
f"Maximum error tolerance for hidden layer outputs. Defaults to {MAX_ERROR}. If you suspect the hidden"
|
||||
" layers outputs will be used for downstream applications, avoid increasing this tolerance."
|
||||
),
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--new-weights",
|
||||
action="store_true",
|
||||
@ -112,14 +126,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
train_parser.set_defaults(func=convert_command_factory)
|
||||
|
||||
@staticmethod
|
||||
def find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input):
|
||||
def find_pt_tf_differences(pt_outputs, tf_outputs):
|
||||
"""
|
||||
Compares the TensorFlow and PyTorch models, given their inputs, returning a dictionary with all tensor
|
||||
differences.
|
||||
Compares the TensorFlow and PyTorch outputs, returning a dictionary with all tensor differences.
|
||||
"""
|
||||
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
|
||||
tf_outputs = tf_model(**tf_input, output_hidden_states=True)
|
||||
|
||||
# 1. All output attributes must be the same
|
||||
pt_out_attrs = set(pt_outputs.keys())
|
||||
tf_out_attrs = set(tf_outputs.keys())
|
||||
@ -158,6 +168,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
self,
|
||||
model_name: str,
|
||||
local_dir: str,
|
||||
max_hidden_error: float,
|
||||
new_weights: bool,
|
||||
no_pr: bool,
|
||||
push: bool,
|
||||
@ -167,6 +178,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
|
||||
self._model_name = model_name
|
||||
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
|
||||
self._max_hidden_error = max_hidden_error
|
||||
self._new_weights = new_weights
|
||||
self._no_pr = no_pr
|
||||
self._push = push
|
||||
@ -260,34 +272,49 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
pt_model = pt_class.from_pretrained(self._local_dir)
|
||||
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
|
||||
pt_input, tf_input = self.get_inputs(pt_model, config)
|
||||
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
|
||||
del pt_model # will no longer be used, and may have a large memory footprint
|
||||
|
||||
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
|
||||
tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True)
|
||||
|
||||
# Confirms that cross loading PT weights into TF worked.
|
||||
crossload_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_from_pt_model, tf_input)
|
||||
max_crossload_diff = max(crossload_differences.values())
|
||||
if max_crossload_diff > MAX_ERROR:
|
||||
crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs)
|
||||
output_differences = {k: v for k, v in crossload_differences.items() if "hidden" not in k}
|
||||
hidden_differences = {k: v for k, v in crossload_differences.items() if "hidden" in k}
|
||||
max_crossload_output_diff = max(output_differences.values())
|
||||
max_crossload_hidden_diff = max(hidden_differences.values())
|
||||
if max_crossload_output_diff > MAX_ERROR or max_crossload_hidden_diff > self._max_hidden_error:
|
||||
raise ValueError(
|
||||
"The cross-loaded TensorFlow model has different outputs, something went wrong! Exaustive list of"
|
||||
f" maximum tensor differences above the error threshold ({MAX_ERROR}):\n"
|
||||
+ "\n".join(
|
||||
[f"{key}: {value:.3e}" for key, value in crossload_differences.items() if value > MAX_ERROR]
|
||||
)
|
||||
"The cross-loaded TensorFlow model has different outputs, something went wrong!\n"
|
||||
+ f"\nList of maximum output differences above the threshold ({MAX_ERROR}):\n"
|
||||
+ "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > MAX_ERROR])
|
||||
+ f"\n\nList of maximum hidden layer differences above the threshold ({self._max_hidden_error}):\n"
|
||||
+ "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_hidden_error])
|
||||
)
|
||||
|
||||
# Save the weights in a TF format (if needed) and confirms that the results are still good
|
||||
tf_weights_path = os.path.join(self._local_dir, TF_WEIGHTS_NAME)
|
||||
if not os.path.exists(tf_weights_path) or self._new_weights:
|
||||
tf_from_pt_model.save_weights(tf_weights_path)
|
||||
tf_weights_path = os.path.join(self._local_dir, TF2_WEIGHTS_NAME)
|
||||
tf_weights_index_path = os.path.join(self._local_dir, TF2_WEIGHTS_INDEX_NAME)
|
||||
if (not os.path.exists(tf_weights_path) and not os.path.exists(tf_weights_index_path)) or self._new_weights:
|
||||
tf_from_pt_model.save_pretrained(self._local_dir)
|
||||
del tf_from_pt_model # will no longer be used, and may have a large memory footprint
|
||||
|
||||
tf_model = tf_class.from_pretrained(self._local_dir)
|
||||
conversion_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input)
|
||||
max_conversion_diff = max(conversion_differences.values())
|
||||
if max_conversion_diff > MAX_ERROR:
|
||||
tf_outputs = tf_model(**tf_input, output_hidden_states=True)
|
||||
|
||||
conversion_differences = self.find_pt_tf_differences(pt_outputs, tf_outputs)
|
||||
output_differences = {k: v for k, v in conversion_differences.items() if "hidden" not in k}
|
||||
hidden_differences = {k: v for k, v in conversion_differences.items() if "hidden" in k}
|
||||
max_conversion_output_diff = max(output_differences.values())
|
||||
max_conversion_hidden_diff = max(hidden_differences.values())
|
||||
if max_conversion_output_diff > MAX_ERROR or max_conversion_hidden_diff > self._max_hidden_error:
|
||||
raise ValueError(
|
||||
"The converted TensorFlow model has different outputs, something went wrong! Exaustive list of maximum"
|
||||
f" tensor differences above the error threshold ({MAX_ERROR}):\n"
|
||||
+ "\n".join(
|
||||
[f"{key}: {value:.3e}" for key, value in conversion_differences.items() if value > MAX_ERROR]
|
||||
)
|
||||
"The converted TensorFlow model has different outputs, something went wrong!\n"
|
||||
+ f"\nList of maximum output differences above the threshold ({MAX_ERROR}):\n"
|
||||
+ "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > MAX_ERROR])
|
||||
+ f"\n\nList of maximum hidden layer differences above the threshold ({self._max_hidden_error}):\n"
|
||||
+ "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_hidden_error])
|
||||
)
|
||||
|
||||
commit_message = "Update TF weights" if self._new_weights else "Add TF weights"
|
||||
@ -300,16 +327,31 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
self._logger.warn("Uploading the weights into a new PR...")
|
||||
commit_descrition = (
|
||||
"Model converted by the [`transformers`' `pt_to_tf`"
|
||||
" CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py)."
|
||||
"\n\nAll converted model outputs and hidden layers were validated against its Pytorch counterpart."
|
||||
f" Maximum crossload output difference={max_crossload_diff:.3e}; Maximum converted output"
|
||||
f" difference={max_conversion_diff:.3e}."
|
||||
" CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py). "
|
||||
"All converted model outputs and hidden layers were validated against its Pytorch counterpart.\n\n"
|
||||
f"Maximum crossload output difference={max_crossload_output_diff:.3e}; "
|
||||
f"Maximum crossload hidden layer difference={max_crossload_hidden_diff:.3e};\n"
|
||||
f"Maximum conversion output difference={max_conversion_output_diff:.3e}; "
|
||||
f"Maximum conversion hidden layer difference={max_conversion_hidden_diff:.3e};\n"
|
||||
)
|
||||
if self._extra_commit_description:
|
||||
commit_descrition += "\n\n" + self._extra_commit_description
|
||||
|
||||
# sharded model -> adds all related files (index and .h5 shards)
|
||||
if os.path.exists(tf_weights_index_path):
|
||||
operations = [
|
||||
CommitOperationAdd(path_in_repo=TF2_WEIGHTS_INDEX_NAME, path_or_fileobj=tf_weights_index_path)
|
||||
]
|
||||
for shard_path in tf.io.gfile.glob(self._local_dir + "/tf_model-*.h5"):
|
||||
operations += [
|
||||
CommitOperationAdd(path_in_repo=os.path.basename(shard_path), path_or_fileobj=shard_path)
|
||||
]
|
||||
else:
|
||||
operations = [CommitOperationAdd(path_in_repo=TF2_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)]
|
||||
|
||||
hub_pr_url = create_commit(
|
||||
repo_id=self._model_name,
|
||||
operations=[CommitOperationAdd(path_in_repo=TF_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)],
|
||||
operations=operations,
|
||||
commit_message=commit_message,
|
||||
commit_description=commit_descrition,
|
||||
repo_type="model",
|
||||
|
@ -117,10 +117,17 @@ def load_pytorch_checkpoint_in_tf2_model(
|
||||
)
|
||||
raise
|
||||
|
||||
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
||||
logger.info(f"Loading PyTorch weights from {pt_path}")
|
||||
# Treats a single file as a collection of shards with 1 shard.
|
||||
if isinstance(pytorch_checkpoint_path, str):
|
||||
pytorch_checkpoint_path = [pytorch_checkpoint_path]
|
||||
|
||||
# Loads all shards into a single state dictionary
|
||||
pt_state_dict = {}
|
||||
for path in pytorch_checkpoint_path:
|
||||
pt_path = os.path.abspath(path)
|
||||
logger.info(f"Loading PyTorch weights from {pt_path}")
|
||||
pt_state_dict.update(torch.load(pt_path, map_location="cpu"))
|
||||
|
||||
pt_state_dict = torch.load(pt_path, map_location="cpu")
|
||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters")
|
||||
|
||||
return load_pytorch_weights_in_tf2_model(
|
||||
|
@ -50,6 +50,7 @@ from .utils import (
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
TF2_WEIGHTS_INDEX_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
EntryNotFoundError,
|
||||
ModelOutput,
|
||||
@ -2157,11 +2158,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint in priority if from_pt
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
|
||||
# Load from a sharded PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||
# Load from a TF 2.0 checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)):
|
||||
# Load from a sharded PyTorch checkpoint
|
||||
# Load from a sharded TF 2.0 checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
|
@ -27,7 +27,7 @@ from typing import List, Tuple
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from huggingface_hub import HfFolder, delete_repo, set_access_token
|
||||
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
@ -1966,6 +1966,16 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_checkpoint_sharding_local_from_pt(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
_ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-bert-sharded")
|
||||
model = TFBertModel.from_pretrained(tmp_dir, from_pt=True)
|
||||
# the model above is the same as the model below, just a sharded pytorch version.
|
||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
def test_shard_checkpoint(self):
|
||||
# This is the model we will use, total size 340,000 bytes.
|
||||
model = tf.keras.Sequential(
|
||||
|
Loading…
Reference in New Issue
Block a user