From 78c695eb624bc863ea165b6fb0a8850bfd9fcefa Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 8 Jun 2022 10:45:10 +0100 Subject: [PATCH] CLI: add stricter automatic checks to `pt-to-tf` (#17588) * Stricter pt-to-tf checks; Update docker image for related tests * check all attributes in the output Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docker/transformers-all-latest-gpu/Dockerfile | 3 +- src/transformers/commands/pt_to_tf.py | 111 ++++++++++++++---- 2 files changed, 93 insertions(+), 21 deletions(-) diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index 8d63921ec02..ecc8474c045 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -4,7 +4,8 @@ LABEL maintainer="Hugging Face" ARG DEBIAN_FRONTEND=noninteractive RUN apt update -RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg +RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg git-lfs +RUN git lfs install RUN python3 -m pip install --no-cache-dir --upgrade pip ARG REF=main diff --git a/src/transformers/commands/pt_to_tf.py b/src/transformers/commands/pt_to_tf.py index 778563363cf..7bc5333b109 100644 --- a/src/transformers/commands/pt_to_tf.py +++ b/src/transformers/commands/pt_to_tf.py @@ -14,13 +14,14 @@ import os from argparse import ArgumentParser, Namespace +from importlib import import_module import numpy as np from datasets import load_dataset from huggingface_hub import Repository, upload_file -from .. import AutoFeatureExtractor, AutoModel, AutoTokenizer, TFAutoModel, is_tf_available, is_torch_available +from .. import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available from ..utils import logging from . import BaseTransformersCLICommand @@ -44,7 +45,7 @@ def convert_command_factory(args: Namespace): Returns: ServeCommand """ - return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr) + return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr, args.new_weights) class PTtoTFCommand(BaseTransformersCLICommand): @@ -78,13 +79,69 @@ class PTtoTFCommand(BaseTransformersCLICommand): train_parser.add_argument( "--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights." ) + train_parser.add_argument( + "--new-weights", + action="store_true", + help="Optional flag to create new TensorFlow weights, even if they already exist.", + ) train_parser.set_defaults(func=convert_command_factory) - def __init__(self, model_name: str, local_dir: str, no_pr: bool, *args): + @staticmethod + def compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input): + """ + Compares the TensorFlow and PyTorch models, given their inputs, returning a tuple with the maximum observed + difference and its source. + """ + pt_outputs = pt_model(**pt_input, output_hidden_states=True) + tf_outputs = tf_model(**tf_input, output_hidden_states=True) + + # 1. All output attributes must be the same + pt_out_attrs = set(pt_outputs.keys()) + tf_out_attrs = set(tf_outputs.keys()) + if pt_out_attrs != tf_out_attrs: + raise ValueError( + f"The model outputs have different attributes, aborting. (Pytorch: {pt_out_attrs}, TensorFlow:" + f" {tf_out_attrs})" + ) + + # 2. For each output attribute, ALL values must be the same + def _compate_pt_tf_models(pt_out, tf_out, attr_name=""): + max_difference = 0 + max_difference_source = "" + + # If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in + # recursivelly, keeping the name of the attribute. + if isinstance(pt_out, (torch.Tensor)): + difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy())) + if difference > max_difference: + max_difference = difference + max_difference_source = attr_name + else: + root_name = attr_name + for i, pt_item in enumerate(pt_out): + # If it is a named attribute, we keep the name. Otherwise, just its index. + if isinstance(pt_item, str): + branch_name = root_name + pt_item + tf_item = tf_out[pt_item] + pt_item = pt_out[pt_item] + else: + branch_name = root_name + f"[{i}]" + tf_item = tf_out[i] + difference, difference_source = _compate_pt_tf_models(pt_item, tf_item, branch_name) + if difference > max_difference: + max_difference = difference + max_difference_source = difference_source + + return max_difference, max_difference_source + + return _compate_pt_tf_models(pt_outputs, tf_outputs) + + def __init__(self, model_name: str, local_dir: str, no_pr: bool, new_weights: bool, *args): self._logger = logging.get_logger("transformers-cli/pt_to_tf") self._model_name = model_name self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name) self._no_pr = no_pr + self._new_weights = new_weights def get_text_inputs(self): tokenizer = AutoTokenizer.from_pretrained(self._local_dir) @@ -119,8 +176,25 @@ class PTtoTFCommand(BaseTransformersCLICommand): repo = Repository(local_dir=self._local_dir, clone_from=self._model_name) repo.git_pull() # in case the repo already exists locally, but with an older commit + # Load config and get the appropriate architecture -- the latter is needed to convert the head's weights + config = AutoConfig.from_pretrained(self._local_dir) + architectures = config.architectures + if architectures is None: # No architecture defined -- use auto classes + pt_class = getattr(import_module("transformers"), "AutoModel") + tf_class = getattr(import_module("transformers"), "TFAutoModel") + self._logger.warn("No detected architecture, using AutoModel/TFAutoModel") + else: # Architecture defined -- use it + if len(architectures) > 1: + raise ValueError(f"More than one architecture was found, aborting. (architectures = {architectures})") + self._logger.warn(f"Detected architecture: {architectures[0]}") + pt_class = getattr(import_module("transformers"), architectures[0]) + try: + tf_class = getattr(import_module("transformers"), "TF" + architectures[0]) + except AttributeError: + raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.") + # Load models and acquire a basic input for its modality. - pt_model = AutoModel.from_pretrained(self._local_dir) + pt_model = pt_class.from_pretrained(self._local_dir) main_input_name = pt_model.main_input_name if main_input_name == "input_ids": pt_input, tf_input = self.get_text_inputs() @@ -130,7 +204,7 @@ class PTtoTFCommand(BaseTransformersCLICommand): pt_input, tf_input = self.get_audio_inputs() else: raise ValueError(f"Can't detect the model modality (`main_input_name` = {main_input_name})") - tf_from_pt_model = TFAutoModel.from_pretrained(self._local_dir, from_pt=True) + tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True) # Extra input requirements, in addition to the input modality if hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"): @@ -139,27 +213,24 @@ class PTtoTFCommand(BaseTransformersCLICommand): tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)}) # Confirms that cross loading PT weights into TF worked. - pt_last_hidden_state = pt_model(**pt_input).last_hidden_state.detach().numpy() - tf_from_pt_last_hidden_state = tf_from_pt_model(**tf_input).last_hidden_state.numpy() - crossload_diff = np.max(np.abs(pt_last_hidden_state - tf_from_pt_last_hidden_state)) + crossload_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_from_pt_model, tf_input) if crossload_diff >= MAX_ERROR: raise ValueError( - "The cross-loaded TF model has different last hidden states, something went wrong! (max difference =" - f" {crossload_diff})" + "The cross-loaded TF model has different outputs, something went wrong! (max difference =" + f" {crossload_diff:.3e}, observed in {diff_source})" ) - # Save the weights in a TF format (if they don't exist) and confirms that the results are still good + # Save the weights in a TF format (if needed) and confirms that the results are still good tf_weights_path = os.path.join(self._local_dir, TF_WEIGHTS_NAME) - if not os.path.exists(tf_weights_path): + if not os.path.exists(tf_weights_path) or self._new_weights: tf_from_pt_model.save_weights(tf_weights_path) - del tf_from_pt_model, pt_model # will no longer be used, and may have a large memory footprint - tf_model = TFAutoModel.from_pretrained(self._local_dir) - tf_last_hidden_state = tf_model(**tf_input).last_hidden_state.numpy() - converted_diff = np.max(np.abs(pt_last_hidden_state - tf_last_hidden_state)) + del tf_from_pt_model # will no longer be used, and may have a large memory footprint + tf_model = tf_class.from_pretrained(self._local_dir) + converted_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input) if converted_diff >= MAX_ERROR: raise ValueError( - "The converted TF model has different last hidden states, something went wrong! (max difference =" - f" {converted_diff})" + "The converted TF model has different outputs, something went wrong! (max difference =" + f" {converted_diff:.3e}, observed in {diff_source})" ) if not self._no_pr: @@ -174,8 +245,8 @@ class PTtoTFCommand(BaseTransformersCLICommand): create_pr=True, pr_commit_summary="Add TF weights", pr_commit_description=( - f"Validated by the `pt_to_tf` CLI. Max crossload hidden state difference={crossload_diff:.3e};" - f" Max converted hidden state difference={converted_diff:.3e}." + f"Validated by the `pt_to_tf` CLI. Max crossload output difference={crossload_diff:.3e};" + f" Max converted output difference={converted_diff:.3e}." ), ) self._logger.warn(f"PR open in {hub_pr_url}")