mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-07 23:00:08 +06:00
CLI: add stricter automatic checks to pt-to-tf
(#17588)
* Stricter pt-to-tf checks; Update docker image for related tests * check all attributes in the output Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
c6cea5a78c
commit
78c695eb62
@ -4,7 +4,8 @@ LABEL maintainer="Hugging Face"
|
|||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
RUN apt update
|
RUN apt update
|
||||||
RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg
|
RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg git-lfs
|
||||||
|
RUN git lfs install
|
||||||
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
||||||
|
|
||||||
ARG REF=main
|
ARG REF=main
|
||||||
|
@ -14,13 +14,14 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
from huggingface_hub import Repository, upload_file
|
from huggingface_hub import Repository, upload_file
|
||||||
|
|
||||||
from .. import AutoFeatureExtractor, AutoModel, AutoTokenizer, TFAutoModel, is_tf_available, is_torch_available
|
from .. import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from . import BaseTransformersCLICommand
|
from . import BaseTransformersCLICommand
|
||||||
|
|
||||||
@ -44,7 +45,7 @@ def convert_command_factory(args: Namespace):
|
|||||||
|
|
||||||
Returns: ServeCommand
|
Returns: ServeCommand
|
||||||
"""
|
"""
|
||||||
return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr)
|
return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr, args.new_weights)
|
||||||
|
|
||||||
|
|
||||||
class PTtoTFCommand(BaseTransformersCLICommand):
|
class PTtoTFCommand(BaseTransformersCLICommand):
|
||||||
@ -78,13 +79,69 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
|||||||
train_parser.add_argument(
|
train_parser.add_argument(
|
||||||
"--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
|
"--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
|
||||||
)
|
)
|
||||||
|
train_parser.add_argument(
|
||||||
|
"--new-weights",
|
||||||
|
action="store_true",
|
||||||
|
help="Optional flag to create new TensorFlow weights, even if they already exist.",
|
||||||
|
)
|
||||||
train_parser.set_defaults(func=convert_command_factory)
|
train_parser.set_defaults(func=convert_command_factory)
|
||||||
|
|
||||||
def __init__(self, model_name: str, local_dir: str, no_pr: bool, *args):
|
@staticmethod
|
||||||
|
def compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input):
|
||||||
|
"""
|
||||||
|
Compares the TensorFlow and PyTorch models, given their inputs, returning a tuple with the maximum observed
|
||||||
|
difference and its source.
|
||||||
|
"""
|
||||||
|
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
|
||||||
|
tf_outputs = tf_model(**tf_input, output_hidden_states=True)
|
||||||
|
|
||||||
|
# 1. All output attributes must be the same
|
||||||
|
pt_out_attrs = set(pt_outputs.keys())
|
||||||
|
tf_out_attrs = set(tf_outputs.keys())
|
||||||
|
if pt_out_attrs != tf_out_attrs:
|
||||||
|
raise ValueError(
|
||||||
|
f"The model outputs have different attributes, aborting. (Pytorch: {pt_out_attrs}, TensorFlow:"
|
||||||
|
f" {tf_out_attrs})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. For each output attribute, ALL values must be the same
|
||||||
|
def _compate_pt_tf_models(pt_out, tf_out, attr_name=""):
|
||||||
|
max_difference = 0
|
||||||
|
max_difference_source = ""
|
||||||
|
|
||||||
|
# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
|
||||||
|
# recursivelly, keeping the name of the attribute.
|
||||||
|
if isinstance(pt_out, (torch.Tensor)):
|
||||||
|
difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy()))
|
||||||
|
if difference > max_difference:
|
||||||
|
max_difference = difference
|
||||||
|
max_difference_source = attr_name
|
||||||
|
else:
|
||||||
|
root_name = attr_name
|
||||||
|
for i, pt_item in enumerate(pt_out):
|
||||||
|
# If it is a named attribute, we keep the name. Otherwise, just its index.
|
||||||
|
if isinstance(pt_item, str):
|
||||||
|
branch_name = root_name + pt_item
|
||||||
|
tf_item = tf_out[pt_item]
|
||||||
|
pt_item = pt_out[pt_item]
|
||||||
|
else:
|
||||||
|
branch_name = root_name + f"[{i}]"
|
||||||
|
tf_item = tf_out[i]
|
||||||
|
difference, difference_source = _compate_pt_tf_models(pt_item, tf_item, branch_name)
|
||||||
|
if difference > max_difference:
|
||||||
|
max_difference = difference
|
||||||
|
max_difference_source = difference_source
|
||||||
|
|
||||||
|
return max_difference, max_difference_source
|
||||||
|
|
||||||
|
return _compate_pt_tf_models(pt_outputs, tf_outputs)
|
||||||
|
|
||||||
|
def __init__(self, model_name: str, local_dir: str, no_pr: bool, new_weights: bool, *args):
|
||||||
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
|
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
|
||||||
self._model_name = model_name
|
self._model_name = model_name
|
||||||
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
|
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
|
||||||
self._no_pr = no_pr
|
self._no_pr = no_pr
|
||||||
|
self._new_weights = new_weights
|
||||||
|
|
||||||
def get_text_inputs(self):
|
def get_text_inputs(self):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(self._local_dir)
|
tokenizer = AutoTokenizer.from_pretrained(self._local_dir)
|
||||||
@ -119,8 +176,25 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
|||||||
repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)
|
repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)
|
||||||
repo.git_pull() # in case the repo already exists locally, but with an older commit
|
repo.git_pull() # in case the repo already exists locally, but with an older commit
|
||||||
|
|
||||||
|
# Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
|
||||||
|
config = AutoConfig.from_pretrained(self._local_dir)
|
||||||
|
architectures = config.architectures
|
||||||
|
if architectures is None: # No architecture defined -- use auto classes
|
||||||
|
pt_class = getattr(import_module("transformers"), "AutoModel")
|
||||||
|
tf_class = getattr(import_module("transformers"), "TFAutoModel")
|
||||||
|
self._logger.warn("No detected architecture, using AutoModel/TFAutoModel")
|
||||||
|
else: # Architecture defined -- use it
|
||||||
|
if len(architectures) > 1:
|
||||||
|
raise ValueError(f"More than one architecture was found, aborting. (architectures = {architectures})")
|
||||||
|
self._logger.warn(f"Detected architecture: {architectures[0]}")
|
||||||
|
pt_class = getattr(import_module("transformers"), architectures[0])
|
||||||
|
try:
|
||||||
|
tf_class = getattr(import_module("transformers"), "TF" + architectures[0])
|
||||||
|
except AttributeError:
|
||||||
|
raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.")
|
||||||
|
|
||||||
# Load models and acquire a basic input for its modality.
|
# Load models and acquire a basic input for its modality.
|
||||||
pt_model = AutoModel.from_pretrained(self._local_dir)
|
pt_model = pt_class.from_pretrained(self._local_dir)
|
||||||
main_input_name = pt_model.main_input_name
|
main_input_name = pt_model.main_input_name
|
||||||
if main_input_name == "input_ids":
|
if main_input_name == "input_ids":
|
||||||
pt_input, tf_input = self.get_text_inputs()
|
pt_input, tf_input = self.get_text_inputs()
|
||||||
@ -130,7 +204,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
|||||||
pt_input, tf_input = self.get_audio_inputs()
|
pt_input, tf_input = self.get_audio_inputs()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Can't detect the model modality (`main_input_name` = {main_input_name})")
|
raise ValueError(f"Can't detect the model modality (`main_input_name` = {main_input_name})")
|
||||||
tf_from_pt_model = TFAutoModel.from_pretrained(self._local_dir, from_pt=True)
|
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
|
||||||
|
|
||||||
# Extra input requirements, in addition to the input modality
|
# Extra input requirements, in addition to the input modality
|
||||||
if hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"):
|
if hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"):
|
||||||
@ -139,27 +213,24 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
|||||||
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
|
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
|
||||||
|
|
||||||
# Confirms that cross loading PT weights into TF worked.
|
# Confirms that cross loading PT weights into TF worked.
|
||||||
pt_last_hidden_state = pt_model(**pt_input).last_hidden_state.detach().numpy()
|
crossload_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_from_pt_model, tf_input)
|
||||||
tf_from_pt_last_hidden_state = tf_from_pt_model(**tf_input).last_hidden_state.numpy()
|
|
||||||
crossload_diff = np.max(np.abs(pt_last_hidden_state - tf_from_pt_last_hidden_state))
|
|
||||||
if crossload_diff >= MAX_ERROR:
|
if crossload_diff >= MAX_ERROR:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The cross-loaded TF model has different last hidden states, something went wrong! (max difference ="
|
"The cross-loaded TF model has different outputs, something went wrong! (max difference ="
|
||||||
f" {crossload_diff})"
|
f" {crossload_diff:.3e}, observed in {diff_source})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save the weights in a TF format (if they don't exist) and confirms that the results are still good
|
# Save the weights in a TF format (if needed) and confirms that the results are still good
|
||||||
tf_weights_path = os.path.join(self._local_dir, TF_WEIGHTS_NAME)
|
tf_weights_path = os.path.join(self._local_dir, TF_WEIGHTS_NAME)
|
||||||
if not os.path.exists(tf_weights_path):
|
if not os.path.exists(tf_weights_path) or self._new_weights:
|
||||||
tf_from_pt_model.save_weights(tf_weights_path)
|
tf_from_pt_model.save_weights(tf_weights_path)
|
||||||
del tf_from_pt_model, pt_model # will no longer be used, and may have a large memory footprint
|
del tf_from_pt_model # will no longer be used, and may have a large memory footprint
|
||||||
tf_model = TFAutoModel.from_pretrained(self._local_dir)
|
tf_model = tf_class.from_pretrained(self._local_dir)
|
||||||
tf_last_hidden_state = tf_model(**tf_input).last_hidden_state.numpy()
|
converted_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input)
|
||||||
converted_diff = np.max(np.abs(pt_last_hidden_state - tf_last_hidden_state))
|
|
||||||
if converted_diff >= MAX_ERROR:
|
if converted_diff >= MAX_ERROR:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The converted TF model has different last hidden states, something went wrong! (max difference ="
|
"The converted TF model has different outputs, something went wrong! (max difference ="
|
||||||
f" {converted_diff})"
|
f" {converted_diff:.3e}, observed in {diff_source})"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self._no_pr:
|
if not self._no_pr:
|
||||||
@ -174,8 +245,8 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
|||||||
create_pr=True,
|
create_pr=True,
|
||||||
pr_commit_summary="Add TF weights",
|
pr_commit_summary="Add TF weights",
|
||||||
pr_commit_description=(
|
pr_commit_description=(
|
||||||
f"Validated by the `pt_to_tf` CLI. Max crossload hidden state difference={crossload_diff:.3e};"
|
f"Validated by the `pt_to_tf` CLI. Max crossload output difference={crossload_diff:.3e};"
|
||||||
f" Max converted hidden state difference={converted_diff:.3e}."
|
f" Max converted output difference={converted_diff:.3e}."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._logger.warn(f"PR open in {hub_pr_url}")
|
self._logger.warn(f"PR open in {hub_pr_url}")
|
||||||
|
Loading…
Reference in New Issue
Block a user