CLI: convert sharded PT models (#17959)

* sharded conversion; add flag to control max hidden error

* better hidden name matching

* Add test: load TF from PT shards

* fix test (PT data must be local)
This commit is contained in:
Joao Gante 2022-06-30 16:51:03 +01:00 committed by GitHub
parent f25457b273
commit 91e1f24ef3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 102 additions and 38 deletions

View File

@ -34,7 +34,7 @@ from .. import (
is_tf_available,
is_torch_available,
)
from ..utils import logging
from ..utils import TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
from . import BaseTransformersCLICommand
@ -48,7 +48,6 @@ if is_torch_available():
MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors
TF_WEIGHTS_NAME = "tf_model.h5"
def convert_command_factory(args: Namespace):
@ -58,7 +57,13 @@ def convert_command_factory(args: Namespace):
Returns: ServeCommand
"""
return PTtoTFCommand(
args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push, args.extra_commit_description
args.model_name,
args.local_dir,
args.max_hidden_error,
args.new_weights,
args.no_pr,
args.push,
args.extra_commit_description,
)
@ -90,6 +95,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
default="",
help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
)
train_parser.add_argument(
"--max-hidden-error",
type=float,
default=MAX_ERROR,
help=(
f"Maximum error tolerance for hidden layer outputs. Defaults to {MAX_ERROR}. If you suspect the hidden"
" layers outputs will be used for downstream applications, avoid increasing this tolerance."
),
)
train_parser.add_argument(
"--new-weights",
action="store_true",
@ -112,14 +126,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
train_parser.set_defaults(func=convert_command_factory)
@staticmethod
def find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input):
def find_pt_tf_differences(pt_outputs, tf_outputs):
"""
Compares the TensorFlow and PyTorch models, given their inputs, returning a dictionary with all tensor
differences.
Compares the TensorFlow and PyTorch outputs, returning a dictionary with all tensor differences.
"""
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
tf_outputs = tf_model(**tf_input, output_hidden_states=True)
# 1. All output attributes must be the same
pt_out_attrs = set(pt_outputs.keys())
tf_out_attrs = set(tf_outputs.keys())
@ -158,6 +168,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self,
model_name: str,
local_dir: str,
max_hidden_error: float,
new_weights: bool,
no_pr: bool,
push: bool,
@ -167,6 +178,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
self._model_name = model_name
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
self._max_hidden_error = max_hidden_error
self._new_weights = new_weights
self._no_pr = no_pr
self._push = push
@ -260,34 +272,49 @@ class PTtoTFCommand(BaseTransformersCLICommand):
pt_model = pt_class.from_pretrained(self._local_dir)
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
pt_input, tf_input = self.get_inputs(pt_model, config)
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
del pt_model # will no longer be used, and may have a large memory footprint
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True)
# Confirms that cross loading PT weights into TF worked.
crossload_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_from_pt_model, tf_input)
max_crossload_diff = max(crossload_differences.values())
if max_crossload_diff > MAX_ERROR:
crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs)
output_differences = {k: v for k, v in crossload_differences.items() if "hidden" not in k}
hidden_differences = {k: v for k, v in crossload_differences.items() if "hidden" in k}
max_crossload_output_diff = max(output_differences.values())
max_crossload_hidden_diff = max(hidden_differences.values())
if max_crossload_output_diff > MAX_ERROR or max_crossload_hidden_diff > self._max_hidden_error:
raise ValueError(
"The cross-loaded TensorFlow model has different outputs, something went wrong! Exaustive list of"
f" maximum tensor differences above the error threshold ({MAX_ERROR}):\n"
+ "\n".join(
[f"{key}: {value:.3e}" for key, value in crossload_differences.items() if value > MAX_ERROR]
)
"The cross-loaded TensorFlow model has different outputs, something went wrong!\n"
+ f"\nList of maximum output differences above the threshold ({MAX_ERROR}):\n"
+ "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > MAX_ERROR])
+ f"\n\nList of maximum hidden layer differences above the threshold ({self._max_hidden_error}):\n"
+ "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_hidden_error])
)
# Save the weights in a TF format (if needed) and confirms that the results are still good
tf_weights_path = os.path.join(self._local_dir, TF_WEIGHTS_NAME)
if not os.path.exists(tf_weights_path) or self._new_weights:
tf_from_pt_model.save_weights(tf_weights_path)
tf_weights_path = os.path.join(self._local_dir, TF2_WEIGHTS_NAME)
tf_weights_index_path = os.path.join(self._local_dir, TF2_WEIGHTS_INDEX_NAME)
if (not os.path.exists(tf_weights_path) and not os.path.exists(tf_weights_index_path)) or self._new_weights:
tf_from_pt_model.save_pretrained(self._local_dir)
del tf_from_pt_model # will no longer be used, and may have a large memory footprint
tf_model = tf_class.from_pretrained(self._local_dir)
conversion_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input)
max_conversion_diff = max(conversion_differences.values())
if max_conversion_diff > MAX_ERROR:
tf_outputs = tf_model(**tf_input, output_hidden_states=True)
conversion_differences = self.find_pt_tf_differences(pt_outputs, tf_outputs)
output_differences = {k: v for k, v in conversion_differences.items() if "hidden" not in k}
hidden_differences = {k: v for k, v in conversion_differences.items() if "hidden" in k}
max_conversion_output_diff = max(output_differences.values())
max_conversion_hidden_diff = max(hidden_differences.values())
if max_conversion_output_diff > MAX_ERROR or max_conversion_hidden_diff > self._max_hidden_error:
raise ValueError(
"The converted TensorFlow model has different outputs, something went wrong! Exaustive list of maximum"
f" tensor differences above the error threshold ({MAX_ERROR}):\n"
+ "\n".join(
[f"{key}: {value:.3e}" for key, value in conversion_differences.items() if value > MAX_ERROR]
)
"The converted TensorFlow model has different outputs, something went wrong!\n"
+ f"\nList of maximum output differences above the threshold ({MAX_ERROR}):\n"
+ "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > MAX_ERROR])
+ f"\n\nList of maximum hidden layer differences above the threshold ({self._max_hidden_error}):\n"
+ "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_hidden_error])
)
commit_message = "Update TF weights" if self._new_weights else "Add TF weights"
@ -300,16 +327,31 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self._logger.warn("Uploading the weights into a new PR...")
commit_descrition = (
"Model converted by the [`transformers`' `pt_to_tf`"
" CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py)."
"\n\nAll converted model outputs and hidden layers were validated against its Pytorch counterpart."
f" Maximum crossload output difference={max_crossload_diff:.3e}; Maximum converted output"
f" difference={max_conversion_diff:.3e}."
" CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py). "
"All converted model outputs and hidden layers were validated against its Pytorch counterpart.\n\n"
f"Maximum crossload output difference={max_crossload_output_diff:.3e}; "
f"Maximum crossload hidden layer difference={max_crossload_hidden_diff:.3e};\n"
f"Maximum conversion output difference={max_conversion_output_diff:.3e}; "
f"Maximum conversion hidden layer difference={max_conversion_hidden_diff:.3e};\n"
)
if self._extra_commit_description:
commit_descrition += "\n\n" + self._extra_commit_description
# sharded model -> adds all related files (index and .h5 shards)
if os.path.exists(tf_weights_index_path):
operations = [
CommitOperationAdd(path_in_repo=TF2_WEIGHTS_INDEX_NAME, path_or_fileobj=tf_weights_index_path)
]
for shard_path in tf.io.gfile.glob(self._local_dir + "/tf_model-*.h5"):
operations += [
CommitOperationAdd(path_in_repo=os.path.basename(shard_path), path_or_fileobj=shard_path)
]
else:
operations = [CommitOperationAdd(path_in_repo=TF2_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)]
hub_pr_url = create_commit(
repo_id=self._model_name,
operations=[CommitOperationAdd(path_in_repo=TF_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)],
operations=operations,
commit_message=commit_message,
commit_description=commit_descrition,
repo_type="model",

View File

@ -117,10 +117,17 @@ def load_pytorch_checkpoint_in_tf2_model(
)
raise
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info(f"Loading PyTorch weights from {pt_path}")
# Treats a single file as a collection of shards with 1 shard.
if isinstance(pytorch_checkpoint_path, str):
pytorch_checkpoint_path = [pytorch_checkpoint_path]
# Loads all shards into a single state dictionary
pt_state_dict = {}
for path in pytorch_checkpoint_path:
pt_path = os.path.abspath(path)
logger.info(f"Loading PyTorch weights from {pt_path}")
pt_state_dict.update(torch.load(pt_path, map_location="cpu"))
pt_state_dict = torch.load(pt_path, map_location="cpu")
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters")
return load_pytorch_weights_in_tf2_model(

View File

@ -50,6 +50,7 @@ from .utils import (
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
EntryNotFoundError,
ModelOutput,
@ -2157,11 +2158,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint in priority if from_pt
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
is_sharded = True
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)):
# Load from a sharded PyTorch checkpoint
# Load from a sharded TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.

View File

@ -27,7 +27,7 @@ from typing import List, Tuple
from datasets import Dataset
from huggingface_hub import HfFolder, delete_repo, set_access_token
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
from requests.exceptions import HTTPError
from transformers import is_tf_available, is_torch_available
from transformers.configuration_utils import PretrainedConfig
@ -1966,6 +1966,16 @@ class UtilsFunctionsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
@is_pt_tf_cross_test
def test_checkpoint_sharding_local_from_pt(self):
with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-bert-sharded")
model = TFBertModel.from_pretrained(tmp_dir, from_pt=True)
# the model above is the same as the model below, just a sharded pytorch version.
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
def test_shard_checkpoint(self):
# This is the model we will use, total size 340,000 bytes.
model = tf.keras.Sequential(