mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix model templates (#8595)
* First fixes * Fix imports and add init * Fix typo * Move init to final dest * Fix tokenization import * More fixes * Styling
This commit is contained in:
parent
042a6aa777
commit
36a19915ea
@ -47,7 +47,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
|
||||
path_to_transformer_root = (
|
||||
Path(__file__).parent.parent.parent.parent if self._path is None else Path(self._path).parent.parent
|
||||
)
|
||||
path_to_cookiecutter = path_to_transformer_root / "templates" / "cookiecutter"
|
||||
path_to_cookiecutter = path_to_transformer_root / "templates" / "adding_a_new_model"
|
||||
|
||||
# Execute cookiecutter
|
||||
if not self._testing:
|
||||
@ -75,9 +75,16 @@ class AddNewModelCommand(BaseTransformersCLICommand):
|
||||
output_pytorch = "PyTorch" in pytorch_or_tensorflow
|
||||
output_tensorflow = "TensorFlow" in pytorch_or_tensorflow
|
||||
|
||||
model_dir = f"{path_to_transformer_root}/src/transformers/models/{lowercase_model_name}"
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/__init__.py",
|
||||
f"{model_dir}/__init__.py",
|
||||
)
|
||||
shutil.move(
|
||||
f"{directory}/configuration_{lowercase_model_name}.py",
|
||||
f"{path_to_transformer_root}/src/transformers/configuration_{lowercase_model_name}.py",
|
||||
f"{model_dir}/configuration_{lowercase_model_name}.py",
|
||||
)
|
||||
|
||||
def remove_copy_lines(path):
|
||||
@ -94,7 +101,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/modeling_{lowercase_model_name}.py",
|
||||
f"{path_to_transformer_root}/src/transformers/modeling_{lowercase_model_name}.py",
|
||||
f"{model_dir}/modeling_{lowercase_model_name}.py",
|
||||
)
|
||||
|
||||
shutil.move(
|
||||
@ -111,7 +118,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/modeling_tf_{lowercase_model_name}.py",
|
||||
f"{path_to_transformer_root}/src/transformers/modeling_tf_{lowercase_model_name}.py",
|
||||
f"{model_dir}/modeling_tf_{lowercase_model_name}.py",
|
||||
)
|
||||
|
||||
shutil.move(
|
||||
@ -129,7 +136,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/tokenization_{lowercase_model_name}.py",
|
||||
f"{path_to_transformer_root}/src/transformers/tokenization_{lowercase_model_name}.py",
|
||||
f"{model_dir}/tokenization_{lowercase_model_name}.py",
|
||||
)
|
||||
|
||||
from os import fdopen, remove
|
||||
|
@ -21,6 +21,8 @@ from collections import OrderedDict
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...utils import logging
|
||||
|
||||
# Add modeling imports here
|
||||
from ..albert.modeling_albert import (
|
||||
AlbertForMaskedLM,
|
||||
AlbertForMultipleChoice,
|
||||
@ -228,8 +230,6 @@ from .configuration_auto import (
|
||||
)
|
||||
|
||||
|
||||
# Add modeling imports here
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
@ -21,6 +21,8 @@ from collections import OrderedDict
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...utils import logging
|
||||
|
||||
# Add modeling imports here
|
||||
from ..albert.modeling_tf_albert import (
|
||||
TFAlbertForMaskedLM,
|
||||
TFAlbertForMultipleChoice,
|
||||
@ -175,8 +177,6 @@ from .configuration_auto import (
|
||||
)
|
||||
|
||||
|
||||
# Add modeling imports here
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
@ -0,0 +1,43 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
{%- if cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" %}
|
||||
from ...file_utils import is_tf_available, is_torch_available
|
||||
{%- elif cookiecutter.generate_tensorflow_and_pytorch == "PyTorch" %}
|
||||
from ...file_utils import is_torch_available
|
||||
{%- elif cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow" %}
|
||||
from ...file_utils import is_tf_available
|
||||
{% endif %}
|
||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
|
||||
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer
|
||||
|
||||
{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "PyTorch") %}
|
||||
if is_torch_available():
|
||||
from .modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
||||
{{cookiecutter.camelcase_modelname}}Layer,
|
||||
{{cookiecutter.camelcase_modelname}}Model,
|
||||
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||
load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
|
||||
)
|
||||
{% endif %}
|
||||
{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow") %}
|
||||
if is_tf_available():
|
||||
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
|
||||
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
||||
TF{{cookiecutter.camelcase_modelname}}Layer,
|
||||
TF{{cookiecutter.camelcase_modelname}}Model,
|
||||
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||
)
|
||||
{% endif %}
|
@ -14,8 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" {{cookiecutter.modelname}} model configuration """
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .utils import logging
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -17,15 +17,14 @@
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from .activations_tf import get_tf_activation
|
||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
|
||||
from .file_utils import (
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import (
|
||||
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
)
|
||||
from .modeling_tf_outputs import (
|
||||
from ...modeling_tf_outputs import (
|
||||
TFBaseModelOutput,
|
||||
TFBaseModelOutputWithPooling,
|
||||
TFMaskedLMOutput,
|
||||
@ -34,7 +33,7 @@ from .modeling_tf_outputs import (
|
||||
TFSequenceClassifierOutput,
|
||||
TFTokenClassifierOutput,
|
||||
)
|
||||
from .modeling_tf_utils import (
|
||||
from ...modeling_tf_utils import (
|
||||
TFMaskedLanguageModelingLoss,
|
||||
TFMultipleChoiceLoss,
|
||||
TFPreTrainedModel,
|
||||
@ -46,8 +45,9 @@ from .modeling_tf_utils import (
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from .tokenization_utils import BatchEncoding
|
||||
from .utils import logging
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -25,13 +25,12 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
|
||||
from .file_utils import (
|
||||
from ...file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
)
|
||||
from .modeling_outputs import (
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
MaskedLMOutput,
|
||||
@ -40,15 +39,16 @@ from .modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from .modeling_utils import (
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
SequenceSummary,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from .utils import logging
|
||||
from .activations import ACT2FN
|
||||
from ...utils import logging
|
||||
from ...activations import ACT2FN
|
||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -14,7 +14,7 @@
|
||||
# To replace in: "src/transformers/__init__.py"
|
||||
# Below: "if is_torch_available():" if generating PyTorch
|
||||
# Replace with:
|
||||
from .modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
@ -30,7 +30,7 @@
|
||||
|
||||
# Below: "if is_tf_available():" if generating TensorFlow
|
||||
# Replace with:
|
||||
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
@ -44,14 +44,14 @@
|
||||
# End.
|
||||
|
||||
|
||||
# Below: "from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
|
||||
# Below: "from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
|
||||
# Replace with:
|
||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
|
||||
# End.
|
||||
|
||||
|
||||
|
||||
# To replace in: "src/transformers/configuration_auto.py"
|
||||
# To replace in: "src/transformers/models/auto/configuration_auto.py"
|
||||
# Below: "# Add configs here"
|
||||
# Replace with:
|
||||
("{{cookiecutter.lowercase_modelname}}", {{cookiecutter.camelcase_modelname}}Config),
|
||||
@ -62,9 +62,9 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.u
|
||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
# End.
|
||||
|
||||
# Below: "from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig",
|
||||
# Below: "from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig",
|
||||
# Replace with:
|
||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
|
||||
from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
|
||||
# End.
|
||||
|
||||
# Below: "# Add full (and cased) model names here"
|
||||
@ -83,7 +83,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.u
|
||||
# Below: "# Add modeling imports here"
|
||||
# Replace with:
|
||||
|
||||
from .modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
@ -138,7 +138,7 @@ from .modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||
# Below: "# Add modeling imports here"
|
||||
# Replace with:
|
||||
|
||||
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
|
||||
from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import (
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
|
@ -15,8 +15,9 @@
|
||||
"""Tokenization classes for {{cookiecutter.modelname}}."""
|
||||
|
||||
{%- if cookiecutter.tokenizer_type == "Based on BERT" %}
|
||||
from .tokenization_bert import BertTokenizer, BertTokenizerFast
|
||||
from .utils import logging
|
||||
from ...utils import logging
|
||||
from ..bert.tokenization_bert import BertTokenizer
|
||||
from ..bert.tokenization_bert_fast import BertTokenizerFast
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -73,14 +74,14 @@ class {{cookiecutter.camelcase_modelname}}TokenizerFast(BertTokenizerFast):
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
{%- elif cookiecutter.tokenizer_type == "Standalone" %}
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
|
||||
from .tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from .tokenization_utils_base import BatchEncoding
|
||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from typing import List, Optional
|
||||
from .utils import logging
|
||||
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
Loading…
Reference in New Issue
Block a user