Inheritance-based framework detection (#21784)

This commit is contained in:
Joao Gante 2023-02-27 15:31:55 +00:00 committed by GitHub
parent 7811bf7e73
commit 92dfceb124
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 35 deletions

View File

@ -15,6 +15,7 @@
import collections
import csv
import importlib
import inspect
import json
import os
import pickle
@ -269,7 +270,7 @@ def infer_framework_load_model(
if isinstance(model, str):
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
return framework, model
@ -342,7 +343,7 @@ def get_framework(model, revision: Optional[str] = None):
except OSError:
model = TFAutoModel.from_pretrained(model, revision=revision)
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
return framework

View File

@ -366,13 +366,14 @@ def can_return_loss(model_class):
Args:
model_class (`type`): The class of the model.
"""
model_name = model_class.__name__
if model_name.startswith("TF"):
signature = inspect.signature(model_class.call)
elif model_name.startswith("Flax"):
signature = inspect.signature(model_class.__call__)
base_classes = str(inspect.getmro(model_class))
if "keras.engine.training.Model" in base_classes:
signature = inspect.signature(model_class.call) # TensorFlow models
elif "torch.nn.modules.module.Module" in base_classes:
signature = inspect.signature(model_class.forward) # PyTorch models
else:
signature = inspect.signature(model_class.forward)
signature = inspect.signature(model_class.__call__) # Flax models
for p in signature.parameters:
if p == "return_loss" and signature.parameters[p].default is True:
@ -389,12 +390,15 @@ def find_labels(model_class):
model_class (`type`): The class of the model.
"""
model_name = model_class.__name__
if model_name.startswith("TF"):
signature = inspect.signature(model_class.call)
elif model_name.startswith("Flax"):
signature = inspect.signature(model_class.__call__)
base_classes = str(inspect.getmro(model_class))
if "keras.engine.training.Model" in base_classes:
signature = inspect.signature(model_class.call) # TensorFlow models
elif "torch.nn.modules.module.Module" in base_classes:
signature = inspect.signature(model_class.forward) # PyTorch models
else:
signature = inspect.signature(model_class.forward)
signature = inspect.signature(model_class.__call__) # Flax models
if "QuestionAnswering" in model_name:
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
else:

View File

@ -21,10 +21,20 @@ import transformers
# Try to import everything from transformers to ensure every object can be loaded.
from transformers import * # noqa F406
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, require_flax, require_tf, require_torch
from transformers.utils import ContextManagers, find_labels, is_flax_available, is_tf_available, is_torch_available
if is_torch_available():
from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification
if is_tf_available():
from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification
if is_flax_available():
from transformers import FlaxBertForPreTraining, FlaxBertForQuestionAnswering, FlaxBertForSequenceClassification
MODEL_ID = DUMMY_UNKNOWN_IDENTIFIER
# An actual model hosted on huggingface.co
@ -85,29 +95,39 @@ class GenericUtilTests(unittest.TestCase):
# The output should be wrapped with an English and French welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
def test_find_labels(self):
if is_torch_available():
from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification
@require_torch
def test_find_labels_pt(self):
self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])
self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])
# find_labels works regardless of the class name (it detects the framework through inheritance)
class DummyModel(BertForSequenceClassification):
pass
if is_tf_available():
from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification
self.assertEqual(find_labels(DummyModel), ["labels"])
self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"])
@require_tf
def test_find_labels_tf(self):
self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"])
if is_flax_available():
# Flax models don't have labels
from transformers import (
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForSequenceClassification,
)
# find_labels works regardless of the class name (it detects the framework through inheritance)
class DummyModel(TFBertForSequenceClassification):
pass
self.assertEqual(find_labels(FlaxBertForSequenceClassification), [])
self.assertEqual(find_labels(FlaxBertForPreTraining), [])
self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])
self.assertEqual(find_labels(DummyModel), ["labels"])
@require_flax
def test_find_labels_flax(self):
# Flax models don't have labels
self.assertEqual(find_labels(FlaxBertForSequenceClassification), [])
self.assertEqual(find_labels(FlaxBertForPreTraining), [])
self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])
# find_labels works regardless of the class name (it detects the framework through inheritance)
class DummyModel(FlaxBertForSequenceClassification):
pass
self.assertEqual(find_labels(DummyModel), [])