Fix model templates (#8595)

* First fixes

* Fix imports and add init

* Fix typo

* Move init to final dest

* Fix tokenization import

* More fixes

* Styling
This commit is contained in:
Sylvain Gugger 2020-11-17 10:35:38 -05:00 committed by GitHub
parent 042a6aa777
commit 36a19915ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 91 additions and 40 deletions

View File

@ -47,7 +47,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
path_to_transformer_root = (
Path(__file__).parent.parent.parent.parent if self._path is None else Path(self._path).parent.parent
)
path_to_cookiecutter = path_to_transformer_root / "templates" / "cookiecutter"
path_to_cookiecutter = path_to_transformer_root / "templates" / "adding_a_new_model"
# Execute cookiecutter
if not self._testing:
@ -75,9 +75,16 @@ class AddNewModelCommand(BaseTransformersCLICommand):
output_pytorch = "PyTorch" in pytorch_or_tensorflow
output_tensorflow = "TensorFlow" in pytorch_or_tensorflow
model_dir = f"{path_to_transformer_root}/src/transformers/models/{lowercase_model_name}"
os.makedirs(model_dir, exist_ok=True)
shutil.move(
f"{directory}/__init__.py",
f"{model_dir}/__init__.py",
)
shutil.move(
f"{directory}/configuration_{lowercase_model_name}.py",
f"{path_to_transformer_root}/src/transformers/configuration_{lowercase_model_name}.py",
f"{model_dir}/configuration_{lowercase_model_name}.py",
)
def remove_copy_lines(path):
@ -94,7 +101,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
shutil.move(
f"{directory}/modeling_{lowercase_model_name}.py",
f"{path_to_transformer_root}/src/transformers/modeling_{lowercase_model_name}.py",
f"{model_dir}/modeling_{lowercase_model_name}.py",
)
shutil.move(
@ -111,7 +118,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
shutil.move(
f"{directory}/modeling_tf_{lowercase_model_name}.py",
f"{path_to_transformer_root}/src/transformers/modeling_tf_{lowercase_model_name}.py",
f"{model_dir}/modeling_tf_{lowercase_model_name}.py",
)
shutil.move(
@ -129,7 +136,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
shutil.move(
f"{directory}/tokenization_{lowercase_model_name}.py",
f"{path_to_transformer_root}/src/transformers/tokenization_{lowercase_model_name}.py",
f"{model_dir}/tokenization_{lowercase_model_name}.py",
)
from os import fdopen, remove

View File

@ -21,6 +21,8 @@ from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings
from ...utils import logging
# Add modeling imports here
from ..albert.modeling_albert import (
AlbertForMaskedLM,
AlbertForMultipleChoice,
@ -228,8 +230,6 @@ from .configuration_auto import (
)
# Add modeling imports here
logger = logging.get_logger(__name__)

View File

@ -21,6 +21,8 @@ from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings
from ...utils import logging
# Add modeling imports here
from ..albert.modeling_tf_albert import (
TFAlbertForMaskedLM,
TFAlbertForMultipleChoice,
@ -175,8 +177,6 @@ from .configuration_auto import (
)
# Add modeling imports here
logger = logging.get_logger(__name__)

View File

@ -0,0 +1,43 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
{%- if cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" %}
from ...file_utils import is_tf_available, is_torch_available
{%- elif cookiecutter.generate_tensorflow_and_pytorch == "PyTorch" %}
from ...file_utils import is_torch_available
{%- elif cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow" %}
from ...file_utils import is_tf_available
{% endif %}
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer
{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "PyTorch") %}
if is_torch_available():
from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
{{cookiecutter.camelcase_modelname}}Layer,
{{cookiecutter.camelcase_modelname}}Model,
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
)
{% endif %}
{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow") %}
if is_tf_available():
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
TF{{cookiecutter.camelcase_modelname}}Layer,
TF{{cookiecutter.camelcase_modelname}}Model,
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
)
{% endif %}

View File

@ -14,8 +14,8 @@
# limitations under the License.
""" {{cookiecutter.modelname}} model configuration """
from .configuration_utils import PretrainedConfig
from .utils import logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)

View File

@ -17,15 +17,14 @@
import tensorflow as tf
from .activations_tf import get_tf_activation
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
from .file_utils import (
from ...activations_tf import get_tf_activation
from ...file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from .modeling_tf_outputs import (
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
TFMaskedLMOutput,
@ -34,7 +33,7 @@ from .modeling_tf_outputs import (
TFSequenceClassifierOutput,
TFTokenClassifierOutput,
)
from .modeling_tf_utils import (
from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss,
TFPreTrainedModel,
@ -46,8 +45,9 @@ from .modeling_tf_utils import (
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
from .utils import logging
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
logger = logging.get_logger(__name__)

View File

@ -25,13 +25,12 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
from .file_utils import (
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from .modeling_outputs import (
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
@ -40,15 +39,16 @@ from .modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import (
from ...modeling_utils import (
PreTrainedModel,
SequenceSummary,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from .utils import logging
from .activations import ACT2FN
from ...utils import logging
from ...activations import ACT2FN
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
logger = logging.get_logger(__name__)

View File

@ -14,7 +14,7 @@
# To replace in: "src/transformers/__init__.py"
# Below: "if is_torch_available():" if generating PyTorch
# Replace with:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
from .models.{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
@ -30,7 +30,7 @@
# Below: "if is_tf_available():" if generating TensorFlow
# Replace with:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
from .models.{{cookiecutter.lowercase_modelname}} import (
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
@ -44,14 +44,14 @@
# End.
# Below: "from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
# Below: "from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
# Replace with:
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
# End.
# To replace in: "src/transformers/configuration_auto.py"
# To replace in: "src/transformers/models/auto/configuration_auto.py"
# Below: "# Add configs here"
# Replace with:
("{{cookiecutter.lowercase_modelname}}", {{cookiecutter.camelcase_modelname}}Config),
@ -62,9 +62,9 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.u
{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP,
# End.
# Below: "from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig",
# Below: "from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig",
# Replace with:
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
# End.
# Below: "# Add full (and cased) model names here"
@ -83,7 +83,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.u
# Below: "# Add modeling imports here"
# Replace with:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
@ -138,7 +138,7 @@ from .modeling_{{cookiecutter.lowercase_modelname}} import (
# Below: "# Add modeling imports here"
# Replace with:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,

View File

@ -15,8 +15,9 @@
"""Tokenization classes for {{cookiecutter.modelname}}."""
{%- if cookiecutter.tokenizer_type == "Based on BERT" %}
from .tokenization_bert import BertTokenizer, BertTokenizerFast
from .utils import logging
from ...utils import logging
from ..bert.tokenization_bert import BertTokenizer
from ..bert.tokenization_bert_fast import BertTokenizerFast
logger = logging.get_logger(__name__)
@ -73,14 +74,14 @@ class {{cookiecutter.camelcase_modelname}}TokenizerFast(BertTokenizerFast):
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
{%- elif cookiecutter.tokenizer_type == "Standalone" %}
import warnings
from typing import List, Optional
from tokenizers import ByteLevelBPETokenizer
from .tokenization_utils import AddedToken, PreTrainedTokenizer
from .tokenization_utils_base import BatchEncoding
from .tokenization_utils_fast import PreTrainedTokenizerFast
from typing import List, Optional
from .utils import logging
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
logger = logging.get_logger(__name__)