From 6d4f8bd02a163ac711bdbec22045f8591ad8aa22 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 20 Oct 2020 07:45:48 -0400 Subject: [PATCH] Add Flax dummy objects (#7918) --- src/transformers/__init__.py | 5 ++ src/transformers/file_utils.py | 12 ++++ src/transformers/utils/dummy_flax_objects.py | 20 ++++++ utils/check_dummies.py | 70 ++++++++++++++++++-- 4 files changed, 102 insertions(+), 5 deletions(-) create mode 100644 src/transformers/utils/dummy_flax_objects.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 599ec72773a..11e6f5cdc0a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -841,6 +841,11 @@ else: if is_flax_available(): from .modeling_flax_bert import FlaxBertModel from .modeling_flax_roberta import FlaxRobertaModel +else: + # Import the same objects as dummies to get them in the namespace. + # They will raise an import error if the user tries to instantiate / use them. + from .utils.dummy_flax_objects import * + if not is_tf_available() and not is_torch_available(): logger.warning( diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index b4ca47077f9..642b6506b1d 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -356,6 +356,12 @@ installation page: https://www.tensorflow.org/install and follow the ones that m """ +FLAX_IMPORT_ERROR = """ +{0} requires the FLAX library but it was not found in your enviromnent. Checkout the instructions on the +installation page: https://github.com/google/flax and follow the ones that match your enviromnent. +""" + + def requires_datasets(obj): name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ if not is_datasets_available(): @@ -386,6 +392,12 @@ def requires_tf(obj): raise ImportError(TENSORFLOW_IMPORT_ERROR.format(name)) +def requires_flax(obj): + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + if not is_flax_available(): + raise ImportError(FLAX_IMPORT_ERROR.format(name)) + + def requires_tokenizers(obj): name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ if not is_tokenizers_available(): diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py new file mode 100644 index 00000000000..77e932652de --- /dev/null +++ b/src/transformers/utils/dummy_flax_objects.py @@ -0,0 +1,20 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..file_utils import requires_flax + + +class FlaxBertModel: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + +class FlaxRobertaModel: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) diff --git a/utils/check_dummies.py b/utils/check_dummies.py index ad1de4fa6ae..81adb416022 100644 --- a/utils/check_dummies.py +++ b/utils/check_dummies.py @@ -72,6 +72,28 @@ def {0}(*args, **kwargs): """ +DUMMY_FLAX_PRETRAINED_CLASS = """ +class {0}: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) +""" + +DUMMY_FLAX_CLASS = """ +class {0}: + def __init__(self, *args, **kwargs): + requires_flax(self) +""" + +DUMMY_FLAX_FUNCTION = """ +def {0}(*args, **kwargs): + requires_flax({0}) +""" + + DUMMY_SENTENCEPIECE_PRETRAINED_CLASS = """ class {0}: def __init__(self, *args, **kwargs): @@ -120,6 +142,7 @@ def {0}(*args, **kwargs): DUMMY_PRETRAINED_CLASS = { "pt": DUMMY_PT_PRETRAINED_CLASS, "tf": DUMMY_TF_PRETRAINED_CLASS, + "flax": DUMMY_FLAX_PRETRAINED_CLASS, "sentencepiece": DUMMY_SENTENCEPIECE_PRETRAINED_CLASS, "tokenizers": DUMMY_TOKENIZERS_PRETRAINED_CLASS, } @@ -127,6 +150,7 @@ DUMMY_PRETRAINED_CLASS = { DUMMY_CLASS = { "pt": DUMMY_PT_CLASS, "tf": DUMMY_TF_CLASS, + "flax": DUMMY_FLAX_CLASS, "sentencepiece": DUMMY_SENTENCEPIECE_CLASS, "tokenizers": DUMMY_TOKENIZERS_CLASS, } @@ -134,6 +158,7 @@ DUMMY_CLASS = { DUMMY_FUNCTION = { "pt": DUMMY_PT_FUNCTION, "tf": DUMMY_TF_FUNCTION, + "flax": DUMMY_FLAX_FUNCTION, "sentencepiece": DUMMY_SENTENCEPIECE_FUNCTION, "tokenizers": DUMMY_TOKENIZERS_FUNCTION, } @@ -208,7 +233,24 @@ def read_init(): elif line.startswith(" "): tf_objects.append(line[8:-2]) line_index += 1 - return sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects + + # Find where the FLAX imports begin + flax_objects = [] + while not lines[line_index].startswith("if is_flax_available():"): + line_index += 1 + line_index += 1 + + # Until we unindent, add PyTorch objects to the list + while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "): + line = lines[line_index] + search = _re_single_line_import.search(line) + if search is not None: + flax_objects += search.groups()[0].split(", ") + elif line.startswith(" "): + flax_objects.append(line[8:-2]) + line_index += 1 + + return sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects def create_dummy_object(name, type="pt"): @@ -224,7 +266,7 @@ def create_dummy_object(name, type="pt"): "Model", "Tokenizer", ] - assert type in ["pt", "tf", "sentencepiece", "tokenizers"] + assert type in ["pt", "tf", "sentencepiece", "tokenizers", "flax"] if name.isupper(): return DUMMY_CONSTANT.format(name) elif name.islower(): @@ -244,7 +286,7 @@ def create_dummy_object(name, type="pt"): def create_dummy_files(): """ Create the content of the dummy files. """ - sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects = read_init() + sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects = read_init() sentencepiece_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" sentencepiece_dummies += "from ..file_utils import requires_sentencepiece\n\n" @@ -262,17 +304,22 @@ def create_dummy_files(): tf_dummies += "from ..file_utils import requires_tf\n\n" tf_dummies += "\n".join([create_dummy_object(o, type="tf") for o in tf_objects]) - return sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies + flax_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" + flax_dummies += "from ..file_utils import requires_flax\n\n" + flax_dummies += "\n".join([create_dummy_object(o, type="flax") for o in flax_objects]) + + return sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies def check_dummies(overwrite=False): """ Check if the dummy files are up to date and maybe `overwrite` with the right content. """ - sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies = create_dummy_files() + sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies = create_dummy_files() path = os.path.join(PATH_TO_TRANSFORMERS, "utils") sentencepiece_file = os.path.join(path, "dummy_sentencepiece_objects.py") tokenizers_file = os.path.join(path, "dummy_tokenizers_objects.py") pt_file = os.path.join(path, "dummy_pt_objects.py") tf_file = os.path.join(path, "dummy_tf_objects.py") + flax_file = os.path.join(path, "dummy_flax_objects.py") with open(sentencepiece_file, "r", encoding="utf-8") as f: actual_sentencepiece_dummies = f.read() @@ -282,6 +329,8 @@ def check_dummies(overwrite=False): actual_pt_dummies = f.read() with open(tf_file, "r", encoding="utf-8") as f: actual_tf_dummies = f.read() + with open(flax_file, "r", encoding="utf-8") as f: + actual_flax_dummies = f.read() if sentencepiece_dummies != actual_sentencepiece_dummies: if overwrite: @@ -327,6 +376,17 @@ def check_dummies(overwrite=False): "Run `make fix-copies` to fix this.", ) + if flax_dummies != actual_flax_dummies: + if overwrite: + print("Updating transformers.utils.dummy_flax_objects.py as the main __init__ has new objects.") + with open(flax_file, "w", encoding="utf-8") as f: + f.write(flax_dummies) + else: + raise ValueError( + "The main __init__ has objects that are not present in transformers.utils.dummy_flax_objects.py.", + "Run `make fix-copies` to fix this.", + ) + if __name__ == "__main__": parser = argparse.ArgumentParser()