mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-14 18:18:24 +06:00
Inheritance-based framework detection (#21784)
This commit is contained in:
parent
7811bf7e73
commit
92dfceb124
@ -15,6 +15,7 @@
|
|||||||
import collections
|
import collections
|
||||||
import csv
|
import csv
|
||||||
import importlib
|
import importlib
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@ -269,7 +270,7 @@ def infer_framework_load_model(
|
|||||||
if isinstance(model, str):
|
if isinstance(model, str):
|
||||||
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
|
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
|
||||||
|
|
||||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
|
||||||
return framework, model
|
return framework, model
|
||||||
|
|
||||||
|
|
||||||
@ -342,7 +343,7 @@ def get_framework(model, revision: Optional[str] = None):
|
|||||||
except OSError:
|
except OSError:
|
||||||
model = TFAutoModel.from_pretrained(model, revision=revision)
|
model = TFAutoModel.from_pretrained(model, revision=revision)
|
||||||
|
|
||||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
|
||||||
return framework
|
return framework
|
||||||
|
|
||||||
|
|
||||||
|
@ -366,13 +366,14 @@ def can_return_loss(model_class):
|
|||||||
Args:
|
Args:
|
||||||
model_class (`type`): The class of the model.
|
model_class (`type`): The class of the model.
|
||||||
"""
|
"""
|
||||||
model_name = model_class.__name__
|
base_classes = str(inspect.getmro(model_class))
|
||||||
if model_name.startswith("TF"):
|
|
||||||
signature = inspect.signature(model_class.call)
|
if "keras.engine.training.Model" in base_classes:
|
||||||
elif model_name.startswith("Flax"):
|
signature = inspect.signature(model_class.call) # TensorFlow models
|
||||||
signature = inspect.signature(model_class.__call__)
|
elif "torch.nn.modules.module.Module" in base_classes:
|
||||||
|
signature = inspect.signature(model_class.forward) # PyTorch models
|
||||||
else:
|
else:
|
||||||
signature = inspect.signature(model_class.forward)
|
signature = inspect.signature(model_class.__call__) # Flax models
|
||||||
|
|
||||||
for p in signature.parameters:
|
for p in signature.parameters:
|
||||||
if p == "return_loss" and signature.parameters[p].default is True:
|
if p == "return_loss" and signature.parameters[p].default is True:
|
||||||
@ -389,12 +390,15 @@ def find_labels(model_class):
|
|||||||
model_class (`type`): The class of the model.
|
model_class (`type`): The class of the model.
|
||||||
"""
|
"""
|
||||||
model_name = model_class.__name__
|
model_name = model_class.__name__
|
||||||
if model_name.startswith("TF"):
|
base_classes = str(inspect.getmro(model_class))
|
||||||
signature = inspect.signature(model_class.call)
|
|
||||||
elif model_name.startswith("Flax"):
|
if "keras.engine.training.Model" in base_classes:
|
||||||
signature = inspect.signature(model_class.__call__)
|
signature = inspect.signature(model_class.call) # TensorFlow models
|
||||||
|
elif "torch.nn.modules.module.Module" in base_classes:
|
||||||
|
signature = inspect.signature(model_class.forward) # PyTorch models
|
||||||
else:
|
else:
|
||||||
signature = inspect.signature(model_class.forward)
|
signature = inspect.signature(model_class.__call__) # Flax models
|
||||||
|
|
||||||
if "QuestionAnswering" in model_name:
|
if "QuestionAnswering" in model_name:
|
||||||
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
|
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
|
||||||
else:
|
else:
|
||||||
|
@ -21,10 +21,20 @@ import transformers
|
|||||||
|
|
||||||
# Try to import everything from transformers to ensure every object can be loaded.
|
# Try to import everything from transformers to ensure every object can be loaded.
|
||||||
from transformers import * # noqa F406
|
from transformers import * # noqa F406
|
||||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, require_flax, require_tf, require_torch
|
||||||
from transformers.utils import ContextManagers, find_labels, is_flax_available, is_tf_available, is_torch_available
|
from transformers.utils import ContextManagers, find_labels, is_flax_available, is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
from transformers import FlaxBertForPreTraining, FlaxBertForQuestionAnswering, FlaxBertForSequenceClassification
|
||||||
|
|
||||||
|
|
||||||
MODEL_ID = DUMMY_UNKNOWN_IDENTIFIER
|
MODEL_ID = DUMMY_UNKNOWN_IDENTIFIER
|
||||||
# An actual model hosted on huggingface.co
|
# An actual model hosted on huggingface.co
|
||||||
|
|
||||||
@ -85,29 +95,39 @@ class GenericUtilTests(unittest.TestCase):
|
|||||||
# The output should be wrapped with an English and French welcome and goodbye
|
# The output should be wrapped with an English and French welcome and goodbye
|
||||||
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
|
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
|
||||||
|
|
||||||
def test_find_labels(self):
|
@require_torch
|
||||||
if is_torch_available():
|
def test_find_labels_pt(self):
|
||||||
from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification
|
|
||||||
|
|
||||||
self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
|
self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
|
||||||
self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
|
self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
|
||||||
self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])
|
self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])
|
||||||
|
|
||||||
if is_tf_available():
|
# find_labels works regardless of the class name (it detects the framework through inheritance)
|
||||||
from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification
|
class DummyModel(BertForSequenceClassification):
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.assertEqual(find_labels(DummyModel), ["labels"])
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_find_labels_tf(self):
|
||||||
self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"])
|
self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"])
|
||||||
self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"])
|
self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"])
|
||||||
self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"])
|
self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"])
|
||||||
|
|
||||||
if is_flax_available():
|
# find_labels works regardless of the class name (it detects the framework through inheritance)
|
||||||
# Flax models don't have labels
|
class DummyModel(TFBertForSequenceClassification):
|
||||||
from transformers import (
|
pass
|
||||||
FlaxBertForPreTraining,
|
|
||||||
FlaxBertForQuestionAnswering,
|
|
||||||
FlaxBertForSequenceClassification,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
self.assertEqual(find_labels(DummyModel), ["labels"])
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
def test_find_labels_flax(self):
|
||||||
|
# Flax models don't have labels
|
||||||
self.assertEqual(find_labels(FlaxBertForSequenceClassification), [])
|
self.assertEqual(find_labels(FlaxBertForSequenceClassification), [])
|
||||||
self.assertEqual(find_labels(FlaxBertForPreTraining), [])
|
self.assertEqual(find_labels(FlaxBertForPreTraining), [])
|
||||||
self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])
|
self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])
|
||||||
|
|
||||||
|
# find_labels works regardless of the class name (it detects the framework through inheritance)
|
||||||
|
class DummyModel(FlaxBertForSequenceClassification):
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.assertEqual(find_labels(DummyModel), [])
|
||||||
|
Loading…
Reference in New Issue
Block a user