mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Allow soft dependencies in the namespace with ImportErrors at use (#7537)
* PoC on RAG * Format class name/obj name * Better name in message * PoC on one TF model * Add PyTorch and TF dummy objects + script * Treat scikit-learn * Bad copy pastes * Typo
This commit is contained in:
parent
1a00f46c74
commit
28d183c90c
@ -248,6 +248,7 @@ jobs:
|
|||||||
- run: isort --check-only examples templates tests src utils
|
- run: isort --check-only examples templates tests src utils
|
||||||
- run: flake8 examples templates tests src utils
|
- run: flake8 examples templates tests src utils
|
||||||
- run: python utils/check_copies.py
|
- run: python utils/check_copies.py
|
||||||
|
- run: python utils/check_dummies.py
|
||||||
- run: python utils/check_repo.py
|
- run: python utils/check_repo.py
|
||||||
check_repository_consistency:
|
check_repository_consistency:
|
||||||
working_directory: ~/transformers
|
working_directory: ~/transformers
|
||||||
|
2
Makefile
2
Makefile
@ -23,6 +23,7 @@ modified_only_fixup:
|
|||||||
|
|
||||||
extra_quality_checks:
|
extra_quality_checks:
|
||||||
python utils/check_copies.py
|
python utils/check_copies.py
|
||||||
|
python utils/check_dummies.py
|
||||||
python utils/check_repo.py
|
python utils/check_repo.py
|
||||||
|
|
||||||
# this target runs checks on all files
|
# this target runs checks on all files
|
||||||
@ -46,6 +47,7 @@ fixup: modified_only_fixup extra_quality_checks
|
|||||||
|
|
||||||
fix-copies:
|
fix-copies:
|
||||||
python utils/check_copies.py --fix_and_overwrite
|
python utils/check_copies.py --fix_and_overwrite
|
||||||
|
python utils/check_dummies.py --fix_and_overwrite
|
||||||
|
|
||||||
# Run tests for the library
|
# Run tests for the library
|
||||||
|
|
||||||
|
@ -73,12 +73,13 @@ from .data import (
|
|||||||
SquadFeatures,
|
SquadFeatures,
|
||||||
SquadV1Processor,
|
SquadV1Processor,
|
||||||
SquadV2Processor,
|
SquadV2Processor,
|
||||||
|
glue_compute_metrics,
|
||||||
glue_convert_examples_to_features,
|
glue_convert_examples_to_features,
|
||||||
glue_output_modes,
|
glue_output_modes,
|
||||||
glue_processors,
|
glue_processors,
|
||||||
glue_tasks_num_labels,
|
glue_tasks_num_labels,
|
||||||
is_sklearn_available,
|
|
||||||
squad_convert_examples_to_features,
|
squad_convert_examples_to_features,
|
||||||
|
xnli_compute_metrics,
|
||||||
xnli_output_modes,
|
xnli_output_modes,
|
||||||
xnli_processors,
|
xnli_processors,
|
||||||
xnli_tasks_num_labels,
|
xnli_tasks_num_labels,
|
||||||
@ -102,6 +103,7 @@ from .file_utils import (
|
|||||||
is_faiss_available,
|
is_faiss_available,
|
||||||
is_psutil_available,
|
is_psutil_available,
|
||||||
is_py3nvml_available,
|
is_py3nvml_available,
|
||||||
|
is_sklearn_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
@ -212,10 +214,6 @@ from .utils import logging
|
|||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
if is_sklearn_available():
|
|
||||||
from .data import glue_compute_metrics, xnli_compute_metrics
|
|
||||||
|
|
||||||
|
|
||||||
# Modeling
|
# Modeling
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
# Benchmarks
|
# Benchmarks
|
||||||
@ -531,6 +529,8 @@ if is_torch_available():
|
|||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
from .trainer import EvalPrediction, Trainer, set_seed, torch_distributed_zero_first
|
from .trainer import EvalPrediction, Trainer, set_seed, torch_distributed_zero_first
|
||||||
|
else:
|
||||||
|
from .utils.dummy_pt_objects import *
|
||||||
|
|
||||||
# TensorFlow
|
# TensorFlow
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@ -753,6 +753,11 @@ if is_tf_available():
|
|||||||
# Trainer
|
# Trainer
|
||||||
from .trainer_tf import TFTrainer
|
from .trainer_tf import TFTrainer
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Import the same objects as dummies to get them in the namespace.
|
||||||
|
# They will raise an import error if the user tries to instantiate / use them.
|
||||||
|
from .utils.dummy_tf_objects import *
|
||||||
|
|
||||||
|
|
||||||
if not is_tf_available() and not is_torch_available():
|
if not is_tf_available() and not is_torch_available():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||||
# module, but to preserve other warnings. So, don't check this module at all.
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
from .metrics import is_sklearn_available
|
from .metrics import glue_compute_metrics, xnli_compute_metrics
|
||||||
from .processors import (
|
from .processors import (
|
||||||
DataProcessor,
|
DataProcessor,
|
||||||
InputExample,
|
InputExample,
|
||||||
@ -21,7 +21,3 @@ from .processors import (
|
|||||||
xnli_processors,
|
xnli_processors,
|
||||||
xnli_tasks_num_labels,
|
xnli_tasks_num_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_sklearn_available():
|
|
||||||
from .metrics import glue_compute_metrics, xnli_compute_metrics
|
|
||||||
|
@ -14,26 +14,22 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
try:
|
from ...file_utils import is_sklearn_available, requires_sklearn
|
||||||
|
|
||||||
|
|
||||||
|
if is_sklearn_available():
|
||||||
from sklearn.metrics import f1_score, matthews_corrcoef
|
from sklearn.metrics import f1_score, matthews_corrcoef
|
||||||
|
|
||||||
from scipy.stats import pearsonr, spearmanr
|
from scipy.stats import pearsonr, spearmanr
|
||||||
|
|
||||||
_has_sklearn = True
|
|
||||||
except (AttributeError, ImportError):
|
|
||||||
_has_sklearn = False
|
|
||||||
|
|
||||||
|
def simple_accuracy(preds, labels):
|
||||||
def is_sklearn_available():
|
requires_sklearn(simple_accuracy)
|
||||||
return _has_sklearn
|
|
||||||
|
|
||||||
|
|
||||||
if _has_sklearn:
|
|
||||||
|
|
||||||
def simple_accuracy(preds, labels):
|
|
||||||
return (preds == labels).mean()
|
return (preds == labels).mean()
|
||||||
|
|
||||||
def acc_and_f1(preds, labels):
|
|
||||||
|
def acc_and_f1(preds, labels):
|
||||||
|
requires_sklearn(acc_and_f1)
|
||||||
acc = simple_accuracy(preds, labels)
|
acc = simple_accuracy(preds, labels)
|
||||||
f1 = f1_score(y_true=labels, y_pred=preds)
|
f1 = f1_score(y_true=labels, y_pred=preds)
|
||||||
return {
|
return {
|
||||||
@ -42,7 +38,9 @@ if _has_sklearn:
|
|||||||
"acc_and_f1": (acc + f1) / 2,
|
"acc_and_f1": (acc + f1) / 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
def pearson_and_spearman(preds, labels):
|
|
||||||
|
def pearson_and_spearman(preds, labels):
|
||||||
|
requires_sklearn(pearson_and_spearman)
|
||||||
pearson_corr = pearsonr(preds, labels)[0]
|
pearson_corr = pearsonr(preds, labels)[0]
|
||||||
spearman_corr = spearmanr(preds, labels)[0]
|
spearman_corr = spearmanr(preds, labels)[0]
|
||||||
return {
|
return {
|
||||||
@ -51,10 +49,10 @@ if _has_sklearn:
|
|||||||
"corr": (pearson_corr + spearman_corr) / 2,
|
"corr": (pearson_corr + spearman_corr) / 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
def glue_compute_metrics(task_name, preds, labels):
|
|
||||||
assert len(preds) == len(
|
def glue_compute_metrics(task_name, preds, labels):
|
||||||
labels
|
requires_sklearn(glue_compute_metrics)
|
||||||
), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
|
assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
|
||||||
if task_name == "cola":
|
if task_name == "cola":
|
||||||
return {"mcc": matthews_corrcoef(labels, preds)}
|
return {"mcc": matthews_corrcoef(labels, preds)}
|
||||||
elif task_name == "sst-2":
|
elif task_name == "sst-2":
|
||||||
@ -80,10 +78,10 @@ if _has_sklearn:
|
|||||||
else:
|
else:
|
||||||
raise KeyError(task_name)
|
raise KeyError(task_name)
|
||||||
|
|
||||||
def xnli_compute_metrics(task_name, preds, labels):
|
|
||||||
assert len(preds) == len(
|
def xnli_compute_metrics(task_name, preds, labels):
|
||||||
labels
|
requires_sklearn(xnli_compute_metrics)
|
||||||
), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
|
assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
|
||||||
if task_name == "xnli":
|
if task_name == "xnli":
|
||||||
return {"acc": simple_accuracy(preds, labels)}
|
return {"acc": simple_accuracy(preds, labels)}
|
||||||
else:
|
else:
|
||||||
|
@ -133,6 +133,15 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
_faiss_available = False
|
_faiss_available = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sklearn.metrics # noqa: F401
|
||||||
|
|
||||||
|
import scipy.stats # noqa: F401
|
||||||
|
|
||||||
|
_has_sklearn = True
|
||||||
|
except (AttributeError, ImportError):
|
||||||
|
_has_sklearn = False
|
||||||
|
|
||||||
|
|
||||||
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||||
|
|
||||||
@ -194,6 +203,88 @@ def is_faiss_available():
|
|||||||
return _faiss_available
|
return _faiss_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_sklearn_available():
|
||||||
|
return _has_sklearn
|
||||||
|
|
||||||
|
|
||||||
|
DATASETS_IMPORT_ERROR = """
|
||||||
|
{0} requires the 🤗 Datasets library but it was not found in your enviromnent. You can install it with:
|
||||||
|
```
|
||||||
|
pip install datasets
|
||||||
|
```
|
||||||
|
In a notebook or a colab, you can install it by executing a cell with
|
||||||
|
```
|
||||||
|
!pip install datasets
|
||||||
|
```
|
||||||
|
then restarting your kernel.
|
||||||
|
|
||||||
|
Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current
|
||||||
|
working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or
|
||||||
|
that python file if that's the case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
FAISS_IMPORT_ERROR = """
|
||||||
|
{0} requires the faiss library but it was not found in your enviromnent. Checkout the instructions on the
|
||||||
|
installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones
|
||||||
|
that match your enviromnent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
PYTORCH_IMPORT_ERROR = """
|
||||||
|
{0} requires the PyTorch library but it was not found in your enviromnent. Checkout the instructions on the
|
||||||
|
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your enviromnent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
SKLEARN_IMPORT_ERROR = """
|
||||||
|
{0} requires the scikit-learn library but it was not found in your enviromnent. You can install it with:
|
||||||
|
```
|
||||||
|
pip install -U scikit-learn
|
||||||
|
```
|
||||||
|
In a notebook or a colab, you can install it by executing a cell with
|
||||||
|
```
|
||||||
|
!pip install -U scikit-learn
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
TENSORFLOW_IMPORT_ERROR = """
|
||||||
|
{0} requires the TensorFlow library but it was not found in your enviromnent. Checkout the instructions on the
|
||||||
|
installation page: https://www.tensorflow.org/install and follow the ones that match your enviromnent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def requires_datasets(obj):
|
||||||
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||||
|
if not is_datasets_available():
|
||||||
|
raise ImportError(DATASETS_IMPORT_ERROR.format(name))
|
||||||
|
|
||||||
|
|
||||||
|
def requires_faiss(obj):
|
||||||
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||||
|
if not is_faiss_available():
|
||||||
|
raise ImportError(FAISS_IMPORT_ERROR.format(name))
|
||||||
|
|
||||||
|
|
||||||
|
def requires_pytorch(obj):
|
||||||
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ImportError(PYTORCH_IMPORT_ERROR.format(name))
|
||||||
|
|
||||||
|
|
||||||
|
def requires_sklearn(obj):
|
||||||
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||||
|
if not is_sklearn_available():
|
||||||
|
raise ImportError(SKLEARN_IMPORT_ERROR.format(name))
|
||||||
|
|
||||||
|
|
||||||
|
def requires_tf(obj):
|
||||||
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||||
|
if not is_tf_available():
|
||||||
|
raise ImportError(TENSORFLOW_IMPORT_ERROR.format(name))
|
||||||
|
|
||||||
|
|
||||||
def add_start_docstrings(*docstr):
|
def add_start_docstrings(*docstr):
|
||||||
def docstring_decorator(fn):
|
def docstring_decorator(fn):
|
||||||
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
||||||
|
@ -22,15 +22,23 @@ from typing import Iterable, List, Optional, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .configuration_rag import RagConfig
|
from .configuration_rag import RagConfig
|
||||||
from .file_utils import cached_path, is_datasets_available, is_faiss_available, is_remote_url
|
from .file_utils import (
|
||||||
|
cached_path,
|
||||||
|
is_datasets_available,
|
||||||
|
is_faiss_available,
|
||||||
|
is_remote_url,
|
||||||
|
requires_datasets,
|
||||||
|
requires_faiss,
|
||||||
|
)
|
||||||
from .tokenization_rag import RagTokenizer
|
from .tokenization_rag import RagTokenizer
|
||||||
from .tokenization_utils_base import BatchEncoding
|
from .tokenization_utils_base import BatchEncoding
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
if is_datasets_available() and is_faiss_available():
|
if is_datasets_available():
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
if is_faiss_available():
|
||||||
import faiss
|
import faiss
|
||||||
|
|
||||||
|
|
||||||
@ -273,6 +281,8 @@ class RagRetriever:
|
|||||||
_init_retrieval = True
|
_init_retrieval = True
|
||||||
|
|
||||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
|
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
|
||||||
|
requires_datasets(self)
|
||||||
|
requires_faiss(self)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.index = (
|
self.index = (
|
||||||
LegacyIndex(
|
LegacyIndex(
|
||||||
@ -301,6 +311,8 @@ class RagRetriever:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, retriever_name_or_path, **kwargs):
|
def from_pretrained(cls, retriever_name_or_path, **kwargs):
|
||||||
|
requires_datasets(cls)
|
||||||
|
requires_faiss(cls)
|
||||||
config = RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
|
config = RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
|
||||||
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
|
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
|
||||||
question_encoder_tokenizer = rag_tokenizer.question_encoder
|
question_encoder_tokenizer = rag_tokenizer.question_encoder
|
||||||
|
1807
src/transformers/utils/dummy_pt_objects.py
Normal file
1807
src/transformers/utils/dummy_pt_objects.py
Normal file
File diff suppressed because it is too large
Load Diff
1291
src/transformers/utils/dummy_tf_objects.py
Normal file
1291
src/transformers/utils/dummy_tf_objects.py
Normal file
File diff suppressed because it is too large
Load Diff
199
utils/check_dummies.py
Normal file
199
utils/check_dummies.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||||
|
# python utils/check_dummies.py
|
||||||
|
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||||
|
|
||||||
|
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||||
|
|
||||||
|
DUMMY_CONSTANT = """
|
||||||
|
{0} = None
|
||||||
|
"""
|
||||||
|
|
||||||
|
DUMMY_PT_PRETRAINED_CLASS = """
|
||||||
|
class {0}:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
"""
|
||||||
|
|
||||||
|
DUMMY_PT_CLASS = """
|
||||||
|
class {0}:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
"""
|
||||||
|
|
||||||
|
DUMMY_PT_FUNCTION = """
|
||||||
|
def {0}(*args, **kwargs):
|
||||||
|
requires_pytorch({0})
|
||||||
|
"""
|
||||||
|
|
||||||
|
DUMMY_TF_PRETRAINED_CLASS = """
|
||||||
|
class {0}:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
"""
|
||||||
|
|
||||||
|
DUMMY_TF_CLASS = """
|
||||||
|
class {0}:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
"""
|
||||||
|
|
||||||
|
DUMMY_TF_FUNCTION = """
|
||||||
|
def {0}(*args, **kwargs):
|
||||||
|
requires_tf({0})
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def read_init():
|
||||||
|
""" Read the init and exctracts PyTorch and TensorFlow objects. """
|
||||||
|
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
line_index = 0
|
||||||
|
# Find where the PyTorch imports begin
|
||||||
|
pt_objects = []
|
||||||
|
while not lines[line_index].startswith("if is_torch_available():"):
|
||||||
|
line_index += 1
|
||||||
|
line_index += 1
|
||||||
|
|
||||||
|
# Until we unindent, add PyTorch objects to the list
|
||||||
|
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||||
|
line = lines[line_index]
|
||||||
|
search = _re_single_line_import.search(line)
|
||||||
|
if search is not None:
|
||||||
|
pt_objects += search.groups()[0].split(", ")
|
||||||
|
elif line.startswith(" "):
|
||||||
|
pt_objects.append(line[8:-2])
|
||||||
|
line_index += 1
|
||||||
|
|
||||||
|
# Find where the TF imports begin
|
||||||
|
tf_objects = []
|
||||||
|
while not lines[line_index].startswith("if is_tf_available():"):
|
||||||
|
line_index += 1
|
||||||
|
line_index += 1
|
||||||
|
|
||||||
|
# Until we unindent, add PyTorch objects to the list
|
||||||
|
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||||
|
line = lines[line_index]
|
||||||
|
search = _re_single_line_import.search(line)
|
||||||
|
if search is not None:
|
||||||
|
tf_objects += search.groups()[0].split(", ")
|
||||||
|
elif line.startswith(" "):
|
||||||
|
tf_objects.append(line[8:-2])
|
||||||
|
line_index += 1
|
||||||
|
return pt_objects, tf_objects
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_object(name, is_pytorch=True):
|
||||||
|
""" Create the code for the dummy object corresponding to `name`."""
|
||||||
|
_pretrained = [
|
||||||
|
"Config" "ForCausalLM",
|
||||||
|
"ForConditionalGeneration",
|
||||||
|
"ForMaskedLM",
|
||||||
|
"ForMultipleChoice",
|
||||||
|
"ForQuestionAnswering",
|
||||||
|
"ForSequenceClassification",
|
||||||
|
"ForTokenClassification",
|
||||||
|
"Model",
|
||||||
|
"Tokenizer",
|
||||||
|
]
|
||||||
|
if name.isupper():
|
||||||
|
return DUMMY_CONSTANT.format(name)
|
||||||
|
elif name.islower():
|
||||||
|
return (DUMMY_PT_FUNCTION if is_pytorch else DUMMY_TF_FUNCTION).format(name)
|
||||||
|
else:
|
||||||
|
is_pretrained = False
|
||||||
|
for part in _pretrained:
|
||||||
|
if part in name:
|
||||||
|
is_pretrained = True
|
||||||
|
break
|
||||||
|
if is_pretrained:
|
||||||
|
template = DUMMY_PT_PRETRAINED_CLASS if is_pytorch else DUMMY_TF_PRETRAINED_CLASS
|
||||||
|
else:
|
||||||
|
template = DUMMY_PT_CLASS if is_pytorch else DUMMY_TF_CLASS
|
||||||
|
return template.format(name)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_files():
|
||||||
|
""" Create the content of the dummy files. """
|
||||||
|
pt_objects, tf_objects = read_init()
|
||||||
|
|
||||||
|
pt_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||||
|
pt_dummies += "from ..file_utils import requires_pytorch\n\n"
|
||||||
|
pt_dummies += "\n".join([create_dummy_object(o) for o in pt_objects])
|
||||||
|
|
||||||
|
tf_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||||
|
tf_dummies += "from ..file_utils import requires_tf\n\n"
|
||||||
|
tf_dummies += "\n".join([create_dummy_object(o, False) for o in tf_objects])
|
||||||
|
|
||||||
|
return pt_dummies, tf_dummies
|
||||||
|
|
||||||
|
|
||||||
|
def check_dummies(overwrite=False):
|
||||||
|
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """
|
||||||
|
pt_dummies, tf_dummies = create_dummy_files()
|
||||||
|
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
|
||||||
|
pt_file = os.path.join(path, "dummy_pt_objects.py")
|
||||||
|
tf_file = os.path.join(path, "dummy_tf_objects.py")
|
||||||
|
|
||||||
|
with open(pt_file, "r", encoding="utf-8") as f:
|
||||||
|
actual_pt_dummies = f.read()
|
||||||
|
with open(tf_file, "r", encoding="utf-8") as f:
|
||||||
|
actual_tf_dummies = f.read()
|
||||||
|
|
||||||
|
if pt_dummies != actual_pt_dummies:
|
||||||
|
if overwrite:
|
||||||
|
print("Updating transformers.utils.dummy_pt_objects.py as the main __init__ has new objects.")
|
||||||
|
with open(pt_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(pt_dummies)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
|
||||||
|
"Run `make fix-copies` to fix this.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if tf_dummies != actual_tf_dummies:
|
||||||
|
if overwrite:
|
||||||
|
print("Updating transformers.utils.dummy_tf_objects.py as the main __init__ has new objects.")
|
||||||
|
with open(tf_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(tf_dummies)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
|
||||||
|
"Run `make fix-copies` to fix this.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
check_dummies(args.fix_and_overwrite)
|
Loading…
Reference in New Issue
Block a user