Allow soft dependencies in the namespace with ImportErrors at use (#7537)

* PoC on RAG

* Format class name/obj name

* Better name in message

* PoC on one TF model

* Add PyTorch and TF dummy objects + script

* Treat scikit-learn

* Bad copy pastes

* Typo
This commit is contained in:
Sylvain Gugger 2020-10-05 09:12:04 -04:00 committed by GitHub
parent 1a00f46c74
commit 28d183c90c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 3476 additions and 74 deletions

View File

@ -248,6 +248,7 @@ jobs:
- run: isort --check-only examples templates tests src utils
- run: flake8 examples templates tests src utils
- run: python utils/check_copies.py
- run: python utils/check_dummies.py
- run: python utils/check_repo.py
check_repository_consistency:
working_directory: ~/transformers

View File

@ -23,6 +23,7 @@ modified_only_fixup:
extra_quality_checks:
python utils/check_copies.py
python utils/check_dummies.py
python utils/check_repo.py
# this target runs checks on all files
@ -46,6 +47,7 @@ fixup: modified_only_fixup extra_quality_checks
fix-copies:
python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite
# Run tests for the library

View File

@ -73,12 +73,13 @@ from .data import (
SquadFeatures,
SquadV1Processor,
SquadV2Processor,
glue_compute_metrics,
glue_convert_examples_to_features,
glue_output_modes,
glue_processors,
glue_tasks_num_labels,
is_sklearn_available,
squad_convert_examples_to_features,
xnli_compute_metrics,
xnli_output_modes,
xnli_processors,
xnli_tasks_num_labels,
@ -102,6 +103,7 @@ from .file_utils import (
is_faiss_available,
is_psutil_available,
is_py3nvml_available,
is_sklearn_available,
is_tf_available,
is_torch_available,
is_torch_tpu_available,
@ -212,10 +214,6 @@ from .utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_sklearn_available():
from .data import glue_compute_metrics, xnli_compute_metrics
# Modeling
if is_torch_available():
# Benchmarks
@ -531,6 +529,8 @@ if is_torch_available():
# Trainer
from .trainer import EvalPrediction, Trainer, set_seed, torch_distributed_zero_first
else:
from .utils.dummy_pt_objects import *
# TensorFlow
if is_tf_available():
@ -753,6 +753,11 @@ if is_tf_available():
# Trainer
from .trainer_tf import TFTrainer
else:
# Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them.
from .utils.dummy_tf_objects import *
if not is_tf_available() and not is_torch_available():
logger.warning(

View File

@ -2,7 +2,7 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from .metrics import is_sklearn_available
from .metrics import glue_compute_metrics, xnli_compute_metrics
from .processors import (
DataProcessor,
InputExample,
@ -21,7 +21,3 @@ from .processors import (
xnli_processors,
xnli_tasks_num_labels,
)
if is_sklearn_available():
from .metrics import glue_compute_metrics, xnli_compute_metrics

View File

@ -14,77 +14,75 @@
# See the License for the specific language governing permissions and
# limitations under the License.
try:
from ...file_utils import is_sklearn_available, requires_sklearn
if is_sklearn_available():
from sklearn.metrics import f1_score, matthews_corrcoef
from scipy.stats import pearsonr, spearmanr
_has_sklearn = True
except (AttributeError, ImportError):
_has_sklearn = False
def simple_accuracy(preds, labels):
requires_sklearn(simple_accuracy)
return (preds == labels).mean()
def is_sklearn_available():
return _has_sklearn
def acc_and_f1(preds, labels):
requires_sklearn(acc_and_f1)
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds)
return {
"acc": acc,
"f1": f1,
"acc_and_f1": (acc + f1) / 2,
}
if _has_sklearn:
def pearson_and_spearman(preds, labels):
requires_sklearn(pearson_and_spearman)
pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0]
return {
"pearson": pearson_corr,
"spearmanr": spearman_corr,
"corr": (pearson_corr + spearman_corr) / 2,
}
def simple_accuracy(preds, labels):
return (preds == labels).mean()
def acc_and_f1(preds, labels):
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds)
return {
"acc": acc,
"f1": f1,
"acc_and_f1": (acc + f1) / 2,
}
def glue_compute_metrics(task_name, preds, labels):
requires_sklearn(glue_compute_metrics)
assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
if task_name == "cola":
return {"mcc": matthews_corrcoef(labels, preds)}
elif task_name == "sst-2":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mrpc":
return acc_and_f1(preds, labels)
elif task_name == "sts-b":
return pearson_and_spearman(preds, labels)
elif task_name == "qqp":
return acc_and_f1(preds, labels)
elif task_name == "mnli":
return {"mnli/acc": simple_accuracy(preds, labels)}
elif task_name == "mnli-mm":
return {"mnli-mm/acc": simple_accuracy(preds, labels)}
elif task_name == "qnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "rte":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "wnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "hans":
return {"acc": simple_accuracy(preds, labels)}
else:
raise KeyError(task_name)
def pearson_and_spearman(preds, labels):
pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0]
return {
"pearson": pearson_corr,
"spearmanr": spearman_corr,
"corr": (pearson_corr + spearman_corr) / 2,
}
def glue_compute_metrics(task_name, preds, labels):
assert len(preds) == len(
labels
), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
if task_name == "cola":
return {"mcc": matthews_corrcoef(labels, preds)}
elif task_name == "sst-2":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mrpc":
return acc_and_f1(preds, labels)
elif task_name == "sts-b":
return pearson_and_spearman(preds, labels)
elif task_name == "qqp":
return acc_and_f1(preds, labels)
elif task_name == "mnli":
return {"mnli/acc": simple_accuracy(preds, labels)}
elif task_name == "mnli-mm":
return {"mnli-mm/acc": simple_accuracy(preds, labels)}
elif task_name == "qnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "rte":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "wnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "hans":
return {"acc": simple_accuracy(preds, labels)}
else:
raise KeyError(task_name)
def xnli_compute_metrics(task_name, preds, labels):
assert len(preds) == len(
labels
), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
if task_name == "xnli":
return {"acc": simple_accuracy(preds, labels)}
else:
raise KeyError(task_name)
def xnli_compute_metrics(task_name, preds, labels):
requires_sklearn(xnli_compute_metrics)
assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
if task_name == "xnli":
return {"acc": simple_accuracy(preds, labels)}
else:
raise KeyError(task_name)

View File

@ -133,6 +133,15 @@ try:
except ImportError:
_faiss_available = False
try:
import sklearn.metrics # noqa: F401
import scipy.stats # noqa: F401
_has_sklearn = True
except (AttributeError, ImportError):
_has_sklearn = False
default_cache_path = os.path.join(torch_cache_home, "transformers")
@ -194,6 +203,88 @@ def is_faiss_available():
return _faiss_available
def is_sklearn_available():
return _has_sklearn
DATASETS_IMPORT_ERROR = """
{0} requires the 🤗 Datasets library but it was not found in your enviromnent. You can install it with:
```
pip install datasets
```
In a notebook or a colab, you can install it by executing a cell with
```
!pip install datasets
```
then restarting your kernel.
Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current
working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or
that python file if that's the case.
"""
FAISS_IMPORT_ERROR = """
{0} requires the faiss library but it was not found in your enviromnent. Checkout the instructions on the
installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones
that match your enviromnent.
"""
PYTORCH_IMPORT_ERROR = """
{0} requires the PyTorch library but it was not found in your enviromnent. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your enviromnent.
"""
SKLEARN_IMPORT_ERROR = """
{0} requires the scikit-learn library but it was not found in your enviromnent. You can install it with:
```
pip install -U scikit-learn
```
In a notebook or a colab, you can install it by executing a cell with
```
!pip install -U scikit-learn
```
"""
TENSORFLOW_IMPORT_ERROR = """
{0} requires the TensorFlow library but it was not found in your enviromnent. Checkout the instructions on the
installation page: https://www.tensorflow.org/install and follow the ones that match your enviromnent.
"""
def requires_datasets(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_datasets_available():
raise ImportError(DATASETS_IMPORT_ERROR.format(name))
def requires_faiss(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_faiss_available():
raise ImportError(FAISS_IMPORT_ERROR.format(name))
def requires_pytorch(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_torch_available():
raise ImportError(PYTORCH_IMPORT_ERROR.format(name))
def requires_sklearn(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_sklearn_available():
raise ImportError(SKLEARN_IMPORT_ERROR.format(name))
def requires_tf(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_tf_available():
raise ImportError(TENSORFLOW_IMPORT_ERROR.format(name))
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")

View File

@ -22,15 +22,23 @@ from typing import Iterable, List, Optional, Tuple
import numpy as np
from .configuration_rag import RagConfig
from .file_utils import cached_path, is_datasets_available, is_faiss_available, is_remote_url
from .file_utils import (
cached_path,
is_datasets_available,
is_faiss_available,
is_remote_url,
requires_datasets,
requires_faiss,
)
from .tokenization_rag import RagTokenizer
from .tokenization_utils_base import BatchEncoding
from .utils import logging
if is_datasets_available() and is_faiss_available():
if is_datasets_available():
from datasets import load_dataset
if is_faiss_available():
import faiss
@ -273,6 +281,8 @@ class RagRetriever:
_init_retrieval = True
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
requires_datasets(self)
requires_faiss(self)
super().__init__()
self.index = (
LegacyIndex(
@ -301,6 +311,8 @@ class RagRetriever:
@classmethod
def from_pretrained(cls, retriever_name_or_path, **kwargs):
requires_datasets(cls)
requires_faiss(cls)
config = RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
question_encoder_tokenizer = rag_tokenizer.question_encoder

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

199
utils/check_dummies.py Normal file
View File

@ -0,0 +1,199 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import re
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_dummies.py
PATH_TO_TRANSFORMERS = "src/transformers"
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
DUMMY_CONSTANT = """
{0} = None
"""
DUMMY_PT_PRETRAINED_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
"""
DUMMY_PT_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
"""
DUMMY_PT_FUNCTION = """
def {0}(*args, **kwargs):
requires_pytorch({0})
"""
DUMMY_TF_PRETRAINED_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
"""
DUMMY_TF_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_tf(self)
"""
DUMMY_TF_FUNCTION = """
def {0}(*args, **kwargs):
requires_tf({0})
"""
def read_init():
""" Read the init and exctracts PyTorch and TensorFlow objects. """
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8") as f:
lines = f.readlines()
line_index = 0
# Find where the PyTorch imports begin
pt_objects = []
while not lines[line_index].startswith("if is_torch_available():"):
line_index += 1
line_index += 1
# Until we unindent, add PyTorch objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
line = lines[line_index]
search = _re_single_line_import.search(line)
if search is not None:
pt_objects += search.groups()[0].split(", ")
elif line.startswith(" "):
pt_objects.append(line[8:-2])
line_index += 1
# Find where the TF imports begin
tf_objects = []
while not lines[line_index].startswith("if is_tf_available():"):
line_index += 1
line_index += 1
# Until we unindent, add PyTorch objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
line = lines[line_index]
search = _re_single_line_import.search(line)
if search is not None:
tf_objects += search.groups()[0].split(", ")
elif line.startswith(" "):
tf_objects.append(line[8:-2])
line_index += 1
return pt_objects, tf_objects
def create_dummy_object(name, is_pytorch=True):
""" Create the code for the dummy object corresponding to `name`."""
_pretrained = [
"Config" "ForCausalLM",
"ForConditionalGeneration",
"ForMaskedLM",
"ForMultipleChoice",
"ForQuestionAnswering",
"ForSequenceClassification",
"ForTokenClassification",
"Model",
"Tokenizer",
]
if name.isupper():
return DUMMY_CONSTANT.format(name)
elif name.islower():
return (DUMMY_PT_FUNCTION if is_pytorch else DUMMY_TF_FUNCTION).format(name)
else:
is_pretrained = False
for part in _pretrained:
if part in name:
is_pretrained = True
break
if is_pretrained:
template = DUMMY_PT_PRETRAINED_CLASS if is_pytorch else DUMMY_TF_PRETRAINED_CLASS
else:
template = DUMMY_PT_CLASS if is_pytorch else DUMMY_TF_CLASS
return template.format(name)
def create_dummy_files():
""" Create the content of the dummy files. """
pt_objects, tf_objects = read_init()
pt_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
pt_dummies += "from ..file_utils import requires_pytorch\n\n"
pt_dummies += "\n".join([create_dummy_object(o) for o in pt_objects])
tf_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
tf_dummies += "from ..file_utils import requires_tf\n\n"
tf_dummies += "\n".join([create_dummy_object(o, False) for o in tf_objects])
return pt_dummies, tf_dummies
def check_dummies(overwrite=False):
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """
pt_dummies, tf_dummies = create_dummy_files()
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
pt_file = os.path.join(path, "dummy_pt_objects.py")
tf_file = os.path.join(path, "dummy_tf_objects.py")
with open(pt_file, "r", encoding="utf-8") as f:
actual_pt_dummies = f.read()
with open(tf_file, "r", encoding="utf-8") as f:
actual_tf_dummies = f.read()
if pt_dummies != actual_pt_dummies:
if overwrite:
print("Updating transformers.utils.dummy_pt_objects.py as the main __init__ has new objects.")
with open(pt_file, "w", encoding="utf-8") as f:
f.write(pt_dummies)
else:
raise ValueError(
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
"Run `make fix-copies` to fix this.",
)
if tf_dummies != actual_tf_dummies:
if overwrite:
print("Updating transformers.utils.dummy_tf_objects.py as the main __init__ has new objects.")
with open(tf_file, "w", encoding="utf-8") as f:
f.write(tf_dummies)
else:
raise ValueError(
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
"Run `make fix-copies` to fix this.",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
check_dummies(args.fix_and_overwrite)