Big model table (#8774)

* First draft

* Styling

* With all changes staged

* Update docs/source/index.rst

Co-authored-by: Julien Chaumond <chaumond@gmail.com>

* Styling

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Sylvain Gugger 2020-11-25 12:02:15 -05:00 committed by GitHub
parent 90d5ab3bfe
commit 4821ea5aeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 257 additions and 19 deletions

View File

@ -2,6 +2,15 @@
/* Colab dropdown */
table.center-aligned-table td {
text-align: center;
}
table.center-aligned-table th {
text-align: center;
vertical-align: middle;
}
.colab-dropdown {
position: relative;
display: inline-block;

View File

@ -35,6 +35,8 @@ Choose the right framework for every part of a model's lifetime:
- Move a single model between TF2.0/PyTorch frameworks at will
- Seamlessly pick the right framework for training, evaluation, production
Experimental support for Flax with a few models right now, expected to grow in the coming months.
Contents
-----------------------------------------------------------------------------------------------------------------------
@ -52,8 +54,8 @@ The documentation is organized in five parts:
- **MODELS** for the classes and functions related to each model implemented in the library.
- **INTERNAL HELPERS** for the classes and functions we use internally.
The library currently contains PyTorch and Tensorflow implementations, pre-trained model weights, usage scripts and
conversion utilities for the following models:
The library currently contains PyTorch, Tensorflow and Flax implementations, pretrained model weights, usage scripts
and conversion utilities for the following models:
..
This list is updated automatically from the README with `make fix-copies`. Do not update manually!
@ -166,6 +168,95 @@ conversion utilities for the following models:
34. `Other community models <https://huggingface.co/models>`__, contributed by the `community
<https://huggingface.co/users>`__.
The table below represents the current support in the library for each of those models, whether they have a Python
tokenizer (called "slow"). A "fast" tokenizer backed by the 🤗 Tokenizers library, whether they have support in PyTorch,
TensorFlow and/or Flax.
..
This table is updated automatically from the auto modules with `make fix-copies`. Do not update manually!
.. rst-class:: center-aligned-table
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Model | Tokenizer slow | Tokenizer fast | PyTorch support | TensorFlow support | Flax Support |
+=============================+================+================+=================+====================+==============+
| ALBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BART | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BERT | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Bert Generation | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DeBERTa | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DistilBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Encoder decoder | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| FlauBERT | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| LayoutLM | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Marian | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| MobileBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| OpenAI GPT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Pegasus | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| RAG | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| T5 | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| mBART | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| mT5 | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
.. toctree::
:maxdepth: 2
:caption: Get started

View File

@ -90,6 +90,7 @@ from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .models.auto import (
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONFIG_MAPPING,
MODEL_NAMES_MAPPING,
TOKENIZER_MAPPING,
AutoConfig,
AutoTokenizer,
@ -880,6 +881,7 @@ else:
if is_flax_available():
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
from .models.bert import FlaxBertModel
from .models.roberta import FlaxRobertaModel
else:

View File

@ -2,8 +2,8 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_torch_available
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
from ...file_utils import is_flax_available, is_tf_available, is_torch_available
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
@ -57,3 +57,6 @@ if is_tf_available():
TFAutoModelForTokenClassification,
TFAutoModelWithLMHead,
)
if is_flax_available():
from .modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel

View File

@ -36,7 +36,7 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
for key, value, in pretrained_map.items()
)
MODEL_MAPPING = OrderedDict(
FLAX_MODEL_MAPPING = OrderedDict(
[
(RobertaConfig, FlaxRobertaModel),
(BertConfig, FlaxBertModel),
@ -79,13 +79,13 @@ class FlaxAutoModel(object):
model = FlaxAutoModel.from_config(config)
# E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
for config_class, model_class in MODEL_MAPPING.items():
for config_class, model_class in FLAX_MODEL_MAPPING.items():
if isinstance(config, config_class):
return model_class(config)
raise ValueError(
f"Unrecognized configuration class {config.__class__} "
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}."
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}."
)
@classmethod
@ -173,11 +173,11 @@ class FlaxAutoModel(object):
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
for config_class, model_class in MODEL_MAPPING.items():
for config_class, model_class in FLAX_MODEL_MAPPING.items():
if isinstance(config, config_class):
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} "
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}"
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}"
)

View File

@ -2,6 +2,18 @@
from ..file_utils import requires_flax
FLAX_MODEL_MAPPING = None
class FlaxAutoModel:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxBertModel:
def __init__(self, *args, **kwargs):
requires_flax(self)

View File

@ -15,6 +15,7 @@
import argparse
import glob
import importlib
import os
import re
import tempfile
@ -250,20 +251,21 @@ def convert_to_rst(model_list, max_per_line=None):
return "\n".join(result)
def check_model_list_copy(overwrite=False, max_per_line=119):
""" Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """
_start_prompt = " This list is updated automatically from the README"
_end_prompt = ".. toctree::"
with open(os.path.join(PATH_TO_DOCS, "index.rst"), "r", encoding="utf-8", newline="\n") as f:
def _find_text_in_file(filename, start_prompt, end_prompt):
"""
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
lines.
"""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Find the start of the list.
# Find the start prompt.
start_index = 0
while not lines[start_index].startswith(_start_prompt):
while not lines[start_index].startswith(start_prompt):
start_index += 1
start_index += 1
end_index = start_index
while not lines[end_index].startswith(_end_prompt):
while not lines[end_index].startswith(end_prompt):
end_index += 1
end_index -= 1
@ -272,8 +274,16 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
while len(lines[end_index]) <= 1:
end_index -= 1
end_index += 1
return "".join(lines[start_index:end_index]), start_index, end_index, lines
rst_list = "".join(lines[start_index:end_index])
def check_model_list_copy(overwrite=False, max_per_line=119):
""" Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """
rst_list, start_index, end_index, lines = _find_text_in_file(
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
start_prompt=" This list is updated automatically from the README",
end_prompt="The table below represents the current support",
)
md_list = get_model_list()
converted_list = convert_to_rst(md_list, max_per_line=max_per_line)
@ -283,7 +293,116 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
f.writelines(lines[:start_index] + [converted_list] + lines[end_index:])
else:
raise ValueError(
"The model list in the README changed and the list in `index.rst` has not been updated. Run `make fix-copies` to fix this."
"The model list in the README changed and the list in `index.rst` has not been updated. Run "
"`make fix-copies` to fix this."
)
def _center_text(text, width):
text_length = 2 if text == "" or text == "" else len(text)
left_indent = (width - text_length) // 2
right_indent = width - text_length - left_indent
return " " * left_indent + text + " " * right_indent
def get_model_table_from_auto_modules():
"""Generates an up-to-date model table from the content of the auto modules."""
# This is to make sure the transformers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location(
"transformers",
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
submodule_search_locations=[TRANSFORMERS_PATH],
)
transformers = spec.loader.load_module()
# Dictionary model names to config.
model_name_to_config = {
name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items()
}
# All tokenizer tuples.
tokenizers = {
name: transformers.TOKENIZER_MAPPING[config]
for name, config in model_name_to_config.items()
if config in transformers.TOKENIZER_MAPPING
}
# Model names that a slow/fast tokenizer.
has_slow_tokenizers = [name for name, tok in tokenizers.items() if tok[0] is not None]
has_fast_tokenizers = [name for name, tok in tokenizers.items() if tok[1] is not None]
# Model names that have a PyTorch implementation.
has_pt_model = [name for name, config in model_name_to_config.items() if config in transformers.MODEL_MAPPING]
# Some of the GenerationModel don't have a base model.
has_pt_model.extend(
[
name
for name, config in model_name_to_config.items()
if config in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
]
)
# Special exception for RAG
has_pt_model.append("RAG")
# Model names that have a TensorFlow implementation.
has_tf_model = [name for name, config in model_name_to_config.items() if config in transformers.TF_MODEL_MAPPING]
# Some of the GenerationModel don't have a base model.
has_tf_model.extend(
[
name
for name, config in model_name_to_config.items()
if config in transformers.TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
]
)
# Model names that have a Flax implementation.
has_flax_model = [
name for name, config in model_name_to_config.items() if config in transformers.FLAX_MODEL_MAPPING
]
# Let's build that table!
model_names = list(model_name_to_config.keys())
model_names.sort()
columns = ["Model", "Tokenizer slow", "Tokenizer fast", "PyTorch support", "TensorFlow support", "Flax Support"]
# We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side).
widths = [len(c) + 2 for c in columns]
widths[0] = max([len(name) for name in model_names]) + 2
# Rst table per se
table = ".. rst-class:: center-aligned-table\n\n"
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
table += "|" + "|".join([_center_text(c, w) for c, w in zip(columns, widths)]) + "|\n"
table += "+" + "+".join(["=" * w for w in widths]) + "+\n"
check = {True: "", False: ""}
for name in model_names:
line = [
name,
check[name in has_slow_tokenizers],
check[name in has_fast_tokenizers],
check[name in has_pt_model],
check[name in has_tf_model],
check[name in has_flax_model],
]
table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n"
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
return table
def check_model_table(overwrite=False):
""" Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`. """
current_table, start_index, end_index, lines = _find_text_in_file(
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
start_prompt=" This table is updated automatically from the auto module",
end_prompt=".. toctree::",
)
new_table = get_model_table_from_auto_modules()
if current_table != new_table:
if overwrite:
with open(os.path.join(PATH_TO_DOCS, "index.rst"), "w", encoding="utf-8", newline="\n") as f:
f.writelines(lines[:start_index] + [new_table] + lines[end_index:])
else:
raise ValueError(
"The model table in the `index.rst` has not been updated. Run `make fix-copies` to fix this."
)
@ -293,3 +412,4 @@ if __name__ == "__main__":
args = parser.parse_args()
check_copies(args.fix_and_overwrite)
check_model_table(args.fix_and_overwrite)

View File

@ -126,6 +126,7 @@ def get_model_modules():
"modeling_outputs",
"modeling_retribert",
"modeling_utils",
"modeling_flax_auto",
"modeling_flax_utils",
"modeling_transfo_xl_utilities",
"modeling_tf_auto",