mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Refactor AutoModel classes and add Flax Auto classes (#11027)
* Refactor AutoModel classes and add Flax Auto classes * Add new objects to the init * Fix hubconf and sort models * Fix TF tests * Missing coma * Update src/transformers/models/auto/auto_factory.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Fix init * Fix dummies * Other init to fix Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
eb3479e7cf
commit
6c25f5228e
@ -189,3 +189,52 @@ FlaxAutoModel
|
||||
|
||||
.. autoclass:: transformers.FlaxAutoModel
|
||||
:members:
|
||||
|
||||
|
||||
FlaxAutoModelForPreTraining
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxAutoModelForPreTraining
|
||||
:members:
|
||||
|
||||
|
||||
FlaxAutoModelForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxAutoModelForMaskedLM
|
||||
:members:
|
||||
|
||||
|
||||
FlaxAutoModelForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxAutoModelForSequenceClassification
|
||||
:members:
|
||||
|
||||
|
||||
FlaxAutoModelForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxAutoModelForQuestionAnswering
|
||||
:members:
|
||||
|
||||
|
||||
FlaxAutoModelForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxAutoModelForTokenClassification
|
||||
:members:
|
||||
|
||||
|
||||
FlaxAutoModelForMultipleChoice
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxAutoModelForMultipleChoice
|
||||
:members:
|
||||
|
||||
|
||||
FlaxAutoModelForNextSentencePrediction
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxAutoModelForNextSentencePrediction
|
||||
:members:
|
||||
|
38
hubconf.py
38
hubconf.py
@ -22,9 +22,10 @@ sys.path.append(SRC_DIR)
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelWithLMHead,
|
||||
AutoTokenizer,
|
||||
add_start_docstrings,
|
||||
)
|
||||
@ -86,22 +87,41 @@ def model(*args, **kwargs):
|
||||
return AutoModel.from_pretrained(*args, **kwargs)
|
||||
|
||||
|
||||
@add_start_docstrings(AutoModelWithLMHead.__doc__)
|
||||
def modelWithLMHead(*args, **kwargs):
|
||||
@add_start_docstrings(AutoModelForCausalLM.__doc__)
|
||||
def modelForCausalLM(*args, **kwargs):
|
||||
r"""
|
||||
# Using torch.hub !
|
||||
import torch
|
||||
|
||||
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased') # Download model and configuration from huggingface.co and cache.
|
||||
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased', output_attentions=True) # Update configuration during loading
|
||||
model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', 'gpt2') # Download model and configuration from huggingface.co and cache.
|
||||
model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', './test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', 'gpt2', output_attentions=True) # Update configuration during loading
|
||||
assert model.config.output_attentions == True
|
||||
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||
config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json')
|
||||
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
config = AutoConfig.from_pretrained('./tf_model/gpt_tf_model_config.json')
|
||||
model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', './tf_model/gpt_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
return AutoModelWithLMHead.from_pretrained(*args, **kwargs)
|
||||
return AutoModelForCausalLM.from_pretrained(*args, **kwargs)
|
||||
|
||||
|
||||
@add_start_docstrings(AutoModelForMaskedLM.__doc__)
|
||||
def modelForMaskedLM(*args, **kwargs):
|
||||
r"""
|
||||
# Using torch.hub !
|
||||
import torch
|
||||
|
||||
model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', 'bert-base-uncased') # Download model and configuration from huggingface.co and cache.
|
||||
model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', 'bert-base-uncased', output_attentions=True) # Update configuration during loading
|
||||
assert model.config.output_attentions == True
|
||||
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||
config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json')
|
||||
model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
|
||||
return AutoModelForMaskedLM.from_pretrained(*args, **kwargs)
|
||||
|
||||
|
||||
@add_start_docstrings(AutoModelForSequenceClassification.__doc__)
|
||||
|
@ -1300,7 +1300,26 @@ else:
|
||||
# FLAX-backed objects
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
|
||||
_import_structure["models.auto"].extend(["FLAX_MODEL_MAPPING", "FlaxAutoModel"])
|
||||
_import_structure["models.auto"].extend(
|
||||
[
|
||||
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
|
||||
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
||||
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"FLAX_MODEL_MAPPING",
|
||||
"FlaxAutoModel",
|
||||
"FlaxAutoModelForMaskedLM",
|
||||
"FlaxAutoModelForMultipleChoice",
|
||||
"FlaxAutoModelForNextSentencePrediction",
|
||||
"FlaxAutoModelForPreTraining",
|
||||
"FlaxAutoModelForQuestionAnswering",
|
||||
"FlaxAutoModelForSequenceClassification",
|
||||
"FlaxAutoModelForTokenClassification",
|
||||
]
|
||||
)
|
||||
_import_structure["models.bert"].extend(
|
||||
[
|
||||
"FlaxBertForMaskedLM",
|
||||
@ -2410,7 +2429,24 @@ if TYPE_CHECKING:
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_utils import FlaxPreTrainedModel
|
||||
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
||||
from .models.auto import (
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_MAPPING,
|
||||
FlaxAutoModel,
|
||||
FlaxAutoModelForMaskedLM,
|
||||
FlaxAutoModelForMultipleChoice,
|
||||
FlaxAutoModelForNextSentencePrediction,
|
||||
FlaxAutoModelForPreTraining,
|
||||
FlaxAutoModelForQuestionAnswering,
|
||||
FlaxAutoModelForSequenceClassification,
|
||||
FlaxAutoModelForTokenClassification,
|
||||
)
|
||||
from .models.bert import (
|
||||
FlaxBertForMaskedLM,
|
||||
FlaxBertForMultipleChoice,
|
||||
|
@ -82,7 +82,24 @@ if is_tf_available():
|
||||
]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_auto"] = ["FLAX_MODEL_MAPPING", "FlaxAutoModel"]
|
||||
_import_structure["modeling_flax_auto"] = [
|
||||
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
|
||||
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
||||
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"FLAX_MODEL_MAPPING",
|
||||
"FlaxAutoModel",
|
||||
"FlaxAutoModelForMaskedLM",
|
||||
"FlaxAutoModelForMultipleChoice",
|
||||
"FlaxAutoModelForNextSentencePrediction",
|
||||
"FlaxAutoModelForPreTraining",
|
||||
"FlaxAutoModelForQuestionAnswering",
|
||||
"FlaxAutoModelForSequenceClassification",
|
||||
"FlaxAutoModelForTokenClassification",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -145,7 +162,24 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
||||
from .modeling_flax_auto import (
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_MAPPING,
|
||||
FlaxAutoModel,
|
||||
FlaxAutoModelForMaskedLM,
|
||||
FlaxAutoModelForMultipleChoice,
|
||||
FlaxAutoModelForNextSentencePrediction,
|
||||
FlaxAutoModelForPreTraining,
|
||||
FlaxAutoModelForQuestionAnswering,
|
||||
FlaxAutoModelForSequenceClassification,
|
||||
FlaxAutoModelForTokenClassification,
|
||||
)
|
||||
|
||||
else:
|
||||
import importlib
|
||||
|
420
src/transformers/models/auto/auto_factory.py
Normal file
420
src/transformers/models/auto/auto_factory.py
Normal file
@ -0,0 +1,420 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Factory function to build auto-model classes."""
|
||||
|
||||
import functools
|
||||
import types
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
|
||||
|
||||
|
||||
CLASS_DOCSTRING = """
|
||||
This is a generic model class that will be instantiated as one of the model classes of the library when created
|
||||
with the :meth:`~transformers.BaseAutoModelClass.from_pretrained` class method or the
|
||||
:meth:`~transformers.BaseAutoModelClass.from_config` class method.
|
||||
|
||||
This class cannot be instantiated directly using ``__init__()`` (throws an error).
|
||||
"""
|
||||
|
||||
FROM_CONFIG_DOCSTRING = """
|
||||
Instantiates one of the model classes of the library from a configuration.
|
||||
|
||||
Note:
|
||||
Loading a model from its configuration file does **not** load the model weights. It only affects the
|
||||
model's configuration. Use :meth:`~transformers.BaseAutoModelClass.from_pretrained` to load the model
|
||||
weights.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.PretrainedConfig`):
|
||||
The model class to instantiate is selected based on the configuration class:
|
||||
|
||||
List options
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import AutoConfig, BaseAutoModelClass
|
||||
>>> # Download configuration from huggingface.co and cache.
|
||||
>>> config = AutoConfig.from_pretrained('checkpoint_placeholder')
|
||||
>>> model = BaseAutoModelClass.from_config(config)
|
||||
"""
|
||||
|
||||
FROM_PRETRAINED_TORCH_DOCSTRING = """
|
||||
Instantiate one of the model classes of the library from a pretrained model.
|
||||
|
||||
The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
|
||||
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing,
|
||||
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:
|
||||
|
||||
List options
|
||||
|
||||
The model is set in evaluation mode by default using ``model.eval()`` (so for instance, dropout modules are
|
||||
deactivated). To train the model, you should first set it back in training mode with ``model.train()``
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
|
||||
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
|
||||
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
|
||||
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
model_args (additional positional arguments, `optional`):
|
||||
Will be passed along to the underlying model ``__init__()`` method.
|
||||
config (:class:`~transformers.PretrainedConfig`, `optional`):
|
||||
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
||||
be automatically loaded when:
|
||||
|
||||
- The model is a model provided by the library (loaded with the `model id` string of a pretrained
|
||||
model).
|
||||
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
|
||||
by supplying the save directory.
|
||||
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
|
||||
configuration JSON file named `config.json` is found in the directory.
|
||||
state_dict (`Dict[str, torch.Tensor]`, `optional`):
|
||||
A state dictionary to use instead of a state dictionary loaded from saved weights file.
|
||||
|
||||
This option can be used if you want to create a model from a pretrained configuration but load your own
|
||||
weights. In this case though, you should check if using
|
||||
:func:`~transformers.PreTrainedModel.save_pretrained` and
|
||||
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Load the model weights from a TensorFlow checkpoint save file (see docstring of
|
||||
``pretrained_model_name_or_path`` argument).
|
||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (:obj:`Dict[str, str], `optional`):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to only look at local files (e.g., not try downloading the model).
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
kwargs (additional keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||
automatically loaded:
|
||||
|
||||
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
|
||||
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
|
||||
already been done)
|
||||
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
|
||||
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
|
||||
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
|
||||
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
|
||||
attribute will be passed to the underlying model's ``__init__`` function.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import AutoConfig, BaseAutoModelClass
|
||||
|
||||
>>> # Download model and configuration from huggingface.co and cache.
|
||||
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder')
|
||||
|
||||
>>> # Update configuration during loading
|
||||
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True)
|
||||
>>> model.config.output_attentions
|
||||
True
|
||||
|
||||
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||
>>> config = AutoConfig.from_pretrained('./tf_model/shortcut_placeholder_tf_model_config.json')
|
||||
>>> model = BaseAutoModelClass.from_pretrained('./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
"""
|
||||
|
||||
FROM_PRETRAINED_TF_DOCSTRING = """
|
||||
Instantiate one of the model classes of the library from a pretrained model.
|
||||
|
||||
The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
|
||||
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing,
|
||||
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:
|
||||
|
||||
List options
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
|
||||
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
|
||||
as ``config`` argument. This loading path is slower than converting the PyTorch model in a
|
||||
TensorFlow model using the provided conversion scripts and loading the TensorFlow model
|
||||
afterwards.
|
||||
model_args (additional positional arguments, `optional`):
|
||||
Will be passed along to the underlying model ``__init__()`` method.
|
||||
config (:class:`~transformers.PretrainedConfig`, `optional`):
|
||||
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
||||
be automatically loaded when:
|
||||
|
||||
- The model is a model provided by the library (loaded with the `model id` string of a pretrained
|
||||
model).
|
||||
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
|
||||
by supplying the save directory.
|
||||
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
|
||||
configuration JSON file named `config.json` is found in the directory.
|
||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Load the model weights from a PyTorch checkpoint save file (see docstring of
|
||||
``pretrained_model_name_or_path`` argument).
|
||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (:obj:`Dict[str, str], `optional`):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to only look at local files (e.g., not try downloading the model).
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
kwargs (additional keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||
automatically loaded:
|
||||
|
||||
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
|
||||
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
|
||||
already been done)
|
||||
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
|
||||
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
|
||||
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
|
||||
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
|
||||
attribute will be passed to the underlying model's ``__init__`` function.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import AutoConfig, BaseAutoModelClass
|
||||
|
||||
>>> # Download model and configuration from huggingface.co and cache.
|
||||
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder')
|
||||
|
||||
>>> # Update configuration during loading
|
||||
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True)
|
||||
>>> model.config.output_attentions
|
||||
True
|
||||
|
||||
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
|
||||
>>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json')
|
||||
>>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config)
|
||||
"""
|
||||
|
||||
FROM_PRETRAINED_FLAX_DOCSTRING = """
|
||||
Instantiate one of the model classes of the library from a pretrained model.
|
||||
|
||||
The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
|
||||
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing,
|
||||
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:
|
||||
|
||||
List options
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
|
||||
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
|
||||
as ``config`` argument. This loading path is slower than converting the PyTorch model in a
|
||||
TensorFlow model using the provided conversion scripts and loading the TensorFlow model
|
||||
afterwards.
|
||||
model_args (additional positional arguments, `optional`):
|
||||
Will be passed along to the underlying model ``__init__()`` method.
|
||||
config (:class:`~transformers.PretrainedConfig`, `optional`):
|
||||
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
||||
be automatically loaded when:
|
||||
|
||||
- The model is a model provided by the library (loaded with the `model id` string of a pretrained
|
||||
model).
|
||||
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
|
||||
by supplying the save directory.
|
||||
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
|
||||
configuration JSON file named `config.json` is found in the directory.
|
||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Load the model weights from a PyTorch checkpoint save file (see docstring of
|
||||
``pretrained_model_name_or_path`` argument).
|
||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (:obj:`Dict[str, str], `optional`):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to only look at local files (e.g., not try downloading the model).
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
kwargs (additional keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||
automatically loaded:
|
||||
|
||||
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
|
||||
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
|
||||
already been done)
|
||||
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
|
||||
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
|
||||
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
|
||||
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
|
||||
attribute will be passed to the underlying model's ``__init__`` function.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import AutoConfig, BaseAutoModelClass
|
||||
|
||||
>>> # Download model and configuration from huggingface.co and cache.
|
||||
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder')
|
||||
|
||||
>>> # Update configuration during loading
|
||||
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True)
|
||||
>>> model.config.output_attentions
|
||||
True
|
||||
|
||||
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
|
||||
>>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json')
|
||||
>>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config)
|
||||
"""
|
||||
|
||||
|
||||
class _BaseAutoModelClass:
|
||||
# Base class for auto models.
|
||||
_model_mapping = None
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError(
|
||||
f"{self.__class__.__name__} is designed to be instantiated "
|
||||
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
f"`{self.__class__.__name__}.from_config(config)` methods."
|
||||
)
|
||||
|
||||
def from_config(cls, config, **kwargs):
|
||||
if type(config) in cls._model_mapping.keys():
|
||||
return cls._model_mapping[type(config)](config, **kwargs)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||
)
|
||||
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
config = kwargs.pop("config", None)
|
||||
kwargs["_from_auto"] = True
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
if type(config) in cls._model_mapping.keys():
|
||||
return cls._model_mapping[type(config)].from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||
)
|
||||
|
||||
|
||||
def copy_func(f):
|
||||
""" Returns a copy of a function f."""
|
||||
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
|
||||
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
|
||||
g = functools.update_wrapper(g, f)
|
||||
g.__kwdefaults__ = f.__kwdefaults__
|
||||
return g
|
||||
|
||||
|
||||
def insert_head_doc(docstring, head_doc=""):
|
||||
if len(head_doc) > 0:
|
||||
return docstring.replace(
|
||||
"one of the model classes of the library ",
|
||||
f"one of the model classes of the library (with a {head_doc} head) ",
|
||||
)
|
||||
return docstring.replace(
|
||||
"one of the model classes of the library ", "one of the base model classes of the library "
|
||||
)
|
||||
|
||||
|
||||
def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-cased", head_doc=""):
|
||||
# Create a new class with the right name from the base class
|
||||
new_class = types.new_class(name, (_BaseAutoModelClass,))
|
||||
new_class._model_mapping = model_mapping
|
||||
class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
|
||||
new_class.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
|
||||
|
||||
# Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
|
||||
# have a specific docstrings for them.
|
||||
from_config = copy_func(_BaseAutoModelClass.from_config)
|
||||
from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc)
|
||||
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
|
||||
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
|
||||
from_config.__doc__ = from_config_docstring
|
||||
from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config)
|
||||
new_class.from_config = classmethod(from_config)
|
||||
|
||||
if name.startswith("TF"):
|
||||
from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
|
||||
elif name.startswith("Flax"):
|
||||
from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING
|
||||
else:
|
||||
from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING
|
||||
from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)
|
||||
from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc)
|
||||
from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name)
|
||||
from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
|
||||
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
|
||||
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
|
||||
from_pretrained.__doc__ = from_pretrained_docstring
|
||||
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
|
||||
new_class.from_pretrained = classmethod(from_pretrained)
|
||||
return new_class
|
@ -256,8 +256,8 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
|
||||
if config in config_to_class
|
||||
}
|
||||
lines = [
|
||||
f"{indent}- **{model_type}** -- :class:`~transformers.{cls_name}` ({MODEL_NAMES_MAPPING[model_type]} model)"
|
||||
for model_type, cls_name in model_type_to_name.items()
|
||||
f"{indent}- **{model_type}** -- :class:`~transformers.{model_type_to_name[model_type]}` ({MODEL_NAMES_MAPPING[model_type]} model)"
|
||||
for model_type in sorted(model_type_to_name.keys())
|
||||
]
|
||||
else:
|
||||
config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()}
|
||||
@ -265,8 +265,8 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
|
||||
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
|
||||
}
|
||||
lines = [
|
||||
f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{cls_name}` ({config_to_model_name[config_name]} model)"
|
||||
for config_name, cls_name in config_to_name.items()
|
||||
f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{config_to_name[config_name]}` ({config_to_model_name[config_name]} model)"
|
||||
for config_name in sorted(config_to_name.keys())
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -17,11 +17,20 @@
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..bert.modeling_flax_bert import FlaxBertModel
|
||||
from ..bert.modeling_flax_bert import (
|
||||
FlaxBertForMaskedLM,
|
||||
FlaxBertForMultipleChoice,
|
||||
FlaxBertForNextSentencePrediction,
|
||||
FlaxBertForPreTraining,
|
||||
FlaxBertForQuestionAnswering,
|
||||
FlaxBertForSequenceClassification,
|
||||
FlaxBertForTokenClassification,
|
||||
FlaxBertModel,
|
||||
)
|
||||
from ..roberta.modeling_flax_roberta import FlaxRobertaModel
|
||||
from .configuration_auto import AutoConfig, BertConfig, RobertaConfig
|
||||
from .auto_factory import auto_class_factory
|
||||
from .configuration_auto import BertConfig, RobertaConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -29,140 +38,90 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
FLAX_MODEL_MAPPING = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
(RobertaConfig, FlaxRobertaModel),
|
||||
(BertConfig, FlaxBertModel),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for pre-training mapping
|
||||
(BertConfig, FlaxBertForPreTraining),
|
||||
]
|
||||
)
|
||||
|
||||
class FlaxAutoModel(object):
|
||||
r"""
|
||||
:class:`~transformers.FlaxAutoModel` is a generic model class that will be instantiated as one of the base model
|
||||
classes of the library when created with the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` or the
|
||||
`FlaxAutoModel.from_config(config)` class methods.
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Masked LM mapping
|
||||
(BertConfig, FlaxBertForMaskedLM),
|
||||
]
|
||||
)
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
(BertConfig, FlaxBertForSequenceClassification),
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError(
|
||||
"FlaxAutoModel is designed to be instantiated "
|
||||
"using the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
"`FlaxAutoModel.from_config(config)` methods."
|
||||
)
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Question Answering mapping
|
||||
(BertConfig, FlaxBertForQuestionAnswering),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
r"""
|
||||
Instantiates one of the base model classes of the library from a configuration.
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Token Classification mapping
|
||||
(BertConfig, FlaxBertForTokenClassification),
|
||||
]
|
||||
)
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.PretrainedConfig`):
|
||||
The model class to instantiate is selected based on the configuration class:
|
||||
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Multiple Choice mapping
|
||||
(BertConfig, FlaxBertForMultipleChoice),
|
||||
]
|
||||
)
|
||||
|
||||
- isInstance of `roberta` configuration class: :class:`~transformers.FlaxRobertaModel` (RoBERTa model)
|
||||
- isInstance of `bert` configuration class: :class:`~transformers.FlaxBertModel` (Bert model
|
||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
|
||||
[
|
||||
(BertConfig, FlaxBertForNextSentencePrediction),
|
||||
]
|
||||
)
|
||||
|
||||
Examples::
|
||||
FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING)
|
||||
|
||||
config = BertConfig.from_pretrained('bert-base-uncased')
|
||||
# Download configuration from huggingface.co and cache.
|
||||
model = FlaxAutoModel.from_config(config)
|
||||
# E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
for config_class, model_class in FLAX_MODEL_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} "
|
||||
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}."
|
||||
)
|
||||
FlaxAutoModelForPreTraining = auto_class_factory(
|
||||
"FlaxAutoModelForPreTraining", FLAX_MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r"""
|
||||
Instantiates one of the base model classes of the library from a pre-trained model configuration.
|
||||
FlaxAutoModelForMaskedLM = auto_class_factory(
|
||||
"FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
|
||||
)
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance based on the
|
||||
`model_type` property of the config object, or when it's missing, falling back to using pattern matching on the
|
||||
`pretrained_model_name_or_path` string.
|
||||
FlaxAutoModelForSequenceClassification = auto_class_factory(
|
||||
"AFlaxutoModelForSequenceClassification",
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
head_doc="sequence classification",
|
||||
)
|
||||
|
||||
The base model class to instantiate is selected as the first pattern matching in the
|
||||
`pretrained_model_name_or_path` string (in the following order):
|
||||
FlaxAutoModelForQuestionAnswering = auto_class_factory(
|
||||
"FlaxAutoModelForQuestionAnswering", FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering"
|
||||
)
|
||||
|
||||
- contains `roberta`: :class:`~transformers.FlaxRobertaModel` (RoBERTa model)
|
||||
- contains `bert`: :class:`~transformers.FlaxBertModel` (Bert model)
|
||||
FlaxAutoModelForTokenClassification = auto_class_factory(
|
||||
"FlaxAutoModelForTokenClassification", FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification"
|
||||
)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) To
|
||||
train the model, you should first set it back in training mode with `model.train()`
|
||||
FlaxAutoModelForMultipleChoice = auto_class_factory(
|
||||
"AutoModelForMultipleChoice", FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice"
|
||||
)
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path: either:
|
||||
|
||||
- a string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. Valid
|
||||
model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or
|
||||
organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- a path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||
- a path or url to a `pytorch index checkpoint file` (e.g. `./pt_model/pytorch_model.bin`). In this
|
||||
case, ``from_pt`` should be set to True and a configuration object should be provided as ``config``
|
||||
argument.
|
||||
|
||||
model_args: (`optional`) Sequence of positional arguments:
|
||||
All remaining positional arguments will be passed to the underlying model's ``__init__`` method
|
||||
|
||||
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
||||
be automatically loaded when:
|
||||
|
||||
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a
|
||||
pretrained model), or
|
||||
- the model was saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained` and is reloaded
|
||||
by supplying the save directory.
|
||||
- the model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
|
||||
configuration JSON file named `config.json` is found in the directory.
|
||||
|
||||
cache_dir: (`optional`) string:
|
||||
Path to a directory in which a downloaded pre-trained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if
|
||||
they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely received file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
|
||||
|
||||
output_loading_info: (`optional`) boolean:
|
||||
Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error
|
||||
messages.
|
||||
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||
These arguments will be passed to the configuration and the model.
|
||||
|
||||
Examples::
|
||||
|
||||
model = FlaxAutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from huggingface.co and cache.
|
||||
model = FlaxAutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
assert model.config.output_attention == True
|
||||
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
for config_class, model_class in FLAX_MODEL_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, _from_auto=True, **kwargs
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} "
|
||||
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}"
|
||||
)
|
||||
FlaxAutoModelForNextSentencePrediction = auto_class_factory(
|
||||
"FlaxAutoModelForNextSentencePrediction",
|
||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
head_doc="next sentence prediction",
|
||||
)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -11,6 +11,27 @@ class FlaxPreTrainedModel:
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING = None
|
||||
|
||||
|
||||
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = None
|
||||
|
||||
|
||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = None
|
||||
|
||||
|
||||
FLAX_MODEL_FOR_PRETRAINING_MAPPING = None
|
||||
|
||||
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None
|
||||
|
||||
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
FLAX_MODEL_MAPPING = None
|
||||
|
||||
|
||||
@ -23,6 +44,69 @@ class FlaxAutoModel:
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxAutoModelForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxAutoModelForMultipleChoice:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxAutoModelForNextSentencePrediction:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxAutoModelForPreTraining:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxAutoModelForQuestionAnswering:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxAutoModelForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxAutoModelForTokenClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxBertForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
Loading…
Reference in New Issue
Block a user