From 92dfceb12446404fe547b1271ecaf86d2498546b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 27 Feb 2023 15:31:55 +0000 Subject: [PATCH] Inheritance-based framework detection (#21784) --- src/transformers/pipelines/base.py | 5 ++- src/transformers/utils/generic.py | 26 +++++++----- tests/utils/test_file_utils.py | 64 ++++++++++++++++++++---------- 3 files changed, 60 insertions(+), 35 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 842653e6fc4..3dca2d33d15 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -15,6 +15,7 @@ import collections import csv import importlib +import inspect import json import os import pickle @@ -269,7 +270,7 @@ def infer_framework_load_model( if isinstance(model, str): raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.") - framework = "tf" if model.__class__.__name__.startswith("TF") else "pt" + framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt" return framework, model @@ -342,7 +343,7 @@ def get_framework(model, revision: Optional[str] = None): except OSError: model = TFAutoModel.from_pretrained(model, revision=revision) - framework = "tf" if model.__class__.__name__.startswith("TF") else "pt" + framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt" return framework diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index d138e0c1d4e..21e9cf514f7 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -366,13 +366,14 @@ def can_return_loss(model_class): Args: model_class (`type`): The class of the model. """ - model_name = model_class.__name__ - if model_name.startswith("TF"): - signature = inspect.signature(model_class.call) - elif model_name.startswith("Flax"): - signature = inspect.signature(model_class.__call__) + base_classes = str(inspect.getmro(model_class)) + + if "keras.engine.training.Model" in base_classes: + signature = inspect.signature(model_class.call) # TensorFlow models + elif "torch.nn.modules.module.Module" in base_classes: + signature = inspect.signature(model_class.forward) # PyTorch models else: - signature = inspect.signature(model_class.forward) + signature = inspect.signature(model_class.__call__) # Flax models for p in signature.parameters: if p == "return_loss" and signature.parameters[p].default is True: @@ -389,12 +390,15 @@ def find_labels(model_class): model_class (`type`): The class of the model. """ model_name = model_class.__name__ - if model_name.startswith("TF"): - signature = inspect.signature(model_class.call) - elif model_name.startswith("Flax"): - signature = inspect.signature(model_class.__call__) + base_classes = str(inspect.getmro(model_class)) + + if "keras.engine.training.Model" in base_classes: + signature = inspect.signature(model_class.call) # TensorFlow models + elif "torch.nn.modules.module.Module" in base_classes: + signature = inspect.signature(model_class.forward) # PyTorch models else: - signature = inspect.signature(model_class.forward) + signature = inspect.signature(model_class.__call__) # Flax models + if "QuestionAnswering" in model_name: return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")] else: diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py index e7963bfa51a..1cbde0fb18c 100644 --- a/tests/utils/test_file_utils.py +++ b/tests/utils/test_file_utils.py @@ -21,10 +21,20 @@ import transformers # Try to import everything from transformers to ensure every object can be loaded. from transformers import * # noqa F406 -from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER +from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, require_flax, require_tf, require_torch from transformers.utils import ContextManagers, find_labels, is_flax_available, is_tf_available, is_torch_available +if is_torch_available(): + from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification + +if is_tf_available(): + from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification + +if is_flax_available(): + from transformers import FlaxBertForPreTraining, FlaxBertForQuestionAnswering, FlaxBertForSequenceClassification + + MODEL_ID = DUMMY_UNKNOWN_IDENTIFIER # An actual model hosted on huggingface.co @@ -85,29 +95,39 @@ class GenericUtilTests(unittest.TestCase): # The output should be wrapped with an English and French welcome and goodbye self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n") - def test_find_labels(self): - if is_torch_available(): - from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification + @require_torch + def test_find_labels_pt(self): + self.assertEqual(find_labels(BertForSequenceClassification), ["labels"]) + self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"]) + self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"]) - self.assertEqual(find_labels(BertForSequenceClassification), ["labels"]) - self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"]) - self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"]) + # find_labels works regardless of the class name (it detects the framework through inheritance) + class DummyModel(BertForSequenceClassification): + pass - if is_tf_available(): - from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification + self.assertEqual(find_labels(DummyModel), ["labels"]) - self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"]) - self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"]) - self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"]) + @require_tf + def test_find_labels_tf(self): + self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"]) + self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"]) + self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"]) - if is_flax_available(): - # Flax models don't have labels - from transformers import ( - FlaxBertForPreTraining, - FlaxBertForQuestionAnswering, - FlaxBertForSequenceClassification, - ) + # find_labels works regardless of the class name (it detects the framework through inheritance) + class DummyModel(TFBertForSequenceClassification): + pass - self.assertEqual(find_labels(FlaxBertForSequenceClassification), []) - self.assertEqual(find_labels(FlaxBertForPreTraining), []) - self.assertEqual(find_labels(FlaxBertForQuestionAnswering), []) + self.assertEqual(find_labels(DummyModel), ["labels"]) + + @require_flax + def test_find_labels_flax(self): + # Flax models don't have labels + self.assertEqual(find_labels(FlaxBertForSequenceClassification), []) + self.assertEqual(find_labels(FlaxBertForPreTraining), []) + self.assertEqual(find_labels(FlaxBertForQuestionAnswering), []) + + # find_labels works regardless of the class name (it detects the framework through inheritance) + class DummyModel(FlaxBertForSequenceClassification): + pass + + self.assertEqual(find_labels(DummyModel), [])