mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Big model table (#8774)
* First draft * Styling * With all changes staged * Update docs/source/index.rst Co-authored-by: Julien Chaumond <chaumond@gmail.com> * Styling Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
parent
90d5ab3bfe
commit
4821ea5aeb
@ -2,6 +2,15 @@
|
||||
|
||||
/* Colab dropdown */
|
||||
|
||||
table.center-aligned-table td {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
table.center-aligned-table th {
|
||||
text-align: center;
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
.colab-dropdown {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
|
@ -35,6 +35,8 @@ Choose the right framework for every part of a model's lifetime:
|
||||
- Move a single model between TF2.0/PyTorch frameworks at will
|
||||
- Seamlessly pick the right framework for training, evaluation, production
|
||||
|
||||
Experimental support for Flax with a few models right now, expected to grow in the coming months.
|
||||
|
||||
Contents
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
@ -52,8 +54,8 @@ The documentation is organized in five parts:
|
||||
- **MODELS** for the classes and functions related to each model implemented in the library.
|
||||
- **INTERNAL HELPERS** for the classes and functions we use internally.
|
||||
|
||||
The library currently contains PyTorch and Tensorflow implementations, pre-trained model weights, usage scripts and
|
||||
conversion utilities for the following models:
|
||||
The library currently contains PyTorch, Tensorflow and Flax implementations, pretrained model weights, usage scripts
|
||||
and conversion utilities for the following models:
|
||||
|
||||
..
|
||||
This list is updated automatically from the README with `make fix-copies`. Do not update manually!
|
||||
@ -166,6 +168,95 @@ conversion utilities for the following models:
|
||||
34. `Other community models <https://huggingface.co/models>`__, contributed by the `community
|
||||
<https://huggingface.co/users>`__.
|
||||
|
||||
|
||||
The table below represents the current support in the library for each of those models, whether they have a Python
|
||||
tokenizer (called "slow"). A "fast" tokenizer backed by the 🤗 Tokenizers library, whether they have support in PyTorch,
|
||||
TensorFlow and/or Flax.
|
||||
|
||||
..
|
||||
This table is updated automatically from the auto modules with `make fix-copies`. Do not update manually!
|
||||
|
||||
.. rst-class:: center-aligned-table
|
||||
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Model | Tokenizer slow | Tokenizer fast | PyTorch support | TensorFlow support | Flax Support |
|
||||
+=============================+================+================+=================+====================+==============+
|
||||
| ALBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| BART | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| BERT | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Bert Generation | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| DeBERTa | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| DistilBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Encoder decoder | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| FlauBERT | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LayoutLM | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Marian | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| MobileBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| OpenAI GPT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Pegasus | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| RAG | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| T5 | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| mBART | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| mT5 | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Get started
|
||||
|
@ -90,6 +90,7 @@ from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||
from .models.auto import (
|
||||
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CONFIG_MAPPING,
|
||||
MODEL_NAMES_MAPPING,
|
||||
TOKENIZER_MAPPING,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
@ -880,6 +881,7 @@ else:
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
||||
from .models.bert import FlaxBertModel
|
||||
from .models.roberta import FlaxRobertaModel
|
||||
else:
|
||||
|
@ -2,8 +2,8 @@
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
from ...file_utils import is_tf_available, is_torch_available
|
||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
|
||||
from ...file_utils import is_flax_available, is_tf_available, is_torch_available
|
||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
|
||||
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
|
||||
|
||||
@ -57,3 +57,6 @@ if is_tf_available():
|
||||
TFAutoModelForTokenClassification,
|
||||
TFAutoModelWithLMHead,
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
||||
|
@ -36,7 +36,7 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
|
||||
MODEL_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_MAPPING = OrderedDict(
|
||||
[
|
||||
(RobertaConfig, FlaxRobertaModel),
|
||||
(BertConfig, FlaxBertModel),
|
||||
@ -79,13 +79,13 @@ class FlaxAutoModel(object):
|
||||
model = FlaxAutoModel.from_config(config)
|
||||
# E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
for config_class, model_class in MODEL_MAPPING.items():
|
||||
for config_class, model_class in FLAX_MODEL_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} "
|
||||
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}."
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -173,11 +173,11 @@ class FlaxAutoModel(object):
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
for config_class, model_class in MODEL_MAPPING.items():
|
||||
for config_class, model_class in FLAX_MODEL_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} "
|
||||
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}"
|
||||
)
|
||||
|
@ -2,6 +2,18 @@
|
||||
from ..file_utils import requires_flax
|
||||
|
||||
|
||||
FLAX_MODEL_MAPPING = None
|
||||
|
||||
|
||||
class FlaxAutoModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxBertModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
@ -250,20 +251,21 @@ def convert_to_rst(model_list, max_per_line=None):
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
""" Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """
|
||||
_start_prompt = " This list is updated automatically from the README"
|
||||
_end_prompt = ".. toctree::"
|
||||
with open(os.path.join(PATH_TO_DOCS, "index.rst"), "r", encoding="utf-8", newline="\n") as f:
|
||||
def _find_text_in_file(filename, start_prompt, end_prompt):
|
||||
"""
|
||||
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
|
||||
lines.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
# Find the start of the list.
|
||||
# Find the start prompt.
|
||||
start_index = 0
|
||||
while not lines[start_index].startswith(_start_prompt):
|
||||
while not lines[start_index].startswith(start_prompt):
|
||||
start_index += 1
|
||||
start_index += 1
|
||||
|
||||
end_index = start_index
|
||||
while not lines[end_index].startswith(_end_prompt):
|
||||
while not lines[end_index].startswith(end_prompt):
|
||||
end_index += 1
|
||||
end_index -= 1
|
||||
|
||||
@ -272,8 +274,16 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
while len(lines[end_index]) <= 1:
|
||||
end_index -= 1
|
||||
end_index += 1
|
||||
return "".join(lines[start_index:end_index]), start_index, end_index, lines
|
||||
|
||||
rst_list = "".join(lines[start_index:end_index])
|
||||
|
||||
def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
""" Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """
|
||||
rst_list, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
|
||||
start_prompt=" This list is updated automatically from the README",
|
||||
end_prompt="The table below represents the current support",
|
||||
)
|
||||
md_list = get_model_list()
|
||||
converted_list = convert_to_rst(md_list, max_per_line=max_per_line)
|
||||
|
||||
@ -283,7 +293,116 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
f.writelines(lines[:start_index] + [converted_list] + lines[end_index:])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model list in the README changed and the list in `index.rst` has not been updated. Run `make fix-copies` to fix this."
|
||||
"The model list in the README changed and the list in `index.rst` has not been updated. Run "
|
||||
"`make fix-copies` to fix this."
|
||||
)
|
||||
|
||||
|
||||
def _center_text(text, width):
|
||||
text_length = 2 if text == "✅" or text == "❌" else len(text)
|
||||
left_indent = (width - text_length) // 2
|
||||
right_indent = width - text_length - left_indent
|
||||
return " " * left_indent + text + " " * right_indent
|
||||
|
||||
|
||||
def get_model_table_from_auto_modules():
|
||||
"""Generates an up-to-date model table from the content of the auto modules."""
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers",
|
||||
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
|
||||
submodule_search_locations=[TRANSFORMERS_PATH],
|
||||
)
|
||||
transformers = spec.loader.load_module()
|
||||
|
||||
# Dictionary model names to config.
|
||||
model_name_to_config = {
|
||||
name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items()
|
||||
}
|
||||
# All tokenizer tuples.
|
||||
tokenizers = {
|
||||
name: transformers.TOKENIZER_MAPPING[config]
|
||||
for name, config in model_name_to_config.items()
|
||||
if config in transformers.TOKENIZER_MAPPING
|
||||
}
|
||||
# Model names that a slow/fast tokenizer.
|
||||
has_slow_tokenizers = [name for name, tok in tokenizers.items() if tok[0] is not None]
|
||||
has_fast_tokenizers = [name for name, tok in tokenizers.items() if tok[1] is not None]
|
||||
|
||||
# Model names that have a PyTorch implementation.
|
||||
has_pt_model = [name for name, config in model_name_to_config.items() if config in transformers.MODEL_MAPPING]
|
||||
# Some of the GenerationModel don't have a base model.
|
||||
has_pt_model.extend(
|
||||
[
|
||||
name
|
||||
for name, config in model_name_to_config.items()
|
||||
if config in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
]
|
||||
)
|
||||
# Special exception for RAG
|
||||
has_pt_model.append("RAG")
|
||||
|
||||
# Model names that have a TensorFlow implementation.
|
||||
has_tf_model = [name for name, config in model_name_to_config.items() if config in transformers.TF_MODEL_MAPPING]
|
||||
# Some of the GenerationModel don't have a base model.
|
||||
has_tf_model.extend(
|
||||
[
|
||||
name
|
||||
for name, config in model_name_to_config.items()
|
||||
if config in transformers.TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
]
|
||||
)
|
||||
|
||||
# Model names that have a Flax implementation.
|
||||
has_flax_model = [
|
||||
name for name, config in model_name_to_config.items() if config in transformers.FLAX_MODEL_MAPPING
|
||||
]
|
||||
|
||||
# Let's build that table!
|
||||
model_names = list(model_name_to_config.keys())
|
||||
model_names.sort()
|
||||
columns = ["Model", "Tokenizer slow", "Tokenizer fast", "PyTorch support", "TensorFlow support", "Flax Support"]
|
||||
# We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side).
|
||||
widths = [len(c) + 2 for c in columns]
|
||||
widths[0] = max([len(name) for name in model_names]) + 2
|
||||
|
||||
# Rst table per se
|
||||
table = ".. rst-class:: center-aligned-table\n\n"
|
||||
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
|
||||
table += "|" + "|".join([_center_text(c, w) for c, w in zip(columns, widths)]) + "|\n"
|
||||
table += "+" + "+".join(["=" * w for w in widths]) + "+\n"
|
||||
|
||||
check = {True: "✅", False: "❌"}
|
||||
for name in model_names:
|
||||
line = [
|
||||
name,
|
||||
check[name in has_slow_tokenizers],
|
||||
check[name in has_fast_tokenizers],
|
||||
check[name in has_pt_model],
|
||||
check[name in has_tf_model],
|
||||
check[name in has_flax_model],
|
||||
]
|
||||
table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n"
|
||||
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
|
||||
return table
|
||||
|
||||
|
||||
def check_model_table(overwrite=False):
|
||||
""" Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`. """
|
||||
current_table, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
|
||||
start_prompt=" This table is updated automatically from the auto module",
|
||||
end_prompt=".. toctree::",
|
||||
)
|
||||
new_table = get_model_table_from_auto_modules()
|
||||
|
||||
if current_table != new_table:
|
||||
if overwrite:
|
||||
with open(os.path.join(PATH_TO_DOCS, "index.rst"), "w", encoding="utf-8", newline="\n") as f:
|
||||
f.writelines(lines[:start_index] + [new_table] + lines[end_index:])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model table in the `index.rst` has not been updated. Run `make fix-copies` to fix this."
|
||||
)
|
||||
|
||||
|
||||
@ -293,3 +412,4 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
check_copies(args.fix_and_overwrite)
|
||||
check_model_table(args.fix_and_overwrite)
|
||||
|
@ -126,6 +126,7 @@ def get_model_modules():
|
||||
"modeling_outputs",
|
||||
"modeling_retribert",
|
||||
"modeling_utils",
|
||||
"modeling_flax_auto",
|
||||
"modeling_flax_utils",
|
||||
"modeling_transfo_xl_utilities",
|
||||
"modeling_tf_auto",
|
||||
|
Loading…
Reference in New Issue
Block a user