CLI: tool to convert PT into TF weights and open hub PR (#17497)

This commit is contained in:
Joao Gante 2022-06-01 18:52:07 +01:00 committed by GitHub
parent 3766df4fe1
commit ca1f1c8685
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 208 additions and 3 deletions

View File

@ -78,7 +78,8 @@ jobs:
keys:
- v0.4-torch_and_tf-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng git-lfs
- run: git lfs install
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,torch-speech,vision]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
@ -117,7 +118,8 @@ jobs:
keys:
- v0.4-torch_and_tf-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng git-lfs
- run: git lfs install
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,torch-speech,vision]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html

View File

@ -0,0 +1,186 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from argparse import ArgumentParser, Namespace
import numpy as np
from datasets import load_dataset
from huggingface_hub import Repository, upload_file
from .. import AutoFeatureExtractor, AutoModel, AutoTokenizer, TFAutoModel, is_tf_available, is_torch_available
from ..utils import logging
from . import BaseTransformersCLICommand
if is_tf_available():
import tensorflow as tf
tf.config.experimental.enable_tensor_float_32_execution(False)
if is_torch_available():
import torch
MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors
TF_WEIGHTS_NAME = "tf_model.h5"
def convert_command_factory(args: Namespace):
"""
Factory function used to convert a model PyTorch checkpoint in a TensorFlow 2 checkpoint.
Returns: ServeCommand
"""
return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr)
class PTtoTFCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
"""
Register this command to argparse so it's available for the transformer-cli
Args:
parser: Root parser to register command-specific arguments
"""
train_parser = parser.add_parser(
"pt-to-tf",
help=(
"CLI tool to run convert a transformers model from a PyTorch checkpoint to a TensorFlow checkpoint."
" Can also be used to validate existing weights without opening PRs, with --no-pr."
),
)
train_parser.add_argument(
"--model-name",
type=str,
required=True,
help="The model name, including owner/organization, as seen on the hub.",
)
train_parser.add_argument(
"--local-dir",
type=str,
default="",
help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
)
train_parser.add_argument(
"--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
)
train_parser.set_defaults(func=convert_command_factory)
def __init__(self, model_name: str, local_dir: str, no_pr: bool, *args):
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
self._model_name = model_name
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
self._no_pr = no_pr
def get_text_inputs(self):
tokenizer = AutoTokenizer.from_pretrained(self._local_dir)
sample_text = ["Hi there!", "I am a batch with more than one row and different input lengths."]
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
pt_input = tokenizer(sample_text, return_tensors="pt", padding=True, truncation=True)
tf_input = tokenizer(sample_text, return_tensors="tf", padding=True, truncation=True)
return pt_input, tf_input
def get_audio_inputs(self):
processor = AutoFeatureExtractor.from_pretrained(self._local_dir)
num_samples = 2
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
raw_samples = [x["array"] for x in speech_samples]
pt_input = processor(raw_samples, return_tensors="pt", padding=True)
tf_input = processor(raw_samples, return_tensors="tf", padding=True)
return pt_input, tf_input
def get_image_inputs(self):
feature_extractor = AutoFeatureExtractor.from_pretrained(self._local_dir)
num_samples = 2
ds = load_dataset("cifar10", "plain_text", split="test")[:num_samples]["img"]
pt_input = feature_extractor(images=ds, return_tensors="pt")
tf_input = feature_extractor(images=ds, return_tensors="tf")
return pt_input, tf_input
def run(self):
# Fetch remote data
# TODO: implement a solution to pull a specific PR/commit, so we can use this CLI to validate pushes.
repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)
repo.git_pull() # in case the repo already exists locally, but with an older commit
# Load models and acquire a basic input for its modality.
pt_model = AutoModel.from_pretrained(self._local_dir)
main_input_name = pt_model.main_input_name
if main_input_name == "input_ids":
pt_input, tf_input = self.get_text_inputs()
elif main_input_name == "pixel_values":
pt_input, tf_input = self.get_image_inputs()
elif main_input_name == "input_features":
pt_input, tf_input = self.get_audio_inputs()
else:
raise ValueError(f"Can't detect the model modality (`main_input_name` = {main_input_name})")
tf_from_pt_model = TFAutoModel.from_pretrained(self._local_dir, from_pt=True)
# Extra input requirements, in addition to the input modality
if hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"):
decoder_input_ids = np.asarray([[1], [1]], dtype=int) * pt_model.config.decoder_start_token_id
pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)})
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
# Confirms that cross loading PT weights into TF worked.
pt_last_hidden_state = pt_model(**pt_input).last_hidden_state.detach().numpy()
tf_from_pt_last_hidden_state = tf_from_pt_model(**tf_input).last_hidden_state.numpy()
crossload_diff = np.max(np.abs(pt_last_hidden_state - tf_from_pt_last_hidden_state))
if crossload_diff >= MAX_ERROR:
raise ValueError(
"The cross-loaded TF model has different last hidden states, something went wrong! (max difference ="
f" {crossload_diff})"
)
# Save the weights in a TF format (if they don't exist) and confirms that the results are still good
tf_weights_path = os.path.join(self._local_dir, TF_WEIGHTS_NAME)
if not os.path.exists(tf_weights_path):
tf_from_pt_model.save_weights(tf_weights_path)
del tf_from_pt_model, pt_model # will no longer be used, and may have a large memory footprint
tf_model = TFAutoModel.from_pretrained(self._local_dir)
tf_last_hidden_state = tf_model(**tf_input).last_hidden_state.numpy()
converted_diff = np.max(np.abs(pt_last_hidden_state - tf_last_hidden_state))
if converted_diff >= MAX_ERROR:
raise ValueError(
"The converted TF model has different last hidden states, something went wrong! (max difference ="
f" {converted_diff})"
)
if not self._no_pr:
# TODO: remove try/except when the upload to PR feature is released
# (https://github.com/huggingface/huggingface_hub/pull/884)
try:
self._logger.warn("Uploading the weights into a new PR...")
hub_pr_url = upload_file(
path_or_fileobj=tf_weights_path,
path_in_repo=TF_WEIGHTS_NAME,
repo_id=self._model_name,
create_pr=True,
pr_commit_summary="Add TF weights",
pr_commit_description=(
f"Validated by the `pt_to_tf` CLI. Max crossload hidden state difference={crossload_diff:.3e};"
f" Max converted hidden state difference={converted_diff:.3e}."
),
)
self._logger.warn(f"PR open in {hub_pr_url}")
except TypeError:
self._logger.warn(
f"You can now open a PR in https://huggingface.co/{self._model_name}/discussions, manually"
f" uploading the file in {tf_weights_path}"
)

View File

@ -21,6 +21,7 @@ from .convert import ConvertCommand
from .download import DownloadCommand
from .env import EnvironmentCommand
from .lfs import LfsCommands
from .pt_to_tf import PTtoTFCommand
from .run import RunCommand
from .serving import ServeCommand
from .user import UserCommands
@ -40,6 +41,7 @@ def main():
AddNewModelCommand.register_subcommand(commands_parser)
AddNewModelLikeCommand.register_subcommand(commands_parser)
LfsCommands.register_subcommand(commands_parser)
PTtoTFCommand.register_subcommand(commands_parser)
# Let's go
args = parser.parse_args()

View File

@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import unittest
from unittest.mock import patch
from transformers.testing_utils import CaptureStd
from transformers.testing_utils import CaptureStd, is_pt_tf_cross_test
class CLITest(unittest.TestCase):
@ -30,3 +32,16 @@ class CLITest(unittest.TestCase):
self.assertIn("Python version", cs.out)
self.assertIn("Platform", cs.out)
self.assertIn("Using distributed or parallel set-up in script?", cs.out)
@is_pt_tf_cross_test
@patch(
"sys.argv", ["fakeprogrampath", "pt-to-tf", "--model-name", "hf-internal-testing/tiny-random-gptj", "--no-pr"]
)
def test_cli_pt_to_tf(self):
import transformers.commands.transformers_cli
shutil.rmtree("/tmp/hf-internal-testing/tiny-random-gptj", ignore_errors=True) # cleans potential past runs
transformers.commands.transformers_cli.main()
# The original repo has no TF weights -- if they exist, they were created by the CLI
self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5"))