Pickle auto models (#12654)

* PoC, it pickles!

* Remove old method.

* Apply to every auto object
This commit is contained in:
Sylvain Gugger 2021-07-12 11:15:54 -04:00 committed by GitHub
parent 379f649434
commit 9b3aab2cce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 230 additions and 125 deletions

View File

@ -1938,7 +1938,7 @@ class _LazyModule(ModuleType):
return importlib.import_module("." + module_name, self.__name__)
def __reduce__(self):
return (self.__class__, (self._name, self._import_structure))
return (self.__class__, (self._name, self.__file__, self._import_structure))
def copy_func(f):

View File

@ -14,8 +14,6 @@
# limitations under the License.
"""Factory function to build auto-model classes."""
import types
from ...configuration_utils import PretrainedConfig
from ...file_utils import copy_func
from ...utils import logging
@ -401,12 +399,12 @@ def insert_head_doc(docstring, head_doc=""):
)
def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-cased", head_doc=""):
def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""):
# Create a new class with the right name from the base class
new_class = types.new_class(name, (_BaseAutoModelClass,))
new_class._model_mapping = model_mapping
model_mapping = cls._model_mapping
name = cls.__name__
class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
new_class.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
# Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
# have a specific docstrings for them.
@ -416,7 +414,7 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
from_config.__doc__ = from_config_docstring
from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config)
new_class.from_config = classmethod(from_config)
cls.from_config = classmethod(from_config)
if name.startswith("TF"):
from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
@ -432,8 +430,8 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
from_pretrained.__doc__ = from_pretrained_docstring
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
new_class.from_pretrained = classmethod(from_pretrained)
return new_class
cls.from_pretrained = classmethod(from_pretrained)
return cls
def get_values(model_mapping):

View File

@ -308,7 +308,7 @@ from ..xlnet.modeling_xlnet import (
XLNetLMHeadModel,
XLNetModel,
)
from .auto_factory import auto_class_factory
from .auto_factory import _BaseAutoModelClass, auto_class_update
from .configuration_auto import (
AlbertConfig,
BartConfig,
@ -780,64 +780,106 @@ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
)
AutoModel = auto_class_factory("AutoModel", MODEL_MAPPING)
class AutoModel(_BaseAutoModelClass):
_model_mapping = MODEL_MAPPING
AutoModel = auto_class_update(AutoModel)
class AutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_PRETRAINING_MAPPING
AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
AutoModelForPreTraining = auto_class_factory(
"AutoModelForPreTraining", MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining"
)
# Private on purpose, the public class will add the deprecation warnings.
_AutoModelWithLMHead = auto_class_factory(
"AutoModelWithLMHead", MODEL_WITH_LM_HEAD_MAPPING, head_doc="language modeling"
class _AutoModelWithLMHead(_BaseAutoModelClass):
_model_mapping = MODEL_WITH_LM_HEAD_MAPPING
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
class AutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
class AutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING
AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")
class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
AutoModelForSeq2SeqLM = auto_class_update(
AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
)
AutoModelForCausalLM = auto_class_factory(
"AutoModelForCausalLM", MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling"
class AutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
AutoModelForSequenceClassification = auto_class_update(
AutoModelForSequenceClassification, head_doc="sequence classification"
)
AutoModelForMaskedLM = auto_class_factory(
"AutoModelForMaskedLM", MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
)
AutoModelForSeq2SeqLM = auto_class_factory(
"AutoModelForSeq2SeqLM",
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
head_doc="sequence-to-sequence language modeling",
checkpoint_for_example="t5-base",
)
class AutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
AutoModelForSequenceClassification = auto_class_factory(
"AutoModelForSequenceClassification", MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, head_doc="sequence classification"
)
AutoModelForQuestionAnswering = auto_class_factory(
"AutoModelForQuestionAnswering", MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering"
)
AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")
AutoModelForTableQuestionAnswering = auto_class_factory(
"AutoModelForTableQuestionAnswering",
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
AutoModelForTableQuestionAnswering = auto_class_update(
AutoModelForTableQuestionAnswering,
head_doc="table question answering",
checkpoint_for_example="google/tapas-base-finetuned-wtq",
)
AutoModelForTokenClassification = auto_class_factory(
"AutoModelForTokenClassification", MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification"
class AutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")
class AutoModelForMultipleChoice(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING
AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")
class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
AutoModelForNextSentencePrediction = auto_class_update(
AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
)
AutoModelForMultipleChoice = auto_class_factory(
"AutoModelForMultipleChoice", MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice"
)
AutoModelForNextSentencePrediction = auto_class_factory(
"AutoModelForNextSentencePrediction",
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
head_doc="next sentence prediction",
)
class AutoModelForImageClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
AutoModelForImageClassification = auto_class_factory(
"AutoModelForImageClassification", MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, head_doc="image classification"
)
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
class AutoModelWithLMHead(_AutoModelWithLMHead):

View File

@ -73,7 +73,7 @@ from ..roberta.modeling_flax_roberta import (
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model
from .auto_factory import auto_class_factory
from .auto_factory import _BaseAutoModelClass, auto_class_update
from .configuration_auto import (
BartConfig,
BertConfig,
@ -217,59 +217,89 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
]
)
FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING)
FlaxAutoModelForImageClassification = auto_class_factory(
"FlaxAutoModelForImageClassification",
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
head_doc="image classification modeling",
)
class FlaxAutoModel(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_MAPPING
FlaxAutoModelForCausalLM = auto_class_factory(
"FlaxAutoModelForCausalLM", FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling"
)
FlaxAutoModelForPreTraining = auto_class_factory(
"FlaxAutoModelForPreTraining", FLAX_MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining"
)
FlaxAutoModel = auto_class_update(FlaxAutoModel)
FlaxAutoModelForMaskedLM = auto_class_factory(
"FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
class FlaxAutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING
FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining")
class FlaxAutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling")
class FlaxAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING
FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling")
class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
FlaxAutoModelForSeq2SeqLM = auto_class_update(
FlaxAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
)
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
"FlaxAutoModelForSeq2SeqLM",
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
head_doc="sequence-to-sequence language modeling",
class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
FlaxAutoModelForSequenceClassification = auto_class_update(
FlaxAutoModelForSequenceClassification, head_doc="sequence classification"
)
FlaxAutoModelForSequenceClassification = auto_class_factory(
"FlaxAutoModelForSequenceClassification",
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
head_doc="sequence classification",
class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering")
class FlaxAutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
FlaxAutoModelForTokenClassification = auto_class_update(
FlaxAutoModelForTokenClassification, head_doc="token classification"
)
FlaxAutoModelForQuestionAnswering = auto_class_factory(
"FlaxAutoModelForQuestionAnswering", FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering"
class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice")
class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
FlaxAutoModelForNextSentencePrediction = auto_class_update(
FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
)
FlaxAutoModelForTokenClassification = auto_class_factory(
"FlaxAutoModelForTokenClassification", FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification"
)
FlaxAutoModelForMultipleChoice = auto_class_factory(
"AutoModelForMultipleChoice", FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice"
)
class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
FlaxAutoModelForNextSentencePrediction = auto_class_factory(
"FlaxAutoModelForNextSentencePrediction",
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
head_doc="next sentence prediction",
)
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
"FlaxAutoModelForSeq2SeqLM",
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
head_doc="sequence-to-sequence language modeling",
FlaxAutoModelForImageClassification = auto_class_update(
FlaxAutoModelForImageClassification, head_doc="image classification"
)

View File

@ -189,7 +189,7 @@ from ..xlnet.modeling_tf_xlnet import (
TFXLNetLMHeadModel,
TFXLNetModel,
)
from .auto_factory import auto_class_factory
from .auto_factory import _BaseAutoModelClass, auto_class_update
from .configuration_auto import (
AlbertConfig,
BartConfig,
@ -487,54 +487,89 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
)
TFAutoModel = auto_class_factory("TFAutoModel", TF_MODEL_MAPPING)
class TFAutoModel(_BaseAutoModelClass):
_model_mapping = TF_MODEL_MAPPING
TFAutoModel = auto_class_update(TFAutoModel)
class TFAutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
TFAutoModelForPreTraining = auto_class_factory(
"TFAutoModelForPreTraining", TF_MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining"
)
# Private on purpose, the public class will add the deprecation warnings.
_TFAutoModelWithLMHead = auto_class_factory(
"TFAutoModelWithLMHead", TF_MODEL_WITH_LM_HEAD_MAPPING, head_doc="language modeling"
class _TFAutoModelWithLMHead(_BaseAutoModelClass):
_model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
class TFAutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
class TFAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")
class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
TFAutoModelForSeq2SeqLM = auto_class_update(
TFAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
)
TFAutoModelForCausalLM = auto_class_factory(
"TFAutoModelForCausalLM", TF_MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling"
class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
TFAutoModelForSequenceClassification = auto_class_update(
TFAutoModelForSequenceClassification, head_doc="sequence classification"
)
TFAutoModelForMaskedLM = auto_class_factory(
"TFAutoModelForMaskedLM", TF_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")
class TFAutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
TFAutoModelForTokenClassification = auto_class_update(
TFAutoModelForTokenClassification, head_doc="token classification"
)
TFAutoModelForSeq2SeqLM = auto_class_factory(
"TFAutoModelForSeq2SeqLM",
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
head_doc="sequence-to-sequence language modeling",
checkpoint_for_example="t5-base",
)
TFAutoModelForSequenceClassification = auto_class_factory(
"TFAutoModelForSequenceClassification",
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
head_doc="sequence classification",
)
class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
TFAutoModelForQuestionAnswering = auto_class_factory(
"TFAutoModelForQuestionAnswering", TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering"
)
TFAutoModelForTokenClassification = auto_class_factory(
"TFAutoModelForTokenClassification", TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification"
)
TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice")
TFAutoModelForMultipleChoice = auto_class_factory(
"TFAutoModelForMultipleChoice", TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice"
)
TFAutoModelForNextSentencePrediction = auto_class_factory(
"TFAutoModelForNextSentencePrediction",
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
head_doc="next sentence prediction",
class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
TFAutoModelForNextSentencePrediction = auto_class_update(
TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
)